diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 52911d3b34d6..4d150e93655b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,8 @@ * @ritchie46 -/.github/ @ritchie46 @stinodego +/.github/ @ritchie46 @stinodego /crates/ @ritchie46 @orlp /crates/polars-sql/ @ritchie46 @orlp @universalmind303 /crates/polars-time/ @ritchie46 @orlp @MarcoGorelli /py-polars/ @ritchie46 @stinodego @alexander-beedie +/docs/ @ritchie46 @c-peters @braaannigan diff --git a/.github/ISSUE_TEMPLATE/documentation.yml b/.github/ISSUE_TEMPLATE/documentation.yml index 4a5c23267ebd..3594bdb6a40e 100644 --- a/.github/ISSUE_TEMPLATE/documentation.yml +++ b/.github/ISSUE_TEMPLATE/documentation.yml @@ -13,11 +13,11 @@ body: required: true - type: input - id: location + id: link attributes: - label: Location + label: Link description: > Provide a link to the existing documentation, if applicable. - placeholder: https://pola-rs.github.io/polars/docs/python/dev/reference/api/polars.read_csv.html + placeholder: ex. https://pola-rs.github.io/polars/docs/python/dev/... validations: required: false diff --git a/.github/workflows/docs-global.yml b/.github/workflows/docs-global.yml new file mode 100644 index 000000000000..801449af6e02 --- /dev/null +++ b/.github/workflows/docs-global.yml @@ -0,0 +1,85 @@ +name: Build documentation + +on: + pull_request: + paths: + - docs/** + - mkdocs.yml + - .github/workflows/docs-global.yml + push: + tags: + - py-** + +jobs: + markdown-link-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: gaurav-nelson/github-action-markdown-link-check@v1 + with: + folder-path: docs + + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: psf/black@stable + with: + src: docs/src/python + version: "23.7.0" + + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Create virtual environment + run: | + python -m venv .venv + echo "$GITHUB_WORKSPACE/.venv/bin" >> $GITHUB_PATH + + - name: Install dependencies + run: | + pip install -r py-polars/requirements-dev.txt + pip install -r docs/requirements.txt + + - name: Set up Rust + run: rustup show + + - name: Cache Rust + uses: Swatinem/rust-cache@v2 + with: + workspaces: py-polars + save-if: ${{ github.ref_name == 'main' }} + + - name: Install Polars + working-directory: py-polars + run: | + source activate + maturin develop + + - name: Set up Graphviz + uses: ts-graphviz/setup-graphviz@v1 + + - name: Build documentation + run: mkdocs build + + - name: Add .nojekyll + if: ${{ github.ref_type == 'tag' }} + working-directory: site + run: touch .nojekyll + + - name: Deploy docs + if: ${{ github.ref_type == 'tag' }} + uses: JamesIves/github-pages-deploy-action@v4 + with: + folder: site + clean-exclude: | + docs/ + py-polars/ + single-commit: true diff --git a/.github/workflows/lint-rust.yml b/.github/workflows/lint-rust.yml index de78f4e2d505..1c5b1e314033 100644 --- a/.github/workflows/lint-rust.yml +++ b/.github/workflows/lint-rust.yml @@ -38,7 +38,7 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Run cargo clippy with all features enabled - run: cargo clippy --workspace --all-targets --all-features -- -D warnings + run: cargo clippy -p polars --all-features -- -D warnings # Default feature set should compile on the stable toolchain clippy-stable: @@ -58,7 +58,7 @@ jobs: save-if: ${{ github.ref_name == 'main' }} - name: Run cargo clippy - run: cargo clippy --workspace --all-targets -- -D warnings + run: cargo clippy -p polars -- -D warnings rustfmt: if: github.ref_name != 'main' @@ -90,7 +90,6 @@ jobs: POLARS_ALLOW_EXTENSION: '1' run: > cargo miri test - --no-default-features --features object -p polars-core -p polars-arrow diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 1d9517d0ed5c..02386075157d 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -41,6 +41,9 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Set up Graphviz + uses: ts-graphviz/setup-graphviz@v1 + - name: Create virtual environment run: | python -m venv .venv @@ -65,11 +68,13 @@ jobs: - name: Run tests and report coverage if: github.ref_name != 'main' - run: pytest --cov -n auto --dist loadgroup -m "not benchmark" + run: pytest --cov -n auto --dist loadgroup -m "not benchmark and not docs" - name: Run doctests if: github.ref_name != 'main' - run: python tests/docs/run_doctest.py + run: | + python tests/docs/run_doctest.py + pytest tests/docs/test_user_guide.py -m docs - name: Check import without optional dependencies if: github.ref_name != 'main' @@ -80,6 +85,7 @@ jobs: "matplotlib" "backports.zoneinfo" "connectorx" + "pyiceberg" "deltalake" "xlsx2csv" ) @@ -125,7 +131,7 @@ jobs: - name: Run tests if: github.ref_name != 'main' - run: pytest -n auto --dist loadgroup -m "not benchmark" + run: pytest -n auto --dist loadgroup -m "not benchmark and not docs" - name: Check import without optional dependencies if: github.ref_name != 'main' diff --git a/.gitignore b/.gitignore index 1dd5ecb4236f..5eb602ae7f52 100644 --- a/.gitignore +++ b/.gitignore @@ -1,27 +1,37 @@ *.iml *.so *.ipynb -.DS_Store .ENV -.coverage .env -.hypothesis/ -.idea/ .ipynb_checkpoints/ -.mypy_cache/ -.pytest_cache/ .python-version .yarn/ -.vscode/ -__pycache__/ -AUTO_CHANGELOG.md -Cargo.lock coverage.lcov coverage.xml data/ -node_modules/ polars/vendor -target/ -venv*/ -.venv*/ + +# OS +.DS_Store + +# IDE +.idea/ +.vscode/ .vim + +# Python +.hypothesis/ +.mypy_cache/ +.pytest_cache/ +.venv/ +__pycache__/ +.coverage + +# Rust +target/ +Cargo.lock + +# Project +/docs/data/ +/docs/images/ +/docs/people.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 315ac4c8acd8..cc0993ff47e3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -148,12 +148,12 @@ If you are stuck or unsure about your solution, feel free to open a draft pull r ## Contributing to documentation -The most important components of Polars documentation are the [user guide](https://pola-rs.github.io/polars-book/user-guide/), the API references, and the database of questions on [StackOverflow](https://stackoverflow.com/). +The most important components of Polars documentation are the [user guide](https://pola-rs.github.io/polars/user-guide/), the API references, and the database of questions on [StackOverflow](https://stackoverflow.com/). ### User guide -The user guide is maintained in the [polars-book](https://github.com/pola-rs/polars-book) repository. -For contributing to the user guide, please refer to the [contributing guide](https://github.com/pola-rs/polars-book/blob/master/CONTRIBUTING.md) in that repository. +The user guide is maintained in the `docs` folder. +Further contributing information will be added shortly. ### API reference diff --git a/Cargo.toml b/Cargo.toml index 2f22c6309025..f4bfc6d14cac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ exclude = [ ] [workspace.package] -version = "0.32.0" +version = "0.33.2" authors = ["Ritchie Vink "] edition = "2021" homepage = "https://www.pola.rs/" @@ -28,7 +28,7 @@ bytemuck = { version = "1", features = ["derive", "extern_crate_alloc"] } chrono = { version = "0.4", default-features = false, features = ["std"] } chrono-tz = "0.8.1" ciborium = "0.2" -either = "1.8" +either = "1.9" futures = "0.3.25" hashbrown = { version = "0.14", features = ["rayon", "ahash"] } indexmap = { version = "2", features = ["std"] } @@ -50,14 +50,36 @@ strum_macros = "0.25" thiserror = "1" url = "2.3.1" version_check = "0.9.4" +simdutf8 = "0.1.4" +hex = "0.4.3" +base64 = "0.21.2" +fallible-streaming-iterator = "0.1.9" +streaming-iterator = "0.1.9" +itoa = "1.0.6" +ryu = "1.0.13" +lexical-core = "0.8.5" xxhash-rust = { version = "0.8.6", features = ["xxh3"] } +polars-core = { version = "0.33.2", path = "crates/polars-core", default-features = false } +polars-arrow = { version = "0.33.2", path = "crates/polars-arrow", default-features = false } +polars-plan = { version = "0.33.2", path = "crates/polars-plan", default-features = false } +polars-lazy = { version = "0.33.2", path = "crates/polars-lazy", default-features = false } +polars-pipe = { version = "0.33.2", path = "crates/polars-pipe", default-features = false } +polars-row = { version = "0.33.2", path = "crates/polars-row", default-features = false } +polars-ffi = { version = "0.33.2", path = "crates/polars-ffi", default-features = false } +polars-ops = { version = "0.33.2", path = "crates/polars-ops", default-features = false } +polars-sql = { version = "0.33.2", path = "crates/polars-sql", default-features = false } +polars-algo = { version = "0.33.2", path = "crates/polars-algo", default-features = false } +polars-time = { version = "0.33.2", path = "crates/polars-time", default-features = false } +polars-utils = { version = "0.33.2", path = "crates/polars-utils", default-features = false } +polars-io = { version = "0.33.2", path = "crates/polars-io", default-features = false } +polars-error = { version = "0.33.2", path = "crates/polars-error", default-features = false } +polars-json = { version = "0.33.2", path = "crates/polars-json", default-features = false } +polars = { version = "0.33.2", path = "crates/polars", default-features = false } [workspace.dependencies.arrow] -package = "arrow2" -git = "https://github.com/jorgecarleitao/arrow2" -rev = "7c93e358fc400bf3c0c0219c22eefc6b38fc2d12" -# branch = "" -# version = "0.17.4" +package = "nano-arrow" +version = "0.1.0" +path = "crates/nano-arrow" default-features = false features = [ "compute_aggregate", diff --git a/Makefile b/Makefile index 532342913f97..54d1bd6d4404 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,11 @@ requirements: .venv ## Install/refresh Python project requirements $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-dev.txt $(VENV_BIN)/pip install --upgrade -r py-polars/requirements-lint.txt $(VENV_BIN)/pip install --upgrade -r py-polars/docs/requirements-docs.txt + $(VENV_BIN)/pip install --upgrade -r docs/requirements.txt + +.PHONY: build-python +build-python: .venv ## Compile and install Python Polars for development + @$(MAKE) -s -C py-polars build .PHONY: clean clean: ## Clean up caches and build artifacts diff --git a/README.md b/README.md index 80857d303ab3..f917c1c9bdb8 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ - R | - User Guide + User Guide | Discord

@@ -58,7 +58,7 @@ Polars is a DataFrame interface on top of an OLAP Query Engine implemented in Ru - Hybrid Streaming (larger than RAM datasets) - Rust | Python | NodeJS | R | ... -To learn more, read the [User Guide](https://pola-rs.github.io/polars-book/). +To learn more, read the [User Guide](https://pola-rs.github.io/polars/). ## Python @@ -208,6 +208,7 @@ You can also install the dependencies directly. | xlsx2csv | Support for reading from Excel files | | openpyxl | Support for reading from Excel files with native types | | deltalake | Support for reading from Delta Lake Tables | +| pyiceberg | Support for reading from Apache Iceberg tables | | timezone | Timezone support, only needed if are on Python<3.9 or you are on Windows | Releases happen quite often (weekly / every few days) at the moment, so updating polars regularly to get the latest bugfixes / features might not be a bad idea. diff --git a/_typos.toml b/_typos.toml index 12406b2f4ea8..4d9ec510b278 100644 --- a/_typos.toml +++ b/_typos.toml @@ -7,6 +7,7 @@ extend-ignore-identifiers-re = [ ba = "ba" Fo = "Fo" nd = "nd" +ND = "ND" opt_nd = "opt_nd" ser = "ser" strat = "strat" diff --git a/crates/Makefile b/crates/Makefile index 9c548833079c..e84da6774e46 100644 --- a/crates/Makefile +++ b/crates/Makefile @@ -10,22 +10,22 @@ fmt: ## Run rustfmt and dprint .PHONY: check check: ## Run cargo check with all features - cargo check --workspace --all-targets --all-features + cargo check --workspace --all-targets --exclude nano-arrow --all-features .PHONY: clippy clippy: ## Run clippy with all features - cargo clippy --workspace --all-targets --all-features + cargo clippy -p polars --all-features .PHONY: clippy-default clippy-default: ## Run clippy with default features - cargo clippy --workspace --all-targets + cargo clippy -p polars .PHONY: pre-commit pre-commit: fmt clippy clippy-default ## Run autoformatting and linting .PHONY: check-features check-features: ## Run cargo check for feature flag combinations (warning: slow) - cargo hack check --each-feature --no-dev-deps + cargo hack check -p polars --each-feature --no-dev-deps .PHONY: miri miri: ## Run miri @@ -35,7 +35,6 @@ miri: ## Run miri MIRIFLAGS="-Zmiri-disable-isolation -Zmiri-ignore-leaks -Zmiri-disable-stacked-borrows" \ POLARS_ALLOW_EXTENSION=1 \ cargo miri test \ - --no-default-features \ --features object \ -p polars-core \ -p polars-arrow @@ -109,6 +108,7 @@ publish: ## Publish Polars crates cargo publish --allow-dirty -p polars-arrow cargo publish --allow-dirty -p polars-json cargo publish --allow-dirty -p polars-core + cargo publish --allow-dirty -p polars-ffi cargo publish --allow-dirty -p polars-ops cargo publish --allow-dirty -p polars-time cargo publish --allow-dirty -p polars-io diff --git a/crates/nano-arrow/Cargo.toml b/crates/nano-arrow/Cargo.toml new file mode 100644 index 000000000000..f51110ea29ca --- /dev/null +++ b/crates/nano-arrow/Cargo.toml @@ -0,0 +1,197 @@ +[package] +name = "nano-arrow" +version = "0.1.0" +authors = ["Jorge C. Leitao ", "Apache Arrow ", "Ritchie Vink"] +edition.workspace = true +homepage.workspace = true +licence = "Apache 2.0 and MIT" +repository.workspace = true +description = "Minimal implementation of the Arrow specification forked from arrow2." + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytemuck.workspace = true +chrono.workspace = true +# for timezone support +chrono-tz = { workspace = true, optional = true } +dyn-clone = "1" +either.workspace = true +foreign_vec = "0.1.0" +hashbrown.workspace = true +num-traits.workspace = true +simdutf8.workspace = true + +# for decimal i256 +ethnum = "1" + +# To efficiently cast numbers to strings +lexical-core = { workspace = true, optional = true } + +fallible-streaming-iterator = { workspace = true, optional = true } +regex = { workspace = true, optional = true } +regex-syntax = { version = "0.7", optional = true } +streaming-iterator = { workspace = true } + +indexmap = { workspace = true, optional = true } + +arrow-format = { version = "0.8", optional = true, features = ["ipc"] } + +hex = { workspace = true, optional = true } + +# for IPC compression +lz4 = { version = "1.24", optional = true } +zstd = { version = "0.12", optional = true } + +base64 = { workspace = true, optional = true } + +# to write to parquet as a stream +futures = { version = "0.3", optional = true } + +# to read IPC as a stream +async-stream = { version = "0.3.2", optional = true } + +# avro support +avro-schema = { version = "0.3", optional = true } + +# for division/remainder optimization at runtime +strength_reduce = { version = "0.2", optional = true } + +# For instruction multiversioning +multiversion = { workspace = true, optional = true } + +# Faster hashing +ahash.workspace = true + +# Support conversion to/from arrow-rs +arrow-array = { version = ">=40", optional = true } +arrow-buffer = { version = ">=40", optional = true } +arrow-data = { version = ">=40", optional = true } +arrow-schema = { version = ">=40", optional = true } + +[target.wasm32-unknown-unknown.dependencies] +getrandom = { version = "0.2", features = ["js"] } + +# parquet support +[dependencies.parquet2] +version = "0.17" +optional = true +default_features = false +features = ["async"] + +[dev-dependencies] +avro-rs = { version = "0.13", features = ["snappy"] } +criterion = "0.4" +crossbeam-channel = "0.5.1" +doc-comment = "0.3" +flate2 = "1" +# used to run formal property testing +proptest = { version = "1", default_features = false, features = ["std"] } +# use for flaky testing +rand = "0.8" +# use for generating and testing random data samples +sample-arrow2 = "0.1" +sample-std = "0.1" +sample-test = "0.1" +# used to test async readers +tokio = { version = "1", features = ["macros", "rt", "fs", "io-util"] } +tokio-util = { version = "0.7", features = ["compat"] } + +[package.metadata.docs.rs] +features = ["full"] +rustdoc-args = ["--cfg", "docsrs"] + +[features] +default = [] +full = [ + "arrow", + "io_ipc", + "io_flight", + "io_ipc_write_async", + "io_ipc_read_async", + "io_ipc_compression", + "io_parquet", + "io_parquet_compression", + "io_avro", + "io_avro_compression", + "io_avro_async", + "regex-syntax", + "compute", + # parses timezones used in timestamp conversions + "chrono-tz", +] +arrow = ["arrow-buffer", "arrow-schema", "arrow-data", "arrow-array"] +io_ipc = ["arrow-format"] +io_ipc_write_async = ["io_ipc", "futures"] +io_ipc_read_async = ["io_ipc", "futures", "async-stream"] +io_ipc_compression = ["lz4", "zstd"] +io_flight = ["io_ipc", "arrow-format/flight-data"] + +# base64 + io_ipc because arrow schemas are stored as base64-encoded ipc format. +io_parquet = ["parquet2", "io_ipc", "base64", "futures", "fallible-streaming-iterator"] + +io_parquet_compression = [ + "io_parquet_zstd", + "io_parquet_gzip", + "io_parquet_snappy", + "io_parquet_lz4", + "io_parquet_brotli", +] + +# sample testing of generated arrow data +io_parquet_sample_test = ["io_parquet"] + +# compression backends +io_parquet_zstd = ["parquet2/zstd"] +io_parquet_snappy = ["parquet2/snappy"] +io_parquet_gzip = ["parquet2/gzip"] +io_parquet_lz4_flex = ["parquet2/lz4_flex"] +io_parquet_lz4 = ["parquet2/lz4"] +io_parquet_brotli = ["parquet2/brotli"] + +# parquet bloom filter functions +io_parquet_bloom_filter = ["parquet2/bloom_filter"] + +io_avro = ["avro-schema"] +io_avro_compression = [ + "avro-schema/compression", +] +io_avro_async = ["avro-schema/async"] + +# the compute kernels. Disabling this significantly reduces compile time. +compute_aggregate = ["multiversion"] +compute_arithmetics_decimal = ["strength_reduce"] +compute_arithmetics = ["strength_reduce", "compute_arithmetics_decimal"] +compute_bitwise = [] +compute_boolean = [] +compute_boolean_kleene = [] +compute_cast = ["lexical-core", "compute_take"] +compute_comparison = ["compute_take", "compute_boolean"] +compute_concatenate = [] +compute_filter = [] +compute_hash = ["multiversion"] +compute_if_then_else = [] +compute_take = [] +compute_temporal = [] +compute = [ + "compute_aggregate", + "compute_arithmetics", + "compute_bitwise", + "compute_boolean", + "compute_boolean_kleene", + "compute_cast", + "compute_comparison", + "compute_concatenate", + "compute_filter", + "compute_hash", + "compute_if_then_else", + "compute_take", + "compute_temporal", +] +simd = [] + +[build-dependencies] +rustc_version = "0.4.0" + +[package.metadata.cargo-all-features] +allowlist = ["compute", "compute_sort", "compute_hash", "compute_nullif"] diff --git a/crates/nano-arrow/src/README.md b/crates/nano-arrow/src/README.md new file mode 100644 index 000000000000..d6371ebc8741 --- /dev/null +++ b/crates/nano-arrow/src/README.md @@ -0,0 +1,32 @@ +# Crate's design + +This document describes the design of this module, and thus the overall crate. +Each module MAY have its own design document, that concerns specifics of that module, and if yes, +it MUST be on each module's `README.md`. + +## Equality + +Array equality is not defined in the Arrow specification. This crate follows the intent of the specification, but there is no guarantee that this no verification that this equals e.g. C++'s definition. + +There is a single source of truth about whether two arrays are equal, and that is via their +equality operators, defined on the module [`array/equal`](array/equal/mod.rs). + +Implementation MUST use these operators for asserting equality, so that all testing follows the same definition of array equality. + +## Error handling + +- Errors from an external dependency MUST be encapsulated on `External`. +- Errors from IO MUST be encapsulated on `Io`. +- This crate MAY return `NotYetImplemented` when the functionality does not exist, or it MAY panic with `unimplemented!`. + +## Logical and physical types + +There is a strict separation between physical and logical types: + +- physical types MUST be implemented via generics +- logical types MUST be implemented via variables (whose value is e.g. an `enum`) +- logical types MUST be declared and implemented on the `datatypes` module + +## Source of undefined behavior + +There is one, and only one, acceptable source of undefined behavior: FFI. It is impossible to prove that data passed via pointers are safe for consumption (only a promise from the specification). diff --git a/crates/nano-arrow/src/array/README.md b/crates/nano-arrow/src/array/README.md new file mode 100644 index 000000000000..af21f91e02ef --- /dev/null +++ b/crates/nano-arrow/src/array/README.md @@ -0,0 +1,73 @@ +# Array module + +This document describes the overall design of this module. + +## Notation: + +- "array" in this module denotes any struct that implements the trait `Array`. +- "mutable array" in this module denotes any struct that implements the trait `MutableArray`. +- words in `code` denote existing terms on this implementation. + +## Arrays: + +- Every arrow array with a different physical representation MUST be implemented as a struct or generic struct. + +- An array MAY have its own module. E.g. `primitive/mod.rs` + +- An array with a null bitmap MUST implement it as `Option` + +- An array MUST be `#[derive(Clone)]` + +- The trait `Array` MUST only be implemented by structs in this module. + +- Every child array on the struct MUST be `Box`. + +- An array MUST implement `try_new(...) -> Self`. This method MUST error iff + the data does not follow the arrow specification, including any sentinel types such as utf8. + +- An array MAY implement `unsafe try_new_unchecked` that skips validation steps that are `O(N)`. + +- An array MUST implement either `new_empty()` or `new_empty(DataType)` that returns a zero-len of `Self`. + +- An array MUST implement either `new_null(length: usize)` or `new_null(DataType, length: usize)` that returns a valid array of length `length` whose all elements are null. + +- An array MAY implement `value(i: usize)` that returns the value at slot `i` ignoring the validity bitmap. + +- functions to create new arrays from native Rust SHOULD be named as follows: + - `from`: from a slice of optional values (e.g. `AsRef<[Option]` for `BooleanArray`) + - `from_slice`: from a slice of values (e.g. `AsRef<[bool]>` for `BooleanArray`) + - `from_trusted_len_iter` from an iterator of trusted len of optional values + - `from_trusted_len_values_iter` from an iterator of trusted len of values + - `try_from_trusted_len_iter` from an fallible iterator of trusted len of optional values + +### Slot offsets + +- An array MUST have a `offset: usize` measuring the number of slots that the array is currently offsetted by if the specification requires. + +- An array MUST implement `fn slice(&self, offset: usize, length: usize) -> Self` that returns an offsetted and/or truncated clone of the array. This function MUST increase the array's offset if it exists. + +- Conversely, `offset` MUST only be changed by `slice`. + +The rational of the above is that it enable us to be fully interoperable with the offset logic supported by the C data interface, while at the same time easily perform array slices +within Rust's type safety mechanism. + +### Mutable Arrays + +- An array MAY have a mutable counterpart. E.g. `MutablePrimitiveArray` is the mutable counterpart of `PrimitiveArray`. + +- Arrays with mutable counterparts MUST have its own module, and have the mutable counterpart declared in `{module}/mutable.rs`. + +- The trait `MutableArray` MUST only be implemented by mutable arrays in this module. + +- A mutable array MUST be `#[derive(Debug)]` + +- A mutable array with a null bitmap MUST implement it as `Option` + +- Converting a `MutableArray` to its immutable counterpart MUST be `O(1)`. Specifically: + - it must not allocate + - it must not cause `O(N)` data transformations + + This is achieved by converting mutable versions to immutable counterparts (e.g. `MutableBitmap -> Bitmap`). + + The rational is that `MutableArray`s can be used to perform in-place operations under + the arrow spec. diff --git a/crates/nano-arrow/src/array/binary/data.rs b/crates/nano-arrow/src/array/binary/data.rs new file mode 100644 index 000000000000..56835dec0c42 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/data.rs @@ -0,0 +1,43 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, BinaryArray}; +use crate::bitmap::Bitmap; +use crate::offset::{Offset, OffsetsBuffer}; + +impl Arrow2Arrow for BinaryArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.offsets().len_proxy()) + .buffers(vec![ + self.offsets.clone().into_inner().into(), + self.values.clone().into(), + ]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let buffers = data.buffers(); + + // Safety: ArrayData is valid + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type, + offsets, + values: buffers[1].clone().into(), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/binary/ffi.rs b/crates/nano-arrow/src/array/binary/ffi.rs new file mode 100644 index 000000000000..3ba66cc130da --- /dev/null +++ b/crates/nano-arrow/src/array/binary/ffi.rs @@ -0,0 +1,63 @@ +use super::BinaryArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for BinaryArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for BinaryArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let values = unsafe { array.buffer::(2) }?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Ok(Self::new(data_type, offsets, values, validity)) + } +} diff --git a/crates/nano-arrow/src/array/binary/fmt.rs b/crates/nano-arrow/src/array/binary/fmt.rs new file mode 100644 index 000000000000..d2a6788ce4d8 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/fmt.rs @@ -0,0 +1,26 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::BinaryArray; +use crate::offset::Offset; + +pub fn write_value(array: &BinaryArray, index: usize, f: &mut W) -> Result { + let bytes = array.value(index); + let writer = |f: &mut W, index| write!(f, "{}", bytes[index]); + + write_vec(f, writer, None, bytes.len(), "None", false) +} + +impl Debug for BinaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + let head = if O::IS_LARGE { + "LargeBinaryArray" + } else { + "BinaryArray" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/binary/from.rs b/crates/nano-arrow/src/array/binary/from.rs new file mode 100644 index 000000000000..73df03531594 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/from.rs @@ -0,0 +1,11 @@ +use std::iter::FromIterator; + +use super::{BinaryArray, MutableBinaryArray}; +use crate::offset::Offset; + +impl> FromIterator> for BinaryArray { + #[inline] + fn from_iter>>(iter: I) -> Self { + MutableBinaryArray::::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/binary/iterator.rs b/crates/nano-arrow/src/array/binary/iterator.rs new file mode 100644 index 000000000000..3fccec58eb50 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/iterator.rs @@ -0,0 +1,42 @@ +use super::{BinaryArray, MutableBinaryValuesArray}; +use crate::array::{ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for BinaryArray { + type Item = &'a [u8]; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`BinaryArray`]. +pub type BinaryValueIter<'a, O> = ArrayValuesIter<'a, BinaryArray>; + +impl<'a, O: Offset> IntoIterator for &'a BinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], BinaryValueIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +/// Iterator of values of an [`MutableBinaryValuesArray`]. +pub type MutableBinaryValuesIter<'a, O> = ArrayValuesIter<'a, MutableBinaryValuesArray>; + +impl<'a, O: Offset> IntoIterator for &'a MutableBinaryValuesArray { + type Item = &'a [u8]; + type IntoIter = MutableBinaryValuesIter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/nano-arrow/src/array/binary/mod.rs b/crates/nano-arrow/src/array/binary/mod.rs new file mode 100644 index 000000000000..ccd58f22d869 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/mod.rs @@ -0,0 +1,423 @@ +use either::Either; + +use super::specification::try_check_offsets_bounds; +use super::{Array, GenericBinaryArray}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; + +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod from; +mod mutable_values; +pub use mutable_values::*; +mod mutable; +pub use mutable::*; + +#[cfg(feature = "arrow")] +mod data; + +/// A [`BinaryArray`] is Arrow's semantically equivalent of an immutable `Vec>>`. +/// It implements [`Array`]. +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use arrow2::array::BinaryArray; +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// +/// let array = BinaryArray::::from([Some([1, 2].as_ref()), None, Some([3].as_ref())]); +/// assert_eq!(array.value(0), &[1, 2]); +/// assert_eq!(array.iter().collect::>(), vec![Some([1, 2].as_ref()), None, Some([3].as_ref())]); +/// assert_eq!(array.values_iter().collect::>(), vec![[1, 2].as_ref(), &[], &[3]]); +/// // the underlying representation: +/// assert_eq!(array.values(), &Buffer::from(vec![1, 2, 3])); +/// assert_eq!(array.offsets().buffer(), &Buffer::from(vec![0, 2, 2, 3])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// ``` +/// +/// # Generic parameter +/// The generic parameter [`Offset`] can only be `i32` or `i64` and tradeoffs maximum array length with +/// memory usage: +/// * the sum of lengths of all elements cannot exceed `Offset::MAX` +/// * the total size of the underlying data is `array.len() * size_of::() + sum of lengths of all elements` +/// +/// # Safety +/// The following invariants hold: +/// * Two consecutives `offsets` casted (`as`) to `usize` are valid slices of `values`. +/// * `len` is equal to `validity.len()`, when defined. +#[derive(Clone)] +pub struct BinaryArray { + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, +} + +impl BinaryArray { + /// Returns a [`BinaryArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "BinaryArray can only be initialized with DataType::Binary or DataType::LargeBinary", + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Creates a new [`BinaryArray`] from slices of `&[u8]`. + pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter()) + } + + /// Creates a new [`BinaryArray`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + MutableBinaryArray::::from(slice).into() + } + + /// Returns an iterator of `Option<&[u8]>` over every element of this array. + pub fn iter(&self) -> ZipValidity<&[u8], BinaryValueIter, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity.as_ref()) + } + + /// Returns an iterator of `&[u8]` over every element of this array, ignoring the validity + pub fn values_iter(&self) -> BinaryValueIter { + BinaryValueIter::new(self) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the element at index `i` + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + + // soundness: the invariant of the struct + self.values.get_unchecked(start..end) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&[u8]> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns the [`DataType`] of this array. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the values of this [`BinaryArray`]. + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the offsets of this [`BinaryArray`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Slices this [`BinaryArray`]. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`BinaryArray`]. + /// # Implementation + /// This function is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, OffsetsBuffer, Buffer, Option) { + let Self { + data_type, + offsets, + values, + validity, + } = self; + (data_type, offsets, values, validity) + } + + /// Try to convert this `BinaryArray` to a `MutableBinaryArray` + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + // Safety: invariants are preserved + Left(bitmap) => Left(BinaryArray::new( + self.data_type, + self.offsets, + self.values, + Some(bitmap), + )), + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets, + values, + Some(mutable_bitmap.into()), + )), + (Left(values), Right(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets.into(), + values, + Some(mutable_bitmap.into()), + )), + (Right(values), Left(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets, + values.into(), + Some(mutable_bitmap.into()), + )), + (Right(values), Right(offsets)) => Right( + MutableBinaryArray::try_new( + self.data_type, + offsets, + values, + Some(mutable_bitmap), + ) + .unwrap(), + ), + }, + } + } else { + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(BinaryArray::new(self.data_type, offsets, values, None)) + }, + (Left(values), Right(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets.into(), + values, + None, + )), + (Right(values), Left(offsets)) => Left(BinaryArray::new( + self.data_type, + offsets, + values.into(), + None, + )), + (Right(values), Right(offsets)) => Right( + MutableBinaryArray::try_new(self.data_type, offsets, values, None).unwrap(), + ), + } + } + } + + /// Creates an empty [`BinaryArray`], i.e. whose `.len` is zero. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, OffsetsBuffer::new(), Buffer::new(), None) + } + + /// Creates an null [`BinaryArray`], i.e. whose `.null_count() == .len()`. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new( + data_type, + Offsets::new_zeroed(length).into(), + Buffer::new(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns the default [`DataType`], `DataType::Binary` or `DataType::LargeBinary` + pub fn default_data_type() -> DataType { + if O::IS_LARGE { + DataType::LargeBinary + } else { + DataType::Binary + } + } + + /// Alias for unwrapping [`Self::try_new`] + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, values, validity).unwrap() + } + + /// Returns a [`BinaryArray`] from an iterator of trusted length. + /// + /// The [`BinaryArray`] is guaranteed to not have a validity + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + MutableBinaryArray::::from_trusted_len_values_iter(iterator).into() + } + + /// Returns a new [`BinaryArray`] from a [`Iterator`] of `&[u8]`. + /// + /// The [`BinaryArray`] is guaranteed to not have a validity + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableBinaryArray::::from_iter_values(iterator).into() + } + + /// Creates a [`BinaryArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + MutableBinaryArray::::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`BinaryArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`BinaryArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked(iterator: I) -> Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + MutableBinaryArray::::try_from_trusted_len_iter_unchecked(iterator).map(|x| x.into()) + } + + /// Creates a [`BinaryArray`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iter: I) -> Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iter) } + } +} + +impl Array for BinaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +unsafe impl GenericBinaryArray for BinaryArray { + #[inline] + fn values(&self) -> &[u8] { + self.values() + } + + #[inline] + fn offsets(&self) -> &[O] { + self.offsets().buffer() + } +} diff --git a/crates/nano-arrow/src/array/binary/mutable.rs b/crates/nano-arrow/src/array/binary/mutable.rs new file mode 100644 index 000000000000..92521b400323 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/mutable.rs @@ -0,0 +1,469 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{BinaryArray, MutableBinaryValuesArray, MutableBinaryValuesIter}; +use crate::array::physical_binary::*; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// The Arrow's equivalent to `Vec>>`. +/// Converting a [`MutableBinaryArray`] into a [`BinaryArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableBinaryArray { + values: MutableBinaryValuesArray, + validity: Option, +} + +impl From> for BinaryArray { + fn from(other: MutableBinaryArray) -> Self { + let validity = other.validity.and_then(|x| { + let validity: Option = x.into(); + validity + }); + let array: BinaryArray = other.values.into(); + array.with_validity(validity) + } +} + +impl Default for MutableBinaryArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBinaryArray { + /// Creates a new empty [`MutableBinaryArray`]. + /// # Implementation + /// This allocates a [`Vec`] of one element + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// Returns a [`MutableBinaryArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + data_type: DataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> Result { + let values = MutableBinaryValuesArray::try_new(data_type, offsets, values)?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity's length must be equal to the number of values", + )); + } + + Ok(Self { values, validity }) + } + + /// Creates a new [`MutableBinaryArray`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } + + fn default_data_type() -> DataType { + BinaryArray::::default_data_type() + } + + /// Initializes a new [`MutableBinaryArray`] with a pre-allocated capacity of slots. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableBinaryArray`] with a pre-allocated capacity of slots and values. + /// # Implementation + /// This does not allocate the validity. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + values: MutableBinaryValuesArray::with_capacities(capacity, values), + validity: None, + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.values.reserve(additional, additional_values); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Pushes a new element to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + pub fn push>(&mut self, value: Option) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableBinaryArray`]. + /// This function returns `None` iff this array is empty + pub fn pop(&mut self) -> Option> { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + fn try_from_iter, I: IntoIterator>>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut primitive = Self::with_capacity(lower); + for item in iterator { + primitive.try_push(item.as_ref())? + } + Ok(primitive) + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity); + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: BinaryArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableBinaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + impl_mutable_array_mut_validity!(); +} + +impl MutableBinaryArray { + /// returns its values. + pub fn values(&self) -> &Vec { + self.values.values() + } + + /// returns its offsets. + pub fn offsets(&self) -> &Offsets { + self.values.offsets() + } + + /// Returns an iterator of `Option<&[u8]>` + pub fn iter(&self) -> ZipValidity<&[u8], MutableBinaryValuesIter, BitmapIter> { + ZipValidity::new(self.values_iter(), self.validity.as_ref().map(|x| x.iter())) + } + + /// Returns an iterator over the values of this array + pub fn values_iter(&self) -> MutableBinaryValuesIter { + self.values.iter() + } +} + +impl MutableArray for MutableBinaryArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: BinaryArray = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: BinaryArray = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + self.values.data_type() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&[u8]>(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator> for MutableBinaryArray { + fn from_iter>>(iter: I) -> Self { + Self::try_from_iter(iter).unwrap() + } +} + +impl MutableBinaryArray { + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator>, + { + let (validity, offsets, values) = trusted_len_unzip(iterator); + + Self::try_new(Self::default_data_type(), offsets, values, validity).unwrap() + } + + /// Creates a [`MutableBinaryArray`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked, I: Iterator>( + iterator: I, + ) -> Self { + let (offsets, values) = trusted_len_values_iter(iterator); + Self::try_new(Self::default_data_type(), offsets, values, None).unwrap() + } + + /// Creates a new [`BinaryArray`] from a [`TrustedLen`] of `&[u8]`. + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_values_iter_unchecked(iterator) } + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef<[u8]>, + I: IntoIterator, E>>, + { + let iterator = iterator.into_iter(); + + // soundness: assumed trusted len + let (mut validity, offsets, values) = try_trusted_len_unzip(iterator)?; + + if validity.as_mut().unwrap().unset_bits() == 0 { + validity = None; + } + + Ok(Self::try_new(Self::default_data_type(), offsets, values, validity).unwrap()) + } + + /// Creates a [`MutableBinaryArray`] from an falible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: AsRef<[u8]>, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of trusted length. + /// This differs from `extend_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen, + { + // Safety: The iterator is `TrustedLen` + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of values. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + let length = self.values.len(); + self.values.extend(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableBinaryArray`] from an `iterator` of values of trusted length. + /// This differs from `extend_trusted_len_unchecked` which accepts iterator of optional + /// values. + /// # Safety + /// The `iterator` must be [`TrustedLen`] + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + let length = self.values.len(); + self.values.extend_trusted_len_unchecked(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen>, + { + // Safety: The iterator is `TrustedLen` + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutableBinaryArray`] from an iterator of [`TrustedLen`] + /// # Safety + /// The `iterator` must be [`TrustedLen`] + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator>, + { + if self.validity.is_none() { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + + self.values + .extend_from_trusted_len_iter(self.validity.as_mut().unwrap(), iterator); + } + + /// Creates a new [`MutableBinaryArray`] from a [`Iterator`] of `&[u8]`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + let (offsets, values) = values_iter(iterator); + Self::try_new(Self::default_data_type(), offsets, values, None).unwrap() + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator, E>>, + T: AsRef<[u8]>, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend> for MutableBinaryArray { + fn extend>>(&mut self, iter: I) { + self.try_extend(iter).unwrap(); + } +} + +impl> TryExtend> for MutableBinaryArray { + fn try_extend>>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush> for MutableBinaryArray { + fn try_push(&mut self, value: Option) -> Result<()> { + match value { + Some(value) => { + self.values.try_push(value.as_ref())?; + + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(""); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } +} + +impl PartialEq for MutableBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableBinaryArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/nano-arrow/src/array/binary/mutable_values.rs b/crates/nano-arrow/src/array/binary/mutable_values.rs new file mode 100644 index 000000000000..e73f0223ec44 --- /dev/null +++ b/crates/nano-arrow/src/array/binary/mutable_values.rs @@ -0,0 +1,374 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{BinaryArray, MutableBinaryArray}; +use crate::array::physical_binary::*; +use crate::array::specification::try_check_offsets_bounds; +use crate::array::{ + Array, ArrayAccessor, ArrayValuesIter, MutableArray, TryExtend, TryExtendFromSelf, TryPush, +}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`BinaryArray`]. It differs +/// from [`MutableBinaryArray`] in that it builds non-null [`BinaryArray`]. +#[derive(Debug, Clone)] +pub struct MutableBinaryValuesArray { + data_type: DataType, + offsets: Offsets, + values: Vec, +} + +impl From> for BinaryArray { + fn from(other: MutableBinaryValuesArray) -> Self { + BinaryArray::::new( + other.data_type, + other.offsets.into(), + other.values.into(), + None, + ) + } +} + +impl From> for MutableBinaryArray { + fn from(other: MutableBinaryValuesArray) -> Self { + MutableBinaryArray::::try_new(other.data_type, other.offsets, other.values, None) + .expect("MutableBinaryValuesArray is consistent with MutableBinaryArray") + } +} + +impl Default for MutableBinaryValuesArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBinaryValuesArray { + /// Returns an empty [`MutableBinaryValuesArray`]. + pub fn new() -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::new(), + values: Vec::::new(), + } + } + + /// Returns a [`MutableBinaryValuesArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Binary` or `LargeBinary`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new(data_type: DataType, offsets: Offsets, values: Vec) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "MutableBinaryValuesArray can only be initialized with DataType::Binary or DataType::LargeBinary", + )); + } + + Ok(Self { + data_type, + offsets, + values, + }) + } + + /// Returns the default [`DataType`] of this container: [`DataType::Utf8`] or [`DataType::LargeUtf8`] + /// depending on the generic [`Offset`]. + pub fn default_data_type() -> DataType { + BinaryArray::::default_data_type() + } + + /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableBinaryValuesArray`] with a pre-allocated capacity of items and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::::with_capacity(capacity), + values: Vec::::with_capacity(values), + } + } + + /// returns its values. + #[inline] + pub fn values(&self) -> &Vec { + &self.values + } + + /// returns its offsets. + #[inline] + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// Reserves `additional` elements and `additional_values` on the values. + #[inline] + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.offsets.reserve(additional); + self.values.reserve(additional_values); + } + + /// Returns the capacity in number of items + pub fn capacity(&self) -> usize { + self.offsets.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Pushes a new item to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: T) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableBinaryValuesArray`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option> { + if self.len() == 0 { + return None; + } + self.offsets.pop()?; + let start = self.offsets.last().to_usize(); + let value = self.values.split_off(start); + Some(value.to_vec()) + } + + /// Returns the value of the element at index `i`. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end(i); + + // soundness: the invariant of the struct + self.values.get_unchecked(start..end) + } + + /// Returns an iterator of `&[u8]` + pub fn iter(&self) -> ArrayValuesIter { + ArrayValuesIter::new(self) + } + + /// Shrinks the capacity of the [`MutableBinaryValuesArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + } + + /// Extract the low-end APIs from the [`MutableBinaryValuesArray`]. + pub fn into_inner(self) -> (DataType, Offsets, Vec) { + (self.data_type, self.offsets, self.values) + } +} + +impl MutableArray for MutableBinaryValuesArray { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let (data_type, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(data_type, offsets.into(), values.into(), None).boxed() + } + + fn as_arc(&mut self) -> Arc { + let (data_type, offsets, values) = std::mem::take(self).into_inner(); + BinaryArray::new(data_type, offsets.into(), values.into(), None).arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&[u8]>(b"") + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator

for MutableBinaryValuesArray { + fn from_iter>(iter: I) -> Self { + let (offsets, values) = values_iter(iter.into_iter()); + Self::try_new(Self::default_data_type(), offsets, values).unwrap() + } +} + +impl MutableBinaryValuesArray { + pub(crate) unsafe fn extend_from_trusted_len_iter( + &mut self, + validity: &mut MutableBitmap, + iterator: I, + ) where + P: AsRef<[u8]>, + I: Iterator>, + { + extend_from_trusted_len_iter(&mut self.offsets, &mut self.values, validity, iterator); + } + + /// Extends the [`MutableBinaryValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableBinaryValuesArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef<[u8]>, + I: Iterator, + { + extend_from_trusted_len_values_iter(&mut self.offsets, &mut self.values, iterator); + } + + /// Creates a [`MutableBinaryValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: TrustedLen, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Returns a new [`MutableBinaryValuesArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef<[u8]>, + I: Iterator, + { + let (offsets, values) = trusted_len_values_iter(iterator); + Self::try_new(Self::default_data_type(), offsets, values).unwrap() + } + + /// Returns a new [`MutableBinaryValuesArray`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + pub fn try_from_iter, I: IntoIterator>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator>, + T: AsRef<[u8]>, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend for MutableBinaryValuesArray { + fn extend>(&mut self, iter: I) { + extend_from_values_iter(&mut self.offsets, &mut self.values, iter.into_iter()); + } +} + +impl> TryExtend for MutableBinaryValuesArray { + fn try_extend>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush for MutableBinaryValuesArray { + #[inline] + fn try_push(&mut self, value: T) -> Result<()> { + let bytes = value.as_ref(); + self.values.extend_from_slice(bytes); + self.offsets.try_push_usize(bytes.len()) + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableBinaryValuesArray { + type Item = &'a [u8]; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +impl TryExtendFromSelf for MutableBinaryValuesArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + self.values.extend_from_slice(&other.values); + self.offsets.try_extend_from_self(&other.offsets) + } +} diff --git a/crates/nano-arrow/src/array/boolean/data.rs b/crates/nano-arrow/src/array/boolean/data.rs new file mode 100644 index 000000000000..e93aeb3b8d2b --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/data.rs @@ -0,0 +1,36 @@ +use arrow_buffer::{BooleanBuffer, NullBuffer}; +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, BooleanArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::DataType; + +impl Arrow2Arrow for BooleanArray { + fn to_data(&self) -> ArrayData { + let buffer = NullBuffer::from(self.values.clone()); + + let builder = ArrayDataBuilder::new(arrow_schema::DataType::Boolean) + .len(buffer.len()) + .offset(buffer.offset()) + .buffers(vec![buffer.into_inner().into_inner()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + assert_eq!(data.data_type(), &arrow_schema::DataType::Boolean); + + let buffers = data.buffers(); + let buffer = BooleanBuffer::new(buffers[0].clone(), data.offset(), data.len()); + // Use NullBuffer to compute set count + let values = Bitmap::from_null_buffer(NullBuffer::new(buffer)); + + Self { + data_type: DataType::Boolean, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/boolean/ffi.rs b/crates/nano-arrow/src/array/boolean/ffi.rs new file mode 100644 index 000000000000..64f22de81d5d --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/ffi.rs @@ -0,0 +1,54 @@ +use super::BooleanArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for BooleanArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for BooleanArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.bitmap(1) }?; + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/boolean/fmt.rs b/crates/nano-arrow/src/array/boolean/fmt.rs new file mode 100644 index 000000000000..229a01cd3e03 --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/fmt.rs @@ -0,0 +1,17 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::BooleanArray; + +pub fn write_value(array: &BooleanArray, index: usize, f: &mut W) -> Result { + write!(f, "{}", array.value(index)) +} + +impl Debug for BooleanArray { + fn fmt(&self, f: &mut Formatter) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + write!(f, "BooleanArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/boolean/from.rs b/crates/nano-arrow/src/array/boolean/from.rs new file mode 100644 index 000000000000..81a5395ccc06 --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/from.rs @@ -0,0 +1,15 @@ +use std::iter::FromIterator; + +use super::{BooleanArray, MutableBooleanArray}; + +impl]>> From

for BooleanArray { + fn from(slice: P) -> Self { + MutableBooleanArray::from(slice).into() + } +} + +impl>> FromIterator for BooleanArray { + fn from_iter>(iter: I) -> Self { + MutableBooleanArray::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/boolean/iterator.rs b/crates/nano-arrow/src/array/boolean/iterator.rs new file mode 100644 index 000000000000..8e914c98faab --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/iterator.rs @@ -0,0 +1,55 @@ +use super::super::MutableArray; +use super::{BooleanArray, MutableBooleanArray}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::IntoIter; + +impl<'a> IntoIterator for &'a BooleanArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl IntoIterator for BooleanArray { + type Item = Option; + type IntoIter = ZipValidity; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + let (_, values, validity) = self.into_inner(); + let values = values.into_iter(); + let validity = + validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.into_iter())); + ZipValidity::new(values, validity) + } +} + +impl<'a> IntoIterator for &'a MutableBooleanArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MutableBooleanArray { + /// Returns an iterator over the optional values of this [`MutableBooleanArray`]. + #[inline] + pub fn iter(&'a self) -> ZipValidity, BitmapIter<'a>> { + ZipValidity::new( + self.values().iter(), + self.validity().as_ref().map(|x| x.iter()), + ) + } + + /// Returns an iterator over the values of this [`MutableBooleanArray`] + #[inline] + pub fn values_iter(&'a self) -> BitmapIter<'a> { + self.values().iter() + } +} diff --git a/crates/nano-arrow/src/array/boolean/mod.rs b/crates/nano-arrow/src/array/boolean/mod.rs new file mode 100644 index 000000000000..93d484120faf --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/mod.rs @@ -0,0 +1,383 @@ +use either::Either; + +use super::Array; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod from; +mod iterator; +mod mutable; + +pub use iterator::*; +pub use mutable::*; + +/// A [`BooleanArray`] is Arrow's semantically equivalent of an immutable `Vec>`. +/// It implements [`Array`]. +/// +/// One way to think about a [`BooleanArray`] is `(DataType, Arc>, Option>>)` +/// where: +/// * the first item is the array's logical type +/// * the second is the immutable values +/// * the third is the immutable validity (whether a value is null or not as a bitmap). +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// +/// let array = BooleanArray::from([Some(true), None, Some(false)]); +/// assert_eq!(array.value(0), true); +/// assert_eq!(array.iter().collect::>(), vec![Some(true), None, Some(false)]); +/// assert_eq!(array.values_iter().collect::>(), vec![true, false, false]); +/// // the underlying representation +/// assert_eq!(array.values(), &Bitmap::from([true, false, false])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// +/// ``` +#[derive(Clone)] +pub struct BooleanArray { + data_type: DataType, + values: Bitmap, + validity: Option, +} + +impl BooleanArray { + /// The canonical method to create a [`BooleanArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + pub fn try_new( + data_type: DataType, + values: Bitmap, + validity: Option, + ) -> Result { + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != PhysicalType::Boolean { + return Err(Error::oos( + "BooleanArray can only be initialized with a DataType whose physical type is Boolean", + )); + } + + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Alias to `Self::try_new().unwrap()` + pub fn new(data_type: DataType, values: Bitmap, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Returns an iterator over the optional values of this [`BooleanArray`]. + #[inline] + pub fn iter(&self) -> ZipValidity { + ZipValidity::new_with_validity(self.values().iter(), self.validity()) + } + + /// Returns an iterator over the values of this [`BooleanArray`]. + #[inline] + pub fn values_iter(&self) -> BitmapIter { + self.values().iter() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// The values [`Bitmap`]. + /// Values on null slots are undetermined (they can be anything). + #[inline] + pub fn values(&self) -> &Bitmap { + &self.values + } + + /// Returns the optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the arrays' [`DataType`]. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the value at index `i` + /// # Panic + /// This function panics iff `i >= self.len()`. + #[inline] + pub fn value(&self, i: usize) -> bool { + self.values.get_bit(i) + } + + /// Returns the element at index `i` as bool + /// # Safety + /// Caller must be sure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> bool { + self.values.get_bit_unchecked(i) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Slices this [`BooleanArray`]. + /// # Implementation + /// This operation is `O(1)` as it amounts to increase up to two ref counts. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`BooleanArray`]. + /// # Implementation + /// This operation is `O(1)` as it amounts to increase two ref counts. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.values.slice_unchecked(offset, length); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns a clone of this [`BooleanArray`] with new values. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[must_use] + pub fn with_values(&self, values: Bitmap) -> Self { + let mut out = self.clone(); + out.set_values(values); + out + } + + /// Sets the values of this [`BooleanArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_values(&mut self, values: Bitmap) { + assert_eq!( + values.len(), + self.len(), + "values length must be equal to this arrays length" + ); + self.values = values; + } + + /// Applies a function `f` to the values of this array, cloning the values + /// iff they are being shared with others + /// + /// This is an API to use clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics if the function modifies the length of the [`MutableBitmap`]. + pub fn apply_values_mut(&mut self, f: F) { + let values = std::mem::take(&mut self.values); + let mut values = values.make_mut(); + f(&mut values); + if let Some(validity) = &self.validity { + assert_eq!(validity.len(), values.len()); + } + self.values = values.into(); + } + + /// Try to convert this [`BooleanArray`] to a [`MutableBooleanArray`] + pub fn into_mut(self) -> Either { + use Either::*; + + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + Left(bitmap) => Left(BooleanArray::new(self.data_type, self.values, Some(bitmap))), + Right(mutable_bitmap) => match self.values.into_mut() { + Left(immutable) => Left(BooleanArray::new( + self.data_type, + immutable, + Some(mutable_bitmap.into()), + )), + Right(mutable) => Right( + MutableBooleanArray::try_new(self.data_type, mutable, Some(mutable_bitmap)) + .unwrap(), + ), + }, + } + } else { + match self.values.into_mut() { + Left(immutable) => Left(BooleanArray::new(self.data_type, immutable, None)), + Right(mutable) => { + Right(MutableBooleanArray::try_new(self.data_type, mutable, None).unwrap()) + }, + } + } + } + + /// Returns a new empty [`BooleanArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, Bitmap::new(), None) + } + + /// Returns a new [`BooleanArray`] whose all slots are null / `None`. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let bitmap = Bitmap::new_zeroed(length); + Self::new(data_type, bitmap.clone(), Some(bitmap)) + } + + /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. + #[inline] + pub fn from_trusted_len_values_iter>(iterator: I) -> Self { + MutableBooleanArray::from_trusted_len_values_iter(iterator).into() + } + + /// Creates a new [`BooleanArray`] from an [`TrustedLen`] of `bool`. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked>( + iterator: I, + ) -> Self { + MutableBooleanArray::from_trusted_len_values_iter_unchecked(iterator).into() + } + + /// Creates a new [`BooleanArray`] from a slice of `bool`. + #[inline] + pub fn from_slice>(slice: P) -> Self { + MutableBooleanArray::from_slice(slice).into() + } + + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + MutableBooleanArray::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + MutableBooleanArray::from_trusted_len_iter(iterator).into() + } + + /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked(iterator: I) -> Result + where + P: std::borrow::Borrow, + I: Iterator, E>>, + { + Ok(MutableBooleanArray::try_from_trusted_len_iter_unchecked(iterator)?.into()) + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + Ok(MutableBooleanArray::try_from_trusted_len_iter(iterator)?.into()) + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, Bitmap, Option) { + let Self { + data_type, + values, + validity, + } = self; + (data_type, values, validity) + } + + /// Creates a `[BooleanArray]` from its internal representation. + /// This is the inverted from `[BooleanArray::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + data_type: DataType, + values: Bitmap, + validity: Option, + ) -> Self { + Self { + data_type, + values, + validity, + } + } +} + +impl Array for BooleanArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/boolean/mutable.rs b/crates/nano-arrow/src/array/boolean/mutable.rs new file mode 100644 index 000000000000..9961cadcb2fd --- /dev/null +++ b/crates/nano-arrow/src/array/boolean/mutable.rs @@ -0,0 +1,564 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::BooleanArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +/// The Arrow's equivalent to `Vec>`, but with `1/16` of its size. +/// Converting a [`MutableBooleanArray`] into a [`BooleanArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableBooleanArray { + data_type: DataType, + values: MutableBitmap, + validity: Option, +} + +impl From for BooleanArray { + fn from(other: MutableBooleanArray) -> Self { + BooleanArray::new( + other.data_type, + other.values.into(), + other.validity.map(|x| x.into()), + ) + } +} + +impl]>> From

for MutableBooleanArray { + /// Creates a new [`MutableBooleanArray`] out of a slice of Optional `bool`. + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } +} + +impl Default for MutableBooleanArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableBooleanArray { + /// Creates an new empty [`MutableBooleanArray`]. + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// The canonical method to create a [`MutableBooleanArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Boolean`]. + pub fn try_new( + data_type: DataType, + values: MutableBitmap, + validity: Option, + ) -> Result { + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != PhysicalType::Boolean { + return Err(Error::oos( + "MutableBooleanArray can only be initialized with a DataType whose physical type is Boolean", + )); + } + + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Creates an new [`MutableBooleanArray`] with a capacity of values. + pub fn with_capacity(capacity: usize) -> Self { + Self { + data_type: DataType::Boolean, + values: MutableBitmap::with_capacity(capacity), + validity: None, + } + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Pushes a new entry to [`MutableBooleanArray`]. + pub fn push(&mut self, value: Option) { + match value { + Some(value) => { + self.values.push(value); + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(false); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + } + + /// Pop an entry from [`MutableBooleanArray`]. + /// Note If the values is empty, this method will return None. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| value)) + .unwrap_or_else(|| Some(value)) + } + + /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len` which accepts in iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + I: TrustedLen, + { + // Safety: `I` is `TrustedLen` + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableBooleanArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len_unchecked`, which accepts in iterator of optional values. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + I: Iterator, + { + let (_, upper) = iterator.size_hint(); + let additional = + upper.expect("extend_trusted_len_values_unchecked requires an upper limit"); + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + + self.values.extend_from_trusted_len_iter_unchecked(iterator) + } + + /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + // Safety: `I` is `TrustedLen` + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutableBooleanArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: Iterator>, + { + if let Some(validity) = self.validity.as_mut() { + extend_trusted_len_unzip(iterator, validity, &mut self.values); + } else { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + + extend_trusted_len_unzip(iterator, &mut validity, &mut self.values); + + if validity.unset_bits() > 0 { + self.validity = Some(validity); + } + } + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: BooleanArray = self.into(); + Arc::new(a) + } +} + +/// Getters +impl MutableBooleanArray { + /// Returns its values. + pub fn values(&self) -> &MutableBitmap { + &self.values + } +} + +/// Setters +impl MutableBooleanArray { + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Panic + /// Panics iff index is larger than `self.len()`. + pub fn set(&mut self, index: usize, value: Option) { + self.values.set(index, value.unwrap_or_default()); + + if value.is_none() && self.validity.is_none() { + // When the validity is None, all elements so far are valid. When one of the elements is set of null, + // the validity must be initialized. + self.validity = Some(MutableBitmap::from_trusted_len_iter( + std::iter::repeat(true).take(self.len()), + )); + } + if let Some(x) = self.validity.as_mut() { + x.set(index, value.is_some()) + } + } +} + +/// From implementations +impl MutableBooleanArray { + /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. + #[inline] + pub fn from_trusted_len_values_iter>(iterator: I) -> Self { + Self::try_new( + DataType::Boolean, + MutableBitmap::from_trusted_len_iter(iterator), + None, + ) + .unwrap() + } + + /// Creates a new [`MutableBooleanArray`] from an [`TrustedLen`] of `bool`. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked>( + iterator: I, + ) -> Self { + let mut mutable = MutableBitmap::new(); + mutable.extend_from_trusted_len_iter_unchecked(iterator); + MutableBooleanArray::try_new(DataType::Boolean, mutable, None).unwrap() + } + + /// Creates a new [`MutableBooleanArray`] from a slice of `bool`. + #[inline] + pub fn from_slice>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter().copied()) + } + + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// Use this over [`BooleanArray::from_trusted_len_iter`] when the iterator is trusted len + /// but this crate does not mark it as such. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + let (validity, values) = trusted_len_unzip(iterator); + + Self::try_new(DataType::Boolean, values, validity).unwrap() + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + // Safety: `I` is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`BooleanArray`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: std::borrow::Borrow, + I: Iterator, E>>, + { + let (validity, values) = try_trusted_len_unzip(iterator)?; + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + Ok(Self::try_new(DataType::Boolean, values, validity).unwrap()) + } + + /// Creates a [`BooleanArray`] from a [`TrustedLen`]. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + // Safety: `I` is `TrustedLen` + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Shrinks the capacity of the [`MutableBooleanArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +/// Creates a Bitmap and an optional [`MutableBitmap`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (Option, MutableBitmap) +where + P: std::borrow::Borrow, + I: Iterator>, +{ + let mut validity = MutableBitmap::new(); + let mut values = MutableBitmap::new(); + + extend_trusted_len_unzip(iterator, &mut validity, &mut values); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + (validity, values) +} + +/// Extends validity [`MutableBitmap`] and values [`MutableBitmap`] from an iterator of `Option`. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn extend_trusted_len_unzip( + iterator: I, + validity: &mut MutableBitmap, + values: &mut MutableBitmap, +) where + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_trusted_len_unzip requires an upper limit"); + + // Length of the array before new values are pushed, + // variable created for assertion post operation + let pre_length = values.len(); + + validity.reserve(additional); + values.reserve(additional); + + for item in iterator { + let item = if let Some(item) = item { + validity.push_unchecked(true); + *item.borrow() + } else { + validity.push_unchecked(false); + bool::default() + }; + values.push_unchecked(item); + } + + debug_assert_eq!( + values.len(), + pre_length + additional, + "Trusted iterator length was not accurately reported" + ); +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(MutableBitmap, MutableBitmap), E> +where + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut values = MutableBitmap::with_capacity(len); + + for item in iterator { + let item = if let Some(item) = item? { + null.push(true); + *item.borrow() + } else { + null.push(false); + false + }; + values.push(item); + } + assert_eq!( + values.len(), + len, + "Trusted iterator length was not accurately reported" + ); + values.set_len(len); + null.set_len(len); + + Ok((null, values)) +} + +impl>> FromIterator for MutableBooleanArray { + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut validity = MutableBitmap::with_capacity(lower); + + let values: MutableBitmap = iter + .map(|item| { + if let Some(a) = item.borrow() { + validity.push(true); + *a + } else { + validity.push(false); + false + } + }) + .collect(); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + MutableBooleanArray::try_new(DataType::Boolean, values, validity).unwrap() + } +} + +impl MutableArray for MutableBooleanArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: BooleanArray = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: BooleanArray = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl Extend> for MutableBooleanArray { + fn extend>>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + iter.for_each(|x| self.push(x)) + } +} + +impl TryExtend> for MutableBooleanArray { + /// This is infalible and is implemented for consistency with all other types + fn try_extend>>(&mut self, iter: I) -> Result<(), Error> { + self.extend(iter); + Ok(()) + } +} + +impl TryPush> for MutableBooleanArray { + /// This is infalible and is implemented for consistency with all other types + fn try_push(&mut self, item: Option) -> Result<(), Error> { + self.push(item); + Ok(()) + } +} + +impl PartialEq for MutableBooleanArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableBooleanArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { + self.values + .extend_from_slice_unchecked(slice, 0, other.values.len()); + } + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/data.rs b/crates/nano-arrow/src/array/dictionary/data.rs new file mode 100644 index 000000000000..ecc763c350b3 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/data.rs @@ -0,0 +1,49 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{ + from_data, to_data, Arrow2Arrow, DictionaryArray, DictionaryKey, PrimitiveArray, +}; +use crate::datatypes::{DataType, PhysicalType}; + +impl Arrow2Arrow for DictionaryArray { + fn to_data(&self) -> ArrayData { + let keys = self.keys.to_data(); + let builder = keys + .into_builder() + .data_type(self.data_type.clone().into()) + .child_data(vec![to_data(self.values.as_ref())]); + + // Safety: Dictionary is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let key = match data.data_type() { + arrow_schema::DataType::Dictionary(k, _) => k.as_ref(), + d => panic!("unsupported dictionary type {d}"), + }; + + let data_type = DataType::from(data.data_type().clone()); + assert_eq!( + data_type.to_physical_type(), + PhysicalType::Dictionary(K::KEY_TYPE) + ); + + let key_builder = ArrayDataBuilder::new(key.clone()) + .buffers(vec![data.buffers()[0].clone()]) + .offset(data.offset()) + .len(data.len()) + .nulls(data.nulls().cloned()); + + // Safety: Dictionary is valid + let key_data = unsafe { key_builder.build_unchecked() }; + let keys = PrimitiveArray::from_data(&key_data); + let values = from_data(&data.child_data()[0]); + + Self { + data_type, + keys, + values, + } + } +} diff --git a/crates/nano-arrow/src/array/dictionary/ffi.rs b/crates/nano-arrow/src/array/dictionary/ffi.rs new file mode 100644 index 000000000000..946c850c48b1 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/ffi.rs @@ -0,0 +1,41 @@ +use super::{DictionaryArray, DictionaryKey}; +use crate::array::{FromFfi, PrimitiveArray, ToFfi}; +use crate::error::Error; +use crate::ffi; + +unsafe impl ToFfi for DictionaryArray { + fn buffers(&self) -> Vec> { + self.keys.buffers() + } + + fn offset(&self) -> Option { + self.keys.offset() + } + + fn to_ffi_aligned(&self) -> Self { + Self { + data_type: self.data_type.clone(), + keys: self.keys.to_ffi_aligned(), + values: self.values.clone(), + } + } +} + +impl FromFfi for DictionaryArray { + unsafe fn try_from_ffi(array: A) -> Result { + // keys: similar to PrimitiveArray, but the datatype is the inner one + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + let data_type = array.data_type().clone(); + + let keys = PrimitiveArray::::try_new(K::PRIMITIVE.into(), values, validity)?; + let values = array + .dictionary()? + .ok_or_else(|| Error::oos("Dictionary Array must contain a dictionary in ffi"))?; + let values = ffi::try_from(values)?; + + // the assumption of this trait + DictionaryArray::::try_new_unchecked(data_type, keys, values) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/fmt.rs b/crates/nano-arrow/src/array/dictionary/fmt.rs new file mode 100644 index 000000000000..b3ce55515902 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/fmt.rs @@ -0,0 +1,31 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::{DictionaryArray, DictionaryKey}; +use crate::array::Array; + +pub fn write_value( + array: &DictionaryArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let keys = array.keys(); + let values = array.values(); + + if keys.is_valid(index) { + let key = array.key_value(index); + get_display(values.as_ref(), null)(f, key) + } else { + write!(f, "{null}") + } +} + +impl Debug for DictionaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "DictionaryArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/iterator.rs b/crates/nano-arrow/src/array/dictionary/iterator.rs new file mode 100644 index 000000000000..68e95ca86fed --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/iterator.rs @@ -0,0 +1,67 @@ +use super::{DictionaryArray, DictionaryKey}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::scalar::Scalar; +use crate::trusted_len::TrustedLen; + +/// Iterator of values of an `ListArray`. +pub struct DictionaryValuesIter<'a, K: DictionaryKey> { + array: &'a DictionaryArray, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey> DictionaryValuesIter<'a, K> { + #[inline] + pub fn new(array: &'a DictionaryArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a, K: DictionaryKey> Iterator for DictionaryValuesIter<'a, K> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(self.array.value(old)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a, K: DictionaryKey> TrustedLen for DictionaryValuesIter<'a, K> {} + +impl<'a, K: DictionaryKey> DoubleEndedIterator for DictionaryValuesIter<'a, K> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(self.array.value(self.end)) + } + } +} + +type ValuesIter<'a, K> = DictionaryValuesIter<'a, K>; +type ZipIter<'a, K> = ZipValidity, ValuesIter<'a, K>, BitmapIter<'a>>; + +impl<'a, K: DictionaryKey> IntoIterator for &'a DictionaryArray { + type Item = Option>; + type IntoIter = ZipIter<'a, K>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/nano-arrow/src/array/dictionary/mod.rs b/crates/nano-arrow/src/array/dictionary/mod.rs new file mode 100644 index 000000000000..2ffb08c01c40 --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/mod.rs @@ -0,0 +1,413 @@ +use std::hash::Hash; +use std::hint::unreachable_unchecked; + +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, IntegerType}; +use crate::error::Error; +use crate::scalar::{new_scalar, Scalar}; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +use crate::array::specification::check_indexes_unchecked; +mod typed_iterator; +mod value_map; + +pub use iterator::*; +pub use mutable::*; + +use super::primitive::PrimitiveArray; +use super::specification::check_indexes; +use super::{new_empty_array, new_null_array, Array}; +use crate::array::dictionary::typed_iterator::{DictValue, DictionaryValuesIterTyped}; + +/// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. +/// # Safety +/// +/// Any implementation of this trait must ensure that `always_fits_usize` only +/// returns `true` if all values succeeds on `value::try_into::().unwrap()`. +pub unsafe trait DictionaryKey: NativeType + TryInto + TryFrom + Hash { + /// The corresponding [`IntegerType`] of this key + const KEY_TYPE: IntegerType; + + /// Represents this key as a `usize`. + /// # Safety + /// The caller _must_ have checked that the value can be casted to `usize`. + #[inline] + unsafe fn as_usize(self) -> usize { + match self.try_into() { + Ok(v) => v, + Err(_) => unreachable_unchecked(), + } + } + + /// If the key type always can be converted to `usize`. + fn always_fits_usize() -> bool { + false + } +} + +unsafe impl DictionaryKey for i8 { + const KEY_TYPE: IntegerType = IntegerType::Int8; +} +unsafe impl DictionaryKey for i16 { + const KEY_TYPE: IntegerType = IntegerType::Int16; +} +unsafe impl DictionaryKey for i32 { + const KEY_TYPE: IntegerType = IntegerType::Int32; +} +unsafe impl DictionaryKey for i64 { + const KEY_TYPE: IntegerType = IntegerType::Int64; +} +unsafe impl DictionaryKey for u8 { + const KEY_TYPE: IntegerType = IntegerType::UInt8; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u16 { + const KEY_TYPE: IntegerType = IntegerType::UInt16; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u32 { + const KEY_TYPE: IntegerType = IntegerType::UInt32; + + fn always_fits_usize() -> bool { + true + } +} +unsafe impl DictionaryKey for u64 { + const KEY_TYPE: IntegerType = IntegerType::UInt64; + + #[cfg(target_pointer_width = "64")] + fn always_fits_usize() -> bool { + true + } +} + +/// An [`Array`] whose values are stored as indices. This [`Array`] is useful when the cardinality of +/// values is low compared to the length of the [`Array`]. +/// +/// # Safety +/// This struct guarantees that each item of [`DictionaryArray::keys`] is castable to `usize` and +/// its value is smaller than [`DictionaryArray::values`]`.len()`. In other words, you can safely +/// use `unchecked` calls to retrieve the values +#[derive(Clone)] +pub struct DictionaryArray { + data_type: DataType, + keys: PrimitiveArray, + values: Box, +} + +fn check_data_type( + key_type: IntegerType, + data_type: &DataType, + values_data_type: &DataType, +) -> Result<(), Error> { + if let DataType::Dictionary(key, value, _) = data_type.to_logical_type() { + if *key != key_type { + return Err(Error::oos( + "DictionaryArray must be initialized with a DataType::Dictionary whose integer is compatible to its keys", + )); + } + if value.as_ref().to_logical_type() != values_data_type.to_logical_type() { + return Err(Error::oos( + "DictionaryArray must be initialized with a DataType::Dictionary whose value is equal to its values", + )); + } + } else { + return Err(Error::oos( + "DictionaryArray must be initialized with logical DataType::Dictionary", + )); + } + Ok(()) +} + +impl DictionaryArray { + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * the `data_type`'s logical type is not a `DictionaryArray` + /// * the `data_type`'s keys is not compatible with `keys` + /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_new( + data_type: DataType, + keys: PrimitiveArray, + values: Box, + ) -> Result { + check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + + if keys.null_count() != keys.len() { + if K::always_fits_usize() { + // safety: we just checked that conversion to `usize` always + // succeeds + unsafe { check_indexes_unchecked(keys.values(), values.len()) }?; + } else { + check_indexes(keys.values(), values.len())?; + } + } + + Ok(Self { + data_type, + keys, + values, + }) + } + + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_from_keys(keys: PrimitiveArray, values: Box) -> Result { + let data_type = Self::default_data_type(values.data_type().clone()); + Self::try_new(data_type, keys, values) + } + + /// Returns a new [`DictionaryArray`]. + /// # Errors + /// This function errors iff + /// * the `data_type`'s logical type is not a `DictionaryArray` + /// * the `data_type`'s keys is not compatible with `keys` + /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// # Safety + /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()` + pub unsafe fn try_new_unchecked( + data_type: DataType, + keys: PrimitiveArray, + values: Box, + ) -> Result { + check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + + Ok(Self { + data_type, + keys, + values, + }) + } + + /// Returns a new empty [`DictionaryArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let values = Self::try_get_child(&data_type).unwrap(); + let values = new_empty_array(values.clone()); + Self::try_new( + data_type, + PrimitiveArray::::new_empty(K::PRIMITIVE.into()), + values, + ) + .unwrap() + } + + /// Returns an [`DictionaryArray`] whose all elements are null + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + let values = Self::try_get_child(&data_type).unwrap(); + let values = new_null_array(values.clone(), 1); + Self::try_new( + data_type, + PrimitiveArray::::new_null(K::PRIMITIVE.into(), length), + values, + ) + .unwrap() + } + + /// Returns an iterator of [`Option>`]. + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn iter(&self) -> ZipValidity, DictionaryValuesIter, BitmapIter> { + ZipValidity::new_with_validity(DictionaryValuesIter::new(self), self.keys.validity()) + } + + /// Returns an iterator of [`Box`] + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn values_iter(&self) -> DictionaryValuesIter { + DictionaryValuesIter::new(self) + } + + /// Returns an iterator over the the values [`V::IterValue`]. + /// + /// # Panics + /// + /// Panics if the keys of this [`DictionaryArray`] have any null types. + /// If they do [`DictionaryArray::iter_typed`] should be called + pub fn values_iter_typed( + &self, + ) -> Result, Error> { + let keys = &self.keys; + assert_eq!(keys.null_count(), 0); + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + Ok(unsafe { DictionaryValuesIterTyped::new(keys, values) }) + } + + /// Returns an iterator over the the optional values of [`Option`]. + /// + /// # Panics + /// + /// This function panics if the `values` array + pub fn iter_typed( + &self, + ) -> Result, DictionaryValuesIterTyped, BitmapIter>, Error> + { + let keys = &self.keys; + let values = self.values.as_ref(); + let values = V::downcast_values(values)?; + let values_iter = unsafe { DictionaryValuesIterTyped::new(keys, values) }; + Ok(ZipValidity::new_with_validity(values_iter, self.validity())) + } + + /// Returns the [`DataType`] of this [`DictionaryArray`] + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns whether the values of this [`DictionaryArray`] are ordered + #[inline] + pub fn is_ordered(&self) -> bool { + match self.data_type.to_logical_type() { + DataType::Dictionary(_, _, is_ordered) => *is_ordered, + _ => unreachable!(), + } + } + + pub(crate) fn default_data_type(values_datatype: DataType) -> DataType { + DataType::Dictionary(K::KEY_TYPE, Box::new(values_datatype), false) + } + + /// Slices this [`DictionaryArray`]. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + self.keys.slice(offset, length); + } + + /// Slices this [`DictionaryArray`]. + /// # Safety + /// Safe iff `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.keys.slice_unchecked(offset, length); + } + + impl_sliced!(); + + /// Returns this [`DictionaryArray`] with a new validity. + /// # Panic + /// This function panics iff `validity.len() != self.len()`. + #[must_use] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of the keys of this [`DictionaryArray`]. + /// # Panics + /// This function panics iff `validity.len() != self.len()`. + pub fn set_validity(&mut self, validity: Option) { + self.keys.set_validity(validity); + } + + impl_into_array!(); + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.keys.len() + } + + /// The optional validity. Equivalent to `self.keys().validity()`. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.keys.validity() + } + + /// Returns the keys of the [`DictionaryArray`]. These keys can be used to fetch values + /// from `values`. + #[inline] + pub fn keys(&self) -> &PrimitiveArray { + &self.keys + } + + /// Returns an iterator of the keys' values of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_values_iter(&self) -> impl TrustedLen + Clone + '_ { + // safety - invariant of the struct + self.keys.values_iter().map(|x| unsafe { x.as_usize() }) + } + + /// Returns an iterator of the keys' of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_iter(&self) -> impl TrustedLen> + Clone + '_ { + // safety - invariant of the struct + self.keys.iter().map(|x| x.map(|x| unsafe { x.as_usize() })) + } + + /// Returns the keys' value of the [`DictionaryArray`] as `usize` + /// # Panics + /// This function panics iff `index >= self.len()` + #[inline] + pub fn key_value(&self, index: usize) -> usize { + // safety - invariant of the struct + unsafe { self.keys.values()[index].as_usize() } + } + + /// Returns the values of the [`DictionaryArray`]. + #[inline] + pub fn values(&self) -> &Box { + &self.values + } + + /// Returns the value of the [`DictionaryArray`] at position `i`. + /// # Implementation + /// This function will allocate a new [`Scalar`] and is usually not performant. + /// Consider calling `keys` and `values`, downcasting `values`, and iterating over that. + /// # Panic + /// This function panics iff `index >= self.len()` + #[inline] + pub fn value(&self, index: usize) -> Box { + // safety - invariant of this struct + let index = unsafe { self.keys.value(index).as_usize() }; + new_scalar(self.values.as_ref(), index) + } + + pub(crate) fn try_get_child(data_type: &DataType) -> Result<&DataType, Error> { + Ok(match data_type.to_logical_type() { + DataType::Dictionary(_, values, _) => values.as_ref(), + _ => { + return Err(Error::oos( + "Dictionaries must be initialized with DataType::Dictionary", + )) + }, + }) + } +} + +impl Array for DictionaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.keys.validity() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/mutable.rs b/crates/nano-arrow/src/array/dictionary/mutable.rs new file mode 100644 index 000000000000..dedd6ead0eaa --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/mutable.rs @@ -0,0 +1,241 @@ +use std::hash::Hash; +use std::sync::Arc; + +use super::value_map::ValueMap; +use super::{DictionaryArray, DictionaryKey}; +use crate::array::indexable::{AsIndexed, Indexable}; +use crate::array::primitive::MutablePrimitiveArray; +use crate::array::{Array, MutableArray, TryExtend, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +/// A mutable, strong-typed version of [`DictionaryArray`]. +/// +/// # Example +/// Building a UTF8 dictionary with `i32` keys. +/// ``` +/// # use arrow2::array::{MutableDictionaryArray, MutableUtf8Array, TryPush}; +/// # fn main() -> Result<(), Box> { +/// let mut array: MutableDictionaryArray> = MutableDictionaryArray::new(); +/// array.try_push(Some("A"))?; +/// array.try_push(Some("B"))?; +/// array.push_null(); +/// array.try_push(Some("C"))?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct MutableDictionaryArray { + data_type: DataType, + map: ValueMap, + // invariant: `max(keys) < map.values().len()` + keys: MutablePrimitiveArray, +} + +impl From> for DictionaryArray { + fn from(other: MutableDictionaryArray) -> Self { + // Safety - the invariant of this struct ensures that this is up-held + unsafe { + DictionaryArray::::try_new_unchecked( + other.data_type, + other.keys.into(), + other.map.into_values().as_box(), + ) + .unwrap() + } + } +} + +impl MutableDictionaryArray { + /// Creates an empty [`MutableDictionaryArray`]. + pub fn new() -> Self { + Self::try_empty(M::default()).unwrap() + } +} + +impl Default for MutableDictionaryArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableDictionaryArray { + /// Creates an empty [`MutableDictionaryArray`] from a given empty values array. + /// # Errors + /// Errors if the array is non-empty. + pub fn try_empty(values: M) -> Result { + Ok(Self::from_value_map(ValueMap::::try_empty(values)?)) + } + + /// Creates an empty [`MutableDictionaryArray`] preloaded with a given dictionary of values. + /// Indices associated with those values are automatically assigned based on the order of + /// the values. + /// # Errors + /// Errors if there's more values than the maximum value of `K` or if values are not unique. + pub fn from_values(values: M) -> Result + where + M: Indexable, + M::Type: Eq + Hash, + { + Ok(Self::from_value_map(ValueMap::::from_values(values)?)) + } + + fn from_value_map(value_map: ValueMap) -> Self { + let keys = MutablePrimitiveArray::::new(); + let data_type = + DataType::Dictionary(K::KEY_TYPE, Box::new(value_map.data_type().clone()), false); + Self { + data_type, + map: value_map, + keys, + } + } + + /// Creates an empty [`MutableDictionaryArray`] retaining the same dictionary as the current + /// mutable dictionary array, but with no data. This may come useful when serializing the + /// array into multiple chunks, where there's a requirement that the dictionary is the same. + /// No copying is performed, the value map is moved over to the new array. + pub fn into_empty(self) -> Self { + Self::from_value_map(self.map) + } + + /// Same as `into_empty` but clones the inner value map instead of taking full ownership. + pub fn to_empty(&self) -> Self + where + M: Clone, + { + Self::from_value_map(self.map.clone()) + } + + /// pushes a null value + pub fn push_null(&mut self) { + self.keys.push(None) + } + + /// returns a reference to the inner values. + pub fn values(&self) -> &M { + self.map.values() + } + + /// converts itself into [`Arc`] + pub fn into_arc(self) -> Arc { + let a: DictionaryArray = self.into(); + Arc::new(a) + } + + /// converts itself into [`Box`] + pub fn into_box(self) -> Box { + let a: DictionaryArray = self.into(); + Box::new(a) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.keys.reserve(additional); + } + + /// Shrinks the capacity of the [`MutableDictionaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.map.shrink_to_fit(); + self.keys.shrink_to_fit(); + } + + /// Returns the dictionary keys + pub fn keys(&self) -> &MutablePrimitiveArray { + &self.keys + } + + fn take_into(&mut self) -> DictionaryArray { + DictionaryArray::::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + self.map.take_into(), + ) + .unwrap() + } +} + +impl MutableArray for MutableDictionaryArray { + fn len(&self) -> usize { + self.keys.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.keys.validity() + } + + fn as_box(&mut self) -> Box { + Box::new(self.take_into()) + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.take_into()) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.keys.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl TryExtend> for MutableDictionaryArray +where + K: DictionaryKey, + M: MutableArray + Indexable + TryExtend>, + T: AsIndexed, + M::Type: Eq + Hash, +{ + fn try_extend>>(&mut self, iter: II) -> Result<()> { + for value in iter { + if let Some(value) = value { + let key = self + .map + .try_push_valid(value, |arr, v| arr.try_extend(std::iter::once(Some(v))))?; + self.keys.try_push(Some(key))?; + } else { + self.push_null(); + } + } + Ok(()) + } +} + +impl TryPush> for MutableDictionaryArray +where + K: DictionaryKey, + M: MutableArray + Indexable + TryPush>, + T: AsIndexed, + M::Type: Eq + Hash, +{ + fn try_push(&mut self, item: Option) -> Result<()> { + if let Some(value) = item { + let key = self + .map + .try_push_valid(value, |arr, v| arr.try_push(Some(v)))?; + self.keys.try_push(Some(key))?; + } else { + self.push_null(); + } + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/dictionary/typed_iterator.rs b/crates/nano-arrow/src/array/dictionary/typed_iterator.rs new file mode 100644 index 000000000000..5c528beb251b --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/typed_iterator.rs @@ -0,0 +1,110 @@ +use super::DictionaryKey; +use crate::array::{Array, PrimitiveArray, Utf8Array}; +use crate::error::{Error, Result}; +use crate::trusted_len::TrustedLen; +use crate::types::Offset; + +pub trait DictValue { + type IterValue<'this> + where + Self: 'this; + + /// # Safety + /// Will not do any bound checks but must check validity. + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_>; + + /// Take a [`dyn Array`] an try to downcast it to the type of `DictValue`. + fn downcast_values(array: &dyn Array) -> Result<&Self> + where + Self: Sized; +} + +impl DictValue for Utf8Array { + type IterValue<'a> = &'a str; + + unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> { + self.value_unchecked(item) + } + + fn downcast_values(array: &dyn Array) -> Result<&Self> + where + Self: Sized, + { + array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + Error::InvalidArgumentError("could not convert array to dictionary value".into()) + }) + .map(|arr| { + assert_eq!( + arr.null_count(), + 0, + "null values in values not supported in iteration" + ); + arr + }) + } +} + +/// Iterator of values of an `ListArray`. +pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> { + keys: &'a PrimitiveArray, + values: &'a V, + index: usize, + end: usize, +} + +impl<'a, K: DictionaryKey, V: DictValue> DictionaryValuesIterTyped<'a, K, V> { + pub(super) unsafe fn new(keys: &'a PrimitiveArray, values: &'a V) -> Self { + Self { + keys, + values, + index: 0, + end: keys.len(), + } + } +} + +impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryValuesIterTyped<'a, K, V> { + type Item = V::IterValue<'a>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + unsafe { + let key = self.keys.value_unchecked(old); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryValuesIterTyped<'a, K, V> {} + +impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator + for DictionaryValuesIterTyped<'a, K, V> +{ + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + unsafe { + let key = self.keys.value_unchecked(self.end); + let idx = key.as_usize(); + Some(self.values.get_unchecked(idx)) + } + } + } +} diff --git a/crates/nano-arrow/src/array/dictionary/value_map.rs b/crates/nano-arrow/src/array/dictionary/value_map.rs new file mode 100644 index 000000000000..5a12534766bd --- /dev/null +++ b/crates/nano-arrow/src/array/dictionary/value_map.rs @@ -0,0 +1,171 @@ +use std::borrow::Borrow; +use std::fmt::{self, Debug}; +use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; + +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; + +use super::DictionaryKey; +use crate::array::indexable::{AsIndexed, Indexable}; +use crate::array::{Array, MutableArray}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Hasher for pre-hashed values; similar to `hash_hasher` but with native endianness. +/// +/// We know that we'll only use it for `u64` values, so we can avoid endian conversion. +/// +/// Invariant: hash of a u64 value is always equal to itself. +#[derive(Copy, Clone, Default)] +pub struct PassthroughHasher(u64); + +impl Hasher for PassthroughHasher { + #[inline] + fn write_u64(&mut self, value: u64) { + self.0 = value; + } + + fn write(&mut self, _: &[u8]) { + unreachable!(); + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +#[derive(Clone)] +pub struct Hashed { + hash: u64, + key: K, +} + +#[inline] +fn ahash_hash(value: &T) -> u64 { + let mut hasher = BuildHasherDefault::::default().build_hasher(); + value.hash(&mut hasher); + hasher.finish() +} + +impl Hash for Hashed { + #[inline] + fn hash(&self, state: &mut H) { + self.hash.hash(state) + } +} + +#[derive(Clone)] +pub struct ValueMap { + pub values: M, + pub map: HashMap, (), BuildHasherDefault>, // NB: *only* use insert_hashed_nocheck() and no other hashmap API +} + +impl ValueMap { + pub fn try_empty(values: M) -> Result { + if !values.is_empty() { + return Err(Error::InvalidArgumentError( + "initializing value map with non-empty values array".into(), + )); + } + Ok(Self { + values, + map: HashMap::default(), + }) + } + + pub fn from_values(values: M) -> Result + where + M: Indexable, + M::Type: Eq + Hash, + { + let mut map = HashMap::, _, _>::with_capacity_and_hasher( + values.len(), + BuildHasherDefault::::default(), + ); + for index in 0..values.len() { + let key = K::try_from(index).map_err(|_| Error::Overflow)?; + // safety: we only iterate within bounds + let value = unsafe { values.value_unchecked_at(index) }; + let hash = ahash_hash(value.borrow()); + match map.raw_entry_mut().from_hash(hash, |item| { + // safety: invariant of the struct, it's always in bounds since we maintain it + let stored_value = unsafe { values.value_unchecked_at(item.key.as_usize()) }; + stored_value.borrow() == value.borrow() + }) { + RawEntryMut::Occupied(_) => { + return Err(Error::InvalidArgumentError( + "duplicate value in dictionary values array".into(), + )) + }, + RawEntryMut::Vacant(entry) => { + // NB: don't use .insert() here! + entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ()); + }, + } + } + Ok(Self { values, map }) + } + + pub fn data_type(&self) -> &DataType { + self.values.data_type() + } + + pub fn into_values(self) -> M { + self.values + } + + pub fn take_into(&mut self) -> Box { + let arr = self.values.as_box(); + self.map.clear(); + arr + } + + #[inline] + pub fn values(&self) -> &M { + &self.values + } + + /// Try to insert a value and return its index (it may or may not get inserted). + pub fn try_push_valid( + &mut self, + value: V, + mut push: impl FnMut(&mut M, V) -> Result<()>, + ) -> Result + where + M: Indexable, + V: AsIndexed, + M::Type: Eq + Hash, + { + let hash = ahash_hash(value.as_indexed()); + Ok( + match self.map.raw_entry_mut().from_hash(hash, |item| { + // safety: we've already checked (the inverse) when we pushed it, so it should be ok? + let index = unsafe { item.key.as_usize() }; + // safety: invariant of the struct, it's always in bounds since we maintain it + let stored_value = unsafe { self.values.value_unchecked_at(index) }; + stored_value.borrow() == value.as_indexed() + }) { + RawEntryMut::Occupied(entry) => entry.key().key, + RawEntryMut::Vacant(entry) => { + let index = self.values.len(); + let key = K::try_from(index).map_err(|_| Error::Overflow)?; + entry.insert_hashed_nocheck(hash, Hashed { hash, key }, ()); // NB: don't use .insert() here! + push(&mut self.values, value)?; + debug_assert_eq!(self.values.len(), index + 1); + key + }, + }, + ) + } + + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + } +} + +impl Debug for ValueMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.values.fmt(f) + } +} diff --git a/crates/nano-arrow/src/array/equal/binary.rs b/crates/nano-arrow/src/array/equal/binary.rs new file mode 100644 index 000000000000..bed8588efb59 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/binary.rs @@ -0,0 +1,6 @@ +use crate::array::BinaryArray; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &BinaryArray, rhs: &BinaryArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/boolean.rs b/crates/nano-arrow/src/array/equal/boolean.rs new file mode 100644 index 000000000000..d9c6af9b0276 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/boolean.rs @@ -0,0 +1,5 @@ +use crate::array::BooleanArray; + +pub(super) fn equal(lhs: &BooleanArray, rhs: &BooleanArray) -> bool { + lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/dictionary.rs b/crates/nano-arrow/src/array/equal/dictionary.rs new file mode 100644 index 000000000000..d65634095fb3 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/dictionary.rs @@ -0,0 +1,14 @@ +use crate::array::{DictionaryArray, DictionaryKey}; + +pub(super) fn equal(lhs: &DictionaryArray, rhs: &DictionaryArray) -> bool { + if !(lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()) { + return false; + }; + + // if x is not valid and y is but its child is not, the slots are equal. + lhs.iter().zip(rhs.iter()).all(|(x, y)| match (&x, &y) { + (None, Some(y)) => !y.is_valid(), + (Some(x), None) => !x.is_valid(), + _ => x == y, + }) +} diff --git a/crates/nano-arrow/src/array/equal/fixed_size_binary.rs b/crates/nano-arrow/src/array/equal/fixed_size_binary.rs new file mode 100644 index 000000000000..883d5739778b --- /dev/null +++ b/crates/nano-arrow/src/array/equal/fixed_size_binary.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, FixedSizeBinaryArray}; + +pub(super) fn equal(lhs: &FixedSizeBinaryArray, rhs: &FixedSizeBinaryArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/fixed_size_list.rs b/crates/nano-arrow/src/array/equal/fixed_size_list.rs new file mode 100644 index 000000000000..aaf77910013f --- /dev/null +++ b/crates/nano-arrow/src/array/equal/fixed_size_list.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, FixedSizeListArray}; + +pub(super) fn equal(lhs: &FixedSizeListArray, rhs: &FixedSizeListArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/list.rs b/crates/nano-arrow/src/array/equal/list.rs new file mode 100644 index 000000000000..26faa1598faf --- /dev/null +++ b/crates/nano-arrow/src/array/equal/list.rs @@ -0,0 +1,6 @@ +use crate::array::{Array, ListArray}; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &ListArray, rhs: &ListArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/map.rs b/crates/nano-arrow/src/array/equal/map.rs new file mode 100644 index 000000000000..e150fb4a4b41 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/map.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, MapArray}; + +pub(super) fn equal(lhs: &MapArray, rhs: &MapArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/mod.rs b/crates/nano-arrow/src/array/equal/mod.rs new file mode 100644 index 000000000000..91fd0c2f464f --- /dev/null +++ b/crates/nano-arrow/src/array/equal/mod.rs @@ -0,0 +1,287 @@ +use super::*; +use crate::offset::Offset; +use crate::types::NativeType; + +mod binary; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod union; +mod utf8; + +impl PartialEq for dyn Array + '_ { + fn eq(&self, that: &dyn Array) -> bool { + equal(self, that) + } +} + +impl PartialEq for std::sync::Arc { + fn eq(&self, that: &dyn Array) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for Box { + fn eq(&self, that: &dyn Array) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for NullArray { + fn eq(&self, other: &Self) -> bool { + null::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for NullArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq<&dyn Array> for PrimitiveArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &PrimitiveArray) -> bool { + equal(*self, other) + } +} + +impl PartialEq> for PrimitiveArray { + fn eq(&self, other: &Self) -> bool { + primitive::equal::(self, other) + } +} + +impl PartialEq for BooleanArray { + fn eq(&self, other: &Self) -> bool { + equal(self, other) + } +} + +impl PartialEq<&dyn Array> for BooleanArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for Utf8Array { + fn eq(&self, other: &Self) -> bool { + utf8::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for Utf8Array { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &Utf8Array) -> bool { + equal(*self, other) + } +} + +impl PartialEq> for BinaryArray { + fn eq(&self, other: &Self) -> bool { + binary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for BinaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for &dyn Array { + fn eq(&self, other: &BinaryArray) -> bool { + equal(*self, other) + } +} + +impl PartialEq for FixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + fixed_size_binary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for FixedSizeBinaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for ListArray { + fn eq(&self, other: &Self) -> bool { + list::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for ListArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for FixedSizeListArray { + fn eq(&self, other: &Self) -> bool { + fixed_size_list::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for FixedSizeListArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for StructArray { + fn eq(&self, other: &Self) -> bool { + struct_::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for StructArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq> for DictionaryArray { + fn eq(&self, other: &Self) -> bool { + dictionary::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for DictionaryArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for UnionArray { + fn eq(&self, other: &Self) -> bool { + union::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for UnionArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +impl PartialEq for MapArray { + fn eq(&self, other: &Self) -> bool { + map::equal(self, other) + } +} + +impl PartialEq<&dyn Array> for MapArray { + fn eq(&self, other: &&dyn Array) -> bool { + equal(self, *other) + } +} + +/// Logically compares two [`Array`]s. +/// Two arrays are logically equal if and only if: +/// * their data types are equal +/// * each of their items are equal +pub fn equal(lhs: &dyn Array, rhs: &dyn Array) -> bool { + if lhs.data_type() != rhs.data_type() { + return false; + } + + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { + Null => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + null::equal(lhs, rhs) + }, + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + boolean::equal(lhs, rhs) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::equal::<$T>(lhs, rhs) + }), + Utf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::equal::(lhs, rhs) + }, + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::equal::(lhs, rhs) + }, + Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::equal::(lhs, rhs) + }, + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::equal::(lhs, rhs) + }, + List => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + list::equal::(lhs, rhs) + }, + LargeList => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + list::equal::(lhs, rhs) + }, + Struct => { + let lhs = lhs.as_any().downcast_ref::().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + struct_::equal(lhs, rhs) + }, + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + dictionary::equal::<$T>(lhs, rhs) + }) + }, + FixedSizeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + fixed_size_binary::equal(lhs, rhs) + }, + FixedSizeList => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + fixed_size_list::equal(lhs, rhs) + }, + Union => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + union::equal(lhs, rhs) + }, + Map => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + map::equal(lhs, rhs) + }, + } +} diff --git a/crates/nano-arrow/src/array/equal/null.rs b/crates/nano-arrow/src/array/equal/null.rs new file mode 100644 index 000000000000..11ad6cc133bb --- /dev/null +++ b/crates/nano-arrow/src/array/equal/null.rs @@ -0,0 +1,6 @@ +use crate::array::{Array, NullArray}; + +#[inline] +pub(super) fn equal(lhs: &NullArray, rhs: &NullArray) -> bool { + lhs.len() == rhs.len() +} diff --git a/crates/nano-arrow/src/array/equal/primitive.rs b/crates/nano-arrow/src/array/equal/primitive.rs new file mode 100644 index 000000000000..dc90bb15da5e --- /dev/null +++ b/crates/nano-arrow/src/array/equal/primitive.rs @@ -0,0 +1,6 @@ +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +pub(super) fn equal(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/struct_.rs b/crates/nano-arrow/src/array/equal/struct_.rs new file mode 100644 index 000000000000..a1741e36368c --- /dev/null +++ b/crates/nano-arrow/src/array/equal/struct_.rs @@ -0,0 +1,54 @@ +use crate::array::{Array, StructArray}; + +pub(super) fn equal(lhs: &StructArray, rhs: &StructArray) -> bool { + lhs.data_type() == rhs.data_type() + && lhs.len() == rhs.len() + && match (lhs.validity(), rhs.validity()) { + (None, None) => lhs.values().iter().eq(rhs.values().iter()), + (Some(l_validity), Some(r_validity)) => lhs + .values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + l_validity.iter().zip(r_validity.iter()).enumerate().all( + |(i, (lhs_is_valid, rhs_is_valid))| { + if lhs_is_valid && rhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + lhs_is_valid == rhs_is_valid + } + }, + ) + }), + (Some(l_validity), None) => { + lhs.values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + l_validity.iter().enumerate().all(|(i, lhs_is_valid)| { + if lhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + // rhs is always valid => different + false + } + }) + }) + }, + (None, Some(r_validity)) => { + lhs.values() + .iter() + .zip(rhs.values().iter()) + .all(|(lhs, rhs)| { + r_validity.iter().enumerate().all(|(i, rhs_is_valid)| { + if rhs_is_valid { + lhs.sliced(i, 1) == rhs.sliced(i, 1) + } else { + // lhs is always valid => different + false + } + }) + }) + }, + } +} diff --git a/crates/nano-arrow/src/array/equal/union.rs b/crates/nano-arrow/src/array/equal/union.rs new file mode 100644 index 000000000000..51b9d960feac --- /dev/null +++ b/crates/nano-arrow/src/array/equal/union.rs @@ -0,0 +1,5 @@ +use crate::array::{Array, UnionArray}; + +pub(super) fn equal(lhs: &UnionArray, rhs: &UnionArray) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/equal/utf8.rs b/crates/nano-arrow/src/array/equal/utf8.rs new file mode 100644 index 000000000000..1327221ca331 --- /dev/null +++ b/crates/nano-arrow/src/array/equal/utf8.rs @@ -0,0 +1,6 @@ +use crate::array::Utf8Array; +use crate::offset::Offset; + +pub(super) fn equal(lhs: &Utf8Array, rhs: &Utf8Array) -> bool { + lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len() && lhs.iter().eq(rhs.iter()) +} diff --git a/crates/nano-arrow/src/array/ffi.rs b/crates/nano-arrow/src/array/ffi.rs new file mode 100644 index 000000000000..0e9629d4fdf0 --- /dev/null +++ b/crates/nano-arrow/src/array/ffi.rs @@ -0,0 +1,86 @@ +use crate::array::*; +use crate::datatypes::PhysicalType; +use crate::error::Result; +use crate::ffi; + +/// Trait describing how a struct presents itself to the +/// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +/// # Safety +/// Implementing this trait incorrect will lead to UB +pub(crate) unsafe trait ToFfi { + /// The pointers to the buffers. + fn buffers(&self) -> Vec>; + + /// The children + fn children(&self) -> Vec> { + vec![] + } + + /// The offset + fn offset(&self) -> Option; + + /// return a partial clone of self with an offset. + fn to_ffi_aligned(&self) -> Self; +} + +/// Trait describing how a struct imports into itself from the +/// [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub(crate) trait FromFfi: Sized { + /// Convert itself from FFI. + /// # Safety + /// This function is intrinsically `unsafe` as it requires the FFI to be made according + /// to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) + unsafe fn try_from_ffi(array: T) -> Result; +} + +macro_rules! ffi_dyn { + ($array:expr, $ty:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + ( + array.offset().unwrap(), + array.buffers(), + array.children(), + None, + ) + }}; +} + +type BuffersChildren = ( + usize, + Vec>, + Vec>, + Option>, +); + +pub fn offset_buffers_children_dictionary(array: &dyn Array) -> BuffersChildren { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => ffi_dyn!(array, NullArray), + Boolean => ffi_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + ffi_dyn!(array, PrimitiveArray<$T>) + }), + Binary => ffi_dyn!(array, BinaryArray), + LargeBinary => ffi_dyn!(array, BinaryArray), + FixedSizeBinary => ffi_dyn!(array, FixedSizeBinaryArray), + Utf8 => ffi_dyn!(array, Utf8Array::), + LargeUtf8 => ffi_dyn!(array, Utf8Array::), + List => ffi_dyn!(array, ListArray::), + LargeList => ffi_dyn!(array, ListArray::), + FixedSizeList => ffi_dyn!(array, FixedSizeListArray), + Struct => ffi_dyn!(array, StructArray), + Union => ffi_dyn!(array, UnionArray), + Map => ffi_dyn!(array, MapArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + ( + array.offset().unwrap(), + array.buffers(), + array.children(), + Some(array.values().clone()), + ) + }) + }, + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/data.rs b/crates/nano-arrow/src/array/fixed_size_binary/data.rs new file mode 100644 index 000000000000..6eb025d91623 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/data.rs @@ -0,0 +1,37 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, FixedSizeBinaryArray}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; + +impl Arrow2Arrow for FixedSizeBinaryArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.values.clone().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type: DataType = data.data_type().clone().into(); + let size = match data_type { + DataType::FixedSizeBinary(size) => size, + _ => unreachable!("must be FixedSizeBinary"), + }; + + let mut values: Buffer = data.buffers()[0].clone().into(); + values.slice(data.offset() * size, data.len() * size); + + Self { + size, + data_type, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/ffi.rs b/crates/nano-arrow/src/array/fixed_size_binary/ffi.rs new file mode 100644 index 000000000000..ee6e6a030df0 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/ffi.rs @@ -0,0 +1,56 @@ +use super::FixedSizeBinaryArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for FixedSizeBinaryArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset() / self.size; + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset() / self.size; + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + size: self.size, + data_type: self.data_type.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for FixedSizeBinaryArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/fmt.rs b/crates/nano-arrow/src/array/fixed_size_binary/fmt.rs new file mode 100644 index 000000000000..c5f9e2dd3293 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/fmt.rs @@ -0,0 +1,20 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::FixedSizeBinaryArray; + +pub fn write_value(array: &FixedSizeBinaryArray, index: usize, f: &mut W) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| write!(f, "{}", values[index]); + + write_vec(f, writer, None, values.len(), "None", false) +} + +impl Debug for FixedSizeBinaryArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + write!(f, "{:?}", self.data_type)?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/iterator.rs b/crates/nano-arrow/src/array/fixed_size_binary/iterator.rs new file mode 100644 index 000000000000..4c885c591943 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/iterator.rs @@ -0,0 +1,49 @@ +use super::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray}; +use crate::array::MutableArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +impl<'a> IntoIterator for &'a FixedSizeBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> FixedSizeBinaryArray { + /// constructs a new iterator + pub fn iter( + &'a self, + ) -> ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>> { + ZipValidity::new_with_validity(self.values_iter(), self.validity()) + } + + /// Returns iterator over the values of [`FixedSizeBinaryArray`] + pub fn values_iter(&'a self) -> std::slice::ChunksExact<'a, u8> { + self.values().chunks_exact(self.size) + } +} + +impl<'a> IntoIterator for &'a MutableFixedSizeBinaryArray { + type Item = Option<&'a [u8]>; + type IntoIter = ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MutableFixedSizeBinaryArray { + /// constructs a new iterator + pub fn iter( + &'a self, + ) -> ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>> { + ZipValidity::new(self.iter_values(), self.validity().map(|x| x.iter())) + } + + /// Returns iterator over the values of [`MutableFixedSizeBinaryArray`] + pub fn iter_values(&'a self) -> std::slice::ChunksExact<'a, u8> { + self.values().chunks_exact(self.size()) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/mod.rs b/crates/nano-arrow/src/array/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..f7f82c0a3ef0 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/mod.rs @@ -0,0 +1,286 @@ +use super::Array; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::Error; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +pub use mutable::*; + +/// The Arrow's equivalent to an immutable `Vec>`. +/// Cloning and slicing this struct is `O(1)`. +#[derive(Clone)] +pub struct FixedSizeBinaryArray { + size: usize, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. + data_type: DataType, + values: Buffer, + validity: Option, +} + +impl FixedSizeBinaryArray { + /// Creates a new [`FixedSizeBinaryArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Result { + let size = Self::maybe_get_size(&data_type)?; + + if values.len() % size != 0 { + return Err(Error::oos(format!( + "values (of len {}) must be a multiple of size ({}) in FixedSizeBinaryArray.", + values.len(), + size + ))); + } + let len = values.len() / size; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "validity mask length must be equal to the number of values divided by size", + )); + } + + Ok(Self { + size, + data_type, + values, + validity, + }) + } + + /// Creates a new [`FixedSizeBinaryArray`]. + /// # Panics + /// This function panics iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn new(data_type: DataType, values: Buffer, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Returns a new empty [`FixedSizeBinaryArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, Buffer::new(), None) + } + + /// Returns a new null [`FixedSizeBinaryArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let size = Self::maybe_get_size(&data_type).unwrap(); + Self::new( + data_type, + vec![0u8; length * size].into(), + Some(Bitmap::new_zeroed(length)), + ) + } +} + +// must use +impl FixedSizeBinaryArray { + /// Slices this [`FixedSizeBinaryArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`FixedSizeBinaryArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.values + .slice_unchecked(offset * self.size, length * self.size); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// accessors +impl FixedSizeBinaryArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the values allocated on this [`FixedSizeBinaryArray`]. + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns value at position `i`. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` as &str + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + // soundness: invariant of the function. + self.values + .get_unchecked(i * self.size..(i + 1) * self.size) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&[u8]> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns a new [`FixedSizeBinaryArray`] with a different logical type. + /// This is `O(1)`. + /// # Panics + /// Panics iff the data_type is not supported for the physical type. + #[inline] + pub fn to(self, data_type: DataType) -> Self { + match ( + data_type.to_logical_type(), + self.data_type().to_logical_type(), + ) { + (DataType::FixedSizeBinary(size_a), DataType::FixedSizeBinary(size_b)) + if size_a == size_b => {}, + _ => panic!("Wrong DataType"), + } + + Self { + size: self.size, + data_type, + values: self.values, + validity: self.validity, + } + } + + /// Returns the size + pub fn size(&self) -> usize { + self.size + } +} + +impl FixedSizeBinaryArray { + pub(crate) fn maybe_get_size(data_type: &DataType) -> Result { + match data_type.to_logical_type() { + DataType::FixedSizeBinary(size) => { + if *size == 0 { + return Err(Error::oos("FixedSizeBinaryArray expects a positive size")); + } + Ok(*size) + }, + _ => Err(Error::oos( + "FixedSizeBinaryArray expects DataType::FixedSizeBinary", + )), + } + } + + pub(crate) fn get_size(data_type: &DataType) -> usize { + Self::maybe_get_size(data_type).unwrap() + } +} + +impl Array for FixedSizeBinaryArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +impl FixedSizeBinaryArray { + /// Creates a [`FixedSizeBinaryArray`] from an fallible iterator of optional `[u8]`. + pub fn try_from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> Result { + MutableFixedSizeBinaryArray::try_from_iter(iter, size).map(|x| x.into()) + } + + /// Creates a [`FixedSizeBinaryArray`] from an iterator of optional `[u8]`. + pub fn from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> Self { + MutableFixedSizeBinaryArray::try_from_iter(iter, size) + .unwrap() + .into() + } + + /// Creates a [`FixedSizeBinaryArray`] from a slice of arrays of bytes + pub fn from_slice>(a: P) -> Self { + let values = a.as_ref().iter().flatten().copied().collect::>(); + Self::new(DataType::FixedSizeBinary(N), values.into(), None) + } + + /// Creates a new [`FixedSizeBinaryArray`] from a slice of optional `[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from]>>(slice: P) -> Self { + MutableFixedSizeBinaryArray::from(slice).into() + } +} + +pub trait FixedSizeBinaryValues { + fn values(&self) -> &[u8]; + fn size(&self) -> usize; +} + +impl FixedSizeBinaryValues for FixedSizeBinaryArray { + #[inline] + fn values(&self) -> &[u8] { + &self.values + } + + #[inline] + fn size(&self) -> usize { + self.size + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_binary/mutable.rs b/crates/nano-arrow/src/array/fixed_size_binary/mutable.rs new file mode 100644 index 000000000000..f5a68facf681 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_binary/mutable.rs @@ -0,0 +1,321 @@ +use std::sync::Arc; + +use super::{FixedSizeBinaryArray, FixedSizeBinaryValues}; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtendFromSelf}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Error; + +/// The Arrow's equivalent to a mutable `Vec>`. +/// Converting a [`MutableFixedSizeBinaryArray`] into a [`FixedSizeBinaryArray`] is `O(1)`. +/// # Implementation +/// This struct does not allocate a validity until one is required (i.e. push a null to it). +#[derive(Debug, Clone)] +pub struct MutableFixedSizeBinaryArray { + data_type: DataType, + size: usize, + values: Vec, + validity: Option, +} + +impl From for FixedSizeBinaryArray { + fn from(other: MutableFixedSizeBinaryArray) -> Self { + FixedSizeBinaryArray::new( + other.data_type, + other.values.into(), + other.validity.map(|x| x.into()), + ) + } +} + +impl MutableFixedSizeBinaryArray { + /// Creates a new [`MutableFixedSizeBinaryArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeBinary`] + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + data_type: DataType, + values: Vec, + validity: Option, + ) -> Result { + let size = FixedSizeBinaryArray::maybe_get_size(&data_type)?; + + if values.len() % size != 0 { + return Err(Error::oos(format!( + "values (of len {}) must be a multiple of size ({}) in FixedSizeBinaryArray.", + values.len(), + size + ))); + } + let len = values.len() / size; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "validity mask length must be equal to the number of values divided by size", + )); + } + + Ok(Self { + size, + data_type, + values, + validity, + }) + } + + /// Creates a new empty [`MutableFixedSizeBinaryArray`]. + pub fn new(size: usize) -> Self { + Self::with_capacity(size, 0) + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] with capacity for `capacity` entries. + pub fn with_capacity(size: usize, capacity: usize) -> Self { + Self::try_new( + DataType::FixedSizeBinary(size), + Vec::::with_capacity(capacity * size), + None, + ) + .unwrap() + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] from a slice of optional `[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from]>>(slice: P) -> Self { + let values = slice + .as_ref() + .iter() + .copied() + .flat_map(|x| x.unwrap_or([0; N])) + .collect::>(); + let validity = slice + .as_ref() + .iter() + .map(|x| x.is_some()) + .collect::(); + Self::try_new(DataType::FixedSizeBinary(N), values, validity.into()).unwrap() + } + + /// tries to push a new entry to [`MutableFixedSizeBinaryArray`]. + /// # Error + /// Errors iff the size of `value` is not equal to its own size. + #[inline] + pub fn try_push>(&mut self, value: Option

) -> Result<(), Error> { + match value { + Some(bytes) => { + let bytes = bytes.as_ref(); + if self.size != bytes.len() { + return Err(Error::InvalidArgumentError( + "FixedSizeBinaryArray requires every item to be of its length".to_string(), + )); + } + self.values.extend_from_slice(bytes); + + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.resize(self.values.len() + self.size, 0); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } + + /// pushes a new entry to [`MutableFixedSizeBinaryArray`]. + /// # Panics + /// Panics iff the size of `value` is not equal to its own size. + #[inline] + pub fn push>(&mut self, value: Option

) { + self.try_push(value).unwrap() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// Pop the last entry from [`MutableFixedSizeBinaryArray`]. + /// This function returns `None` iff this array is empty + pub fn pop(&mut self) -> Option> { + if self.values.len() < self.size { + return None; + } + let value_start = self.values.len() - self.size; + let value = self.values.split_off(value_start); + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + /// Creates a new [`MutableFixedSizeBinaryArray`] from an iterator of values. + /// # Errors + /// Errors iff the size of any of the `value` is not equal to its own size. + pub fn try_from_iter, I: IntoIterator>>( + iter: I, + size: usize, + ) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut primitive = Self::with_capacity(size, lower); + for item in iterator { + primitive.try_push(item)? + } + Ok(primitive) + } + + /// returns the (fixed) size of the [`MutableFixedSizeBinaryArray`]. + #[inline] + pub fn size(&self) -> usize { + self.size + } + + /// Returns the capacity of this array + pub fn capacity(&self) -> usize { + self.values.capacity() / self.size + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Returns the element at index `i` as `&[u8]` + #[inline] + pub fn value(&self, i: usize) -> &[u8] { + &self.values[i * self.size..(i + 1) * self.size] + } + + /// Returns the element at index `i` as `&[u8]` + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { + std::slice::from_raw_parts(self.values.as_ptr().add(i * self.size), self.size) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional * self.size); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableFixedSizeBinaryArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +/// Accessors +impl MutableFixedSizeBinaryArray { + /// Returns its values. + pub fn values(&self) -> &Vec { + &self.values + } + + /// Returns a mutable slice of values. + pub fn values_mut_slice(&mut self) -> &mut [u8] { + self.values.as_mut_slice() + } +} + +impl MutableArray for MutableFixedSizeBinaryArray { + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(self.size), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(self.size), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push::<&[u8]>(None); + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl FixedSizeBinaryValues for MutableFixedSizeBinaryArray { + #[inline] + fn values(&self) -> &[u8] { + &self.values + } + + #[inline] + fn size(&self) -> usize { + self.size + } +} + +impl PartialEq for MutableFixedSizeBinaryArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableFixedSizeBinaryArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + self.values.extend_from_slice(slice); + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/data.rs b/crates/nano-arrow/src/array/fixed_size_list/data.rs new file mode 100644 index 000000000000..966504bf3b6c --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/data.rs @@ -0,0 +1,36 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, FixedSizeListArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::DataType; + +impl Arrow2Arrow for FixedSizeListArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(vec![to_data(self.values.as_ref())]); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type: DataType = data.data_type().clone().into(); + let size = match data_type { + DataType::FixedSizeList(_, size) => size, + _ => unreachable!("must be FixedSizeList type"), + }; + + let mut values = from_data(&data.child_data()[0]); + values.slice(data.offset() * size, data.len() * size); + + Self { + size, + data_type, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/ffi.rs b/crates/nano-arrow/src/array/fixed_size_list/ffi.rs new file mode 100644 index 000000000000..237001809598 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/ffi.rs @@ -0,0 +1,39 @@ +use super::FixedSizeListArray; +use crate::array::ffi::{FromFfi, ToFfi}; +use crate::array::Array; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for FixedSizeListArray { + fn buffers(&self) -> Vec> { + vec![self.validity.as_ref().map(|x| x.as_ptr())] + } + + fn children(&self) -> Vec> { + vec![self.values.clone()] + } + + fn offset(&self) -> Option { + Some( + self.validity + .as_ref() + .map(|bitmap| bitmap.offset()) + .unwrap_or_default(), + ) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for FixedSizeListArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let child = unsafe { array.child(0)? }; + let values = ffi::try_from(child)?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/fmt.rs b/crates/nano-arrow/src/array/fixed_size_list/fmt.rs new file mode 100644 index 000000000000..ee7d86115a14 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::FixedSizeListArray; + +pub fn write_value( + array: &FixedSizeListArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for FixedSizeListArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "FixedSizeListArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/iterator.rs b/crates/nano-arrow/src/array/fixed_size_list/iterator.rs new file mode 100644 index 000000000000..123658005adc --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/iterator.rs @@ -0,0 +1,43 @@ +use super::FixedSizeListArray; +use crate::array::{Array, ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; + +unsafe impl<'a> ArrayAccessor<'a> for FixedSizeListArray { + type Item = Box; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of a [`FixedSizeListArray`]. +pub type FixedSizeListValuesIter<'a> = ArrayValuesIter<'a, FixedSizeListArray>; + +type ZipIter<'a> = ZipValidity, FixedSizeListValuesIter<'a>, BitmapIter<'a>>; + +impl<'a> IntoIterator for &'a FixedSizeListArray { + type Item = Option>; + type IntoIter = ZipIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> FixedSizeListArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a> { + ZipValidity::new_with_validity(FixedSizeListValuesIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> FixedSizeListValuesIter<'a> { + FixedSizeListValuesIter::new(self) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/mod.rs b/crates/nano-arrow/src/array/fixed_size_list/mod.rs new file mode 100644 index 000000000000..25ee0db14874 --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/mod.rs @@ -0,0 +1,220 @@ +use super::{new_empty_array, new_null_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod mutable; +pub use mutable::*; + +/// The Arrow's equivalent to an immutable `Vec>` where `T` is an Arrow type. +/// Cloning and slicing this struct is `O(1)`. +#[derive(Clone)] +pub struct FixedSizeListArray { + size: usize, // this is redundant with `data_type`, but useful to not have to deconstruct the data_type. + data_type: DataType, + values: Box, + validity: Option, +} + +impl FixedSizeListArray { + /// Creates a new [`FixedSizeListArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::FixedSizeList`] + /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// * The length of `values` is not a multiple of `size` in `data_type` + /// * the validity's length is not equal to `values.len() / size`. + pub fn try_new( + data_type: DataType, + values: Box, + validity: Option, + ) -> Result { + let (child, size) = Self::try_child_and_size(&data_type)?; + + let child_data_type = &child.data_type; + let values_data_type = values.data_type(); + if child_data_type != values_data_type { + return Err(Error::oos( + format!("FixedSizeListArray's child's DataType must match. However, the expected DataType is {child_data_type:?} while it got {values_data_type:?}."), + )); + } + + if values.len() % size != 0 { + return Err(Error::oos(format!( + "values (of len {}) must be a multiple of size ({}) in FixedSizeListArray.", + values.len(), + size + ))); + } + let len = values.len() / size; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "validity mask length must be equal to the number of values divided by size", + )); + } + + Ok(Self { + size, + data_type, + values, + validity, + }) + } + + /// Alias to `Self::try_new(...).unwrap()` + pub fn new(data_type: DataType, values: Box, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. + pub const fn size(&self) -> usize { + self.size + } + + /// Returns a new empty [`FixedSizeListArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let values = new_empty_array(Self::get_child_and_size(&data_type).0.data_type().clone()); + Self::new(data_type, values, None) + } + + /// Returns a new null [`FixedSizeListArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let (field, size) = Self::get_child_and_size(&data_type); + + let values = new_null_array(field.data_type().clone(), length * size); + Self::new(data_type, values, Some(Bitmap::new_zeroed(length))) + } +} + +// must use +impl FixedSizeListArray { + /// Slices this [`FixedSizeListArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Panics + /// panics iff `offset + length > self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`FixedSizeListArray`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.values + .slice_unchecked(offset * self.size, length * self.size); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// accessors +impl FixedSizeListArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the inner array. + pub fn values(&self) -> &Box { + &self.values + } + + /// Returns the `Vec` at position `i`. + /// # Panic: + /// panics iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> Box { + self.values.sliced(i * self.size, self.size) + } + + /// Returns the `Vec` at position `i`. + /// # Safety + /// Caller must ensure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + self.values.sliced_unchecked(i * self.size, self.size) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } +} + +impl FixedSizeListArray { + pub(crate) fn try_child_and_size(data_type: &DataType) -> Result<(&Field, usize), Error> { + match data_type.to_logical_type() { + DataType::FixedSizeList(child, size) => { + if *size == 0 { + return Err(Error::oos("FixedSizeBinaryArray expects a positive size")); + } + Ok((child.as_ref(), *size)) + }, + _ => Err(Error::oos( + "FixedSizeListArray expects DataType::FixedSizeList", + )), + } + } + + pub(crate) fn get_child_and_size(data_type: &DataType) -> (&Field, usize) { + Self::try_child_and_size(data_type).unwrap() + } + + /// Returns a [`DataType`] consistent with [`FixedSizeListArray`]. + pub fn default_datatype(data_type: DataType, size: usize) -> DataType { + let field = Box::new(Field::new("item", data_type, true)); + DataType::FixedSizeList(field, size) + } +} + +impl Array for FixedSizeListArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/fixed_size_list/mutable.rs b/crates/nano-arrow/src/array/fixed_size_list/mutable.rs new file mode 100644 index 000000000000..bef25a1cbf1f --- /dev/null +++ b/crates/nano-arrow/src/array/fixed_size_list/mutable.rs @@ -0,0 +1,256 @@ +use std::sync::Arc; + +use super::FixedSizeListArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, PushUnchecked, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; + +/// The mutable version of [`FixedSizeListArray`]. +#[derive(Debug, Clone)] +pub struct MutableFixedSizeListArray { + data_type: DataType, + size: usize, + values: M, + validity: Option, +} + +impl From> for FixedSizeListArray { + fn from(mut other: MutableFixedSizeListArray) -> Self { + FixedSizeListArray::new( + other.data_type, + other.values.as_box(), + other.validity.map(|x| x.into()), + ) + } +} + +impl MutableFixedSizeListArray { + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. + pub fn new(values: M, size: usize) -> Self { + let data_type = FixedSizeListArray::default_datatype(values.data_type().clone(), size); + Self::new_from(values, data_type, size) + } + + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`] and size. + pub fn new_with_field(values: M, name: &str, nullable: bool, size: usize) -> Self { + let data_type = DataType::FixedSizeList( + Box::new(Field::new(name, values.data_type().clone(), nullable)), + size, + ); + Self::new_from(values, data_type, size) + } + + /// Creates a new [`MutableFixedSizeListArray`] from a [`MutableArray`], [`DataType`] and size. + pub fn new_from(values: M, data_type: DataType, size: usize) -> Self { + assert_eq!(values.len(), 0); + match data_type { + DataType::FixedSizeList(..) => (), + _ => panic!("data type must be FixedSizeList (got {data_type:?})"), + }; + Self { + size, + data_type, + values, + validity: None, + } + } + + /// Returns the size (number of elements per slot) of this [`FixedSizeListArray`]. + pub const fn size(&self) -> usize { + self.size + } + + /// The length of this array + pub fn len(&self) -> usize { + self.values.len() / self.size + } + + /// The inner values + pub fn values(&self) -> &M { + &self.values + } + + /// The values as a mutable reference + pub fn mut_values(&mut self) -> &mut M { + &mut self.values + } + + fn init_validity(&mut self) { + let len = self.values.len() / self.size; + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn try_push_valid(&mut self) -> Result<()> { + if self.values.len() % self.size != 0 { + return Err(Error::Overflow); + }; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn push_valid(&mut self) { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + } + + #[inline] + fn push_null(&mut self) { + (0..self.size).for_each(|_| self.values.push_null()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableFixedSizeListArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableFixedSizeListArray { + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + FixedSizeListArray::new( + self.data_type.clone(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + FixedSizeListArray::new( + self.data_type.clone(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + (0..self.size).for_each(|_| { + self.values.push_null(); + }); + if let Some(validity) = &mut self.validity { + validity.push(false) + } else { + self.init_validity() + } + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl TryExtend> for MutableFixedSizeListArray +where + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_extend>>(&mut self, iter: II) -> Result<()> { + for items in iter { + self.try_push(items)?; + } + Ok(()) + } +} + +impl TryPush> for MutableFixedSizeListArray +where + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_push(&mut self, item: Option) -> Result<()> { + if let Some(items) = item { + self.values.try_extend(items)?; + self.try_push_valid()?; + } else { + self.push_null(); + } + Ok(()) + } +} + +impl PushUnchecked> for MutableFixedSizeListArray +where + M: MutableArray + Extend>, + I: IntoIterator>, +{ + /// # Safety + /// The caller must ensure that the `I` iterates exactly over `size` + /// items, where `size` is the fixed size width. + #[inline] + unsafe fn push_unchecked(&mut self, item: Option) { + if let Some(items) = item { + self.values.extend(items); + self.push_valid(); + } else { + self.push_null(); + } + } +} + +impl TryExtendFromSelf for MutableFixedSizeListArray +where + M: MutableArray + TryExtendFromSelf, +{ + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/nano-arrow/src/array/fmt.rs b/crates/nano-arrow/src/array/fmt.rs new file mode 100644 index 000000000000..ebc6937714cc --- /dev/null +++ b/crates/nano-arrow/src/array/fmt.rs @@ -0,0 +1,181 @@ +use std::fmt::{Result, Write}; + +use super::Array; +use crate::bitmap::Bitmap; + +/// Returns a function that writes the value of the element of `array` +/// at position `index` to a [`Write`], +/// writing `null` in the null slots. +pub fn get_value_display<'a, F: Write + 'a>( + array: &'a dyn Array, + null: &'static str, +) -> Box Result + 'a> { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => Box::new(move |f, _| write!(f, "{null}")), + Boolean => Box::new(|f, index| { + super::boolean::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, f) + }), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let writer = super::primitive::fmt::get_write_value::<$T, _>( + array.as_any().downcast_ref().unwrap(), + ); + Box::new(move |f, index| writer(f, index)) + }), + Binary => Box::new(|f, index| { + super::binary::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + FixedSizeBinary => Box::new(|f, index| { + super::fixed_size_binary::fmt::write_value( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + LargeBinary => Box::new(|f, index| { + super::binary::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + Utf8 => Box::new(|f, index| { + super::utf8::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + LargeUtf8 => Box::new(|f, index| { + super::utf8::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + f, + ) + }), + List => Box::new(move |f, index| { + super::list::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + FixedSizeList => Box::new(move |f, index| { + super::fixed_size_list::fmt::write_value( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + LargeList => Box::new(move |f, index| { + super::list::fmt::write_value::( + array.as_any().downcast_ref().unwrap(), + index, + null, + f, + ) + }), + Struct => Box::new(move |f, index| { + super::struct_::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Union => Box::new(move |f, index| { + super::union::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Map => Box::new(move |f, index| { + super::map::fmt::write_value(array.as_any().downcast_ref().unwrap(), index, null, f) + }), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + Box::new(move |f, index| { + super::dictionary::fmt::write_value::<$T,_>(array.as_any().downcast_ref().unwrap(), index, null, f) + }) + }), + } +} + +/// Returns a function that writes the element of `array` +/// at position `index` to a [`Write`], writing `null` to the null slots. +pub fn get_display<'a, F: Write + 'a>( + array: &'a dyn Array, + null: &'static str, +) -> Box Result + 'a> { + let value_display = get_value_display(array, null); + Box::new(move |f, row| { + if array.is_null(row) { + f.write_str(null) + } else { + value_display(f, row) + } + }) +} + +pub fn write_vec( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + f.write_char('[')?; + write_list(f, d, validity, len, null, new_lines)?; + f.write_char(']')?; + Ok(()) +} + +fn write_list( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + for index in 0..len { + if index != 0 { + f.write_char(',')?; + f.write_char(if new_lines { '\n' } else { ' ' })?; + } + if let Some(val) = validity { + if val.get_bit(index) { + d(f, index) + } else { + write!(f, "{null}") + } + } else { + d(f, index) + }?; + } + Ok(()) +} + +pub fn write_map( + f: &mut F, + d: D, + validity: Option<&Bitmap>, + len: usize, + null: &'static str, + new_lines: bool, +) -> Result +where + D: Fn(&mut F, usize) -> Result, + F: Write, +{ + f.write_char('{')?; + write_list(f, d, validity, len, null, new_lines)?; + f.write_char('}')?; + Ok(()) +} diff --git a/crates/nano-arrow/src/array/growable/binary.rs b/crates/nano-arrow/src/array/growable/binary.rs new file mode 100644 index 000000000000..ca095f351446 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/binary.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, extend_offset_values, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, BinaryArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::offset::{Offset, Offsets}; + +/// Concrete [`Growable`] for the [`BinaryArray`]. +pub struct GrowableBinary<'a, O: Offset> { + arrays: Vec<&'a BinaryArray>, + data_type: DataType, + validity: MutableBitmap, + values: Vec, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a, O: Offset> GrowableBinary<'a, O> { + /// Creates a new [`GrowableBinary`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a BinaryArray>, mut use_validity: bool, capacity: usize) -> Self { + let data_type = arrays[0].data_type().clone(); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + Self { + arrays, + data_type, + values: Vec::with_capacity(0), + offsets: Offsets::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> BinaryArray { + let data_type = self.data_type.clone(); + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = std::mem::take(&mut self.values); + + BinaryArray::::new(data_type, offsets.into(), values.into(), validity.into()) + } +} + +impl<'a, O: Offset> Growable<'a> for GrowableBinary<'a, O> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let offsets = array.offsets(); + let values = array.values(); + + self.offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + // values + extend_offset_values::(&mut self.values, offsets.buffer(), values, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + self.to().arced() + } + + fn as_box(&mut self) -> Box { + self.to().boxed() + } +} + +impl<'a, O: Offset> From> for BinaryArray { + fn from(val: GrowableBinary<'a, O>) -> Self { + BinaryArray::::new( + val.data_type, + val.offsets.into(), + val.values.into(), + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/boolean.rs b/crates/nano-arrow/src/array/growable/boolean.rs new file mode 100644 index 000000000000..f69d66f1d696 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/boolean.rs @@ -0,0 +1,91 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, BooleanArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`BooleanArray`]. +pub struct GrowableBoolean<'a> { + arrays: Vec<&'a BooleanArray>, + data_type: DataType, + validity: MutableBitmap, + values: MutableBitmap, + extend_null_bits: Vec>, +} + +impl<'a> GrowableBoolean<'a> { + /// Creates a new [`GrowableBoolean`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a BooleanArray>, mut use_validity: bool, capacity: usize) -> Self { + let data_type = arrays[0].data_type().clone(); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + Self { + arrays, + data_type, + values: MutableBitmap::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> BooleanArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + + BooleanArray::new(self.data_type.clone(), values.into(), validity.into()) + } +} + +impl<'a> Growable<'a> for GrowableBoolean<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let values = array.values(); + + let (slice, offset, _) = values.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { + self.values + .extend_from_slice_unchecked(slice, start + offset, len); + } + } + + fn extend_validity(&mut self, additional: usize) { + self.values.extend_constant(additional, false); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for BooleanArray { + fn from(val: GrowableBoolean<'a>) -> Self { + BooleanArray::new(val.data_type, val.values.into(), val.validity.into()) + } +} diff --git a/crates/nano-arrow/src/array/growable/dictionary.rs b/crates/nano-arrow/src/array/growable/dictionary.rs new file mode 100644 index 000000000000..fa85cdad6f8e --- /dev/null +++ b/crates/nano-arrow/src/array/growable/dictionary.rs @@ -0,0 +1,157 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`DictionaryArray`]. +/// # Implementation +/// This growable does not perform collision checks and instead concatenates +/// the values of each [`DictionaryArray`] one after the other. +pub struct GrowableDictionary<'a, K: DictionaryKey> { + data_type: DataType, + keys_values: Vec<&'a [K]>, + key_values: Vec, + key_validity: MutableBitmap, + offsets: Vec, + values: Box, + extend_null_bits: Vec>, +} + +fn concatenate_values( + arrays_keys: &[&PrimitiveArray], + arrays_values: &[&dyn Array], + capacity: usize, +) -> (Box, Vec) { + let mut mutable = make_growable(arrays_values, false, capacity); + let mut offsets = Vec::with_capacity(arrays_keys.len() + 1); + offsets.push(0); + for (i, values) in arrays_values.iter().enumerate() { + mutable.extend(i, 0, values.len()); + offsets.push(offsets[i] + values.len()); + } + (mutable.as_box(), offsets) +} + +impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { + /// Creates a new [`GrowableDictionary`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: &[&'a DictionaryArray], mut use_validity: bool, capacity: usize) -> Self { + let data_type = arrays[0].data_type().clone(); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let arrays_keys = arrays.iter().map(|array| array.keys()).collect::>(); + let keys_values = arrays_keys + .iter() + .map(|array| array.values().as_slice()) + .collect::>(); + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(array.keys(), use_validity)) + .collect(); + + let arrays_values = arrays + .iter() + .map(|array| array.values().as_ref()) + .collect::>(); + + let (values, offsets) = concatenate_values(&arrays_keys, &arrays_values, capacity); + + Self { + data_type, + offsets, + values, + keys_values, + key_values: Vec::with_capacity(capacity), + key_validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + #[inline] + fn to(&mut self) -> DictionaryArray { + let validity = std::mem::take(&mut self.key_validity); + let key_values = std::mem::take(&mut self.key_values); + + #[cfg(debug_assertions)] + { + crate::array::specification::check_indexes(&key_values, self.values.len()).unwrap(); + } + let keys = + PrimitiveArray::::new(T::PRIMITIVE.into(), key_values.into(), validity.into()); + + // Safety - the invariant of this struct ensures that this is up-held + unsafe { + DictionaryArray::::try_new_unchecked( + self.data_type.clone(), + keys, + self.values.clone(), + ) + .unwrap() + } + } +} + +impl<'a, T: DictionaryKey> Growable<'a> for GrowableDictionary<'a, T> { + #[inline] + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.key_validity, start, len); + + let values = &self.keys_values[index][start..start + len]; + let offset = self.offsets[index]; + self.key_values.extend( + values + .iter() + // `.unwrap_or(0)` because this operation does not check for null values, which may contain any key. + .map(|x| { + let x: usize = offset + (*x).try_into().unwrap_or(0); + let x: T = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }), + ); + } + + #[inline] + fn len(&self) -> usize { + self.key_values.len() + } + + #[inline] + fn extend_validity(&mut self, additional: usize) { + self.key_values + .resize(self.key_values.len() + additional, T::default()); + self.key_validity.extend_constant(additional, false); + } + + #[inline] + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + #[inline] + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, T: DictionaryKey> From> for DictionaryArray { + #[inline] + fn from(mut val: GrowableDictionary<'a, T>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/fixed_binary.rs b/crates/nano-arrow/src/array/growable/fixed_binary.rs new file mode 100644 index 000000000000..bc6b307f97f9 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/fixed_binary.rs @@ -0,0 +1,98 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, FixedSizeBinaryArray}; +use crate::bitmap::MutableBitmap; + +/// Concrete [`Growable`] for the [`FixedSizeBinaryArray`]. +pub struct GrowableFixedSizeBinary<'a> { + arrays: Vec<&'a FixedSizeBinaryArray>, + validity: MutableBitmap, + values: Vec, + extend_null_bits: Vec>, + size: usize, // just a cache +} + +impl<'a> GrowableFixedSizeBinary<'a> { + /// Creates a new [`GrowableFixedSizeBinary`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new( + arrays: Vec<&'a FixedSizeBinaryArray>, + mut use_validity: bool, + capacity: usize, + ) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let size = FixedSizeBinaryArray::get_size(arrays[0].data_type()); + Self { + arrays, + values: Vec::with_capacity(0), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + size, + } + } + + fn to(&mut self) -> FixedSizeBinaryArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + + FixedSizeBinaryArray::new( + self.arrays[0].data_type().clone(), + values.into(), + validity.into(), + ) + } +} + +impl<'a> Growable<'a> for GrowableFixedSizeBinary<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let values = array.values(); + + self.values + .extend_from_slice(&values[start * self.size..start * self.size + len * self.size]); + } + + fn extend_validity(&mut self, additional: usize) { + self.values + .extend_from_slice(&vec![0; self.size * additional]); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for FixedSizeBinaryArray { + fn from(val: GrowableFixedSizeBinary<'a>) -> Self { + FixedSizeBinaryArray::new( + val.arrays[0].data_type().clone(), + val.values.into(), + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/fixed_size_list.rs b/crates/nano-arrow/src/array/growable/fixed_size_list.rs new file mode 100644 index 000000000000..cacad36bb4a7 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/fixed_size_list.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, FixedSizeListArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`FixedSizeListArray`]. +pub struct GrowableFixedSizeList<'a> { + arrays: Vec<&'a FixedSizeListArray>, + validity: MutableBitmap, + values: Box + 'a>, + extend_null_bits: Vec>, + size: usize, +} + +impl<'a> GrowableFixedSizeList<'a> { + /// Creates a new [`GrowableFixedSizeList`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new( + arrays: Vec<&'a FixedSizeListArray>, + mut use_validity: bool, + capacity: usize, + ) -> Self { + assert!(!arrays.is_empty()); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let size = + if let DataType::FixedSizeList(_, size) = &arrays[0].data_type().to_logical_type() { + *size + } else { + unreachable!("`GrowableFixedSizeList` expects `DataType::FixedSizeList`") + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let inner = arrays + .iter() + .map(|array| array.values().as_ref()) + .collect::>(); + let values = make_growable(&inner, use_validity, 0); + + Self { + arrays, + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + size, + } + } + + fn to(&mut self) -> FixedSizeListArray { + let validity = std::mem::take(&mut self.validity); + let values = self.values.as_box(); + + FixedSizeListArray::new(self.arrays[0].data_type().clone(), values, validity.into()) + } +} + +impl<'a> Growable<'a> for GrowableFixedSizeList<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + self.values + .extend(index, start * self.size, len * self.size); + } + + fn extend_validity(&mut self, additional: usize) { + self.values.extend_validity(additional * self.size); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for FixedSizeListArray { + fn from(val: GrowableFixedSizeList<'a>) -> Self { + let mut values = val.values; + let values = values.as_box(); + + Self::new( + val.arrays[0].data_type().clone(), + values, + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/list.rs b/crates/nano-arrow/src/array/growable/list.rs new file mode 100644 index 000000000000..9fdf9eb047bf --- /dev/null +++ b/crates/nano-arrow/src/array/growable/list.rs @@ -0,0 +1,112 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, ListArray}; +use crate::bitmap::MutableBitmap; +use crate::offset::{Offset, Offsets}; + +fn extend_offset_values( + growable: &mut GrowableList<'_, O>, + index: usize, + start: usize, + len: usize, +) { + let array = growable.arrays[index]; + let offsets = array.offsets(); + + growable + .offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + let end = offsets.buffer()[start + len].to_usize(); + let start = offsets.buffer()[start].to_usize(); + let len = end - start; + growable.values.extend(index, start, len); +} + +/// Concrete [`Growable`] for the [`ListArray`]. +pub struct GrowableList<'a, O: Offset> { + arrays: Vec<&'a ListArray>, + validity: MutableBitmap, + values: Box + 'a>, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a, O: Offset> GrowableList<'a, O> { + /// Creates a new [`GrowableList`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a ListArray>, mut use_validity: bool, capacity: usize) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let inner = arrays + .iter() + .map(|array| array.values().as_ref()) + .collect::>(); + let values = make_growable(&inner, use_validity, 0); + + Self { + arrays, + offsets: Offsets::with_capacity(capacity), + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> ListArray { + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = self.values.as_box(); + + ListArray::::new( + self.arrays[0].data_type().clone(), + offsets.into(), + values, + validity.into(), + ) + } +} + +impl<'a, O: Offset> Growable<'a> for GrowableList<'a, O> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + extend_offset_values::(self, index, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, O: Offset> From> for ListArray { + fn from(mut val: GrowableList<'a, O>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/map.rs b/crates/nano-arrow/src/array/growable/map.rs new file mode 100644 index 000000000000..62f9d4c5c53a --- /dev/null +++ b/crates/nano-arrow/src/array/growable/map.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, MapArray}; +use crate::bitmap::MutableBitmap; +use crate::offset::Offsets; + +fn extend_offset_values(growable: &mut GrowableMap<'_>, index: usize, start: usize, len: usize) { + let array = growable.arrays[index]; + let offsets = array.offsets(); + + growable + .offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + let end = offsets.buffer()[start + len] as usize; + let start = offsets.buffer()[start] as usize; + let len = end - start; + growable.values.extend(index, start, len); +} + +/// Concrete [`Growable`] for the [`MapArray`]. +pub struct GrowableMap<'a> { + arrays: Vec<&'a MapArray>, + validity: MutableBitmap, + values: Box + 'a>, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a> GrowableMap<'a> { + /// Creates a new [`GrowableMap`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a MapArray>, mut use_validity: bool, capacity: usize) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let inner = arrays + .iter() + .map(|array| array.field().as_ref()) + .collect::>(); + let values = make_growable(&inner, use_validity, 0); + + Self { + arrays, + offsets: Offsets::with_capacity(capacity), + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> MapArray { + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = self.values.as_box(); + + MapArray::new( + self.arrays[0].data_type().clone(), + offsets.into(), + values, + validity.into(), + ) + } +} + +impl<'a> Growable<'a> for GrowableMap<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + extend_offset_values(self, index, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for MapArray { + fn from(mut val: GrowableMap<'a>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/mod.rs b/crates/nano-arrow/src/array/growable/mod.rs new file mode 100644 index 000000000000..a3fe4b739451 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/mod.rs @@ -0,0 +1,149 @@ +//! Contains the trait [`Growable`] and corresponding concreate implementations, one per concrete array, +//! that offer the ability to create a new [`Array`] out of slices of existing [`Array`]s. + +use std::sync::Arc; + +use crate::array::*; +use crate::datatypes::*; + +mod binary; +pub use binary::GrowableBinary; +mod union; +pub use union::GrowableUnion; +mod boolean; +pub use boolean::GrowableBoolean; +mod fixed_binary; +pub use fixed_binary::GrowableFixedSizeBinary; +mod null; +pub use null::GrowableNull; +mod primitive; +pub use primitive::GrowablePrimitive; +mod list; +pub use list::GrowableList; +mod map; +pub use map::GrowableMap; +mod structure; +pub use structure::GrowableStruct; +mod fixed_size_list; +pub use fixed_size_list::GrowableFixedSizeList; +mod utf8; +pub use utf8::GrowableUtf8; +mod dictionary; +pub use dictionary::GrowableDictionary; + +mod utils; + +/// Describes a struct that can be extended from slices of other pre-existing [`Array`]s. +/// This is used in operations where a new array is built out of other arrays, such +/// as filter and concatenation. +pub trait Growable<'a> { + /// Extends this [`Growable`] with elements from the bounded [`Array`] at index `index` from + /// a slice starting at `start` and length `len`. + /// # Panic + /// This function panics if the range is out of bounds, i.e. if `start + len >= array.len()`. + fn extend(&mut self, index: usize, start: usize, len: usize); + + /// Extends this [`Growable`] with null elements, disregarding the bound arrays + fn extend_validity(&mut self, additional: usize); + + /// The current length of the [`Growable`]. + fn len(&self) -> usize; + + /// Converts this [`Growable`] to an [`Arc`], thereby finishing the mutation. + /// Self will be empty after such operation. + fn as_arc(&mut self) -> Arc { + self.as_box().into() + } + + /// Converts this [`Growable`] to an [`Box`], thereby finishing the mutation. + /// Self will be empty after such operation + fn as_box(&mut self) -> Box; +} + +macro_rules! dyn_growable { + ($ty:ty, $arrays:expr, $use_validity:expr, $capacity:expr) => {{ + let arrays = $arrays + .iter() + .map(|array| array.as_any().downcast_ref().unwrap()) + .collect::>(); + Box::new(<$ty>::new(arrays, $use_validity, $capacity)) + }}; +} + +/// Creates a new [`Growable`] from an arbitrary number of [`Array`]s. +/// # Panics +/// This function panics iff +/// * the arrays do not have the same [`DataType`]. +/// * `arrays.is_empty()`. +pub fn make_growable<'a>( + arrays: &[&'a dyn Array], + use_validity: bool, + capacity: usize, +) -> Box + 'a> { + assert!(!arrays.is_empty()); + let data_type = arrays[0].data_type(); + + use PhysicalType::*; + match data_type.to_physical_type() { + Null => Box::new(null::GrowableNull::new(data_type.clone())), + Boolean => dyn_growable!(boolean::GrowableBoolean, arrays, use_validity, capacity), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + dyn_growable!(primitive::GrowablePrimitive::<$T>, arrays, use_validity, capacity) + }), + Utf8 => dyn_growable!(utf8::GrowableUtf8::, arrays, use_validity, capacity), + LargeUtf8 => dyn_growable!(utf8::GrowableUtf8::, arrays, use_validity, capacity), + Binary => dyn_growable!( + binary::GrowableBinary::, + arrays, + use_validity, + capacity + ), + LargeBinary => dyn_growable!( + binary::GrowableBinary::, + arrays, + use_validity, + capacity + ), + FixedSizeBinary => dyn_growable!( + fixed_binary::GrowableFixedSizeBinary, + arrays, + use_validity, + capacity + ), + List => dyn_growable!(list::GrowableList::, arrays, use_validity, capacity), + LargeList => dyn_growable!(list::GrowableList::, arrays, use_validity, capacity), + Struct => dyn_growable!(structure::GrowableStruct, arrays, use_validity, capacity), + FixedSizeList => dyn_growable!( + fixed_size_list::GrowableFixedSizeList, + arrays, + use_validity, + capacity + ), + Union => { + let arrays = arrays + .iter() + .map(|array| array.as_any().downcast_ref().unwrap()) + .collect::>(); + Box::new(union::GrowableUnion::new(arrays, capacity)) + }, + Map => dyn_growable!(map::GrowableMap, arrays, use_validity, capacity), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let arrays = arrays + .iter() + .map(|array| { + array + .as_any() + .downcast_ref::>() + .unwrap() + }) + .collect::>(); + Box::new(dictionary::GrowableDictionary::<$T>::new( + &arrays, + use_validity, + capacity, + )) + }) + }, + } +} diff --git a/crates/nano-arrow/src/array/growable/null.rs b/crates/nano-arrow/src/array/growable/null.rs new file mode 100644 index 000000000000..44e1c2488b0f --- /dev/null +++ b/crates/nano-arrow/src/array/growable/null.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use super::Growable; +use crate::array::{Array, NullArray}; +use crate::datatypes::DataType; + +/// Concrete [`Growable`] for the [`NullArray`]. +pub struct GrowableNull { + data_type: DataType, + length: usize, +} + +impl Default for GrowableNull { + fn default() -> Self { + Self::new(DataType::Null) + } +} + +impl GrowableNull { + /// Creates a new [`GrowableNull`]. + pub fn new(data_type: DataType) -> Self { + Self { + data_type, + length: 0, + } + } +} + +impl<'a> Growable<'a> for GrowableNull { + fn extend(&mut self, _: usize, _: usize, len: usize) { + self.length += len; + } + + fn extend_validity(&mut self, additional: usize) { + self.length += additional; + } + + #[inline] + fn len(&self) -> usize { + self.length + } + + fn as_arc(&mut self) -> Arc { + Arc::new(NullArray::new(self.data_type.clone(), self.length)) + } + + fn as_box(&mut self) -> Box { + Box::new(NullArray::new(self.data_type.clone(), self.length)) + } +} + +impl From for NullArray { + fn from(val: GrowableNull) -> Self { + NullArray::new(val.data_type, val.length) + } +} diff --git a/crates/nano-arrow/src/array/growable/primitive.rs b/crates/nano-arrow/src/array/growable/primitive.rs new file mode 100644 index 000000000000..cade744a5936 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/primitive.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::types::NativeType; + +/// Concrete [`Growable`] for the [`PrimitiveArray`]. +pub struct GrowablePrimitive<'a, T: NativeType> { + data_type: DataType, + arrays: Vec<&'a [T]>, + validity: MutableBitmap, + values: Vec, + extend_null_bits: Vec>, +} + +impl<'a, T: NativeType> GrowablePrimitive<'a, T> { + /// Creates a new [`GrowablePrimitive`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new( + arrays: Vec<&'a PrimitiveArray>, + mut use_validity: bool, + capacity: usize, + ) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if !use_validity & arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let data_type = arrays[0].data_type().clone(); + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let arrays = arrays + .iter() + .map(|array| array.values().as_slice()) + .collect::>(); + + Self { + data_type, + arrays, + values: Vec::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + #[inline] + fn to(&mut self) -> PrimitiveArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + + PrimitiveArray::::new(self.data_type.clone(), values.into(), validity.into()) + } +} + +impl<'a, T: NativeType> Growable<'a> for GrowablePrimitive<'a, T> { + #[inline] + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let values = self.arrays[index]; + self.values.extend_from_slice(&values[start..start + len]); + } + + #[inline] + fn extend_validity(&mut self, additional: usize) { + self.values + .resize(self.values.len() + additional, T::default()); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.values.len() + } + + #[inline] + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + #[inline] + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, T: NativeType> From> for PrimitiveArray { + #[inline] + fn from(val: GrowablePrimitive<'a, T>) -> Self { + PrimitiveArray::::new(val.data_type, val.values.into(), val.validity.into()) + } +} diff --git a/crates/nano-arrow/src/array/growable/structure.rs b/crates/nano-arrow/src/array/growable/structure.rs new file mode 100644 index 000000000000..10afd20e7f06 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/structure.rs @@ -0,0 +1,132 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, ExtendNullBits}; +use super::{make_growable, Growable}; +use crate::array::{Array, StructArray}; +use crate::bitmap::MutableBitmap; + +/// Concrete [`Growable`] for the [`StructArray`]. +pub struct GrowableStruct<'a> { + arrays: Vec<&'a StructArray>, + validity: MutableBitmap, + values: Vec + 'a>>, + extend_null_bits: Vec>, +} + +impl<'a> GrowableStruct<'a> { + /// Creates a new [`GrowableStruct`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a StructArray>, mut use_validity: bool, capacity: usize) -> Self { + assert!(!arrays.is_empty()); + + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + let arrays = arrays + .iter() + .map(|array| array.as_any().downcast_ref::().unwrap()) + .collect::>(); + + // ([field1, field2], [field3, field4]) -> ([field1, field3], [field2, field3]) + let values = (0..arrays[0].values().len()) + .map(|i| { + make_growable( + &arrays + .iter() + .map(|x| x.values()[i].as_ref()) + .collect::>(), + use_validity, + capacity, + ) + }) + .collect::>>(); + + Self { + arrays, + values, + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> StructArray { + let validity = std::mem::take(&mut self.validity); + let values = std::mem::take(&mut self.values); + let values = values.into_iter().map(|mut x| x.as_box()).collect(); + + StructArray::new(self.arrays[0].data_type().clone(), values, validity.into()) + } +} + +impl<'a> Growable<'a> for GrowableStruct<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + if array.null_count() == 0 { + self.values + .iter_mut() + .for_each(|child| child.extend(index, start, len)) + } else { + (start..start + len).for_each(|i| { + if array.is_valid(i) { + self.values + .iter_mut() + .for_each(|child| child.extend(index, i, 1)) + } else { + self.values + .iter_mut() + .for_each(|child| child.extend_validity(1)) + } + }) + } + } + + fn extend_validity(&mut self, additional: usize) { + self.values + .iter_mut() + .for_each(|child| child.extend_validity(additional)); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + // All children should have the same indexing, so just use the first + // one. If we don't have children, we might still have a validity + // array, so use that. + if let Some(child) = self.values.get(0) { + child.len() + } else { + self.validity.len() + } + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a> From> for StructArray { + fn from(val: GrowableStruct<'a>) -> Self { + let values = val.values.into_iter().map(|mut x| x.as_box()).collect(); + + StructArray::new( + val.arrays[0].data_type().clone(), + values, + val.validity.into(), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/union.rs b/crates/nano-arrow/src/array/growable/union.rs new file mode 100644 index 000000000000..4ef39f16fbb3 --- /dev/null +++ b/crates/nano-arrow/src/array/growable/union.rs @@ -0,0 +1,120 @@ +use std::sync::Arc; + +use super::{make_growable, Growable}; +use crate::array::{Array, UnionArray}; + +/// Concrete [`Growable`] for the [`UnionArray`]. +pub struct GrowableUnion<'a> { + arrays: Vec<&'a UnionArray>, + types: Vec, + offsets: Option>, + fields: Vec + 'a>>, +} + +impl<'a> GrowableUnion<'a> { + /// Creates a new [`GrowableUnion`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// Panics iff + /// * `arrays` is empty. + /// * any of the arrays has a different + pub fn new(arrays: Vec<&'a UnionArray>, capacity: usize) -> Self { + let first = arrays[0].data_type(); + assert!(arrays.iter().all(|x| x.data_type() == first)); + + let has_offsets = arrays[0].offsets().is_some(); + + let fields = (0..arrays[0].fields().len()) + .map(|i| { + make_growable( + &arrays + .iter() + .map(|x| x.fields()[i].as_ref()) + .collect::>(), + false, + capacity, + ) + }) + .collect::>>(); + + Self { + arrays, + fields, + offsets: if has_offsets { + Some(Vec::with_capacity(capacity)) + } else { + None + }, + types: Vec::with_capacity(capacity), + } + } + + fn to(&mut self) -> UnionArray { + let types = std::mem::take(&mut self.types); + let fields = std::mem::take(&mut self.fields); + let offsets = std::mem::take(&mut self.offsets); + let fields = fields.into_iter().map(|mut x| x.as_box()).collect(); + + UnionArray::new( + self.arrays[0].data_type().clone(), + types.into(), + fields, + offsets.map(|x| x.into()), + ) + } +} + +impl<'a> Growable<'a> for GrowableUnion<'a> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + let array = self.arrays[index]; + + let types = &array.types()[start..start + len]; + self.types.extend(types); + if let Some(x) = self.offsets.as_mut() { + let offsets = &array.offsets().unwrap()[start..start + len]; + + // in a dense union, each slot has its own offset. We extend the fields accordingly. + for (&type_, &offset) in types.iter().zip(offsets.iter()) { + let field = &mut self.fields[type_ as usize]; + // The offset for the element that is about to be extended is the current length + // of the child field of the corresponding type. Note that this may be very + // different than the original offset from the array we are extending from as + // it is a function of the previous extensions to this child. + x.push(field.len() as i32); + field.extend(index, offset as usize, 1); + } + } else { + // in a sparse union, every field has the same length => extend all fields equally + self.fields + .iter_mut() + .for_each(|field| field.extend(index, start, len)) + } + } + + fn extend_validity(&mut self, _additional: usize) {} + + #[inline] + fn len(&self) -> usize { + self.types.len() + } + + fn as_arc(&mut self) -> Arc { + self.to().arced() + } + + fn as_box(&mut self) -> Box { + self.to().boxed() + } +} + +impl<'a> From> for UnionArray { + fn from(val: GrowableUnion<'a>) -> Self { + let fields = val.fields.into_iter().map(|mut x| x.as_box()).collect(); + + UnionArray::new( + val.arrays[0].data_type().clone(), + val.types.into(), + fields, + val.offsets.map(|x| x.into()), + ) + } +} diff --git a/crates/nano-arrow/src/array/growable/utf8.rs b/crates/nano-arrow/src/array/growable/utf8.rs new file mode 100644 index 000000000000..1ea01ffd040a --- /dev/null +++ b/crates/nano-arrow/src/array/growable/utf8.rs @@ -0,0 +1,104 @@ +use std::sync::Arc; + +use super::utils::{build_extend_null_bits, extend_offset_values, ExtendNullBits}; +use super::Growable; +use crate::array::{Array, Utf8Array}; +use crate::bitmap::MutableBitmap; +use crate::offset::{Offset, Offsets}; + +/// Concrete [`Growable`] for the [`Utf8Array`]. +pub struct GrowableUtf8<'a, O: Offset> { + arrays: Vec<&'a Utf8Array>, + validity: MutableBitmap, + values: Vec, + offsets: Offsets, + extend_null_bits: Vec>, +} + +impl<'a, O: Offset> GrowableUtf8<'a, O> { + /// Creates a new [`GrowableUtf8`] bound to `arrays` with a pre-allocated `capacity`. + /// # Panics + /// If `arrays` is empty. + pub fn new(arrays: Vec<&'a Utf8Array>, mut use_validity: bool, capacity: usize) -> Self { + // if any of the arrays has nulls, insertions from any array requires setting bits + // as there is at least one array with nulls. + if arrays.iter().any(|array| array.null_count() > 0) { + use_validity = true; + }; + + let extend_null_bits = arrays + .iter() + .map(|array| build_extend_null_bits(*array, use_validity)) + .collect(); + + Self { + arrays: arrays.to_vec(), + values: Vec::with_capacity(0), + offsets: Offsets::with_capacity(capacity), + validity: MutableBitmap::with_capacity(capacity), + extend_null_bits, + } + } + + fn to(&mut self) -> Utf8Array { + let validity = std::mem::take(&mut self.validity); + let offsets = std::mem::take(&mut self.offsets); + let values = std::mem::take(&mut self.values); + + #[cfg(debug_assertions)] + { + crate::array::specification::try_check_utf8(&offsets, &values).unwrap(); + } + + unsafe { + Utf8Array::::try_new_unchecked( + self.arrays[0].data_type().clone(), + offsets.into(), + values.into(), + validity.into(), + ) + .unwrap() + } + } +} + +impl<'a, O: Offset> Growable<'a> for GrowableUtf8<'a, O> { + fn extend(&mut self, index: usize, start: usize, len: usize) { + (self.extend_null_bits[index])(&mut self.validity, start, len); + + let array = self.arrays[index]; + let offsets = array.offsets(); + let values = array.values(); + + self.offsets + .try_extend_from_slice(offsets, start, len) + .unwrap(); + + // values + extend_offset_values::(&mut self.values, offsets.as_slice(), values, start, len); + } + + fn extend_validity(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + self.validity.extend_constant(additional, false); + } + + #[inline] + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn as_arc(&mut self) -> Arc { + Arc::new(self.to()) + } + + fn as_box(&mut self) -> Box { + Box::new(self.to()) + } +} + +impl<'a, O: Offset> From> for Utf8Array { + fn from(mut val: GrowableUtf8<'a, O>) -> Self { + val.to() + } +} diff --git a/crates/nano-arrow/src/array/growable/utils.rs b/crates/nano-arrow/src/array/growable/utils.rs new file mode 100644 index 000000000000..ecdfb522249f --- /dev/null +++ b/crates/nano-arrow/src/array/growable/utils.rs @@ -0,0 +1,40 @@ +use crate::array::Array; +use crate::bitmap::MutableBitmap; +use crate::offset::Offset; + +// function used to extend nulls from arrays. This function's lifetime is bound to the array +// because it reads nulls from it. +pub(super) type ExtendNullBits<'a> = Box; + +pub(super) fn build_extend_null_bits(array: &dyn Array, use_validity: bool) -> ExtendNullBits { + if let Some(bitmap) = array.validity() { + Box::new(move |validity, start, len| { + debug_assert!(start + len <= bitmap.len()); + let (slice, offset, _) = bitmap.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { + validity.extend_from_slice_unchecked(slice, start + offset, len); + } + }) + } else if use_validity { + Box::new(|validity, _, len| { + validity.extend_constant(len, true); + }) + } else { + Box::new(|_, _, _| {}) + } +} + +#[inline] +pub(super) fn extend_offset_values( + buffer: &mut Vec, + offsets: &[O], + values: &[u8], + start: usize, + len: usize, +) { + let start_values = offsets[start].to_usize(); + let end_values = offsets[start + len].to_usize(); + let new_values = &values[start_values..end_values]; + buffer.extend_from_slice(new_values); +} diff --git a/crates/nano-arrow/src/array/indexable.rs b/crates/nano-arrow/src/array/indexable.rs new file mode 100644 index 000000000000..d3f466722aa6 --- /dev/null +++ b/crates/nano-arrow/src/array/indexable.rs @@ -0,0 +1,194 @@ +use std::borrow::Borrow; + +use crate::array::{ + MutableArray, MutableBinaryArray, MutableBinaryValuesArray, MutableBooleanArray, + MutableFixedSizeBinaryArray, MutablePrimitiveArray, MutableUtf8Array, MutableUtf8ValuesArray, +}; +use crate::offset::Offset; +use crate::types::NativeType; + +/// Trait for arrays that can be indexed directly to extract a value. +pub trait Indexable { + /// The type of the element at index `i`; may be a reference type or a value type. + type Value<'a>: Borrow + where + Self: 'a; + + type Type: ?Sized; + + /// Returns the element at index `i`. + /// # Panic + /// May panic if `i >= self.len()`. + fn value_at(&self, index: usize) -> Self::Value<'_>; + + /// Returns the element at index `i`. + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + unsafe fn value_unchecked_at(&self, index: usize) -> Self::Value<'_> { + self.value_at(index) + } +} + +pub trait AsIndexed { + fn as_indexed(&self) -> &M::Type; +} + +impl Indexable for MutableBooleanArray { + type Value<'a> = bool; + type Type = bool; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.values().get(i) + } +} + +impl AsIndexed for bool { + #[inline] + fn as_indexed(&self) -> &bool { + self + } +} + +impl Indexable for MutableBinaryArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + // TODO: add .value() / .value_unchecked() to MutableBinaryArray? + assert!(i < self.len()); + unsafe { self.value_unchecked_at(i) } + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + // TODO: add .value() / .value_unchecked() to MutableBinaryArray? + // soundness: the invariant of the function + let (start, end) = self.offsets().start_end_unchecked(i); + // soundness: the invariant of the struct + self.values().get_unchecked(start..end) + } +} + +impl AsIndexed> for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +impl Indexable for MutableBinaryValuesArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl AsIndexed> for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +impl Indexable for MutableFixedSizeBinaryArray { + type Value<'a> = &'a [u8]; + type Type = [u8]; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + // soundness: the invariant of the struct + self.value_unchecked(i) + } +} + +impl AsIndexed for &[u8] { + #[inline] + fn as_indexed(&self) -> &[u8] { + self + } +} + +// TODO: should NativeType derive from Hash? +impl Indexable for MutablePrimitiveArray { + type Value<'a> = T; + type Type = T; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + assert!(i < self.len()); + // TODO: add Length trait? (for both Array and MutableArray) + unsafe { self.value_unchecked_at(i) } + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + *self.values().get_unchecked(i) + } +} + +impl AsIndexed> for T { + #[inline] + fn as_indexed(&self) -> &T { + self + } +} + +impl Indexable for MutableUtf8Array { + type Value<'a> = &'a str; + type Type = str; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl> AsIndexed> for V { + #[inline] + fn as_indexed(&self) -> &str { + self.as_ref() + } +} + +impl Indexable for MutableUtf8ValuesArray { + type Value<'a> = &'a str; + type Type = str; + + #[inline] + fn value_at(&self, i: usize) -> Self::Value<'_> { + self.value(i) + } + + #[inline] + unsafe fn value_unchecked_at(&self, i: usize) -> Self::Value<'_> { + self.value_unchecked(i) + } +} + +impl> AsIndexed> for V { + #[inline] + fn as_indexed(&self) -> &str { + self.as_ref() + } +} diff --git a/crates/nano-arrow/src/array/iterator.rs b/crates/nano-arrow/src/array/iterator.rs new file mode 100644 index 000000000000..5e8ed44d861e --- /dev/null +++ b/crates/nano-arrow/src/array/iterator.rs @@ -0,0 +1,83 @@ +use crate::trusted_len::TrustedLen; + +mod private { + pub trait Sealed {} + + impl<'a, T: super::ArrayAccessor<'a>> Sealed for T {} +} + +/// Sealed trait representing assess to a value of an array. +/// # Safety +/// Implementers of this trait guarantee that +/// `value_unchecked` is safe when called up to `len` +pub unsafe trait ArrayAccessor<'a>: private::Sealed { + type Item: 'a; + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item; + fn len(&self) -> usize; +} + +/// Iterator of values of an [`ArrayAccessor`]. +#[derive(Debug, Clone)] +pub struct ArrayValuesIter<'a, A: ArrayAccessor<'a>> { + array: &'a A, + index: usize, + end: usize, +} + +impl<'a, A: ArrayAccessor<'a>> ArrayValuesIter<'a, A> { + /// Creates a new [`ArrayValuesIter`] + #[inline] + pub fn new(array: &'a A) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a, A: ArrayAccessor<'a>> Iterator for ArrayValuesIter<'a, A> { + type Item = A::Item; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(unsafe { self.array.value_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl<'a, A: ArrayAccessor<'a>> DoubleEndedIterator for ArrayValuesIter<'a, A> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(unsafe { self.array.value_unchecked(self.end) }) + } + } +} + +unsafe impl<'a, A: ArrayAccessor<'a>> TrustedLen for ArrayValuesIter<'a, A> {} +impl<'a, A: ArrayAccessor<'a>> ExactSizeIterator for ArrayValuesIter<'a, A> {} diff --git a/crates/nano-arrow/src/array/list/data.rs b/crates/nano-arrow/src/array/list/data.rs new file mode 100644 index 000000000000..6f3424c96ce6 --- /dev/null +++ b/crates/nano-arrow/src/array/list/data.rs @@ -0,0 +1,38 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, ListArray}; +use crate::bitmap::Bitmap; +use crate::offset::{Offset, OffsetsBuffer}; + +impl Arrow2Arrow for ListArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.offsets.clone().into_inner().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(vec![to_data(self.values.as_ref())]); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type, + offsets, + values: from_data(&data.child_data()[0]), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/list/ffi.rs b/crates/nano-arrow/src/array/list/ffi.rs new file mode 100644 index 000000000000..487b4ad40128 --- /dev/null +++ b/crates/nano-arrow/src/array/list/ffi.rs @@ -0,0 +1,68 @@ +use super::super::ffi::ToFfi; +use super::super::Array; +use super::ListArray; +use crate::array::FromFfi; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for ListArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + ] + } + + fn children(&self) -> Vec> { + vec![self.values.clone()] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for ListArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let child = unsafe { array.child(0)? }; + let values = ffi::try_from(child)?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Ok(Self::new(data_type, offsets, values, validity)) + } +} diff --git a/crates/nano-arrow/src/array/list/fmt.rs b/crates/nano-arrow/src/array/list/fmt.rs new file mode 100644 index 000000000000..67dcd6b78786 --- /dev/null +++ b/crates/nano-arrow/src/array/list/fmt.rs @@ -0,0 +1,30 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::ListArray; +use crate::offset::Offset; + +pub fn write_value( + array: &ListArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for ListArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + let head = if O::IS_LARGE { + "LargeListArray" + } else { + "ListArray" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/list/iterator.rs b/crates/nano-arrow/src/array/list/iterator.rs new file mode 100644 index 000000000000..28552bf4bb65 --- /dev/null +++ b/crates/nano-arrow/src/array/list/iterator.rs @@ -0,0 +1,68 @@ +use super::ListArray; +use crate::array::{Array, ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for ListArray { + type Item = Box; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of a [`ListArray`]. +pub type ListValuesIter<'a, O> = ArrayValuesIter<'a, ListArray>; + +type ZipIter<'a, O> = ZipValidity, ListValuesIter<'a, O>, BitmapIter<'a>>; + +impl<'a, O: Offset> IntoIterator for &'a ListArray { + type Item = Option>; + type IntoIter = ZipIter<'a, O>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, O: Offset> ListArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a, O> { + ZipValidity::new_with_validity(ListValuesIter::new(self), self.validity.as_ref()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> ListValuesIter<'a, O> { + ListValuesIter::new(self) + } +} + +struct Iter>> { + current: i32, + offsets: std::vec::IntoIter, + values: I, +} + +impl> + Clone> Iterator for Iter { + type Item = Option>>; + + fn next(&mut self) -> Option { + let next = self.offsets.next(); + next.map(|next| { + let length = next - self.current; + let iter = self + .values + .clone() + .skip(self.current as usize) + .take(length as usize); + self.current = next; + Some(iter) + }) + } +} diff --git a/crates/nano-arrow/src/array/list/mod.rs b/crates/nano-arrow/src/array/list/mod.rs new file mode 100644 index 000000000000..dff4584d0cbf --- /dev/null +++ b/crates/nano-arrow/src/array/list/mod.rs @@ -0,0 +1,240 @@ +use super::specification::try_check_offsets_bounds; +use super::{new_empty_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; +mod mutable; +pub use mutable::*; + +/// An [`Array`] semantically equivalent to `Vec>>>` with Arrow's in-memory. +#[derive(Clone)] +pub struct ListArray { + data_type: DataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, +} + +impl ListArray { + /// Creates a new [`ListArray`]. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// # Implementation + /// This function is `O(1)` + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + let child_data_type = Self::try_get_child(&data_type)?.data_type(); + let values_data_type = values.data_type(); + if child_data_type != values_data_type { + return Err(Error::oos( + format!("ListArray's child's DataType must match. However, the expected DataType is {child_data_type:?} while it got {values_data_type:?}."), + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Creates a new [`ListArray`]. + /// + /// # Panics + /// This function panics iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either [`crate::datatypes::PhysicalType::List`] or [`crate::datatypes::PhysicalType::LargeList`]. + /// * The `data_type`'s inner field's data type is not equal to `values.data_type`. + /// # Implementation + /// This function is `O(1)` + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Box, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, values, validity).unwrap() + } + + /// Returns a new empty [`ListArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let values = new_empty_array(Self::get_child_type(&data_type).clone()); + Self::new(data_type, OffsetsBuffer::default(), values, None) + } + + /// Returns a new null [`ListArray`]. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + let child = Self::get_child_type(&data_type).clone(); + Self::new( + data_type, + Offsets::new_zeroed(length).into(), + new_empty_array(child), + Some(Bitmap::new_zeroed(length)), + ) + } +} + +impl ListArray { + /// Slices this [`ListArray`]. + /// # Panics + /// panics iff `offset + length >= self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`ListArray`]. + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); +} + +// Accessors +impl ListArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the element at index `i` + /// # Panic + /// Panics iff `i >= self.len()` + #[inline] + pub fn value(&self, i: usize) -> Box { + assert!(i < self.len()); + // Safety: invariant of this function + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i` as &str + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + // safety: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + let length = end - start; + + // safety: the invariant of the struct + self.values.sliced_unchecked(start, length) + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// The offsets [`Buffer`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The values. + #[inline] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl ListArray { + /// Returns a default [`DataType`]: inner field is named "item" and is nullable + pub fn default_datatype(data_type: DataType) -> DataType { + let field = Box::new(Field::new("item", data_type, true)); + if O::IS_LARGE { + DataType::LargeList(field) + } else { + DataType::List(field) + } + } + + /// Returns a the inner [`Field`] + /// # Panics + /// Panics iff the logical type is not consistent with this struct. + pub fn get_child_field(data_type: &DataType) -> &Field { + Self::try_get_child(data_type).unwrap() + } + + /// Returns a the inner [`Field`] + /// # Errors + /// Panics iff the logical type is not consistent with this struct. + pub fn try_get_child(data_type: &DataType) -> Result<&Field, Error> { + if O::IS_LARGE { + match data_type.to_logical_type() { + DataType::LargeList(child) => Ok(child.as_ref()), + _ => Err(Error::oos("ListArray expects DataType::LargeList")), + } + } else { + match data_type.to_logical_type() { + DataType::List(child) => Ok(child.as_ref()), + _ => Err(Error::oos("ListArray expects DataType::List")), + } + } + } + + /// Returns a the inner [`DataType`] + /// # Panics + /// Panics iff the logical type is not consistent with this struct. + pub fn get_child_type(data_type: &DataType) -> &DataType { + Self::get_child_field(data_type).data_type() + } +} + +impl Array for ListArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/list/mutable.rs b/crates/nano-arrow/src/array/list/mutable.rs new file mode 100644 index 000000000000..39dc22da3cb0 --- /dev/null +++ b/crates/nano-arrow/src/array/list/mutable.rs @@ -0,0 +1,315 @@ +use std::sync::Arc; + +use super::ListArray; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// The mutable version of [`ListArray`]. +#[derive(Debug, Clone)] +pub struct MutableListArray { + data_type: DataType, + offsets: Offsets, + values: M, + validity: Option, +} + +impl MutableListArray { + /// Creates a new empty [`MutableListArray`]. + pub fn new() -> Self { + let values = M::default(); + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self::new_from(values, data_type, 0) + } + + /// Creates a new [`MutableListArray`] with a capacity. + pub fn with_capacity(capacity: usize) -> Self { + let values = M::default(); + let data_type = ListArray::::default_datatype(values.data_type().clone()); + + let offsets = Offsets::::with_capacity(capacity); + Self { + data_type, + offsets, + values, + validity: None, + } + } +} + +impl Default for MutableListArray { + fn default() -> Self { + Self::new() + } +} + +impl From> for ListArray { + fn from(mut other: MutableListArray) -> Self { + ListArray::new( + other.data_type, + other.offsets.into(), + other.values.as_box(), + other.validity.map(|x| x.into()), + ) + } +} + +impl TryExtend> for MutableListArray +where + O: Offset, + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + fn try_extend>>(&mut self, iter: II) -> Result<()> { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + for items in iter { + self.try_push(items)?; + } + Ok(()) + } +} + +impl TryPush> for MutableListArray +where + O: Offset, + M: MutableArray + TryExtend>, + I: IntoIterator>, +{ + #[inline] + fn try_push(&mut self, item: Option) -> Result<()> { + if let Some(items) = item { + let values = self.mut_values(); + values.try_extend(items)?; + self.try_push_valid()?; + } else { + self.push_null(); + } + Ok(()) + } +} + +impl TryExtendFromSelf for MutableListArray +where + O: Offset, + M: MutableArray + TryExtendFromSelf, +{ + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values)?; + self.offsets.try_extend_from_self(&other.offsets) + } +} + +impl MutableListArray { + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_from(values: M, data_type: DataType, capacity: usize) -> Self { + let offsets = Offsets::::with_capacity(capacity); + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&data_type); + Self { + data_type, + offsets, + values, + validity: None, + } + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`]. + pub fn new_with_field(values: M, name: &str, nullable: bool) -> Self { + let field = Box::new(Field::new(name, values.data_type().clone(), nullable)); + let data_type = if O::IS_LARGE { + DataType::LargeList(field) + } else { + DataType::List(field) + }; + Self::new_from(values, data_type, 0) + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`] and capacity. + pub fn new_with_capacity(values: M, capacity: usize) -> Self { + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self::new_from(values, data_type, capacity) + } + + /// Creates a new [`MutableListArray`] from a [`MutableArray`], [`Offsets`] and + /// [`MutableBitmap`]. + pub fn new_from_mutable( + values: M, + offsets: Offsets, + validity: Option, + ) -> Self { + assert_eq!(values.len(), offsets.last().to_usize()); + let data_type = ListArray::::default_datatype(values.data_type().clone()); + Self { + data_type, + offsets, + values, + validity, + } + } + + #[inline] + /// Needs to be called when a valid value was extended to this array. + /// This is a relatively low level function, prefer `try_push` when you can. + pub fn try_push_valid(&mut self) -> Result<()> { + let total_length = self.values.len(); + let offset = self.offsets.last().to_usize(); + let length = total_length.checked_sub(offset).ok_or(Error::Overflow)?; + + self.offsets.try_push_usize(length)?; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.extend_constant(1); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + /// Expand this array, using elements from the underlying backing array. + /// Assumes the expansion begins at the highest previous offset, or zero if + /// this [`MutableListArray`] is currently empty. + /// + /// Panics if: + /// - the new offsets are not in monotonic increasing order. + /// - any new offset is not in bounds of the backing array. + /// - the passed iterator has no upper bound. + pub fn try_extend_from_lengths(&mut self, iterator: II) -> Result<()> + where + II: TrustedLen> + Clone, + { + self.offsets + .try_extend_from_lengths(iterator.clone().map(|x| x.unwrap_or_default()))?; + if let Some(validity) = &mut self.validity { + validity.extend_from_trusted_len_iter(iterator.map(|x| x.is_some())) + } + assert_eq!(self.offsets.last().to_usize(), self.values.len()); + Ok(()) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// The values + pub fn mut_values(&mut self) -> &mut M { + &mut self.values + } + + /// The offsets + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// The values + pub fn values(&self) -> &M { + &self.values + } + + fn init_validity(&mut self) { + let len = self.offsets.len_proxy(); + + let mut validity = MutableBitmap::with_capacity(self.offsets.capacity()); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: ListArray = self.into(); + Arc::new(a) + } + + /// converts itself into [`Box`] + pub fn into_box(self) -> Box { + let a: ListArray = self.into(); + Box::new(a) + } + + /// Reserves `additional` slots. + pub fn reserve(&mut self, additional: usize) { + self.offsets.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Shrinks the capacity of the [`MutableListArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableListArray { + fn len(&self) -> usize { + MutableListArray::len(self) + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit(); + } +} diff --git a/crates/nano-arrow/src/array/map/data.rs b/crates/nano-arrow/src/array/map/data.rs new file mode 100644 index 000000000000..cb8862a4df3d --- /dev/null +++ b/crates/nano-arrow/src/array/map/data.rs @@ -0,0 +1,38 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, MapArray}; +use crate::bitmap::Bitmap; +use crate::offset::OffsetsBuffer; + +impl Arrow2Arrow for MapArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.offsets.clone().into_inner().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(vec![to_data(self.field.as_ref())]); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(data.buffers()[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type: data.data_type().clone().into(), + offsets, + field: from_data(&data.child_data()[0]), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/map/ffi.rs b/crates/nano-arrow/src/array/map/ffi.rs new file mode 100644 index 000000000000..9193e7253753 --- /dev/null +++ b/crates/nano-arrow/src/array/map/ffi.rs @@ -0,0 +1,68 @@ +use super::super::ffi::ToFfi; +use super::super::Array; +use super::MapArray; +use crate::array::FromFfi; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::OffsetsBuffer; + +unsafe impl ToFfi for MapArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + ] + } + + fn children(&self) -> Vec> { + vec![self.field.clone()] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + field: self.field.clone(), + } + } +} + +impl FromFfi for MapArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let child = array.child(0)?; + let values = ffi::try_from(child)?; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Self::try_new(data_type, offsets, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/map/fmt.rs b/crates/nano-arrow/src/array/map/fmt.rs new file mode 100644 index 000000000000..60abf56e18c5 --- /dev/null +++ b/crates/nano-arrow/src/array/map/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::MapArray; + +pub fn write_value( + array: &MapArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let values = array.value(index); + let writer = |f: &mut W, index| get_display(values.as_ref(), null)(f, index); + write_vec(f, writer, None, values.len(), null, false) +} + +impl Debug for MapArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "MapArray")?; + write_vec(f, writer, self.validity.as_ref(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/map/iterator.rs b/crates/nano-arrow/src/array/map/iterator.rs new file mode 100644 index 000000000000..f424e91b8043 --- /dev/null +++ b/crates/nano-arrow/src/array/map/iterator.rs @@ -0,0 +1,81 @@ +use super::MapArray; +use crate::array::Array; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::trusted_len::TrustedLen; + +/// Iterator of values of an [`ListArray`]. +#[derive(Clone, Debug)] +pub struct MapValuesIter<'a> { + array: &'a MapArray, + index: usize, + end: usize, +} + +impl<'a> MapValuesIter<'a> { + #[inline] + pub fn new(array: &'a MapArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a> Iterator for MapValuesIter<'a> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + // Safety: + // self.end is maximized by the length of the array + Some(unsafe { self.array.value_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a> TrustedLen for MapValuesIter<'a> {} + +impl<'a> DoubleEndedIterator for MapValuesIter<'a> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + // Safety: + // self.end is maximized by the length of the array + Some(unsafe { self.array.value_unchecked(self.end) }) + } + } +} + +impl<'a> IntoIterator for &'a MapArray { + type Item = Option>; + type IntoIter = ZipValidity, MapValuesIter<'a>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> MapArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipValidity, MapValuesIter<'a>, BitmapIter<'a>> { + ZipValidity::new_with_validity(MapValuesIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> MapValuesIter<'a> { + MapValuesIter::new(self) + } +} diff --git a/crates/nano-arrow/src/array/map/mod.rs b/crates/nano-arrow/src/array/map/mod.rs new file mode 100644 index 000000000000..fca2e3bf68c1 --- /dev/null +++ b/crates/nano-arrow/src/array/map/mod.rs @@ -0,0 +1,204 @@ +use super::specification::try_check_offsets_bounds; +use super::{new_empty_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; +use crate::offset::OffsetsBuffer; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +pub use iterator::*; + +/// An array representing a (key, value), both of arbitrary logical types. +#[derive(Clone)] +pub struct MapArray { + data_type: DataType, + // invariant: field.len() == offsets.len() + offsets: OffsetsBuffer, + field: Box, + // invariant: offsets.len() - 1 == Bitmap::len() + validity: Option, +} + +impl MapArray { + /// Returns a new [`MapArray`]. + /// # Errors + /// This function errors iff: + /// * The last offset is not equal to the field' length + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Map`] + /// * The fields' `data_type` is not equal to the inner field of `data_type` + /// * The validity is not `None` and its length is different from `offsets.len() - 1`. + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + field: Box, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, field.len())?; + + let inner_field = Self::try_get_field(&data_type)?; + if let DataType::Struct(inner) = inner_field.data_type() { + if inner.len() != 2 { + return Err(Error::InvalidArgumentError( + "MapArray's inner `Struct` must have 2 fields (keys and maps)".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "MapArray expects `DataType::Struct` as its inner logical type".to_string(), + )); + } + if field.data_type() != inner_field.data_type() { + return Err(Error::InvalidArgumentError( + "MapArray expects `field.data_type` to match its inner DataType".to_string(), + )); + } + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + Ok(Self { + data_type, + field, + offsets, + validity, + }) + } + + /// Creates a new [`MapArray`]. + /// # Panics + /// * The last offset is not equal to the field' length. + /// * The `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Map`], + /// * The validity is not `None` and its length is different from `offsets.len() - 1`. + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + field: Box, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, field, validity).unwrap() + } + + /// Returns a new null [`MapArray`] of `length`. + pub fn new_null(data_type: DataType, length: usize) -> Self { + let field = new_empty_array(Self::get_field(&data_type).data_type().clone()); + Self::new( + data_type, + vec![0i32; 1 + length].try_into().unwrap(), + field, + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns a new empty [`MapArray`]. + pub fn new_empty(data_type: DataType) -> Self { + let field = new_empty_array(Self::get_field(&data_type).data_type().clone()); + Self::new(data_type, OffsetsBuffer::default(), field, None) + } +} + +impl MapArray { + /// Returns a slice of this [`MapArray`]. + /// # Panics + /// panics iff `offset + length >= self.len()` + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a slice of this [`MapArray`]. + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + pub(crate) fn try_get_field(data_type: &DataType) -> Result<&Field, Error> { + if let DataType::Map(field, _) = data_type.to_logical_type() { + Ok(field.as_ref()) + } else { + Err(Error::oos( + "The data_type's logical type must be DataType::Map", + )) + } + } + + pub(crate) fn get_field(data_type: &DataType) -> &Field { + Self::try_get_field(data_type).unwrap() + } +} + +// Accessors +impl MapArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// returns the offsets + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// Returns the field (guaranteed to be a `Struct`) + #[inline] + pub fn field(&self) -> &Box { + &self.field + } + + /// Returns the element at index `i`. + #[inline] + pub fn value(&self, i: usize) -> Box { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the element at index `i`. + /// # Safety + /// Assumes that the `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> Box { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + let length = end - start; + + // soundness: the invariant of the struct + self.field.sliced_unchecked(start, length) + } +} + +impl Array for MapArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/mod.rs b/crates/nano-arrow/src/array/mod.rs new file mode 100644 index 000000000000..0d8534fc9e71 --- /dev/null +++ b/crates/nano-arrow/src/array/mod.rs @@ -0,0 +1,787 @@ +//! Contains the [`Array`] and [`MutableArray`] trait objects declaring arrays, +//! as well as concrete arrays (such as [`Utf8Array`] and [`MutableUtf8Array`]). +//! +//! Fixed-length containers with optional values +//! that are laid in memory according to the Arrow specification. +//! Each array type has its own `struct`. The following are the main array types: +//! * [`PrimitiveArray`] and [`MutablePrimitiveArray`], an array of values with a fixed length such as integers, floats, etc. +//! * [`BooleanArray`] and [`MutableBooleanArray`], an array of boolean values (stored as a bitmap) +//! * [`Utf8Array`] and [`MutableUtf8Array`], an array of variable length utf8 values +//! * [`BinaryArray`] and [`MutableBinaryArray`], an array of opaque variable length values +//! * [`ListArray`] and [`MutableListArray`], an array of arrays (e.g. `[[1, 2], None, [], [None]]`) +//! * [`StructArray`] and [`MutableStructArray`], an array of arrays identified by a string (e.g. `{"a": [1, 2], "b": [true, false]}`) +//! All immutable arrays implement the trait object [`Array`] and that can be downcasted +//! to a concrete struct based on [`PhysicalType`](crate::datatypes::PhysicalType) available from [`Array::data_type`]. +//! All immutable arrays are backed by [`Buffer`](crate::buffer::Buffer) and thus cloning and slicing them is `O(1)`. +//! +//! Most arrays contain a [`MutableArray`] counterpart that is neither clonable nor sliceable, but +//! can be operated in-place. +use std::any::Any; +use std::sync::Arc; + +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::Result; + +pub mod physical_binary; + +/// A trait representing an immutable Arrow array. Arrow arrays are trait objects +/// that are infallibly downcasted to concrete types according to the [`Array::data_type`]. +pub trait Array: Send + Sync + dyn_clone::DynClone + 'static { + /// Converts itself to a reference of [`Any`], which enables downcasting to concrete types. + fn as_any(&self) -> &dyn Any; + + /// Converts itself to a mutable reference of [`Any`], which enables mutable downcasting to concrete types. + fn as_any_mut(&mut self) -> &mut dyn Any; + + /// The length of the [`Array`]. Every array has a length corresponding to the number of + /// elements (slots). + fn len(&self) -> usize; + + /// whether the array is empty + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The [`DataType`] of the [`Array`]. In combination with [`Array::as_any`], this can be + /// used to downcast trait objects (`dyn Array`) to concrete arrays. + fn data_type(&self) -> &DataType; + + /// The validity of the [`Array`]: every array has an optional [`Bitmap`] that, when available + /// specifies whether the array slot is valid or not (null). + /// When the validity is [`None`], all slots are valid. + fn validity(&self) -> Option<&Bitmap>; + + /// The number of null slots on this [`Array`]. + /// # Implementation + /// This is `O(1)` since the number of null elements is pre-computed. + #[inline] + fn null_count(&self) -> usize { + if self.data_type() == &DataType::Null { + return self.len(); + }; + self.validity() + .as_ref() + .map(|x| x.unset_bits()) + .unwrap_or(0) + } + + /// Returns whether slot `i` is null. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + fn is_null(&self, i: usize) -> bool { + assert!(i < self.len()); + unsafe { self.is_null_unchecked(i) } + } + + /// Returns whether slot `i` is null. + /// # Safety + /// The caller must ensure `i < self.len()` + #[inline] + unsafe fn is_null_unchecked(&self, i: usize) -> bool { + self.validity() + .as_ref() + .map(|x| !x.get_bit_unchecked(i)) + .unwrap_or(false) + } + + /// Returns whether slot `i` is valid. + /// # Panic + /// Panics iff `i >= self.len()`. + #[inline] + fn is_valid(&self, i: usize) -> bool { + !self.is_null(i) + } + + /// Slices this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + fn slice(&mut self, offset: usize, length: usize); + + /// Slices the [`Array`]. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()` + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize); + + /// Returns a slice of this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + #[must_use] + fn sliced(&self, offset: usize, length: usize) -> Box { + let mut new = self.to_boxed(); + new.slice(offset, length); + new + } + + /// Returns a slice of this [`Array`]. + /// # Implementation + /// This operation is `O(1)` over `len`, as it amounts to increase two ref counts + /// and moving the struct to the heap. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()` + #[must_use] + unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Box { + let mut new = self.to_boxed(); + new.slice_unchecked(offset, length); + new + } + + /// Clones this [`Array`] with a new new assigned bitmap. + /// # Panic + /// This function panics iff `validity.len() != self.len()`. + fn with_validity(&self, validity: Option) -> Box; + + /// Clone a `&dyn Array` to an owned `Box`. + fn to_boxed(&self) -> Box; +} + +dyn_clone::clone_trait_object!(Array); + +/// A trait describing an array with a backing store that can be preallocated to +/// a given size. +pub(crate) trait Container { + /// Create this array with a given capacity. + fn with_capacity(capacity: usize) -> Self + where + Self: Sized; +} + +/// A trait describing a mutable array; i.e. an array whose values can be changed. +/// Mutable arrays cannot be cloned but can be mutated in place, +/// thereby making them useful to perform numeric operations without allocations. +/// As in [`Array`], concrete arrays (such as [`MutablePrimitiveArray`]) implement how they are mutated. +pub trait MutableArray: std::fmt::Debug + Send + Sync { + /// The [`DataType`] of the array. + fn data_type(&self) -> &DataType; + + /// The length of the array. + fn len(&self) -> usize; + + /// Whether the array is empty. + fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// The optional validity of the array. + fn validity(&self) -> Option<&MutableBitmap>; + + /// Convert itself to an (immutable) [`Array`]. + fn as_box(&mut self) -> Box; + + /// Convert itself to an (immutable) atomically reference counted [`Array`]. + // This provided implementation has an extra allocation as it first + // boxes `self`, then converts the box into an `Arc`. Implementors may wish + // to avoid an allocation by skipping the box completely. + fn as_arc(&mut self) -> std::sync::Arc { + self.as_box().into() + } + + /// Convert to `Any`, to enable dynamic casting. + fn as_any(&self) -> &dyn Any; + + /// Convert to mutable `Any`, to enable dynamic casting. + fn as_mut_any(&mut self) -> &mut dyn Any; + + /// Adds a new null element to the array. + fn push_null(&mut self); + + /// Whether `index` is valid / set. + /// # Panic + /// Panics if `index >= self.len()`. + #[inline] + fn is_valid(&self, index: usize) -> bool { + self.validity() + .as_ref() + .map(|x| x.get(index)) + .unwrap_or(true) + } + + /// Reserves additional slots to its capacity. + fn reserve(&mut self, additional: usize); + + /// Shrink the array to fit its length. + fn shrink_to_fit(&mut self); +} + +impl MutableArray for Box { + fn len(&self) -> usize { + self.as_ref().len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.as_ref().validity() + } + + fn as_box(&mut self) -> Box { + self.as_mut().as_box() + } + + fn as_arc(&mut self) -> Arc { + self.as_mut().as_arc() + } + + fn data_type(&self) -> &DataType { + self.as_ref().data_type() + } + + fn as_any(&self) -> &dyn std::any::Any { + self.as_ref().as_any() + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self.as_mut().as_mut_any() + } + + #[inline] + fn push_null(&mut self) { + self.as_mut().push_null() + } + + fn shrink_to_fit(&mut self) { + self.as_mut().shrink_to_fit(); + } + + fn reserve(&mut self, additional: usize) { + self.as_mut().reserve(additional); + } +} + +macro_rules! general_dyn { + ($array:expr, $ty:ty, $f:expr) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + ($f)(array) + }}; +} + +macro_rules! fmt_dyn { + ($array:expr, $ty:ty, $f:expr) => {{ + let mut f = |x: &$ty| x.fmt($f); + general_dyn!($array, $ty, f) + }}; +} + +macro_rules! match_integer_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::IntegerType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + } +})} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns, f16, i256}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float16 => __with_ty__! { f16 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +impl std::fmt::Debug for dyn Array + '_ { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use crate::datatypes::PhysicalType::*; + match self.data_type().to_physical_type() { + Null => fmt_dyn!(self, NullArray, f), + Boolean => fmt_dyn!(self, BooleanArray, f), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + fmt_dyn!(self, PrimitiveArray<$T>, f) + }), + Binary => fmt_dyn!(self, BinaryArray, f), + LargeBinary => fmt_dyn!(self, BinaryArray, f), + FixedSizeBinary => fmt_dyn!(self, FixedSizeBinaryArray, f), + Utf8 => fmt_dyn!(self, Utf8Array::, f), + LargeUtf8 => fmt_dyn!(self, Utf8Array::, f), + List => fmt_dyn!(self, ListArray::, f), + LargeList => fmt_dyn!(self, ListArray::, f), + FixedSizeList => fmt_dyn!(self, FixedSizeListArray, f), + Struct => fmt_dyn!(self, StructArray, f), + Union => fmt_dyn!(self, UnionArray, f), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + fmt_dyn!(self, DictionaryArray::<$T>, f) + }) + }, + Map => fmt_dyn!(self, MapArray, f), + } + } +} + +/// Creates a new [`Array`] with a [`Array::len`] of 0. +pub fn new_empty_array(data_type: DataType) -> Box { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null => Box::new(NullArray::new_empty(data_type)), + Boolean => Box::new(BooleanArray::new_empty(data_type)), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::new_empty(data_type)) + }), + Binary => Box::new(BinaryArray::::new_empty(data_type)), + LargeBinary => Box::new(BinaryArray::::new_empty(data_type)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_empty(data_type)), + Utf8 => Box::new(Utf8Array::::new_empty(data_type)), + LargeUtf8 => Box::new(Utf8Array::::new_empty(data_type)), + List => Box::new(ListArray::::new_empty(data_type)), + LargeList => Box::new(ListArray::::new_empty(data_type)), + FixedSizeList => Box::new(FixedSizeListArray::new_empty(data_type)), + Struct => Box::new(StructArray::new_empty(data_type)), + Union => Box::new(UnionArray::new_empty(data_type)), + Map => Box::new(MapArray::new_empty(data_type)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::new_empty(data_type)) + }) + }, + } +} + +/// Creates a new [`Array`] of [`DataType`] `data_type` and `length`. +/// The array is guaranteed to have [`Array::null_count`] equal to [`Array::len`] +/// for all types except Union, which does not have a validity. +pub fn new_null_array(data_type: DataType, length: usize) -> Box { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null => Box::new(NullArray::new_null(data_type, length)), + Boolean => Box::new(BooleanArray::new_null(data_type, length)), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::new_null(data_type, length)) + }), + Binary => Box::new(BinaryArray::::new_null(data_type, length)), + LargeBinary => Box::new(BinaryArray::::new_null(data_type, length)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::new_null(data_type, length)), + Utf8 => Box::new(Utf8Array::::new_null(data_type, length)), + LargeUtf8 => Box::new(Utf8Array::::new_null(data_type, length)), + List => Box::new(ListArray::::new_null(data_type, length)), + LargeList => Box::new(ListArray::::new_null(data_type, length)), + FixedSizeList => Box::new(FixedSizeListArray::new_null(data_type, length)), + Struct => Box::new(StructArray::new_null(data_type, length)), + Union => Box::new(UnionArray::new_null(data_type, length)), + Map => Box::new(MapArray::new_null(data_type, length)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::new_null(data_type, length)) + }) + }, + } +} + +/// Trait providing bi-directional conversion between arrow2 [`Array`] and arrow-rs [`ArrayData`] +/// +/// [`ArrayData`]: arrow_data::ArrayData +#[cfg(feature = "arrow")] +pub trait Arrow2Arrow: Array { + /// Convert this [`Array`] into [`ArrayData`] + fn to_data(&self) -> arrow_data::ArrayData; + + /// Create this [`Array`] from [`ArrayData`] + fn from_data(data: &arrow_data::ArrayData) -> Self; +} + +#[cfg(feature = "arrow")] +macro_rules! to_data_dyn { + ($array:expr, $ty:ty) => {{ + let f = |x: &$ty| x.to_data(); + general_dyn!($array, $ty, f) + }}; +} + +#[cfg(feature = "arrow")] +impl From> for arrow_array::ArrayRef { + fn from(value: Box) -> Self { + value.as_ref().into() + } +} + +#[cfg(feature = "arrow")] +impl From<&dyn Array> for arrow_array::ArrayRef { + fn from(value: &dyn Array) -> Self { + arrow_array::make_array(to_data(value)) + } +} + +#[cfg(feature = "arrow")] +impl From for Box { + fn from(value: arrow_array::ArrayRef) -> Self { + value.as_ref().into() + } +} + +#[cfg(feature = "arrow")] +impl From<&dyn arrow_array::Array> for Box { + fn from(value: &dyn arrow_array::Array) -> Self { + from_data(&value.to_data()) + } +} + +/// Convert an arrow2 [`Array`] to [`arrow_data::ArrayData`] +#[cfg(feature = "arrow")] +pub fn to_data(array: &dyn Array) -> arrow_data::ArrayData { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => to_data_dyn!(array, NullArray), + Boolean => to_data_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + to_data_dyn!(array, PrimitiveArray<$T>) + }), + Binary => to_data_dyn!(array, BinaryArray), + LargeBinary => to_data_dyn!(array, BinaryArray), + FixedSizeBinary => to_data_dyn!(array, FixedSizeBinaryArray), + Utf8 => to_data_dyn!(array, Utf8Array::), + LargeUtf8 => to_data_dyn!(array, Utf8Array::), + List => to_data_dyn!(array, ListArray::), + LargeList => to_data_dyn!(array, ListArray::), + FixedSizeList => to_data_dyn!(array, FixedSizeListArray), + Struct => to_data_dyn!(array, StructArray), + Union => to_data_dyn!(array, UnionArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + to_data_dyn!(array, DictionaryArray::<$T>) + }) + }, + Map => to_data_dyn!(array, MapArray), + } +} + +/// Convert an [`arrow_data::ArrayData`] to arrow2 [`Array`] +#[cfg(feature = "arrow")] +pub fn from_data(data: &arrow_data::ArrayData) -> Box { + use crate::datatypes::PhysicalType::*; + let data_type: DataType = data.data_type().clone().into(); + match data_type.to_physical_type() { + Null => Box::new(NullArray::from_data(data)), + Boolean => Box::new(BooleanArray::from_data(data)), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::from_data(data)) + }), + Binary => Box::new(BinaryArray::::from_data(data)), + LargeBinary => Box::new(BinaryArray::::from_data(data)), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::from_data(data)), + Utf8 => Box::new(Utf8Array::::from_data(data)), + LargeUtf8 => Box::new(Utf8Array::::from_data(data)), + List => Box::new(ListArray::::from_data(data)), + LargeList => Box::new(ListArray::::from_data(data)), + FixedSizeList => Box::new(FixedSizeListArray::from_data(data)), + Struct => Box::new(StructArray::from_data(data)), + Union => Box::new(UnionArray::from_data(data)), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::from_data(data)) + }) + }, + Map => Box::new(MapArray::from_data(data)), + } +} + +macro_rules! clone_dyn { + ($array:expr, $ty:ty) => {{ + let f = |x: &$ty| Box::new(x.clone()); + general_dyn!($array, $ty, f) + }}; +} + +// macro implementing `sliced` and `sliced_unchecked` +macro_rules! impl_sliced { + () => { + /// Returns this array sliced. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + #[inline] + #[must_use] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Returns this array sliced. + /// # Implementation + /// This function is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + self.slice_unchecked(offset, length); + self + } + }; +} + +// macro implementing `with_validity` and `set_validity` +macro_rules! impl_mut_validity { + () => { + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.validity = validity; + } + } +} + +// macro implementing `with_validity`, `set_validity` and `apply_validity` for mutable arrays +macro_rules! impl_mutable_array_mut_validity { + () => { + /// Returns this array with a new validity. + /// # Panic + /// Panics iff `validity.len() != self.len()`. + #[must_use] + #[inline] + pub fn with_validity(mut self, validity: Option) -> Self { + self.set_validity(validity); + self + } + + /// Sets the validity of this array. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[inline] + pub fn set_validity(&mut self, validity: Option) { + if matches!(&validity, Some(bitmap) if bitmap.len() != self.len()) { + panic!("validity must be equal to the array's length") + } + self.validity = validity; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + #[inline] + pub fn apply_validity MutableBitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } + + } +} + +// macro implementing `boxed` and `arced` +macro_rules! impl_into_array { + () => { + /// Boxes this array into a [`Box`]. + pub fn boxed(self) -> Box { + Box::new(self) + } + + /// Arcs this array into a [`std::sync::Arc`]. + pub fn arced(self) -> std::sync::Arc { + std::sync::Arc::new(self) + } + }; +} + +// macro implementing common methods of trait `Array` +macro_rules! impl_common_array { + () => { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } + + #[inline] + fn slice(&mut self, offset: usize, length: usize) { + self.slice(offset, length); + } + + #[inline] + unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.slice_unchecked(offset, length); + } + + #[inline] + fn to_boxed(&self) -> Box { + Box::new(self.clone()) + } + }; +} + +/// Clones a dynamic [`Array`]. +/// # Implementation +/// This operation is `O(1)` over `len`, as it amounts to increase two ref counts +/// and moving the concrete struct under a `Box`. +pub fn clone(array: &dyn Array) -> Box { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => clone_dyn!(array, NullArray), + Boolean => clone_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + clone_dyn!(array, PrimitiveArray<$T>) + }), + Binary => clone_dyn!(array, BinaryArray), + LargeBinary => clone_dyn!(array, BinaryArray), + FixedSizeBinary => clone_dyn!(array, FixedSizeBinaryArray), + Utf8 => clone_dyn!(array, Utf8Array::), + LargeUtf8 => clone_dyn!(array, Utf8Array::), + List => clone_dyn!(array, ListArray::), + LargeList => clone_dyn!(array, ListArray::), + FixedSizeList => clone_dyn!(array, FixedSizeListArray), + Struct => clone_dyn!(array, StructArray), + Union => clone_dyn!(array, UnionArray), + Map => clone_dyn!(array, MapArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + clone_dyn!(array, DictionaryArray::<$T>) + }) + }, + } +} + +// see https://users.rust-lang.org/t/generic-for-dyn-a-or-box-dyn-a-or-arc-dyn-a/69430/3 +// for details +impl<'a> AsRef<(dyn Array + 'a)> for dyn Array { + fn as_ref(&self) -> &(dyn Array + 'a) { + self + } +} + +mod binary; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod fixed_size_list; +mod list; +mod map; +mod null; +mod primitive; +mod specification; +mod struct_; +mod union; +mod utf8; + +mod equal; +mod ffi; +mod fmt; +#[doc(hidden)] +pub mod indexable; +mod iterator; + +pub mod growable; +pub mod ord; + +pub use binary::{BinaryArray, BinaryValueIter, MutableBinaryArray, MutableBinaryValuesArray}; +pub use boolean::{BooleanArray, MutableBooleanArray}; +pub use dictionary::{DictionaryArray, DictionaryKey, MutableDictionaryArray}; +pub use equal::equal; +pub use fixed_size_binary::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray}; +pub use fixed_size_list::{FixedSizeListArray, MutableFixedSizeListArray}; +pub use fmt::{get_display, get_value_display}; +pub(crate) use iterator::ArrayAccessor; +pub use iterator::ArrayValuesIter; +pub use list::{ListArray, ListValuesIter, MutableListArray}; +pub use map::MapArray; +pub use null::{MutableNullArray, NullArray}; +pub use primitive::*; +pub use struct_::{MutableStructArray, StructArray}; +pub use union::UnionArray; +pub use utf8::{MutableUtf8Array, MutableUtf8ValuesArray, Utf8Array, Utf8ValuesIter}; + +pub(crate) use self::ffi::{offset_buffers_children_dictionary, FromFfi, ToFfi}; + +/// A trait describing the ability of a struct to create itself from a iterator. +/// This is similar to [`Extend`], but accepted the creation to error. +pub trait TryExtend { + /// Fallible version of [`Extend::extend`]. + fn try_extend>(&mut self, iter: I) -> Result<()>; +} + +/// A trait describing the ability of a struct to receive new items. +pub trait TryPush { + /// Tries to push a new element. + fn try_push(&mut self, item: A) -> Result<()>; +} + +/// A trait describing the ability of a struct to receive new items. +pub trait PushUnchecked { + /// Push a new element that holds the invariants of the struct. + /// # Safety + /// The items must uphold the invariants of the struct + /// Read the specific implementation of the trait to understand what these are. + unsafe fn push_unchecked(&mut self, item: A); +} + +/// A trait describing the ability of a struct to extend from a reference of itself. +/// Specialization of [`TryExtend`]. +pub trait TryExtendFromSelf { + /// Tries to extend itself with elements from `other`, failing only on overflow. + fn try_extend_from_self(&mut self, other: &Self) -> Result<()>; +} + +/// Trait that [`BinaryArray`] and [`Utf8Array`] implement for the purposes of DRY. +/// # Safety +/// The implementer must ensure that +/// 1. `offsets.len() > 0` +/// 2. `offsets[i] >= offsets[i-1] for all i` +/// 3. `offsets[i] < values.len() for all i` +pub unsafe trait GenericBinaryArray: Array { + /// The values of the array + fn values(&self) -> &[u8]; + /// The offsets of the array + fn offsets(&self) -> &[O]; +} diff --git a/crates/nano-arrow/src/array/null.rs b/crates/nano-arrow/src/array/null.rs new file mode 100644 index 000000000000..0fb9dd6644bd --- /dev/null +++ b/crates/nano-arrow/src/array/null.rs @@ -0,0 +1,200 @@ +use std::any::Any; + +use crate::array::{Array, FromFfi, MutableArray, ToFfi}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::ffi; + +/// The concrete [`Array`] of [`DataType::Null`]. +#[derive(Clone)] +pub struct NullArray { + data_type: DataType, + length: usize, +} + +impl NullArray { + /// Returns a new [`NullArray`]. + /// # Errors + /// This function errors iff: + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn try_new(data_type: DataType, length: usize) -> Result { + if data_type.to_physical_type() != PhysicalType::Null { + return Err(Error::oos( + "NullArray can only be initialized with a DataType whose physical type is Boolean", + )); + } + + Ok(Self { data_type, length }) + } + + /// Returns a new [`NullArray`]. + /// # Panics + /// This function errors iff: + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(data_type: DataType, length: usize) -> Self { + Self::try_new(data_type, length).unwrap() + } + + /// Returns a new empty [`NullArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, 0) + } + + /// Returns a new [`NullArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new(data_type, length) + } + + impl_sliced!(); + impl_into_array!(); +} + +impl NullArray { + /// Returns a slice of the [`NullArray`]. + /// # Panic + /// This function panics iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the arrays' length" + ); + unsafe { self.slice_unchecked(offset, length) }; + } + + /// Returns a slice of the [`NullArray`]. + /// # Safety + /// The caller must ensure that `offset + length < self.len()`. + pub unsafe fn slice_unchecked(&mut self, _offset: usize, length: usize) { + self.length = length; + } + + #[inline] + fn len(&self) -> usize { + self.length + } +} + +impl Array for NullArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + None + } + + fn with_validity(&self, _: Option) -> Box { + panic!("cannot set validity of a null array") + } +} + +#[derive(Debug)] +/// A distinct type to disambiguate +/// clashing methods +pub struct MutableNullArray { + inner: NullArray, +} + +impl MutableNullArray { + /// Returns a new [`MutableNullArray`]. + /// # Panics + /// This function errors iff: + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Null`]. + pub fn new(data_type: DataType, length: usize) -> Self { + let inner = NullArray::try_new(data_type, length).unwrap(); + Self { inner } + } +} + +impl From for NullArray { + fn from(value: MutableNullArray) -> Self { + value.inner + } +} + +impl MutableArray for MutableNullArray { + fn data_type(&self) -> &DataType { + &DataType::Null + } + + fn len(&self) -> usize { + self.inner.length + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + self.inner.clone().boxed() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn Any { + self + } + + fn push_null(&mut self) { + self.inner.length += 1; + } + + fn reserve(&mut self, _additional: usize) { + // no-op + } + + fn shrink_to_fit(&mut self) { + // no-op + } +} + +impl std::fmt::Debug for NullArray { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NullArray({})", self.len()) + } +} + +unsafe impl ToFfi for NullArray { + fn buffers(&self) -> Vec> { + // `None` is technically not required by the specification, but older C++ implementations require it, so leaving + // it here for backward compatibility + vec![None] + } + + fn offset(&self) -> Option { + Some(0) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for NullArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + Self::try_new(data_type, array.array().len()) + } +} + +#[cfg(feature = "arrow")] +mod arrow { + use arrow_data::{ArrayData, ArrayDataBuilder}; + + use super::*; + impl NullArray { + /// Convert this array into [`arrow_data::ArrayData`] + pub fn to_data(&self) -> ArrayData { + let builder = ArrayDataBuilder::new(arrow_schema::DataType::Null).len(self.len()); + + // Safety: safe by construction + unsafe { builder.build_unchecked() } + } + + /// Create this array from [`ArrayData`] + pub fn from_data(data: &ArrayData) -> Self { + Self::new(DataType::Null, data.len()) + } + } +} diff --git a/crates/nano-arrow/src/array/ord.rs b/crates/nano-arrow/src/array/ord.rs new file mode 100644 index 000000000000..914eff2639f8 --- /dev/null +++ b/crates/nano-arrow/src/array/ord.rs @@ -0,0 +1,245 @@ +//! Contains functions and function factories to order values within arrays. + +use std::cmp::Ordering; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::offset::Offset; +use crate::types::NativeType; + +/// Compare the values at two arbitrary indices in two arrays. +pub type DynComparator = Box Ordering + Send + Sync>; + +/// implements comparison using IEEE 754 total ordering for f32 +// Original implementation from https://doc.rust-lang.org/std/primitive.f32.html#method.total_cmp +// TODO to change to use std when it becomes stable +#[inline] +pub fn total_cmp_f32(l: &f32, r: &f32) -> std::cmp::Ordering { + let mut left = l.to_bits() as i32; + let mut right = r.to_bits() as i32; + + left ^= (((left >> 31) as u32) >> 1) as i32; + right ^= (((right >> 31) as u32) >> 1) as i32; + + left.cmp(&right) +} + +/// implements comparison using IEEE 754 total ordering for f64 +// Original implementation from https://doc.rust-lang.org/std/primitive.f64.html#method.total_cmp +// TODO to change to use std when it becomes stable +#[inline] +pub fn total_cmp_f64(l: &f64, r: &f64) -> std::cmp::Ordering { + let mut left = l.to_bits() as i64; + let mut right = r.to_bits() as i64; + + left ^= (((left >> 63) as u64) >> 1) as i64; + right ^= (((right >> 63) as u64) >> 1) as i64; + + left.cmp(&right) +} + +/// Total order of all native types whose Rust implementation +/// that support total order. +#[inline] +pub fn total_cmp(l: &T, r: &T) -> std::cmp::Ordering +where + T: NativeType + Ord, +{ + l.cmp(r) +} + +fn compare_primitives(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| total_cmp(&left.value(i), &right.value(j))) +} + +fn compare_boolean(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(&right.value(j))) +} + +fn compare_f32(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| total_cmp_f32(&left.value(i), &right.value(j))) +} + +fn compare_f64(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| total_cmp_f64(&left.value(i), &right.value(j))) +} + +fn compare_string(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(right.value(j))) +} + +fn compare_binary(left: &dyn Array, right: &dyn Array) -> DynComparator { + let left = left + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + let right = right + .as_any() + .downcast_ref::>() + .unwrap() + .clone(); + Box::new(move |i, j| left.value(i).cmp(right.value(j))) +} + +fn compare_dict(left: &DictionaryArray, right: &DictionaryArray) -> Result +where + K: DictionaryKey, +{ + let left_keys = left.keys().values().clone(); + let right_keys = right.keys().values().clone(); + + let comparator = build_compare(left.values().as_ref(), right.values().as_ref())?; + + Ok(Box::new(move |i: usize, j: usize| { + // safety: all dictionaries keys are guaranteed to be castable to usize + let key_left = unsafe { left_keys[i].as_usize() }; + let key_right = unsafe { right_keys[j].as_usize() }; + (comparator)(key_left, key_right) + })) +} + +macro_rules! dyn_dict { + ($key:ty, $lhs:expr, $rhs:expr) => {{ + let lhs = $lhs.as_any().downcast_ref().unwrap(); + let rhs = $rhs.as_any().downcast_ref().unwrap(); + compare_dict::<$key>(lhs, rhs)? + }}; +} + +/// returns a comparison function that compares values at two different slots +/// between two [`Array`]. +/// # Example +/// ``` +/// use arrow2::array::{ord::build_compare, PrimitiveArray}; +/// +/// # fn main() -> arrow2::error::Result<()> { +/// let array1 = PrimitiveArray::from_slice([1, 2]); +/// let array2 = PrimitiveArray::from_slice([3, 4]); +/// +/// let cmp = build_compare(&array1, &array2)?; +/// +/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2) +/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1)); +/// # Ok(()) +/// # } +/// ``` +/// # Error +/// The arrays' [`DataType`] must be equal and the types must have a natural order. +// This is a factory of comparisons. +pub fn build_compare(left: &dyn Array, right: &dyn Array) -> Result { + use DataType::*; + use IntervalUnit::*; + use TimeUnit::*; + Ok(match (left.data_type(), right.data_type()) { + (a, b) if a != b => { + return Err(Error::InvalidArgumentError( + "Can't compare arrays of different types".to_string(), + )); + }, + (Boolean, Boolean) => compare_boolean(left, right), + (UInt8, UInt8) => compare_primitives::(left, right), + (UInt16, UInt16) => compare_primitives::(left, right), + (UInt32, UInt32) => compare_primitives::(left, right), + (UInt64, UInt64) => compare_primitives::(left, right), + (Int8, Int8) => compare_primitives::(left, right), + (Int16, Int16) => compare_primitives::(left, right), + (Int32, Int32) + | (Date32, Date32) + | (Time32(Second), Time32(Second)) + | (Time32(Millisecond), Time32(Millisecond)) + | (Interval(YearMonth), Interval(YearMonth)) => compare_primitives::(left, right), + (Int64, Int64) + | (Date64, Date64) + | (Time64(Microsecond), Time64(Microsecond)) + | (Time64(Nanosecond), Time64(Nanosecond)) + | (Timestamp(Second, None), Timestamp(Second, None)) + | (Timestamp(Millisecond, None), Timestamp(Millisecond, None)) + | (Timestamp(Microsecond, None), Timestamp(Microsecond, None)) + | (Timestamp(Nanosecond, None), Timestamp(Nanosecond, None)) + | (Duration(Second), Duration(Second)) + | (Duration(Millisecond), Duration(Millisecond)) + | (Duration(Microsecond), Duration(Microsecond)) + | (Duration(Nanosecond), Duration(Nanosecond)) => compare_primitives::(left, right), + (Float32, Float32) => compare_f32(left, right), + (Float64, Float64) => compare_f64(left, right), + (Decimal(_, _), Decimal(_, _)) => compare_primitives::(left, right), + (Utf8, Utf8) => compare_string::(left, right), + (LargeUtf8, LargeUtf8) => compare_string::(left, right), + (Binary, Binary) => compare_binary::(left, right), + (LargeBinary, LargeBinary) => compare_binary::(left, right), + (Dictionary(key_type_lhs, ..), Dictionary(key_type_rhs, ..)) => { + match (key_type_lhs, key_type_rhs) { + (IntegerType::UInt8, IntegerType::UInt8) => dyn_dict!(u8, left, right), + (IntegerType::UInt16, IntegerType::UInt16) => dyn_dict!(u16, left, right), + (IntegerType::UInt32, IntegerType::UInt32) => dyn_dict!(u32, left, right), + (IntegerType::UInt64, IntegerType::UInt64) => dyn_dict!(u64, left, right), + (IntegerType::Int8, IntegerType::Int8) => dyn_dict!(i8, left, right), + (IntegerType::Int16, IntegerType::Int16) => dyn_dict!(i16, left, right), + (IntegerType::Int32, IntegerType::Int32) => dyn_dict!(i32, left, right), + (IntegerType::Int64, IntegerType::Int64) => dyn_dict!(i64, left, right), + (lhs, _) => { + return Err(Error::InvalidArgumentError(format!( + "Dictionaries do not support keys of type {lhs:?}" + ))) + }, + } + }, + (lhs, _) => { + return Err(Error::InvalidArgumentError(format!( + "The data type type {lhs:?} has no natural order" + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/array/physical_binary.rs b/crates/nano-arrow/src/array/physical_binary.rs new file mode 100644 index 000000000000..694e61a7ea63 --- /dev/null +++ b/crates/nano-arrow/src/array/physical_binary.rs @@ -0,0 +1,230 @@ +use crate::bitmap::MutableBitmap; +use crate::offset::{Offset, Offsets}; + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +#[allow(clippy::type_complexity)] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(Option, Offsets, Vec), E> +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut offsets = Vec::::with_capacity(len + 1); + let mut values = Vec::::new(); + + let mut length = O::default(); + let mut dst = offsets.as_mut_ptr(); + std::ptr::write(dst, length); + dst = dst.add(1); + for item in iterator { + if let Some(item) = item? { + null.push_unchecked(true); + let s = item.as_ref(); + length += O::from_usize(s.len()).unwrap(); + values.extend_from_slice(s); + } else { + null.push_unchecked(false); + }; + + std::ptr::write(dst, length); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(offsets.as_ptr()) as usize, + len + 1, + "Trusted iterator length was not accurately reported" + ); + offsets.set_len(len + 1); + + Ok((null.into(), Offsets::new_unchecked(offsets), values)) +} + +/// Creates [`MutableBitmap`] and two [`Vec`]s from an iterator of `Option`. +/// The first buffer corresponds to a offset buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip( + iterator: I, +) -> (Option, Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut offsets = Offsets::::with_capacity(len); + let mut values = Vec::::new(); + let mut validity = MutableBitmap::new(); + + extend_from_trusted_len_iter(&mut offsets, &mut values, &mut validity, iterator); + + let validity = if validity.unset_bits() > 0 { + Some(validity) + } else { + None + }; + + (validity, offsets, values) +} + +/// Creates two [`Buffer`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is [`TrustedLen`]. +#[inline] +pub(crate) unsafe fn trusted_len_values_iter(iterator: I) -> (Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut offsets = Offsets::::with_capacity(len); + let mut values = Vec::::new(); + + extend_from_trusted_len_values_iter(&mut offsets, &mut values, iterator); + + (offsets, values) +} + +// Populates `offsets` and `values` [`Vec`]s with information extracted +// from the incoming `iterator`. +// # Safety +// The caller must ensure the `iterator` is [`TrustedLen`] +#[inline] +pub(crate) unsafe fn extend_from_trusted_len_values_iter( + offsets: &mut Offsets, + values: &mut Vec, + iterator: I, +) where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let lengths = iterator.map(|item| { + let s = item.as_ref(); + // Push new entries for both `values` and `offsets` buffer + values.extend_from_slice(s); + s.len() + }); + offsets.try_extend_from_lengths(lengths).unwrap(); +} + +// Populates `offsets` and `values` [`Vec`]s with information extracted +// from the incoming `iterator`. +// the return value indicates how many items were added. +#[inline] +pub(crate) fn extend_from_values_iter( + offsets: &mut Offsets, + values: &mut Vec, + iterator: I, +) -> usize +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (size_hint, _) = iterator.size_hint(); + + offsets.reserve(size_hint); + + let start_index = offsets.len_proxy(); + + for item in iterator { + let bytes = item.as_ref(); + values.extend_from_slice(bytes); + offsets.try_push_usize(bytes.len()).unwrap(); + } + offsets.len_proxy() - start_index +} + +// Populates `offsets`, `values`, and `validity` [`Vec`]s with +// information extracted from the incoming `iterator`. +// +// # Safety +// The caller must ensure that `iterator` is [`TrustedLen`] +#[inline] +pub(crate) unsafe fn extend_from_trusted_len_iter( + offsets: &mut Offsets, + values: &mut Vec, + validity: &mut MutableBitmap, + iterator: I, +) where + O: Offset, + P: AsRef<[u8]>, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("extend_from_trusted_len_iter requires an upper limit"); + + offsets.reserve(additional); + validity.reserve(additional); + + let lengths = iterator.map(|item| { + if let Some(item) = item { + let bytes = item.as_ref(); + values.extend_from_slice(bytes); + validity.push_unchecked(true); + bytes.len() + } else { + validity.push_unchecked(false); + 0 + } + }); + offsets.try_extend_from_lengths(lengths).unwrap(); +} + +/// Creates two [`Vec`]s from an iterator of `&[u8]`. +/// The first buffer corresponds to a offset buffer, the second to a values buffer. +#[inline] +pub(crate) fn values_iter(iterator: I) -> (Offsets, Vec) +where + O: Offset, + P: AsRef<[u8]>, + I: Iterator, +{ + let (lower, _) = iterator.size_hint(); + + let mut offsets = Offsets::::with_capacity(lower); + let mut values = Vec::::new(); + + for item in iterator { + let s = item.as_ref(); + values.extend_from_slice(s); + offsets.try_push_usize(s.len()).unwrap(); + } + (offsets, values) +} + +/// Extends `validity` with all items from `other` +pub(crate) fn extend_validity( + length: usize, + validity: &mut Option, + other: &Option, +) { + if let Some(other) = other { + if let Some(validity) = validity { + let slice = other.as_slice(); + // safety: invariant offset + length <= slice.len() + unsafe { validity.extend_from_slice_unchecked(slice, 0, other.len()) } + } else { + let mut new_validity = MutableBitmap::from_len_set(length); + new_validity.extend_from_slice(other.as_slice(), 0, other.len()); + *validity = Some(new_validity); + } + } +} diff --git a/crates/nano-arrow/src/array/primitive/data.rs b/crates/nano-arrow/src/array/primitive/data.rs new file mode 100644 index 000000000000..d4879f796812 --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/data.rs @@ -0,0 +1,33 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, PrimitiveArray}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::types::NativeType; + +impl Arrow2Arrow for PrimitiveArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .buffers(vec![self.values.clone().into()]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + + let mut values: Buffer = data.buffers()[0].clone().into(); + values.slice(data.offset(), data.len()); + + Self { + data_type, + values, + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/primitive/ffi.rs b/crates/nano-arrow/src/array/primitive/ffi.rs new file mode 100644 index 000000000000..c74c157f750f --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/ffi.rs @@ -0,0 +1,56 @@ +use super::PrimitiveArray; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::types::NativeType; + +unsafe impl ToFfi for PrimitiveArray { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.values.offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.values.offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + values: self.values.clone(), + } + } +} + +impl FromFfi for PrimitiveArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let values = unsafe { array.buffer::(1) }?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/primitive/fmt.rs b/crates/nano-arrow/src/array/primitive/fmt.rs new file mode 100644 index 000000000000..3743a16a188e --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/fmt.rs @@ -0,0 +1,149 @@ +#![allow(clippy::redundant_closure_call)] +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::PrimitiveArray; +use crate::array::fmt::write_vec; +use crate::array::Array; +use crate::datatypes::{IntervalUnit, TimeUnit}; +use crate::temporal_conversions; +use crate::types::{days_ms, i256, months_days_ns, NativeType}; + +macro_rules! dyn_primitive { + ($array:expr, $ty:ty, $expr:expr) => {{ + let array = ($array as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(move |f, index| write!(f, "{}", $expr(array.value(index)))) + }}; +} + +pub fn get_write_value<'a, T: NativeType, F: Write>( + array: &'a PrimitiveArray, +) -> Box Result + 'a> { + use crate::datatypes::DataType::*; + match array.data_type().to_logical_type() { + Int8 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int16 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Int64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt8 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt16 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + UInt64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Float16 => unreachable!(), + Float32 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Float64 => Box::new(|f, index| write!(f, "{}", array.value(index))), + Date32 => { + dyn_primitive!(array, i32, temporal_conversions::date32_to_date) + }, + Date64 => { + dyn_primitive!(array, i64, temporal_conversions::date64_to_date) + }, + Time32(TimeUnit::Second) => { + dyn_primitive!(array, i32, temporal_conversions::time32s_to_time) + }, + Time32(TimeUnit::Millisecond) => { + dyn_primitive!(array, i32, temporal_conversions::time32ms_to_time) + }, + Time32(_) => unreachable!(), // remaining are not valid + Time64(TimeUnit::Microsecond) => { + dyn_primitive!(array, i64, temporal_conversions::time64us_to_time) + }, + Time64(TimeUnit::Nanosecond) => { + dyn_primitive!(array, i64, temporal_conversions::time64ns_to_time) + }, + Time64(_) => unreachable!(), // remaining are not valid + Timestamp(time_unit, tz) => { + if let Some(tz) = tz { + let timezone = temporal_conversions::parse_offset(tz); + match timezone { + Ok(timezone) => { + dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_datetime(time, *time_unit, &timezone) + }) + }, + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(tz); + match timezone { + Ok(timezone) => dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_datetime( + time, *time_unit, &timezone, + ) + }), + Err(_) => { + let tz = tz.clone(); + Box::new(move |f, index| { + write!(f, "{} ({})", array.value(index), tz) + }) + }, + } + }, + #[cfg(not(feature = "chrono-tz"))] + _ => { + let tz = tz.clone(); + Box::new(move |f, index| write!(f, "{} ({})", array.value(index), tz)) + }, + } + } else { + dyn_primitive!(array, i64, |time| { + temporal_conversions::timestamp_to_naive_datetime(time, *time_unit) + }) + } + }, + Interval(IntervalUnit::YearMonth) => { + dyn_primitive!(array, i32, |x| format!("{x}m")) + }, + Interval(IntervalUnit::DayTime) => { + dyn_primitive!(array, days_ms, |x: days_ms| format!( + "{}d{}ms", + x.days(), + x.milliseconds() + )) + }, + Interval(IntervalUnit::MonthDayNano) => { + dyn_primitive!(array, months_days_ns, |x: months_days_ns| format!( + "{}m{}d{}ns", + x.months(), + x.days(), + x.ns() + )) + }, + Duration(TimeUnit::Second) => dyn_primitive!(array, i64, |x| format!("{x}s")), + Duration(TimeUnit::Millisecond) => dyn_primitive!(array, i64, |x| format!("{x}ms")), + Duration(TimeUnit::Microsecond) => dyn_primitive!(array, i64, |x| format!("{x}us")), + Duration(TimeUnit::Nanosecond) => dyn_primitive!(array, i64, |x| format!("{x}ns")), + Decimal(_, scale) => { + // The number 999.99 has a precision of 5 and scale of 2 + let scale = *scale as u32; + let factor = 10i128.pow(scale); + let display = move |x: i128| { + let base = x / factor; + let decimals = (x - base * factor).abs(); + format!("{base}.{decimals}") + }; + dyn_primitive!(array, i128, display) + }, + Decimal256(_, scale) => { + let scale = *scale as u32; + let factor = (ethnum::I256::ONE * 10).pow(scale); + let display = move |x: i256| { + let base = x.0 / factor; + let decimals = (x.0 - base * factor).abs(); + format!("{base}.{decimals}") + }; + dyn_primitive!(array, i256, display) + }, + _ => unreachable!(), + } +} + +impl Debug for PrimitiveArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = get_write_value(self); + + write!(f, "{:?}", self.data_type())?; + write_vec(f, &*writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/primitive/from_natural.rs b/crates/nano-arrow/src/array/primitive/from_natural.rs new file mode 100644 index 000000000000..0530c748af7e --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/from_natural.rs @@ -0,0 +1,16 @@ +use std::iter::FromIterator; + +use super::{MutablePrimitiveArray, PrimitiveArray}; +use crate::types::NativeType; + +impl]>> From

for PrimitiveArray { + fn from(slice: P) -> Self { + MutablePrimitiveArray::::from(slice).into() + } +} + +impl>> FromIterator for PrimitiveArray { + fn from_iter>(iter: I) -> Self { + MutablePrimitiveArray::::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/primitive/iterator.rs b/crates/nano-arrow/src/array/primitive/iterator.rs new file mode 100644 index 000000000000..9433979dad84 --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/iterator.rs @@ -0,0 +1,47 @@ +use super::{MutablePrimitiveArray, PrimitiveArray}; +use crate::array::MutableArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::IntoIter as BitmapIntoIter; +use crate::buffer::IntoIter; +use crate::types::NativeType; + +impl IntoIterator for PrimitiveArray { + type Item = Option; + type IntoIter = ZipValidity, BitmapIntoIter>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + let (_, values, validity) = self.into_inner(); + let values = values.into_iter(); + let validity = + validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.into_iter())); + ZipValidity::new(values, validity) + } +} + +impl<'a, T: NativeType> IntoIterator for &'a PrimitiveArray { + type Item = Option<&'a T>; + type IntoIter = ZipValidity<&'a T, std::slice::Iter<'a, T>, BitmapIter<'a>>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T: NativeType> MutablePrimitiveArray { + /// Returns an iterator over `Option` + #[inline] + pub fn iter(&'a self) -> ZipValidity<&'a T, std::slice::Iter<'a, T>, BitmapIter<'a>> { + ZipValidity::new( + self.values().iter(), + self.validity().as_ref().map(|x| x.iter()), + ) + } + + /// Returns an iterator of `T` + #[inline] + pub fn values_iter(&'a self) -> std::slice::Iter<'a, T> { + self.values().iter() + } +} diff --git a/crates/nano-arrow/src/array/primitive/mod.rs b/crates/nano-arrow/src/array/primitive/mod.rs new file mode 100644 index 000000000000..a3f80a581210 --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/mod.rs @@ -0,0 +1,510 @@ +use either::Either; + +use super::Array; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::*; +use crate::error::Error; +use crate::trusted_len::TrustedLen; +use crate::types::{days_ms, f16, i256, months_days_ns, NativeType}; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod from_natural; +mod iterator; +pub use iterator::*; +mod mutable; +pub use mutable::*; + +/// A [`PrimitiveArray`] is Arrow's semantically equivalent of an immutable `Vec>` where +/// T is [`NativeType`] (e.g. [`i32`]). It implements [`Array`]. +/// +/// One way to think about a [`PrimitiveArray`] is `(DataType, Arc>, Option>>)` +/// where: +/// * the first item is the array's logical type +/// * the second is the immutable values +/// * the third is the immutable validity (whether a value is null or not as a bitmap). +/// +/// The size of this struct is `O(1)`, as all data is stored behind an [`std::sync::Arc`]. +/// # Example +/// ``` +/// use arrow2::array::PrimitiveArray; +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// +/// let array = PrimitiveArray::from([Some(1i32), None, Some(10)]); +/// assert_eq!(array.value(0), 1); +/// assert_eq!(array.iter().collect::>(), vec![Some(&1i32), None, Some(&10)]); +/// assert_eq!(array.values_iter().copied().collect::>(), vec![1, 0, 10]); +/// // the underlying representation +/// assert_eq!(array.values(), &Buffer::from(vec![1i32, 0, 10])); +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// +/// ``` +#[derive(Clone)] +pub struct PrimitiveArray { + data_type: DataType, + values: Buffer, + validity: Option, +} + +pub(super) fn check( + data_type: &DataType, + values: &[T], + validity_len: Option, +) -> Result<(), Error> { + if validity_len.map_or(false, |len| len != values.len()) { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != PhysicalType::Primitive(T::PRIMITIVE) { + return Err(Error::oos( + "PrimitiveArray can only be initialized with a DataType whose physical type is Primitive", + )); + } + Ok(()) +} + +impl PrimitiveArray { + /// The canonical method to create a [`PrimitiveArray`] out of its internal components. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + pub fn try_new( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|v| v.len()))?; + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Returns a new [`PrimitiveArray`] with a different logical type. + /// + /// This function is useful to assign a different [`DataType`] to the array. + /// Used to change the arrays' logical type (see example). + /// # Example + /// ``` + /// use arrow2::array::Int32Array; + /// use arrow2::datatypes::DataType; + /// + /// let array = Int32Array::from(&[Some(1), None, Some(2)]).to(DataType::Date32); + /// assert_eq!( + /// format!("{:?}", array), + /// "Date32[1970-01-02, None, 1970-01-03]" + /// ); + /// ``` + /// # Panics + /// Panics iff the `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive(T::PRIMITIVE)`] + #[inline] + #[must_use] + pub fn to(self, data_type: DataType) -> Self { + check( + &data_type, + &self.values, + self.validity.as_ref().map(|v| v.len()), + ) + .unwrap(); + Self { + data_type, + values: self.values, + validity: self.validity, + } + } + + /// Creates a (non-null) [`PrimitiveArray`] from a vector of values. + /// This function is `O(1)`. + /// # Examples + /// ``` + /// use arrow2::array::PrimitiveArray; + /// + /// let array = PrimitiveArray::from_vec(vec![1, 2, 3]); + /// assert_eq!(format!("{:?}", array), "Int32[1, 2, 3]"); + /// ``` + pub fn from_vec(values: Vec) -> Self { + Self::new(T::PRIMITIVE.into(), values.into(), None) + } + + /// Returns an iterator over the values and validity, `Option<&T>`. + #[inline] + pub fn iter(&self) -> ZipValidity<&T, std::slice::Iter, BitmapIter> { + ZipValidity::new_with_validity(self.values().iter(), self.validity()) + } + + /// Returns an iterator of the values, `&T`, ignoring the arrays' validity. + #[inline] + pub fn values_iter(&self) -> std::slice::Iter { + self.values().iter() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// The values [`Buffer`]. + /// Values on null slots are undetermined (they can be anything). + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the arrays' [`DataType`]. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the value at slot `i`. + /// + /// Equivalent to `self.values()[i]`. The value of a null slot is undetermined (it can be anything). + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> T { + self.values[i] + } + + /// Returns the value at index `i`. + /// The value on null slots is undetermined (it can be anything). + /// # Safety + /// Caller must be sure that `i < self.len()` + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> T { + *self.values.get_unchecked(i) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Slices this [`PrimitiveArray`] by an offset and length. + /// # Implementation + /// This operation is `O(1)`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "offset + length may not exceed length of array" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`PrimitiveArray`] by an offset and length. + /// # Implementation + /// This operation is `O(1)`. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.values.slice_unchecked(offset, length); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns this [`PrimitiveArray`] with new values. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + #[must_use] + pub fn with_values(mut self, values: Buffer) -> Self { + self.set_values(values); + self + } + + /// Update the values of this [`PrimitiveArray`]. + /// # Panics + /// This function panics iff `values.len() != self.len()`. + pub fn set_values(&mut self, values: Buffer) { + assert_eq!( + values.len(), + self.len(), + "values' length must be equal to this arrays' length" + ); + self.values = values; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity Bitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } + + /// Returns an option of a mutable reference to the values of this [`PrimitiveArray`]. + pub fn get_mut_values(&mut self) -> Option<&mut [T]> { + self.values.get_mut_slice() + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, Buffer, Option) { + let Self { + data_type, + values, + validity, + } = self; + (data_type, values, validity) + } + + /// Creates a `[PrimitiveArray]` from its internal representation. + /// This is the inverted from `[PrimitiveArray::into_inner]` + pub fn from_inner( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|v| v.len()))?; + Ok(unsafe { Self::from_inner_unchecked(data_type, values, validity) }) + } + + /// Creates a `[PrimitiveArray]` from its internal representation. + /// This is the inverted from `[PrimitiveArray::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + data_type: DataType, + values: Buffer, + validity: Option, + ) -> Self { + Self { + data_type, + values, + validity, + } + } + + /// Try to convert this [`PrimitiveArray`] to a [`MutablePrimitiveArray`] via copy-on-write semantics. + /// + /// A [`PrimitiveArray`] is backed by a [`Buffer`] and [`Bitmap`] which are essentially `Arc>`. + /// This function returns a [`MutablePrimitiveArray`] (via [`std::sync::Arc::get_mut`]) iff both values + /// and validity have not been cloned / are unique references to their underlying vectors. + /// + /// This function is primarily used to re-use memory regions. + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + Left(bitmap) => Left(PrimitiveArray::new( + self.data_type, + self.values, + Some(bitmap), + )), + Right(mutable_bitmap) => match self.values.into_mut() { + Right(values) => Right( + MutablePrimitiveArray::try_new( + self.data_type, + values, + Some(mutable_bitmap), + ) + .unwrap(), + ), + Left(values) => Left(PrimitiveArray::new( + self.data_type, + values, + Some(mutable_bitmap.into()), + )), + }, + } + } else { + match self.values.into_mut() { + Right(values) => { + Right(MutablePrimitiveArray::try_new(self.data_type, values, None).unwrap()) + }, + Left(values) => Left(PrimitiveArray::new(self.data_type, values, None)), + } + } + } + + /// Returns a new empty (zero-length) [`PrimitiveArray`]. + pub fn new_empty(data_type: DataType) -> Self { + Self::new(data_type, Buffer::new(), None) + } + + /// Returns a new [`PrimitiveArray`] where all slots are null / `None`. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new( + data_type, + vec![T::default(); length].into(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Creates a (non-null) [`PrimitiveArray`] from an iterator of values. + /// # Implementation + /// This does not assume that the iterator has a known length. + pub fn from_values>(iter: I) -> Self { + Self::new(T::PRIMITIVE.into(), Vec::::from_iter(iter).into(), None) + } + + /// Creates a (non-null) [`PrimitiveArray`] from a slice of values. + /// # Implementation + /// This is essentially a memcopy and is thus `O(N)` + pub fn from_slice>(slice: P) -> Self { + Self::new( + T::PRIMITIVE.into(), + Vec::::from(slice.as_ref()).into(), + None, + ) + } + + /// Creates a (non-null) [`PrimitiveArray`] from a [`TrustedLen`] of values. + /// # Implementation + /// This does not assume that the iterator has a known length. + pub fn from_trusted_len_values_iter>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_values_iter(iter).into() + } + + /// Creates a new [`PrimitiveArray`] from an iterator over values + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_values_iter_unchecked>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_values_iter_unchecked(iter).into() + } + + /// Creates a [`PrimitiveArray`] from a [`TrustedLen`] of optional values. + pub fn from_trusted_len_iter>>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_iter(iter).into() + } + + /// Creates a [`PrimitiveArray`] from an iterator of optional values. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_iter_unchecked>>(iter: I) -> Self { + MutablePrimitiveArray::::from_trusted_len_iter_unchecked(iter).into() + } + + /// Alias for `Self::try_new(..).unwrap()`. + /// # Panics + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`PhysicalType`] is not equal to [`PhysicalType::Primitive`]. + pub fn new(data_type: DataType, values: Buffer, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } +} + +impl Array for PrimitiveArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +/// A type definition [`PrimitiveArray`] for `i8` +pub type Int8Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i16` +pub type Int16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i32` +pub type Int32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i64` +pub type Int64Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i128` +pub type Int128Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `i256` +pub type Int256Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for [`days_ms`] +pub type DaysMsArray = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for [`months_days_ns`] +pub type MonthsDaysNsArray = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f16` +pub type Float16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f32` +pub type Float32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `f64` +pub type Float64Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u8` +pub type UInt8Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u16` +pub type UInt16Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u32` +pub type UInt32Array = PrimitiveArray; +/// A type definition [`PrimitiveArray`] for `u64` +pub type UInt64Array = PrimitiveArray; + +/// A type definition [`MutablePrimitiveArray`] for `i8` +pub type Int8Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i16` +pub type Int16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i32` +pub type Int32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i64` +pub type Int64Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i128` +pub type Int128Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `i256` +pub type Int256Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for [`days_ms`] +pub type DaysMsVec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for [`months_days_ns`] +pub type MonthsDaysNsVec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f16` +pub type Float16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f32` +pub type Float32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `f64` +pub type Float64Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u8` +pub type UInt8Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u16` +pub type UInt16Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u32` +pub type UInt32Vec = MutablePrimitiveArray; +/// A type definition [`MutablePrimitiveArray`] for `u64` +pub type UInt64Vec = MutablePrimitiveArray; + +impl Default for PrimitiveArray { + fn default() -> Self { + PrimitiveArray::new(T::PRIMITIVE.into(), Default::default(), None) + } +} diff --git a/crates/nano-arrow/src/array/primitive/mutable.rs b/crates/nano-arrow/src/array/primitive/mutable.rs new file mode 100644 index 000000000000..fc61b2e74884 --- /dev/null +++ b/crates/nano-arrow/src/array/primitive/mutable.rs @@ -0,0 +1,665 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{check, PrimitiveArray}; +use crate::array::physical_binary::extend_validity; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +/// The Arrow's equivalent to `Vec>` where `T` is byte-size (e.g. `i32`). +/// Converting a [`MutablePrimitiveArray`] into a [`PrimitiveArray`] is `O(1)`. +#[derive(Debug, Clone)] +pub struct MutablePrimitiveArray { + data_type: DataType, + values: Vec, + validity: Option, +} + +impl From> for PrimitiveArray { + fn from(other: MutablePrimitiveArray) -> Self { + let validity = other.validity.and_then(|x| { + let bitmap: Bitmap = x.into(); + if bitmap.unset_bits() == 0 { + None + } else { + Some(bitmap) + } + }); + + PrimitiveArray::::new(other.data_type, other.values.into(), validity) + } +} + +impl]>> From

for MutablePrimitiveArray { + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } +} + +impl MutablePrimitiveArray { + /// Creates a new empty [`MutablePrimitiveArray`]. + pub fn new() -> Self { + Self::with_capacity(0) + } + + /// Creates a new [`MutablePrimitiveArray`] with a capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacity_from(capacity, T::PRIMITIVE.into()) + } + + /// The canonical method to create a [`MutablePrimitiveArray`] out of its internal components. + /// # Implementation + /// This function is `O(1)`. + /// + /// # Errors + /// This function errors iff: + /// * The validity is not `None` and its length is different from `values`'s length + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to [`crate::datatypes::PhysicalType::Primitive(T::PRIMITIVE)`] + pub fn try_new( + data_type: DataType, + values: Vec, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|x| x.len()))?; + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Extract the low-end APIs from the [`MutablePrimitiveArray`]. + pub fn into_inner(self) -> (DataType, Vec, Option) { + (self.data_type, self.values, self.validity) + } + + /// Applies a function `f` to the values of this array, cloning the values + /// iff they are being shared with others + /// + /// This is an API to use clone-on-write + /// # Implementation + /// This function is `O(f)` if the data is not being shared, and `O(N) + O(f)` + /// if it is being shared (since it results in a `O(N)` memcopy). + /// # Panics + /// This function panics iff `f` panics + pub fn apply_values(&mut self, f: F) { + f(&mut self.values); + } +} + +impl Default for MutablePrimitiveArray { + fn default() -> Self { + Self::new() + } +} + +impl From for MutablePrimitiveArray { + fn from(data_type: DataType) -> Self { + assert!(data_type.to_physical_type().eq_primitive(T::PRIMITIVE)); + Self { + data_type, + values: Vec::::new(), + validity: None, + } + } +} + +impl MutablePrimitiveArray { + /// Creates a new [`MutablePrimitiveArray`] from a capacity and [`DataType`]. + pub fn with_capacity_from(capacity: usize, data_type: DataType) -> Self { + assert!(data_type.to_physical_type().eq_primitive(T::PRIMITIVE)); + Self { + data_type, + values: Vec::::with_capacity(capacity), + validity: None, + } + } + + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + self.values.reserve(additional); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Adds a new value to the array. + #[inline] + pub fn push(&mut self, value: Option) { + match value { + Some(value) => { + self.values.push(value); + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(T::default()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => { + self.init_validity(); + }, + } + }, + } + } + + /// Pop a value from the array. + /// Note if the values is empty, this method will return None. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| value)) + .unwrap_or_else(|| Some(value)) + } + + /// Extends the [`MutablePrimitiveArray`] with a constant + #[inline] + pub fn extend_constant(&mut self, additional: usize, value: Option) { + if let Some(value) = value { + self.values.resize(self.values.len() + additional, value); + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, true) + } + } else { + if let Some(validity) = &mut self.validity { + validity.extend_constant(additional, false) + } else { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.extend_constant(additional, false); + self.validity = Some(validity) + } + self.values + .resize(self.values.len() + additional, T::default()); + } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: std::borrow::Borrow, + I: Iterator>, + { + if let Some(validity) = self.validity.as_mut() { + extend_trusted_len_unzip(iterator, validity, &mut self.values) + } else { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + extend_trusted_len_unzip(iterator, &mut validity, &mut self.values); + self.validity = Some(validity); + } + } + /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len` which accepts in iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + I: TrustedLen, + { + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutablePrimitiveArray`] from an iterator of values of trusted len. + /// This differs from `extend_trusted_len_unchecked` which accepts in iterator of optional values. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + I: Iterator, + { + self.values.extend(iterator); + self.update_all_valid(); + } + + #[inline] + /// Extends the [`MutablePrimitiveArray`] from a slice + pub fn extend_from_slice(&mut self, items: &[T]) { + self.values.extend_from_slice(items); + self.update_all_valid(); + } + + fn update_all_valid(&mut self) { + // get len before mutable borrow + let len = self.len(); + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(len - validity.len(), true); + } + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity) + } + + /// Changes the arrays' [`DataType`], returning a new [`MutablePrimitiveArray`]. + /// Use to change the logical type without changing the corresponding physical Type. + /// # Implementation + /// This operation is `O(1)`. + #[inline] + pub fn to(self, data_type: DataType) -> Self { + Self::try_new(data_type, self.values, self.validity).unwrap() + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: PrimitiveArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutablePrimitiveArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + /// Returns the capacity of this [`MutablePrimitiveArray`]. + pub fn capacity(&self) -> usize { + self.values.capacity() + } +} + +/// Accessors +impl MutablePrimitiveArray { + /// Returns its values. + pub fn values(&self) -> &Vec { + &self.values + } + + /// Returns a mutable slice of values. + pub fn values_mut_slice(&mut self) -> &mut [T] { + self.values.as_mut_slice() + } +} + +/// Setters +impl MutablePrimitiveArray { + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Panic + /// Panics iff index is larger than `self.len()`. + pub fn set(&mut self, index: usize, value: Option) { + assert!(index < self.len()); + // Safety: + // we just checked bounds + unsafe { self.set_unchecked(index, value) } + } + + /// Sets position `index` to `value`. + /// Note that if it is the first time a null appears in this array, + /// this initializes the validity bitmap (`O(N)`). + /// # Safety + /// Caller must ensure `index < self.len()` + pub unsafe fn set_unchecked(&mut self, index: usize, value: Option) { + *self.values.get_unchecked_mut(index) = value.unwrap_or_default(); + + if value.is_none() && self.validity.is_none() { + // When the validity is None, all elements so far are valid. When one of the elements is set of null, + // the validity must be initialized. + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + if let Some(x) = self.validity.as_mut() { + x.set_unchecked(index, value.is_some()) + } + } + + /// Sets the validity. + /// # Panic + /// Panics iff the validity's len is not equal to the existing values' length. + pub fn set_validity(&mut self, validity: Option) { + if let Some(validity) = &validity { + assert_eq!(self.values.len(), validity.len()) + } + self.validity = validity; + } + + /// Sets values. + /// # Panic + /// Panics iff the values' length is not equal to the existing validity's len. + pub fn set_values(&mut self, values: Vec) { + assert_eq!(values.len(), self.values.len()); + self.values = values; + } +} + +impl Extend> for MutablePrimitiveArray { + fn extend>>(&mut self, iter: I) { + let iter = iter.into_iter(); + self.reserve(iter.size_hint().0); + iter.for_each(|x| self.push(x)) + } +} + +impl TryExtend> for MutablePrimitiveArray { + /// This is infalible and is implemented for consistency with all other types + fn try_extend>>(&mut self, iter: I) -> Result<(), Error> { + self.extend(iter); + Ok(()) + } +} + +impl TryPush> for MutablePrimitiveArray { + /// This is infalible and is implemented for consistency with all other types + fn try_push(&mut self, item: Option) -> Result<(), Error> { + self.push(item); + Ok(()) + } +} + +impl MutableArray for MutablePrimitiveArray { + fn len(&self) -> usize { + self.values.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + PrimitiveArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + PrimitiveArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values).into(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl MutablePrimitiveArray { + /// Creates a [`MutablePrimitiveArray`] from a slice of values. + pub fn from_slice>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter().copied()) + } + + /// Creates a [`MutablePrimitiveArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: Iterator>, + { + let (validity, values) = trusted_len_unzip(iterator); + + Self { + data_type: T::PRIMITIVE.into(), + values, + validity, + } + } + + /// Creates a [`MutablePrimitiveArray`] from a [`TrustedLen`]. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: std::borrow::Borrow, + I: TrustedLen>, + { + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iter: I, + ) -> std::result::Result + where + P: std::borrow::Borrow, + I: IntoIterator, E>>, + { + let iterator = iter.into_iter(); + + let (validity, values) = try_trusted_len_unzip(iterator)?; + + Ok(Self { + data_type: T::PRIMITIVE.into(), + values, + validity, + }) + } + + /// Creates a [`MutablePrimitiveArray`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: std::borrow::Borrow, + I: TrustedLen, E>>, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutablePrimitiveArray`] out an iterator over values + pub fn from_trusted_len_values_iter>(iter: I) -> Self { + Self { + data_type: T::PRIMITIVE.into(), + values: iter.collect(), + validity: None, + } + } + + /// Creates a (non-null) [`MutablePrimitiveArray`] from a vector of values. + /// This does not have memcopy and is the fastest way to create a [`PrimitiveArray`]. + pub fn from_vec(values: Vec) -> Self { + Self::try_new(T::PRIMITIVE.into(), values, None).unwrap() + } + + /// Creates a new [`MutablePrimitiveArray`] from an iterator over values + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + pub unsafe fn from_trusted_len_values_iter_unchecked>(iter: I) -> Self { + Self { + data_type: T::PRIMITIVE.into(), + values: iter.collect(), + validity: None, + } + } +} + +impl>> FromIterator + for MutablePrimitiveArray +{ + fn from_iter>(iter: I) -> Self { + let iter = iter.into_iter(); + let (lower, _) = iter.size_hint(); + + let mut validity = MutableBitmap::with_capacity(lower); + + let values: Vec = iter + .map(|item| { + if let Some(a) = item.borrow() { + validity.push(true); + *a + } else { + validity.push(false); + T::default() + } + }) + .collect(); + + let validity = Some(validity); + + Self { + data_type: T::PRIMITIVE.into(), + values, + validity, + } + } +} + +/// Extends a [`MutableBitmap`] and a [`Vec`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn extend_trusted_len_unzip( + iterator: I, + validity: &mut MutableBitmap, + buffer: &mut Vec, +) where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator>, +{ + let (_, upper) = iterator.size_hint(); + let additional = upper.expect("trusted_len_unzip requires an upper limit"); + + validity.reserve(additional); + let values = iterator.map(|item| { + if let Some(item) = item { + validity.push_unchecked(true); + *item.borrow() + } else { + validity.push_unchecked(false); + T::default() + } + }); + buffer.extend(values); +} + +/// Creates a [`MutableBitmap`] and a [`Vec`] from an iterator of `Option`. +/// The first buffer corresponds to a bitmap buffer, the second one +/// corresponds to a values buffer. +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn trusted_len_unzip(iterator: I) -> (Option, Vec) +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator>, +{ + let mut validity = MutableBitmap::new(); + let mut buffer = Vec::::new(); + + extend_trusted_len_unzip(iterator, &mut validity, &mut buffer); + + let validity = Some(validity); + + (validity, buffer) +} + +/// # Safety +/// The caller must ensure that `iterator` is `TrustedLen`. +#[inline] +pub(crate) unsafe fn try_trusted_len_unzip( + iterator: I, +) -> std::result::Result<(Option, Vec), E> +where + T: NativeType, + P: std::borrow::Borrow, + I: Iterator, E>>, +{ + let (_, upper) = iterator.size_hint(); + let len = upper.expect("trusted_len_unzip requires an upper limit"); + + let mut null = MutableBitmap::with_capacity(len); + let mut buffer = Vec::::with_capacity(len); + + let mut dst = buffer.as_mut_ptr(); + for item in iterator { + let item = if let Some(item) = item? { + null.push(true); + *item.borrow() + } else { + null.push(false); + T::default() + }; + std::ptr::write(dst, item); + dst = dst.add(1); + } + assert_eq!( + dst.offset_from(buffer.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.set_len(len); + null.set_len(len); + + let validity = Some(null); + + Ok((validity, buffer)) +} + +impl PartialEq for MutablePrimitiveArray { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutablePrimitiveArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + let slice = other.values.as_slice(); + self.values.extend_from_slice(slice); + Ok(()) + } +} diff --git a/crates/nano-arrow/src/array/specification.rs b/crates/nano-arrow/src/array/specification.rs new file mode 100644 index 000000000000..efa8fe1be4a4 --- /dev/null +++ b/crates/nano-arrow/src/array/specification.rs @@ -0,0 +1,178 @@ +use crate::array::DictionaryKey; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +/// Helper trait to support `Offset` and `OffsetBuffer` +pub(crate) trait OffsetsContainer { + fn last(&self) -> usize; + fn as_slice(&self) -> &[O]; +} + +impl OffsetsContainer for OffsetsBuffer { + #[inline] + fn last(&self) -> usize { + self.last().to_usize() + } + + #[inline] + fn as_slice(&self) -> &[O] { + self.buffer() + } +} + +impl OffsetsContainer for Offsets { + #[inline] + fn last(&self) -> usize { + self.last().to_usize() + } + + #[inline] + fn as_slice(&self) -> &[O] { + self.as_slice() + } +} + +pub(crate) fn try_check_offsets_bounds>( + offsets: &C, + values_len: usize, +) -> Result<()> { + if offsets.last() > values_len { + Err(Error::oos("offsets must not exceed the values length")) + } else { + Ok(()) + } +} + +/// # Error +/// * any offset is larger or equal to `values_len`. +/// * any slice of `values` between two consecutive pairs from `offsets` is invalid `utf8`, or +pub(crate) fn try_check_utf8>( + offsets: &C, + values: &[u8], +) -> Result<()> { + if offsets.as_slice().len() == 1 { + return Ok(()); + } + + try_check_offsets_bounds(offsets, values.len())?; + + if values.is_ascii() { + Ok(()) + } else { + simdutf8::basic::from_utf8(values)?; + + // offsets can be == values.len() + // find first offset from the end that is smaller + // Example: + // values.len() = 10 + // offsets = [0, 5, 10, 10] + let offsets = offsets.as_slice(); + let last = offsets + .iter() + .enumerate() + .skip(1) + .rev() + .find_map(|(i, offset)| (offset.to_usize() < values.len()).then(|| i)); + + let last = if let Some(last) = last { + // following the example: last = 1 (offset = 5) + last + } else { + // given `l = values.len()`, this branch is hit iff either: + // * `offsets = [0, l, l, ...]`, which was covered by `from_utf8(values)` above + // * `offsets = [0]`, which never happens because offsets.as_slice().len() == 1 is short-circuited above + return Ok(()); + }; + + // truncate to relevant offsets. Note: `=last` because last was computed skipping the first item + // following the example: starts = [0, 5] + let starts = unsafe { offsets.get_unchecked(..=last) }; + + let mut any_invalid = false; + for start in starts { + let start = start.to_usize(); + + // Safety: `try_check_offsets_bounds` just checked for bounds + let b = *unsafe { values.get_unchecked(start) }; + + // A valid code-point iff it does not start with 0b10xxxxxx + // Bit-magic taken from `std::str::is_char_boundary` + if (b as i8) < -0x40 { + any_invalid = true + } + } + if any_invalid { + return Err(Error::oos("Non-valid char boundary detected")); + } + Ok(()) + } +} + +/// Check dictionary indexes without checking usize conversion. +/// # Safety +/// The caller must ensure that `K::as_usize` always succeeds. +pub(crate) unsafe fn check_indexes_unchecked( + keys: &[K], + len: usize, +) -> Result<()> { + let mut invalid = false; + + // this loop is auto-vectorized + keys.iter().for_each(|k| { + if k.as_usize() > len { + invalid = true; + } + }); + + if invalid { + let key = keys.iter().map(|k| k.as_usize()).max().unwrap(); + Err(Error::oos(format!("One of the dictionary keys is {key} but it must be < than the length of the dictionary values, which is {len}"))) + } else { + Ok(()) + } +} + +pub fn check_indexes(keys: &[K], len: usize) -> Result<()> +where + K: std::fmt::Debug + Copy + TryInto, +{ + keys.iter().try_for_each(|key| { + let key: usize = (*key) + .try_into() + .map_err(|_| Error::oos(format!("The dictionary key must fit in a `usize`, but {key:?} does not")))?; + if key >= len { + Err(Error::oos(format!("One of the dictionary keys is {key} but it must be < than the length of the dictionary values, which is {len}"))) + } else { + Ok(()) + } + }) +} + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use super::*; + + pub(crate) fn binary_strategy() -> impl Strategy> { + prop::collection::vec(any::(), 1..100) + } + + proptest! { + // a bit expensive, feel free to run it when changing the code above + // #![proptest_config(ProptestConfig::with_cases(100000))] + #[test] + #[cfg_attr(miri, ignore)] // miri and proptest do not work well + fn check_utf8_validation(values in binary_strategy()) { + + for offset in 0..values.len() - 1 { + let offsets = vec![0, offset as i32, values.len() as i32].try_into().unwrap(); + + let mut is_valid = std::str::from_utf8(&values[..offset]).is_ok(); + is_valid &= std::str::from_utf8(&values[offset..]).is_ok(); + + assert_eq!(try_check_utf8::>(&offsets, &values).is_ok(), is_valid) + } + } + } +} diff --git a/crates/nano-arrow/src/array/struct_/data.rs b/crates/nano-arrow/src/array/struct_/data.rs new file mode 100644 index 000000000000..b96dc4ffe28b --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/data.rs @@ -0,0 +1,28 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, StructArray}; +use crate::bitmap::Bitmap; + +impl Arrow2Arrow for StructArray { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type.clone().into(); + + let builder = ArrayDataBuilder::new(data_type) + .len(self.len()) + .nulls(self.validity.as_ref().map(|b| b.clone().into())) + .child_data(self.values.iter().map(|x| to_data(x.as_ref())).collect()); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + + Self { + data_type, + values: data.child_data().iter().map(from_data).collect(), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/struct_/ffi.rs b/crates/nano-arrow/src/array/struct_/ffi.rs new file mode 100644 index 000000000000..95abe00694b2 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/ffi.rs @@ -0,0 +1,72 @@ +use super::super::ffi::ToFfi; +use super::super::{Array, FromFfi}; +use super::StructArray; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for StructArray { + fn buffers(&self) -> Vec> { + vec![self.validity.as_ref().map(|x| x.as_ptr())] + } + + fn children(&self) -> Vec> { + self.values.clone() + } + + fn offset(&self) -> Option { + Some( + self.validity + .as_ref() + .map(|bitmap| bitmap.offset()) + .unwrap_or_default(), + ) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for StructArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let fields = Self::get_fields(&data_type); + + let arrow_array = array.array(); + let validity = unsafe { array.validity() }?; + let len = arrow_array.len(); + let offset = arrow_array.offset(); + let values = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + ffi::try_from(child).map(|arr| { + // there is a discrepancy with how arrow2 exports sliced + // struct array and how pyarrow does it. + // # Pyarrow + // ## struct array len 3 + // * slice 1 by with len 2 + // offset on struct array: 1 + // length on struct array: 2 + // offset on value array: 0 + // length on value array: 3 + // # Arrow2 + // ## struct array len 3 + // * slice 1 by with len 2 + // offset on struct array: 0 + // length on struct array: 3 + // offset on value array: 1 + // length on value array: 2 + // + // this branch will ensure both can round trip + if arr.len() >= (len + offset) { + arr.sliced(offset, len) + } else { + arr + } + }) + }) + .collect::>>>()?; + + Self::try_new(data_type, values, validity) + } +} diff --git a/crates/nano-arrow/src/array/struct_/fmt.rs b/crates/nano-arrow/src/array/struct_/fmt.rs new file mode 100644 index 000000000000..999cd8b67e08 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/fmt.rs @@ -0,0 +1,34 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_map, write_vec}; +use super::StructArray; + +pub fn write_value( + array: &StructArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let writer = |f: &mut W, _index| { + for (i, (field, column)) in array.fields().iter().zip(array.values()).enumerate() { + if i != 0 { + write!(f, ", ")?; + } + let writer = get_display(column.as_ref(), null); + write!(f, "{}: ", field.name)?; + writer(f, index)?; + } + Ok(()) + }; + + write_map(f, writer, None, 1, null, false) +} + +impl Debug for StructArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "StructArray")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/struct_/iterator.rs b/crates/nano-arrow/src/array/struct_/iterator.rs new file mode 100644 index 000000000000..cb8e6aafbb09 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/iterator.rs @@ -0,0 +1,96 @@ +use super::StructArray; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::scalar::{new_scalar, Scalar}; +use crate::trusted_len::TrustedLen; + +pub struct StructValueIter<'a> { + array: &'a StructArray, + index: usize, + end: usize, +} + +impl<'a> StructValueIter<'a> { + #[inline] + pub fn new(array: &'a StructArray) -> Self { + Self { + array, + index: 0, + end: array.len(), + } + } +} + +impl<'a> Iterator for StructValueIter<'a> { + type Item = Vec>; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + + // Safety: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), old)) + .collect(), + ) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } +} + +unsafe impl<'a> TrustedLen for StructValueIter<'a> {} + +impl<'a> DoubleEndedIterator for StructValueIter<'a> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + + // Safety: + // self.end is maximized by the length of the array + Some( + self.array + .values() + .iter() + .map(|v| new_scalar(v.as_ref(), self.end)) + .collect(), + ) + } + } +} + +type ValuesIter<'a> = StructValueIter<'a>; +type ZipIter<'a> = ZipValidity>, ValuesIter<'a>, BitmapIter<'a>>; + +impl<'a> IntoIterator for &'a StructArray { + type Item = Option>>; + type IntoIter = ZipIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> StructArray { + /// Returns an iterator of `Option>` + pub fn iter(&'a self) -> ZipIter<'a> { + ZipValidity::new_with_validity(StructValueIter::new(self), self.validity()) + } + + /// Returns an iterator of `Box` + pub fn values_iter(&'a self) -> ValuesIter<'a> { + StructValueIter::new(self) + } +} diff --git a/crates/nano-arrow/src/array/struct_/mod.rs b/crates/nano-arrow/src/array/struct_/mod.rs new file mode 100644 index 000000000000..e38597036574 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/mod.rs @@ -0,0 +1,254 @@ +use super::{new_empty_array, new_null_array, Array}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; +mod mutable; +pub use mutable::*; + +/// A [`StructArray`] is a nested [`Array`] with an optional validity representing +/// multiple [`Array`] with the same number of rows. +/// # Example +/// ``` +/// use arrow2::array::*; +/// use arrow2::datatypes::*; +/// let boolean = BooleanArray::from_slice(&[false, false, true, true]).boxed(); +/// let int = Int32Array::from_slice(&[42, 28, 19, 31]).boxed(); +/// +/// let fields = vec![ +/// Field::new("b", DataType::Boolean, false), +/// Field::new("c", DataType::Int32, false), +/// ]; +/// +/// let array = StructArray::new(DataType::Struct(fields), vec![boolean, int], None); +/// ``` +#[derive(Clone)] +pub struct StructArray { + data_type: DataType, + values: Vec>, + validity: Option, +} + +impl StructArray { + /// Returns a new [`StructArray`]. + /// # Errors + /// This function errors iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `data_type` are empty + /// * the values's len is different from children's length + /// * any of the values's data type is different from its corresponding children' data type + /// * any element of values has a different length than the first element + /// * the validity's length is not equal to the length of the first element + pub fn try_new( + data_type: DataType, + values: Vec>, + validity: Option, + ) -> Result { + let fields = Self::try_get_fields(&data_type)?; + if fields.is_empty() { + return Err(Error::oos("A StructArray must contain at least one field")); + } + if fields.len() != values.len() { + return Err(Error::oos( + "A StructArray must have a number of fields in its DataType equal to the number of child values", + )); + } + + fields + .iter().map(|a| &a.data_type) + .zip(values.iter().map(|a| a.data_type())) + .enumerate() + .try_for_each(|(index, (data_type, child))| { + if data_type != child { + Err(Error::oos(format!( + "The children DataTypes of a StructArray must equal the children data types. + However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + ))) + } else { + Ok(()) + } + })?; + + let len = values[0].len(); + values + .iter() + .map(|a| a.len()) + .enumerate() + .try_for_each(|(index, a_len)| { + if a_len != len { + Err(Error::oos(format!( + "The children must have an equal number of values. + However, the values at index {index} have a length of {a_len}, which is different from values at index 0, {len}." + ))) + } else { + Ok(()) + } + })?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != len) + { + return Err(Error::oos( + "The validity length of a StructArray must match its number of elements", + )); + } + + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Returns a new [`StructArray`] + /// # Panics + /// This function panics iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Struct`]. + /// * the children of `data_type` are empty + /// * the values's len is different from children's length + /// * any of the values's data type is different from its corresponding children' data type + /// * any element of values has a different length than the first element + /// * the validity's length is not equal to the length of the first element + pub fn new(data_type: DataType, values: Vec>, validity: Option) -> Self { + Self::try_new(data_type, values, validity).unwrap() + } + + /// Creates an empty [`StructArray`]. + pub fn new_empty(data_type: DataType) -> Self { + if let DataType::Struct(fields) = &data_type.to_logical_type() { + let values = fields + .iter() + .map(|field| new_empty_array(field.data_type().clone())) + .collect(); + Self::new(data_type, values, None) + } else { + panic!("StructArray must be initialized with DataType::Struct"); + } + } + + /// Creates a null [`StructArray`] of length `length`. + pub fn new_null(data_type: DataType, length: usize) -> Self { + if let DataType::Struct(fields) = &data_type { + let values = fields + .iter() + .map(|field| new_null_array(field.data_type().clone(), length)) + .collect(); + Self::new(data_type, values, Some(Bitmap::new_zeroed(length))) + } else { + panic!("StructArray must be initialized with DataType::Struct"); + } + } +} + +// must use +impl StructArray { + /// Deconstructs the [`StructArray`] into its individual components. + #[must_use] + pub fn into_data(self) -> (Vec, Vec>, Option) { + let Self { + data_type, + values, + validity, + } = self; + let fields = if let DataType::Struct(fields) = data_type { + fields + } else { + unreachable!() + }; + (fields, values, validity) + } + + /// Slices this [`StructArray`]. + /// # Panics + /// * `offset + length` must be smaller than `self.len()`. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "offset + length may not exceed length of array" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`StructArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.values + .iter_mut() + .for_each(|x| x.slice_unchecked(offset, length)); + } + + impl_sliced!(); + + impl_mut_validity!(); + + impl_into_array!(); +} + +// Accessors +impl StructArray { + #[inline] + fn len(&self) -> usize { + self.values[0].len() + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Returns the values of this [`StructArray`]. + pub fn values(&self) -> &[Box] { + &self.values + } + + /// Returns the fields of this [`StructArray`]. + pub fn fields(&self) -> &[Field] { + Self::get_fields(&self.data_type) + } +} + +impl StructArray { + /// Returns the fields the `DataType::Struct`. + pub(crate) fn try_get_fields(data_type: &DataType) -> Result<&[Field], Error> { + match data_type.to_logical_type() { + DataType::Struct(fields) => Ok(fields), + _ => Err(Error::oos( + "Struct array must be created with a DataType whose physical type is Struct", + )), + } + } + + /// Returns the fields the `DataType::Struct`. + pub fn get_fields(data_type: &DataType) -> &[Field] { + Self::try_get_fields(data_type).unwrap() + } +} + +impl Array for StructArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} diff --git a/crates/nano-arrow/src/array/struct_/mutable.rs b/crates/nano-arrow/src/array/struct_/mutable.rs new file mode 100644 index 000000000000..8060a698fb63 --- /dev/null +++ b/crates/nano-arrow/src/array/struct_/mutable.rs @@ -0,0 +1,245 @@ +use std::sync::Arc; + +use super::StructArray; +use crate::array::{Array, MutableArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Error; + +/// Converting a [`MutableStructArray`] into a [`StructArray`] is `O(1)`. +#[derive(Debug)] +pub struct MutableStructArray { + data_type: DataType, + values: Vec>, + validity: Option, +} + +fn check( + data_type: &DataType, + values: &[Box], + validity: Option, +) -> Result<(), Error> { + let fields = StructArray::try_get_fields(data_type)?; + if fields.is_empty() { + return Err(Error::oos("A StructArray must contain at least one field")); + } + if fields.len() != values.len() { + return Err(Error::oos( + "A StructArray must have a number of fields in its DataType equal to the number of child values", + )); + } + + fields + .iter().map(|a| &a.data_type) + .zip(values.iter().map(|a| a.data_type())) + .enumerate() + .try_for_each(|(index, (data_type, child))| { + if data_type != child { + Err(Error::oos(format!( + "The children DataTypes of a StructArray must equal the children data types. + However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + ))) + } else { + Ok(()) + } + })?; + + let len = values[0].len(); + values + .iter() + .map(|a| a.len()) + .enumerate() + .try_for_each(|(index, a_len)| { + if a_len != len { + Err(Error::oos(format!( + "The children must have an equal number of values. + However, the values at index {index} have a length of {a_len}, which is different from values at index 0, {len}." + ))) + } else { + Ok(()) + } + })?; + + if validity.map_or(false, |validity| validity != len) { + return Err(Error::oos( + "The validity length of a StructArray must match its number of elements", + )); + } + Ok(()) +} + +impl From for StructArray { + fn from(other: MutableStructArray) -> Self { + let validity = if other.validity.as_ref().map(|x| x.unset_bits()).unwrap_or(0) > 0 { + other.validity.map(|x| x.into()) + } else { + None + }; + + StructArray::new( + other.data_type, + other.values.into_iter().map(|mut v| v.as_box()).collect(), + validity, + ) + } +} + +impl MutableStructArray { + /// Creates a new [`MutableStructArray`]. + pub fn new(data_type: DataType, values: Vec>) -> Self { + Self::try_new(data_type, values, None).unwrap() + } + + /// Create a [`MutableStructArray`] out of low-end APIs. + /// # Errors + /// This function errors iff: + /// * `data_type` is not [`DataType::Struct`] + /// * The inner types of `data_type` are not equal to those of `values` + /// * `validity` is not `None` and its length is different from the `values`'s length + pub fn try_new( + data_type: DataType, + values: Vec>, + validity: Option, + ) -> Result { + check(&data_type, &values, validity.as_ref().map(|x| x.len()))?; + Ok(Self { + data_type, + values, + validity, + }) + } + + /// Extract the low-end APIs from the [`MutableStructArray`]. + pub fn into_inner(self) -> (DataType, Vec>, Option) { + (self.data_type, self.values, self.validity) + } + + /// The mutable values + pub fn mut_values(&mut self) -> &mut Vec> { + &mut self.values + } + + /// The values + pub fn values(&self) -> &Vec> { + &self.values + } + + /// Return the `i`th child array. + pub fn value(&mut self, i: usize) -> Option<&mut A> { + self.values[i].as_mut_any().downcast_mut::() + } +} + +impl MutableStructArray { + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + for v in &mut self.values { + v.reserve(additional); + } + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Call this once for each "row" of children you push. + pub fn push(&mut self, valid: bool) { + match &mut self.validity { + Some(validity) => validity.push(valid), + None => match valid { + true => (), + false => self.init_validity(), + }, + }; + } + + fn push_null(&mut self) { + for v in &mut self.values { + v.push_null(); + } + self.push(false); + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + let len = self.len(); + if len > 0 { + validity.extend_constant(len, true); + validity.set(len - 1, false); + } + self.validity = Some(validity) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: StructArray = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableStructArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + for v in &mut self.values { + v.shrink_to_fit(); + } + if let Some(validity) = self.validity.as_mut() { + validity.shrink_to_fit() + } + } +} + +impl MutableArray for MutableStructArray { + fn len(&self) -> usize { + self.values.first().map(|v| v.len()).unwrap_or(0) + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + StructArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values) + .into_iter() + .map(|mut v| v.as_box()) + .collect(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> Arc { + StructArray::new( + self.data_type.clone(), + std::mem::take(&mut self.values) + .into_iter() + .map(|mut v| v.as_box()) + .collect(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + self.push_null() + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } +} diff --git a/crates/nano-arrow/src/array/union/data.rs b/crates/nano-arrow/src/array/union/data.rs new file mode 100644 index 000000000000..6de6c0074231 --- /dev/null +++ b/crates/nano-arrow/src/array/union/data.rs @@ -0,0 +1,70 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{from_data, to_data, Arrow2Arrow, UnionArray}; +use crate::buffer::Buffer; +use crate::datatypes::DataType; + +impl Arrow2Arrow for UnionArray { + fn to_data(&self) -> ArrayData { + let data_type = arrow_schema::DataType::from(self.data_type.clone()); + let len = self.len(); + + let builder = match self.offsets.clone() { + Some(offsets) => ArrayDataBuilder::new(data_type) + .len(len) + .buffers(vec![self.types.clone().into(), offsets.into()]) + .child_data(self.fields.iter().map(|x| to_data(x.as_ref())).collect()), + None => ArrayDataBuilder::new(data_type) + .len(len) + .buffers(vec![self.types.clone().into()]) + .child_data( + self.fields + .iter() + .map(|x| to_data(x.as_ref()).slice(self.offset, len)) + .collect(), + ), + }; + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type: DataType = data.data_type().clone().into(); + + let fields = data.child_data().iter().map(from_data).collect(); + let buffers = data.buffers(); + let mut types: Buffer = buffers[0].clone().into(); + types.slice(data.offset(), data.len()); + let offsets = match buffers.len() == 2 { + true => { + let mut offsets: Buffer = buffers[1].clone().into(); + offsets.slice(data.offset(), data.len()); + Some(offsets) + }, + false => None, + }; + + // Map from type id to array index + let map = match &data_type { + DataType::Union(_, Some(ids), _) => { + let mut map = [0; 127]; + for (pos, &id) in ids.iter().enumerate() { + map[id as usize] = pos; + } + Some(map) + }, + DataType::Union(_, None, _) => None, + _ => unreachable!("must be Union type"), + }; + + Self { + types, + map, + fields, + offsets, + data_type, + offset: data.offset(), + } + } +} diff --git a/crates/nano-arrow/src/array/union/ffi.rs b/crates/nano-arrow/src/array/union/ffi.rs new file mode 100644 index 000000000000..590afec0c6c5 --- /dev/null +++ b/crates/nano-arrow/src/array/union/ffi.rs @@ -0,0 +1,60 @@ +use super::super::ffi::ToFfi; +use super::super::Array; +use super::UnionArray; +use crate::array::FromFfi; +use crate::error::Result; +use crate::ffi; + +unsafe impl ToFfi for UnionArray { + fn buffers(&self) -> Vec> { + if let Some(offsets) = &self.offsets { + vec![ + Some(self.types.as_ptr().cast::()), + Some(offsets.as_ptr().cast::()), + ] + } else { + vec![Some(self.types.as_ptr().cast::())] + } + } + + fn children(&self) -> Vec> { + self.fields.clone() + } + + fn offset(&self) -> Option { + Some(self.types.offset()) + } + + fn to_ffi_aligned(&self) -> Self { + self.clone() + } +} + +impl FromFfi for UnionArray { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let fields = Self::get_fields(&data_type); + + let mut types = unsafe { array.buffer::(0) }?; + let offsets = if Self::is_sparse(&data_type) { + None + } else { + Some(unsafe { array.buffer::(1) }?) + }; + + let length = array.array().len(); + let offset = array.array().offset(); + let fields = (0..fields.len()) + .map(|index| { + let child = array.child(index)?; + ffi::try_from(child) + }) + .collect::>>>()?; + + if offset > 0 { + types.slice(offset, length); + }; + + Self::try_new(data_type, types, fields, offsets) + } +} diff --git a/crates/nano-arrow/src/array/union/fmt.rs b/crates/nano-arrow/src/array/union/fmt.rs new file mode 100644 index 000000000000..521201fffd6d --- /dev/null +++ b/crates/nano-arrow/src/array/union/fmt.rs @@ -0,0 +1,24 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::{get_display, write_vec}; +use super::UnionArray; + +pub fn write_value( + array: &UnionArray, + index: usize, + null: &'static str, + f: &mut W, +) -> Result { + let (field, index) = array.index(index); + + get_display(array.fields()[field].as_ref(), null)(f, index) +} + +impl Debug for UnionArray { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, "None", f); + + write!(f, "UnionArray")?; + write_vec(f, writer, None, self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/union/iterator.rs b/crates/nano-arrow/src/array/union/iterator.rs new file mode 100644 index 000000000000..bdcf5825af6c --- /dev/null +++ b/crates/nano-arrow/src/array/union/iterator.rs @@ -0,0 +1,59 @@ +use super::UnionArray; +use crate::scalar::Scalar; +use crate::trusted_len::TrustedLen; + +#[derive(Debug, Clone)] +pub struct UnionIter<'a> { + array: &'a UnionArray, + current: usize, +} + +impl<'a> UnionIter<'a> { + #[inline] + pub fn new(array: &'a UnionArray) -> Self { + Self { array, current: 0 } + } +} + +impl<'a> Iterator for UnionIter<'a> { + type Item = Box; + + #[inline] + fn next(&mut self) -> Option { + if self.current == self.array.len() { + None + } else { + let old = self.current; + self.current += 1; + Some(unsafe { self.array.value_unchecked(old) }) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.array.len() - self.current; + (len, Some(len)) + } +} + +impl<'a> IntoIterator for &'a UnionArray { + type Item = Box; + type IntoIter = UnionIter<'a>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a> UnionArray { + /// constructs a new iterator + #[inline] + pub fn iter(&'a self) -> UnionIter<'a> { + UnionIter::new(self) + } +} + +impl<'a> std::iter::ExactSizeIterator for UnionIter<'a> {} + +unsafe impl<'a> TrustedLen for UnionIter<'a> {} diff --git a/crates/nano-arrow/src/array/union/mod.rs b/crates/nano-arrow/src/array/union/mod.rs new file mode 100644 index 000000000000..75c83fb91759 --- /dev/null +++ b/crates/nano-arrow/src/array/union/mod.rs @@ -0,0 +1,377 @@ +use super::{new_empty_array, new_null_array, Array}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::{DataType, Field, UnionMode}; +use crate::error::Error; +use crate::scalar::{new_scalar, Scalar}; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod iterator; + +type UnionComponents<'a> = (&'a [Field], Option<&'a [i32]>, UnionMode); + +/// [`UnionArray`] represents an array whose each slot can contain different values. +/// +// How to read a value at slot i: +// ``` +// let index = self.types()[i] as usize; +// let field = self.fields()[index]; +// let offset = self.offsets().map(|x| x[index]).unwrap_or(i); +// let field = field.as_any().downcast to correct type; +// let value = field.value(offset); +// ``` +#[derive(Clone)] +pub struct UnionArray { + // Invariant: every item in `types` is `> 0 && < fields.len()` + types: Buffer, + // Invariant: `map.len() == fields.len()` + // Invariant: every item in `map` is `> 0 && < fields.len()` + map: Option<[usize; 127]>, + fields: Vec>, + // Invariant: when set, `offsets.len() == types.len()` + offsets: Option>, + data_type: DataType, + offset: usize, +} + +impl UnionArray { + /// Returns a new [`UnionArray`]. + /// # Errors + /// This function errors iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `data_type`'s children's length + /// * The number of `fields` is larger than `i8::MAX` + /// * any of the values's data type is different from its corresponding children' data type + pub fn try_new( + data_type: DataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> Result { + let (f, ids, mode) = Self::try_get_all(&data_type)?; + + if f.len() != fields.len() { + return Err(Error::oos( + "The number of `fields` must equal the number of children fields in DataType::Union", + )); + }; + let number_of_fields: i8 = fields + .len() + .try_into() + .map_err(|_| Error::oos("The number of `fields` cannot be larger than i8::MAX"))?; + + f + .iter().map(|a| a.data_type()) + .zip(fields.iter().map(|a| a.data_type())) + .enumerate() + .try_for_each(|(index, (data_type, child))| { + if data_type != child { + Err(Error::oos(format!( + "The children DataTypes of a UnionArray must equal the children data types. + However, the field {index} has data type {data_type:?} but the value has data type {child:?}" + ))) + } else { + Ok(()) + } + })?; + + if let Some(offsets) = &offsets { + if offsets.len() != types.len() { + return Err(Error::oos( + "In a UnionArray, the offsets' length must be equal to the number of types", + )); + } + } + if offsets.is_none() != mode.is_sparse() { + return Err(Error::oos( + "In a sparse UnionArray, the offsets must be set (and vice-versa)", + )); + } + + // build hash + let map = if let Some(&ids) = ids.as_ref() { + if ids.len() != fields.len() { + return Err(Error::oos( + "In a union, when the ids are set, their length must be equal to the number of fields", + )); + } + + // example: + // * types = [5, 7, 5, 7, 7, 7, 5, 7, 7, 5, 5] + // * ids = [5, 7] + // => hash = [0, 0, 0, 0, 0, 0, 1, 0, ...] + let mut hash = [0; 127]; + + for (pos, &id) in ids.iter().enumerate() { + if !(0..=127).contains(&id) { + return Err(Error::oos( + "In a union, when the ids are set, every id must belong to [0, 128[", + )); + } + hash[id as usize] = pos; + } + + types.iter().try_for_each(|&type_| { + if type_ < 0 { + return Err(Error::oos("In a union, when the ids are set, every type must be >= 0")); + } + let id = hash[type_ as usize]; + if id >= fields.len() { + Err(Error::oos("In a union, when the ids are set, each id must be smaller than the number of fields.")) + } else { + Ok(()) + } + })?; + + Some(hash) + } else { + // Safety: every type in types is smaller than number of fields + let mut is_valid = true; + for &type_ in types.iter() { + if type_ < 0 || type_ >= number_of_fields { + is_valid = false + } + } + if !is_valid { + return Err(Error::oos( + "Every type in `types` must be larger than 0 and smaller than the number of fields.", + )); + } + + None + }; + + Ok(Self { + data_type, + map, + fields, + offsets, + types, + offset: 0, + }) + } + + /// Returns a new [`UnionArray`]. + /// # Panics + /// This function panics iff: + /// * `data_type`'s physical type is not [`crate::datatypes::PhysicalType::Union`]. + /// * the fields's len is different from the `data_type`'s children's length + /// * any of the values's data type is different from its corresponding children' data type + pub fn new( + data_type: DataType, + types: Buffer, + fields: Vec>, + offsets: Option>, + ) -> Self { + Self::try_new(data_type, types, fields, offsets).unwrap() + } + + /// Creates a new null [`UnionArray`]. + pub fn new_null(data_type: DataType, length: usize) -> Self { + if let DataType::Union(f, _, mode) = &data_type { + let fields = f + .iter() + .map(|x| new_null_array(x.data_type().clone(), length)) + .collect(); + + let offsets = if mode.is_sparse() { + None + } else { + Some((0..length as i32).collect::>().into()) + }; + + // all from the same field + let types = vec![0i8; length].into(); + + Self::new(data_type, types, fields, offsets) + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } + + /// Creates a new empty [`UnionArray`]. + pub fn new_empty(data_type: DataType) -> Self { + if let DataType::Union(f, _, mode) = data_type.to_logical_type() { + let fields = f + .iter() + .map(|x| new_empty_array(x.data_type().clone())) + .collect(); + + let offsets = if mode.is_sparse() { + None + } else { + Some(Buffer::default()) + }; + + Self { + data_type, + map: None, + fields, + offsets, + types: Buffer::new(), + offset: 0, + } + } else { + panic!("Union struct must be created with the corresponding Union DataType") + } + } +} + +impl UnionArray { + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Panic + /// This function panics iff `offset + length >= self.len()`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the existing length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a slice of this [`UnionArray`]. + /// # Implementation + /// This operation is `O(F)` where `F` is the number of fields. + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + debug_assert!(offset + length <= self.len()); + + self.types.slice_unchecked(offset, length); + if let Some(offsets) = self.offsets.as_mut() { + offsets.slice_unchecked(offset, length) + } + self.offset += offset; + } + + impl_sliced!(); + impl_into_array!(); +} + +impl UnionArray { + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.types.len() + } + + /// The optional offsets. + pub fn offsets(&self) -> Option<&Buffer> { + self.offsets.as_ref() + } + + /// The fields. + pub fn fields(&self) -> &Vec> { + &self.fields + } + + /// The types. + pub fn types(&self) -> &Buffer { + &self.types + } + + #[inline] + unsafe fn field_slot_unchecked(&self, index: usize) -> usize { + self.offsets() + .as_ref() + .map(|x| *x.get_unchecked(index) as usize) + .unwrap_or(index + self.offset) + } + + /// Returns the index and slot of the field to select from `self.fields`. + #[inline] + pub fn index(&self, index: usize) -> (usize, usize) { + assert!(index < self.len()); + unsafe { self.index_unchecked(index) } + } + + /// Returns the index and slot of the field to select from `self.fields`. + /// The first value is guaranteed to be `< self.fields().len()` + /// # Safety + /// This function is safe iff `index < self.len`. + #[inline] + pub unsafe fn index_unchecked(&self, index: usize) -> (usize, usize) { + debug_assert!(index < self.len()); + // Safety: assumption of the function + let type_ = unsafe { *self.types.get_unchecked(index) }; + // Safety: assumption of the struct + let type_ = self + .map + .as_ref() + .map(|map| unsafe { *map.get_unchecked(type_ as usize) }) + .unwrap_or(type_ as usize); + // Safety: assumption of the function + let index = self.field_slot_unchecked(index); + (type_, index) + } + + /// Returns the slot `index` as a [`Scalar`]. + /// # Panics + /// iff `index >= self.len()` + pub fn value(&self, index: usize) -> Box { + assert!(index < self.len()); + unsafe { self.value_unchecked(index) } + } + + /// Returns the slot `index` as a [`Scalar`]. + /// # Safety + /// This function is safe iff `i < self.len`. + pub unsafe fn value_unchecked(&self, index: usize) -> Box { + debug_assert!(index < self.len()); + let (type_, index) = self.index_unchecked(index); + // Safety: assumption of the struct + debug_assert!(type_ < self.fields.len()); + let field = self.fields.get_unchecked(type_).as_ref(); + new_scalar(field, index) + } +} + +impl Array for UnionArray { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + None + } + + fn with_validity(&self, _: Option) -> Box { + panic!("cannot set validity of a union array") + } +} + +impl UnionArray { + fn try_get_all(data_type: &DataType) -> Result { + match data_type.to_logical_type() { + DataType::Union(fields, ids, mode) => { + Ok((fields, ids.as_ref().map(|x| x.as_ref()), *mode)) + }, + _ => Err(Error::oos( + "The UnionArray requires a logical type of DataType::Union", + )), + } + } + + fn get_all(data_type: &DataType) -> (&[Field], Option<&[i32]>, UnionMode) { + Self::try_get_all(data_type).unwrap() + } + + /// Returns all fields from [`DataType::Union`]. + /// # Panic + /// Panics iff `data_type`'s logical type is not [`DataType::Union`]. + pub fn get_fields(data_type: &DataType) -> &[Field] { + Self::get_all(data_type).0 + } + + /// Returns whether the [`DataType::Union`] is sparse or not. + /// # Panic + /// Panics iff `data_type`'s logical type is not [`DataType::Union`]. + pub fn is_sparse(data_type: &DataType) -> bool { + Self::get_all(data_type).2.is_sparse() + } +} diff --git a/crates/nano-arrow/src/array/utf8/data.rs b/crates/nano-arrow/src/array/utf8/data.rs new file mode 100644 index 000000000000..16674c969372 --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/data.rs @@ -0,0 +1,42 @@ +use arrow_data::{ArrayData, ArrayDataBuilder}; + +use crate::array::{Arrow2Arrow, Utf8Array}; +use crate::bitmap::Bitmap; +use crate::offset::{Offset, OffsetsBuffer}; + +impl Arrow2Arrow for Utf8Array { + fn to_data(&self) -> ArrayData { + let data_type = self.data_type().clone().into(); + let builder = ArrayDataBuilder::new(data_type) + .len(self.offsets().len_proxy()) + .buffers(vec![ + self.offsets.clone().into_inner().into(), + self.values.clone().into(), + ]) + .nulls(self.validity.as_ref().map(|b| b.clone().into())); + + // Safety: Array is valid + unsafe { builder.build_unchecked() } + } + + fn from_data(data: &ArrayData) -> Self { + let data_type = data.data_type().clone().into(); + if data.is_empty() { + // Handle empty offsets + return Self::new_empty(data_type); + } + + let buffers = data.buffers(); + + // Safety: ArrayData is valid + let mut offsets = unsafe { OffsetsBuffer::new_unchecked(buffers[0].clone().into()) }; + offsets.slice(data.offset(), data.len() + 1); + + Self { + data_type, + offsets, + values: buffers[1].clone().into(), + validity: data.nulls().map(|n| Bitmap::from_null_buffer(n.clone())), + } + } +} diff --git a/crates/nano-arrow/src/array/utf8/ffi.rs b/crates/nano-arrow/src/array/utf8/ffi.rs new file mode 100644 index 000000000000..2129a85a6f8f --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/ffi.rs @@ -0,0 +1,62 @@ +use super::Utf8Array; +use crate::array::{FromFfi, ToFfi}; +use crate::bitmap::align; +use crate::error::Result; +use crate::ffi; +use crate::offset::{Offset, OffsetsBuffer}; + +unsafe impl ToFfi for Utf8Array { + fn buffers(&self) -> Vec> { + vec![ + self.validity.as_ref().map(|x| x.as_ptr()), + Some(self.offsets.buffer().as_ptr().cast::()), + Some(self.values.as_ptr().cast::()), + ] + } + + fn offset(&self) -> Option { + let offset = self.offsets.buffer().offset(); + if let Some(bitmap) = self.validity.as_ref() { + if bitmap.offset() == offset { + Some(offset) + } else { + None + } + } else { + Some(offset) + } + } + + fn to_ffi_aligned(&self) -> Self { + let offset = self.offsets.buffer().offset(); + + let validity = self.validity.as_ref().map(|bitmap| { + if bitmap.offset() == offset { + bitmap.clone() + } else { + align(bitmap, offset) + } + }); + + Self { + data_type: self.data_type.clone(), + validity, + offsets: self.offsets.clone(), + values: self.values.clone(), + } + } +} + +impl FromFfi for Utf8Array { + unsafe fn try_from_ffi(array: A) -> Result { + let data_type = array.data_type().clone(); + let validity = unsafe { array.validity() }?; + let offsets = unsafe { array.buffer::(1) }?; + let values = unsafe { array.buffer::(2)? }; + + // assumption that data from FFI is well constructed + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets) }; + + Ok(Self::new_unchecked(data_type, offsets, values, validity)) + } +} diff --git a/crates/nano-arrow/src/array/utf8/fmt.rs b/crates/nano-arrow/src/array/utf8/fmt.rs new file mode 100644 index 000000000000..4466444ffe3b --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/fmt.rs @@ -0,0 +1,23 @@ +use std::fmt::{Debug, Formatter, Result, Write}; + +use super::super::fmt::write_vec; +use super::Utf8Array; +use crate::offset::Offset; + +pub fn write_value(array: &Utf8Array, index: usize, f: &mut W) -> Result { + write!(f, "{}", array.value(index)) +} + +impl Debug for Utf8Array { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let writer = |f: &mut Formatter, index| write_value(self, index, f); + + let head = if O::IS_LARGE { + "LargeUtf8Array" + } else { + "Utf8Array" + }; + write!(f, "{head}")?; + write_vec(f, writer, self.validity(), self.len(), "None", false) + } +} diff --git a/crates/nano-arrow/src/array/utf8/from.rs b/crates/nano-arrow/src/array/utf8/from.rs new file mode 100644 index 000000000000..c1dcaf09b10d --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/from.rs @@ -0,0 +1,11 @@ +use std::iter::FromIterator; + +use super::{MutableUtf8Array, Utf8Array}; +use crate::offset::Offset; + +impl> FromIterator> for Utf8Array { + #[inline] + fn from_iter>>(iter: I) -> Self { + MutableUtf8Array::::from_iter(iter).into() + } +} diff --git a/crates/nano-arrow/src/array/utf8/iterator.rs b/crates/nano-arrow/src/array/utf8/iterator.rs new file mode 100644 index 000000000000..262b98c10d79 --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/iterator.rs @@ -0,0 +1,79 @@ +use super::{MutableUtf8Array, MutableUtf8ValuesArray, Utf8Array}; +use crate::array::{ArrayAccessor, ArrayValuesIter}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::offset::Offset; + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for Utf8Array { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`Utf8Array`]. +pub type Utf8ValuesIter<'a, O> = ArrayValuesIter<'a, Utf8Array>; + +impl<'a, O: Offset> IntoIterator for &'a Utf8Array { + type Item = Option<&'a str>; + type IntoIter = ZipValidity<&'a str, Utf8ValuesIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableUtf8Array { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +/// Iterator of values of an [`MutableUtf8ValuesArray`]. +pub type MutableUtf8ValuesIter<'a, O> = ArrayValuesIter<'a, MutableUtf8ValuesArray>; + +impl<'a, O: Offset> IntoIterator for &'a MutableUtf8Array { + type Item = Option<&'a str>; + type IntoIter = ZipValidity<&'a str, MutableUtf8ValuesIter<'a, O>, BitmapIter<'a>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +unsafe impl<'a, O: Offset> ArrayAccessor<'a> for MutableUtf8ValuesArray { + type Item = &'a str; + + #[inline] + unsafe fn value_unchecked(&'a self, index: usize) -> Self::Item { + self.value_unchecked(index) + } + + #[inline] + fn len(&self) -> usize { + self.len() + } +} + +impl<'a, O: Offset> IntoIterator for &'a MutableUtf8ValuesArray { + type Item = &'a str; + type IntoIter = ArrayValuesIter<'a, MutableUtf8ValuesArray>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/crates/nano-arrow/src/array/utf8/mod.rs b/crates/nano-arrow/src/array/utf8/mod.rs new file mode 100644 index 000000000000..ab2c2a7bab8b --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/mod.rs @@ -0,0 +1,545 @@ +use either::Either; + +use super::specification::{try_check_offsets_bounds, try_check_utf8}; +use super::{Array, GenericBinaryArray}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; + +#[cfg(feature = "arrow")] +mod data; +mod ffi; +pub(super) mod fmt; +mod from; +mod iterator; +mod mutable; +mod mutable_values; +pub use iterator::*; +pub use mutable::*; +pub use mutable_values::MutableUtf8ValuesArray; + +// Auxiliary struct to allow presenting &str as [u8] to a generic function +pub(super) struct StrAsBytes

(P); +impl> AsRef<[u8]> for StrAsBytes { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.0.as_ref().as_bytes() + } +} + +/// A [`Utf8Array`] is arrow's semantic equivalent of an immutable `Vec>`. +/// Cloning and slicing this struct is `O(1)`. +/// # Example +/// ``` +/// use arrow2::bitmap::Bitmap; +/// use arrow2::buffer::Buffer; +/// use arrow2::array::Utf8Array; +/// # fn main() { +/// let array = Utf8Array::::from([Some("hi"), None, Some("there")]); +/// assert_eq!(array.value(0), "hi"); +/// assert_eq!(array.iter().collect::>(), vec![Some("hi"), None, Some("there")]); +/// assert_eq!(array.values_iter().collect::>(), vec!["hi", "", "there"]); +/// // the underlying representation +/// assert_eq!(array.validity(), Some(&Bitmap::from([true, false, true]))); +/// assert_eq!(array.values(), &Buffer::from(b"hithere".to_vec())); +/// assert_eq!(array.offsets().buffer(), &Buffer::from(vec![0, 2, 2, 2 + 5])); +/// # } +/// ``` +/// +/// # Generic parameter +/// The generic parameter [`Offset`] can only be `i32` or `i64` and tradeoffs maximum array length with +/// memory usage: +/// * the sum of lengths of all elements cannot exceed `Offset::MAX` +/// * the total size of the underlying data is `array.len() * size_of::() + sum of lengths of all elements` +/// +/// # Safety +/// The following invariants hold: +/// * Two consecutives `offsets` casted (`as`) to `usize` are valid slices of `values`. +/// * A slice of `values` taken from two consecutives `offsets` is valid `utf8`. +/// * `len` is equal to `validity.len()`, when defined. +#[derive(Clone)] +pub struct Utf8Array { + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, +} + +// constructors +impl Utf8Array { + /// Returns a [`Utf8Array`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Result { + try_check_utf8(&offsets, &values)?; + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "Utf8Array can only be initialized with DataType::Utf8 or DataType::LargeUtf8", + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Returns a [`Utf8Array`] from a slice of `&str`. + /// + /// A convenience method that uses [`Self::from_trusted_len_values_iter`]. + pub fn from_slice, P: AsRef<[T]>>(slice: P) -> Self { + Self::from_trusted_len_values_iter(slice.as_ref().iter()) + } + + /// Returns a new [`Utf8Array`] from a slice of `&str`. + /// + /// A convenience method that uses [`Self::from_trusted_len_iter`]. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + MutableUtf8Array::::from(slice).into() + } + + /// Returns an iterator of `Option<&str>` + pub fn iter(&self) -> ZipValidity<&str, Utf8ValuesIter, BitmapIter> { + ZipValidity::new_with_validity(self.values_iter(), self.validity()) + } + + /// Returns an iterator of `&str` + pub fn values_iter(&self) -> Utf8ValuesIter { + Utf8ValuesIter::new(self) + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end_unchecked(i); + + // soundness: the invariant of the struct + let slice = self.values.get_unchecked(start..end); + + // soundness: the invariant of the struct + std::str::from_utf8_unchecked(slice) + } + + /// Returns the element at index `i` or `None` if it is null + /// # Panics + /// iff `i >= self.len()` + #[inline] + pub fn get(&self, i: usize) -> Option<&str> { + if !self.is_null(i) { + // soundness: Array::is_null panics if i >= self.len + unsafe { Some(self.value_unchecked(i)) } + } else { + None + } + } + + /// Returns the [`DataType`] of this array. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns the values of this [`Utf8Array`]. + #[inline] + pub fn values(&self) -> &Buffer { + &self.values + } + + /// Returns the offsets of this [`Utf8Array`]. + #[inline] + pub fn offsets(&self) -> &OffsetsBuffer { + &self.offsets + } + + /// The optional validity. + #[inline] + pub fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + /// Slices this [`Utf8Array`]. + /// # Implementation + /// This function is `O(1)`. + /// # Panics + /// iff `offset + length > self.len()`. + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new array cannot exceed the arrays' length" + ); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices this [`Utf8Array`]. + /// # Implementation + /// This function is `O(1)` + /// # Safety + /// The caller must ensure that `offset + length <= self.len()`. + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.validity.as_mut().and_then(|bitmap| { + bitmap.slice_unchecked(offset, length); + (bitmap.unset_bits() > 0).then(|| bitmap) + }); + self.offsets.slice_unchecked(offset, length + 1); + } + + impl_sliced!(); + impl_mut_validity!(); + impl_into_array!(); + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (DataType, OffsetsBuffer, Buffer, Option) { + let Self { + data_type, + offsets, + values, + validity, + } = self; + (data_type, offsets, values, validity) + } + + /// Try to convert this `Utf8Array` to a `MutableUtf8Array` + #[must_use] + pub fn into_mut(self) -> Either> { + use Either::*; + if let Some(bitmap) = self.validity { + match bitmap.into_mut() { + // Safety: invariants are preserved + Left(bitmap) => Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + self.offsets, + self.values, + Some(bitmap), + ) + }), + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + // Safety: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + offsets, + values, + Some(mutable_bitmap.into()), + ) + }) + }, + (Left(values), Right(offsets)) => { + // Safety: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + offsets.into(), + values, + Some(mutable_bitmap.into()), + ) + }) + }, + (Right(values), Left(offsets)) => { + // Safety: invariants are preserved + Left(unsafe { + Utf8Array::new_unchecked( + self.data_type, + offsets, + values.into(), + Some(mutable_bitmap.into()), + ) + }) + }, + (Right(values), Right(offsets)) => Right(unsafe { + MutableUtf8Array::new_unchecked( + self.data_type, + offsets, + values, + Some(mutable_bitmap), + ) + }), + }, + } + } else { + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(unsafe { Utf8Array::new_unchecked(self.data_type, offsets, values, None) }) + }, + (Left(values), Right(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.data_type, offsets.into(), values, None) + }), + (Right(values), Left(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.data_type, offsets, values.into(), None) + }), + (Right(values), Right(offsets)) => Right(unsafe { + MutableUtf8Array::new_unchecked(self.data_type, offsets, values, None) + }), + } + } + } + + /// Returns a new empty [`Utf8Array`]. + /// + /// The array is guaranteed to have no elements nor validity. + #[inline] + pub fn new_empty(data_type: DataType) -> Self { + unsafe { Self::new_unchecked(data_type, OffsetsBuffer::new(), Buffer::new(), None) } + } + + /// Returns a new [`Utf8Array`] whose all slots are null / `None`. + #[inline] + pub fn new_null(data_type: DataType, length: usize) -> Self { + Self::new( + data_type, + Offsets::new_zeroed(length).into(), + Buffer::new(), + Some(Bitmap::new_zeroed(length)), + ) + } + + /// Returns a default [`DataType`] of this array, which depends on the generic parameter `O`: `DataType::Utf8` or `DataType::LargeUtf8` + pub fn default_data_type() -> DataType { + if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + } + } + + /// Creates a new [`Utf8Array`] without checking for offsets monotinicity nor utf8-validity + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// # Safety + /// This function is unsound iff: + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn try_new_unchecked( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Result { + try_check_offsets_bounds(&offsets, values.len())?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != offsets.len_proxy()) + { + return Err(Error::oos( + "validity mask length must match the number of values", + )); + } + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "BinaryArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8", + )); + } + + Ok(Self { + data_type, + offsets, + values, + validity, + }) + } + + /// Creates a new [`Utf8Array`]. + /// # Panics + /// This function panics iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn new( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new(data_type, offsets, values, validity).unwrap() + } + + /// Creates a new [`Utf8Array`] without checking for offsets monotinicity. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// # Safety + /// This function is unsound iff: + /// * the offsets are not monotonically increasing + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn new_unchecked( + data_type: DataType, + offsets: OffsetsBuffer, + values: Buffer, + validity: Option, + ) -> Self { + Self::try_new_unchecked(data_type, offsets, values, validity).unwrap() + } + + /// Returns a (non-null) [`Utf8Array`] created from a [`TrustedLen`] of `&str`. + /// # Implementation + /// This function is `O(N)` + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + MutableUtf8Array::::from_trusted_len_values_iter(iterator).into() + } + + /// Creates a new [`Utf8Array`] from a [`Iterator`] of `&str`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableUtf8Array::::from_iter_values(iterator).into() + } + + /// Creates a [`Utf8Array`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator>, + { + MutableUtf8Array::::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a [`Utf8Array`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen>, + { + MutableUtf8Array::::from_trusted_len_iter(iterator).into() + } + + /// Creates a [`Utf8Array`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef, + I: IntoIterator, E>>, + { + MutableUtf8Array::::try_from_trusted_len_iter_unchecked(iterator).map(|x| x.into()) + } + + /// Creates a [`Utf8Array`] from an fallible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iter: I) -> std::result::Result + where + P: AsRef, + I: TrustedLen, E>>, + { + MutableUtf8Array::::try_from_trusted_len_iter(iter).map(|x| x.into()) + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity Bitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } +} + +impl Array for Utf8Array { + impl_common_array!(); + + fn validity(&self) -> Option<&Bitmap> { + self.validity.as_ref() + } + + #[inline] + fn with_validity(&self, validity: Option) -> Box { + Box::new(self.clone().with_validity(validity)) + } +} + +unsafe impl GenericBinaryArray for Utf8Array { + #[inline] + fn values(&self) -> &[u8] { + self.values() + } + + #[inline] + fn offsets(&self) -> &[O] { + self.offsets().buffer() + } +} + +impl Default for Utf8Array { + fn default() -> Self { + let data_type = if O::IS_LARGE { + DataType::LargeUtf8 + } else { + DataType::Utf8 + }; + Utf8Array::new(data_type, Default::default(), Default::default(), None) + } +} diff --git a/crates/nano-arrow/src/array/utf8/mutable.rs b/crates/nano-arrow/src/array/utf8/mutable.rs new file mode 100644 index 000000000000..3fc47b3eae1d --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/mutable.rs @@ -0,0 +1,549 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{MutableUtf8ValuesArray, MutableUtf8ValuesIter, StrAsBytes, Utf8Array}; +use crate::array::physical_binary::*; +use crate::array::{Array, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::utils::{BitmapIter, ZipValidity}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`Utf8Array`]. It differs +/// from [`MutableUtf8ValuesArray`] in that it can build nullable [`Utf8Array`]s. +#[derive(Debug, Clone)] +pub struct MutableUtf8Array { + values: MutableUtf8ValuesArray, + validity: Option, +} + +impl From> for Utf8Array { + fn from(other: MutableUtf8Array) -> Self { + let validity = other.validity.and_then(|x| { + let validity: Option = x.into(); + validity + }); + let array: Utf8Array = other.values.into(); + array.with_validity(validity) + } +} + +impl Default for MutableUtf8Array { + fn default() -> Self { + Self::new() + } +} + +impl MutableUtf8Array { + /// Initializes a new empty [`MutableUtf8Array`]. + pub fn new() -> Self { + Self { + values: Default::default(), + validity: None, + } + } + + /// Returns a [`MutableUtf8Array`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * the validity's length is not equal to `offsets.len()`. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new( + data_type: DataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> Result { + let values = MutableUtf8ValuesArray::try_new(data_type, offsets, values)?; + + if validity + .as_ref() + .map_or(false, |validity| validity.len() != values.len()) + { + return Err(Error::oos( + "validity's length must be equal to the number of values", + )); + } + + Ok(Self { values, validity }) + } + + /// Create a [`MutableUtf8Array`] out of low-end APIs. + /// # Safety + /// The caller must ensure that every value between offsets is a valid utf8. + /// # Panics + /// This function panics iff: + /// * The `offsets` and `values` are inconsistent + /// * The validity is not `None` and its length is different from `offsets`'s length minus one. + pub unsafe fn new_unchecked( + data_type: DataType, + offsets: Offsets, + values: Vec, + validity: Option, + ) -> Self { + let values = MutableUtf8ValuesArray::new_unchecked(data_type, offsets, values); + if let Some(ref validity) = validity { + assert_eq!(values.len(), validity.len()); + } + Self { values, validity } + } + + /// Creates a new [`MutableUtf8Array`] from a slice of optional `&[u8]`. + // Note: this can't be `impl From` because Rust does not allow double `AsRef` on it. + pub fn from, P: AsRef<[Option]>>(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().map(|x| x.as_ref())) + } + + fn default_data_type() -> DataType { + Utf8Array::::default_data_type() + } + + /// Initializes a new [`MutableUtf8Array`] with a pre-allocated capacity of slots. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableUtf8Array`] with a pre-allocated capacity of slots and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + values: MutableUtf8ValuesArray::with_capacities(capacity, values), + validity: None, + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.values.reserve(additional, additional_values); + if let Some(x) = self.validity.as_mut() { + x.reserve(additional) + } + } + + /// Reserves `additional` elements and `additional_values` on the values buffer. + pub fn capacity(&self) -> usize { + self.values.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.values.len() + } + + /// Pushes a new element to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: Option) { + self.try_push(value).unwrap() + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + self.values.value(i) + } + + /// Returns the value of the element at index `i`, ignoring the array's validity. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + self.values.value_unchecked(i) + } + + /// Pop the last entry from [`MutableUtf8Array`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option { + let value = self.values.pop()?; + self.validity + .as_mut() + .map(|x| x.pop()?.then(|| ())) + .unwrap_or_else(|| Some(())) + .map(|_| value) + } + + fn init_validity(&mut self) { + let mut validity = MutableBitmap::with_capacity(self.values.capacity()); + validity.extend_constant(self.len(), true); + validity.set(self.len() - 1, false); + self.validity = Some(validity); + } + + /// Returns an iterator of `Option<&str>` + pub fn iter(&self) -> ZipValidity<&str, MutableUtf8ValuesIter, BitmapIter> { + ZipValidity::new(self.values_iter(), self.validity.as_ref().map(|x| x.iter())) + } + + /// Converts itself into an [`Array`]. + pub fn into_arc(self) -> Arc { + let a: Utf8Array = self.into(); + Arc::new(a) + } + + /// Shrinks the capacity of the [`MutableUtf8Array`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + if let Some(validity) = &mut self.validity { + validity.shrink_to_fit() + } + } + + /// Extract the low-end APIs from the [`MutableUtf8Array`]. + pub fn into_data(self) -> (DataType, Offsets, Vec, Option) { + let (data_type, offsets, values) = self.values.into_inner(); + (data_type, offsets, values, self.validity) + } + + /// Returns an iterator of `&str` + pub fn values_iter(&self) -> MutableUtf8ValuesIter { + self.values.iter() + } + + /// Sets the validity. + /// # Panic + /// Panics iff the validity's len is not equal to the existing values' length. + pub fn set_validity(&mut self, validity: Option) { + if let Some(validity) = &validity { + assert_eq!(self.values.len(), validity.len()) + } + self.validity = validity; + } + + /// Applies a function `f` to the validity of this array. + /// + /// This is an API to leverage clone-on-write + /// # Panics + /// This function panics if the function `f` modifies the length of the [`Bitmap`]. + pub fn apply_validity MutableBitmap>(&mut self, f: F) { + if let Some(validity) = std::mem::take(&mut self.validity) { + self.set_validity(Some(f(validity))) + } + } +} + +impl MutableUtf8Array { + /// returns its values. + pub fn values(&self) -> &Vec { + self.values.values() + } + + /// returns its offsets. + pub fn offsets(&self) -> &Offsets { + self.values.offsets() + } +} + +impl MutableArray for MutableUtf8Array { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let array: Utf8Array = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: Utf8Array = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + if O::IS_LARGE { + &DataType::LargeUtf8 + } else { + &DataType::Utf8 + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&str>(None) + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator> for MutableUtf8Array { + fn from_iter>>(iter: I) -> Self { + Self::try_from_iter(iter).unwrap() + } +} + +impl MutableUtf8Array { + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_trusted_len_values(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_values_unchecked(iterator) } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of values. + /// This differs from `extended_trusted_len` which accepts iterator of optional values. + #[inline] + pub fn extend_values(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let length = self.values.len(); + self.values.extend(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of values of trusted len. + /// This differs from `extended_trusted_len_unchecked` which accepts iterator of optional + /// values. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_values_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let length = self.values.len(); + self.values.extend_trusted_len_unchecked(iterator); + let additional = self.values.len() - length; + + if let Some(validity) = self.validity.as_mut() { + validity.extend_constant(additional, true); + } + } + + /// Extends the [`MutableUtf8Array`] from an iterator of trusted len. + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen>, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableUtf8Array`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator>, + { + if self.validity.is_none() { + let mut validity = MutableBitmap::new(); + validity.extend_constant(self.len(), true); + self.validity = Some(validity); + } + + self.values + .extend_from_trusted_len_iter(self.validity.as_mut().unwrap(), iterator); + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator>, + { + let iterator = iterator.map(|x| x.map(StrAsBytes)); + let (validity, offsets, values) = trusted_len_unzip(iterator); + + // soundness: P is `str` + Self::new_unchecked(Self::default_data_type(), offsets, values, validity) + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen>, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a [`MutableUtf8Array`] from an iterator of trusted length of `&str`. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_values_iter_unchecked, I: Iterator>( + iterator: I, + ) -> Self { + MutableUtf8ValuesArray::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a new [`MutableUtf8Array`] from a [`TrustedLen`] of `&str`. + #[inline] + pub fn from_trusted_len_values_iter, I: TrustedLen>( + iterator: I, + ) -> Self { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_values_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableUtf8Array`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + fn try_from_iter, I: IntoIterator>>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked( + iterator: I, + ) -> std::result::Result + where + P: AsRef, + I: IntoIterator, E>>, + { + let iterator = iterator.into_iter(); + + let iterator = iterator.map(|x| x.map(|x| x.map(StrAsBytes))); + let (validity, offsets, values) = try_trusted_len_unzip(iterator)?; + + // soundness: P is `str` + Ok(Self::new_unchecked( + Self::default_data_type(), + offsets, + values, + validity, + )) + } + + /// Creates a [`MutableUtf8Array`] from an falible iterator of trusted length. + #[inline] + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + P: AsRef, + I: TrustedLen, E>>, + { + // soundness: I: TrustedLen + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableUtf8Array`] from a [`Iterator`] of `&str`. + pub fn from_iter_values, I: Iterator>(iterator: I) -> Self { + MutableUtf8ValuesArray::from_iter(iterator).into() + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator, E>>, + T: AsRef, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend> for MutableUtf8Array { + fn extend>>(&mut self, iter: I) { + self.try_extend(iter).unwrap(); + } +} + +impl> TryExtend> for MutableUtf8Array { + fn try_extend>>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush> for MutableUtf8Array { + #[inline] + fn try_push(&mut self, value: Option) -> Result<()> { + match value { + Some(value) => { + self.values.try_push(value.as_ref())?; + + match &mut self.validity { + Some(validity) => validity.push(true), + None => {}, + } + }, + None => { + self.values.push(""); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + }, + } + Ok(()) + } +} + +impl PartialEq for MutableUtf8Array { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl TryExtendFromSelf for MutableUtf8Array { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + extend_validity(self.len(), &mut self.validity, &other.validity); + + self.values.try_extend_from_self(&other.values) + } +} diff --git a/crates/nano-arrow/src/array/utf8/mutable_values.rs b/crates/nano-arrow/src/array/utf8/mutable_values.rs new file mode 100644 index 000000000000..8810d30febb5 --- /dev/null +++ b/crates/nano-arrow/src/array/utf8/mutable_values.rs @@ -0,0 +1,407 @@ +use std::iter::FromIterator; +use std::sync::Arc; + +use super::{MutableUtf8Array, StrAsBytes, Utf8Array}; +use crate::array::physical_binary::*; +use crate::array::specification::{try_check_offsets_bounds, try_check_utf8}; +use crate::array::{Array, ArrayValuesIter, MutableArray, TryExtend, TryExtendFromSelf, TryPush}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; +use crate::trusted_len::TrustedLen; + +/// A [`MutableArray`] that builds a [`Utf8Array`]. It differs +/// from [`MutableUtf8Array`] in that it builds non-null [`Utf8Array`]. +#[derive(Debug, Clone)] +pub struct MutableUtf8ValuesArray { + data_type: DataType, + offsets: Offsets, + values: Vec, +} + +impl From> for Utf8Array { + fn from(other: MutableUtf8ValuesArray) -> Self { + // Safety: + // `MutableUtf8ValuesArray` has the same invariants as `Utf8Array` and thus + // `Utf8Array` can be safely created from `MutableUtf8ValuesArray` without checks. + unsafe { + Utf8Array::::new_unchecked( + other.data_type, + other.offsets.into(), + other.values.into(), + None, + ) + } + } +} + +impl From> for MutableUtf8Array { + fn from(other: MutableUtf8ValuesArray) -> Self { + // Safety: + // `MutableUtf8ValuesArray` has the same invariants as `MutableUtf8Array` + unsafe { + MutableUtf8Array::::new_unchecked(other.data_type, other.offsets, other.values, None) + } + } +} + +impl Default for MutableUtf8ValuesArray { + fn default() -> Self { + Self::new() + } +} + +impl MutableUtf8ValuesArray { + /// Returns an empty [`MutableUtf8ValuesArray`]. + pub fn new() -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::new(), + values: Vec::::new(), + } + } + + /// Returns a [`MutableUtf8ValuesArray`] created from its internal representation. + /// + /// # Errors + /// This function returns an error iff: + /// * The last offset is not equal to the values' length. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is not equal to either `Utf8` or `LargeUtf8`. + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(N)` - checking utf8 is `O(N)` + pub fn try_new(data_type: DataType, offsets: Offsets, values: Vec) -> Result { + try_check_utf8(&offsets, &values)?; + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + return Err(Error::oos( + "MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8", + )); + } + + Ok(Self { + data_type, + offsets, + values, + }) + } + + /// Returns a [`MutableUtf8ValuesArray`] created from its internal representation. + /// + /// # Panic + /// This function does not panic iff: + /// * The last offset is equal to the values' length. + /// * The `data_type`'s [`crate::datatypes::PhysicalType`] is equal to either `Utf8` or `LargeUtf8`. + /// # Safety + /// This function is safe iff: + /// * the offsets are monotonically increasing + /// * The `values` between two consecutive `offsets` are not valid utf8 + /// # Implementation + /// This function is `O(1)` + pub unsafe fn new_unchecked(data_type: DataType, offsets: Offsets, values: Vec) -> Self { + try_check_offsets_bounds(&offsets, values.len()) + .expect("The length of the values must be equal to the last offset value"); + + if data_type.to_physical_type() != Self::default_data_type().to_physical_type() { + panic!("MutableUtf8ValuesArray can only be initialized with DataType::Utf8 or DataType::LargeUtf8") + } + + Self { + data_type, + offsets, + values, + } + } + + /// Returns the default [`DataType`] of this container: [`DataType::Utf8`] or [`DataType::LargeUtf8`] + /// depending on the generic [`Offset`]. + pub fn default_data_type() -> DataType { + Utf8Array::::default_data_type() + } + + /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items. + pub fn with_capacity(capacity: usize) -> Self { + Self::with_capacities(capacity, 0) + } + + /// Initializes a new [`MutableUtf8ValuesArray`] with a pre-allocated capacity of items and values. + pub fn with_capacities(capacity: usize, values: usize) -> Self { + Self { + data_type: Self::default_data_type(), + offsets: Offsets::::with_capacity(capacity), + values: Vec::::with_capacity(values), + } + } + + /// returns its values. + #[inline] + pub fn values(&self) -> &Vec { + &self.values + } + + /// returns its offsets. + #[inline] + pub fn offsets(&self) -> &Offsets { + &self.offsets + } + + /// Reserves `additional` elements and `additional_values` on the values. + #[inline] + pub fn reserve(&mut self, additional: usize, additional_values: usize) { + self.offsets.reserve(additional + 1); + self.values.reserve(additional_values); + } + + /// Returns the capacity in number of items + pub fn capacity(&self) -> usize { + self.offsets.capacity() + } + + /// Returns the length of this array + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + /// Pushes a new item to the array. + /// # Panic + /// This operation panics iff the length of all values (in bytes) exceeds `O` maximum value. + #[inline] + pub fn push>(&mut self, value: T) { + self.try_push(value).unwrap() + } + + /// Pop the last entry from [`MutableUtf8ValuesArray`]. + /// This function returns `None` iff this array is empty. + pub fn pop(&mut self) -> Option { + if self.len() == 0 { + return None; + } + self.offsets.pop()?; + let start = self.offsets.last().to_usize(); + let value = self.values.split_off(start); + // Safety: utf8 is validated on initialization + Some(unsafe { String::from_utf8_unchecked(value) }) + } + + /// Returns the value of the element at index `i`. + /// # Panic + /// This function panics iff `i >= self.len`. + #[inline] + pub fn value(&self, i: usize) -> &str { + assert!(i < self.len()); + unsafe { self.value_unchecked(i) } + } + + /// Returns the value of the element at index `i`. + /// # Safety + /// This function is safe iff `i < self.len`. + #[inline] + pub unsafe fn value_unchecked(&self, i: usize) -> &str { + // soundness: the invariant of the function + let (start, end) = self.offsets.start_end(i); + + // soundness: the invariant of the struct + let slice = self.values.get_unchecked(start..end); + + // soundness: the invariant of the struct + std::str::from_utf8_unchecked(slice) + } + + /// Returns an iterator of `&str` + pub fn iter(&self) -> ArrayValuesIter { + ArrayValuesIter::new(self) + } + + /// Shrinks the capacity of the [`MutableUtf8ValuesArray`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.values.shrink_to_fit(); + self.offsets.shrink_to_fit(); + } + + /// Extract the low-end APIs from the [`MutableUtf8ValuesArray`]. + pub fn into_inner(self) -> (DataType, Offsets, Vec) { + (self.data_type, self.offsets, self.values) + } +} + +impl MutableArray for MutableUtf8ValuesArray { + fn len(&self) -> usize { + self.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let array: Utf8Array = std::mem::take(self).into(); + array.boxed() + } + + fn as_arc(&mut self) -> Arc { + let array: Utf8Array = std::mem::take(self).into(); + array.arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push::<&str>("") + } + + fn reserve(&mut self, additional: usize) { + self.reserve(additional, 0) + } + + fn shrink_to_fit(&mut self) { + self.shrink_to_fit() + } +} + +impl> FromIterator

for MutableUtf8ValuesArray { + fn from_iter>(iter: I) -> Self { + let (offsets, values) = values_iter(iter.into_iter().map(StrAsBytes)); + // soundness: T: AsRef and offsets are monotonically increasing + unsafe { Self::new_unchecked(Self::default_data_type(), offsets, values) } + } +} + +impl MutableUtf8ValuesArray { + pub(crate) unsafe fn extend_from_trusted_len_iter( + &mut self, + validity: &mut MutableBitmap, + iterator: I, + ) where + P: AsRef, + I: Iterator>, + { + let iterator = iterator.map(|x| x.map(StrAsBytes)); + extend_from_trusted_len_iter(&mut self.offsets, &mut self.values, validity, iterator); + } + + /// Extends the [`MutableUtf8ValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn extend_trusted_len(&mut self, iterator: I) + where + P: AsRef, + I: TrustedLen, + { + unsafe { self.extend_trusted_len_unchecked(iterator) } + } + + /// Extends [`MutableUtf8ValuesArray`] from an iterator of trusted len. + /// # Safety + /// The iterator must be trusted len. + #[inline] + pub unsafe fn extend_trusted_len_unchecked(&mut self, iterator: I) + where + P: AsRef, + I: Iterator, + { + let iterator = iterator.map(StrAsBytes); + extend_from_trusted_len_values_iter(&mut self.offsets, &mut self.values, iterator); + } + + /// Creates a [`MutableUtf8ValuesArray`] from a [`TrustedLen`] + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + P: AsRef, + I: TrustedLen, + { + // soundness: I is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Returns a new [`MutableUtf8ValuesArray`] from an iterator of trusted length. + /// # Safety + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + P: AsRef, + I: Iterator, + { + let iterator = iterator.map(StrAsBytes); + let (offsets, values) = trusted_len_values_iter(iterator); + + // soundness: P is `str` and offsets are monotonically increasing + Self::new_unchecked(Self::default_data_type(), offsets, values) + } + + /// Returns a new [`MutableUtf8ValuesArray`] from an iterator. + /// # Error + /// This operation errors iff the total length in bytes on the iterator exceeds `O`'s maximum value. + /// (`i32::MAX` or `i64::MAX` respectively). + pub fn try_from_iter, I: IntoIterator>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut array = Self::with_capacity(lower); + for item in iterator { + array.try_push(item)?; + } + Ok(array) + } + + /// Extend with a fallible iterator + pub fn extend_fallible(&mut self, iter: I) -> std::result::Result<(), E> + where + E: std::error::Error, + I: IntoIterator>, + T: AsRef, + { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| { + self.push(x?); + Ok(()) + }) + } +} + +impl> Extend for MutableUtf8ValuesArray { + fn extend>(&mut self, iter: I) { + extend_from_values_iter( + &mut self.offsets, + &mut self.values, + iter.into_iter().map(StrAsBytes), + ); + } +} + +impl> TryExtend for MutableUtf8ValuesArray { + fn try_extend>(&mut self, iter: I) -> Result<()> { + let mut iter = iter.into_iter(); + self.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| self.try_push(x)) + } +} + +impl> TryPush for MutableUtf8ValuesArray { + #[inline] + fn try_push(&mut self, value: T) -> Result<()> { + let bytes = value.as_ref().as_bytes(); + self.values.extend_from_slice(bytes); + self.offsets.try_push_usize(bytes.len()) + } +} + +impl TryExtendFromSelf for MutableUtf8ValuesArray { + fn try_extend_from_self(&mut self, other: &Self) -> Result<()> { + self.values.extend_from_slice(&other.values); + self.offsets.try_extend_from_self(&other.offsets) + } +} diff --git a/crates/nano-arrow/src/bitmap/assign_ops.rs b/crates/nano-arrow/src/bitmap/assign_ops.rs new file mode 100644 index 000000000000..b4d3702c69eb --- /dev/null +++ b/crates/nano-arrow/src/bitmap/assign_ops.rs @@ -0,0 +1,190 @@ +use super::utils::{BitChunk, BitChunkIterExact, BitChunksExact}; +use crate::bitmap::{Bitmap, MutableBitmap}; + +/// Applies a function to every bit of this [`MutableBitmap`] in chunks +/// +/// This function can be for operations like `!` to a [`MutableBitmap`]. +pub fn unary_assign T>(bitmap: &mut MutableBitmap, op: F) { + let mut chunks = bitmap.bitchunks_exact_mut::(); + + chunks.by_ref().for_each(|chunk| { + let new_chunk: T = match (chunk as &[u8]).try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + let new_chunk = op(new_chunk); + chunk.copy_from_slice(new_chunk.to_ne_bytes().as_ref()); + }); + + if chunks.remainder().is_empty() { + return; + } + let mut new_remainder = T::zero().to_ne_bytes(); + chunks + .remainder() + .iter() + .enumerate() + .for_each(|(index, b)| new_remainder[index] = *b); + new_remainder = op(T::from_ne_bytes(new_remainder)).to_ne_bytes(); + + let len = chunks.remainder().len(); + chunks + .remainder() + .copy_from_slice(&new_remainder.as_ref()[..len]); +} + +impl std::ops::Not for MutableBitmap { + type Output = Self; + + #[inline] + fn not(mut self) -> Self { + unary_assign(&mut self, |a: u64| !a); + self + } +} + +fn binary_assign_impl(lhs: &mut MutableBitmap, mut rhs: I, op: F) +where + I: BitChunkIterExact, + T: BitChunk, + F: Fn(T, T) -> T, +{ + let mut lhs_chunks = lhs.bitchunks_exact_mut::(); + + lhs_chunks + .by_ref() + .zip(rhs.by_ref()) + .for_each(|(lhs, rhs)| { + let new_chunk: T = match (lhs as &[u8]).try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + let new_chunk = op(new_chunk, rhs); + lhs.copy_from_slice(new_chunk.to_ne_bytes().as_ref()); + }); + + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs.remainder(); + if rem_lhs.is_empty() { + return; + } + let mut new_remainder = T::zero().to_ne_bytes(); + lhs_chunks + .remainder() + .iter() + .enumerate() + .for_each(|(index, b)| new_remainder[index] = *b); + new_remainder = op(T::from_ne_bytes(new_remainder), rem_rhs).to_ne_bytes(); + + let len = lhs_chunks.remainder().len(); + lhs_chunks + .remainder() + .copy_from_slice(&new_remainder.as_ref()[..len]); +} + +/// Apply a bitwise binary operation to a [`MutableBitmap`]. +/// +/// This function can be used for operations like `&=` to a [`MutableBitmap`]. +/// # Panics +/// This function panics iff `lhs.len() != `rhs.len()` +pub fn binary_assign(lhs: &mut MutableBitmap, rhs: &Bitmap, op: F) +where + F: Fn(T, T) -> T, +{ + assert_eq!(lhs.len(), rhs.len()); + + let (slice, offset, length) = rhs.as_slice(); + if offset == 0 { + let iter = BitChunksExact::::new(slice, length); + binary_assign_impl(lhs, iter, op) + } else { + let rhs_chunks = rhs.chunks::(); + binary_assign_impl(lhs, rhs_chunks, op) + } +} + +#[inline] +/// Compute bitwise OR operation in-place +fn or_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + if rhs.unset_bits() == 0 { + assert_eq!(lhs.len(), rhs.len()); + lhs.clear(); + lhs.extend_constant(rhs.len(), true); + } else if rhs.unset_bits() == rhs.len() { + // bitmap remains + } else { + binary_assign(lhs, rhs, |x: T, y| x | y) + } +} + +impl<'a> std::ops::BitOrAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitor_assign(&mut self, rhs: &'a Bitmap) { + or_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitOr<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitor(mut self, rhs: &'a Bitmap) -> Self { + or_assign::(&mut self, rhs); + self + } +} + +#[inline] +/// Compute bitwise `&` between `lhs` and `rhs`, assigning it to `lhs` +fn and_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + if rhs.unset_bits() == 0 { + // bitmap remains + } + if rhs.unset_bits() == rhs.len() { + assert_eq!(lhs.len(), rhs.len()); + lhs.clear(); + lhs.extend_constant(rhs.len(), false); + } else { + binary_assign(lhs, rhs, |x: T, y| x & y) + } +} + +impl<'a> std::ops::BitAndAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitand_assign(&mut self, rhs: &'a Bitmap) { + and_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitAnd<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitand(mut self, rhs: &'a Bitmap) -> Self { + and_assign::(&mut self, rhs); + self + } +} + +#[inline] +/// Compute bitwise XOR operation +fn xor_assign(lhs: &mut MutableBitmap, rhs: &Bitmap) { + binary_assign(lhs, rhs, |x: T, y| x ^ y) +} + +impl<'a> std::ops::BitXorAssign<&'a Bitmap> for &mut MutableBitmap { + #[inline] + fn bitxor_assign(&mut self, rhs: &'a Bitmap) { + xor_assign::(self, rhs) + } +} + +impl<'a> std::ops::BitXor<&'a Bitmap> for MutableBitmap { + type Output = Self; + + #[inline] + fn bitxor(mut self, rhs: &'a Bitmap) -> Self { + xor_assign::(&mut self, rhs); + self + } +} diff --git a/crates/nano-arrow/src/bitmap/bitmap_ops.rs b/crates/nano-arrow/src/bitmap/bitmap_ops.rs new file mode 100644 index 000000000000..c83e63255093 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/bitmap_ops.rs @@ -0,0 +1,268 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use super::utils::{BitChunk, BitChunkIterExact, BitChunksExact}; +use super::Bitmap; +use crate::bitmap::MutableBitmap; +use crate::trusted_len::TrustedLen; + +/// Creates a [Vec] from an [`Iterator`] of [`BitChunk`]. +/// # Safety +/// The iterator must be [`TrustedLen`]. +pub unsafe fn from_chunk_iter_unchecked>( + iterator: I, +) -> Vec { + let (_, upper) = iterator.size_hint(); + let upper = upper.expect("try_from_trusted_len_iter requires an upper limit"); + let len = upper * std::mem::size_of::(); + + let mut buffer = Vec::with_capacity(len); + + let mut dst = buffer.as_mut_ptr(); + for item in iterator { + let bytes = item.to_ne_bytes(); + for i in 0..std::mem::size_of::() { + std::ptr::write(dst, bytes[i]); + dst = dst.add(1); + } + } + assert_eq!( + dst.offset_from(buffer.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.set_len(len); + buffer +} + +/// Creates a [`Vec`] from a [`TrustedLen`] of [`BitChunk`]. +pub fn chunk_iter_to_vec>(iter: I) -> Vec { + unsafe { from_chunk_iter_unchecked(iter) } +} + +/// Apply a bitwise operation `op` to four inputs and return the result as a [`Bitmap`]. +pub fn quaternary(a1: &Bitmap, a2: &Bitmap, a3: &Bitmap, a4: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64, u64, u64) -> u64, +{ + assert_eq!(a1.len(), a2.len()); + assert_eq!(a1.len(), a3.len()); + assert_eq!(a1.len(), a4.len()); + let a1_chunks = a1.chunks(); + let a2_chunks = a2.chunks(); + let a3_chunks = a3.chunks(); + let a4_chunks = a4.chunks(); + + let rem_a1 = a1_chunks.remainder(); + let rem_a2 = a2_chunks.remainder(); + let rem_a3 = a3_chunks.remainder(); + let rem_a4 = a4_chunks.remainder(); + + let chunks = a1_chunks + .zip(a2_chunks) + .zip(a3_chunks) + .zip(a4_chunks) + .map(|(((a1, a2), a3), a4)| op(a1, a2, a3, a4)); + let buffer = + chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3, rem_a4)))); + + let length = a1.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to three inputs and return the result as a [`Bitmap`]. +pub fn ternary(a1: &Bitmap, a2: &Bitmap, a3: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64, u64) -> u64, +{ + assert_eq!(a1.len(), a2.len()); + assert_eq!(a1.len(), a3.len()); + let a1_chunks = a1.chunks(); + let a2_chunks = a2.chunks(); + let a3_chunks = a3.chunks(); + + let rem_a1 = a1_chunks.remainder(); + let rem_a2 = a2_chunks.remainder(); + let rem_a3 = a3_chunks.remainder(); + + let chunks = a1_chunks + .zip(a2_chunks) + .zip(a3_chunks) + .map(|((a1, a2), a3)| op(a1, a2, a3)); + + let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_a1, rem_a2, rem_a3)))); + + let length = a1.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to two inputs and return the result as a [`Bitmap`]. +pub fn binary(lhs: &Bitmap, rhs: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64, u64) -> u64, +{ + assert_eq!(lhs.len(), rhs.len()); + let lhs_chunks = lhs.chunks(); + let rhs_chunks = rhs.chunks(); + let rem_lhs = lhs_chunks.remainder(); + let rem_rhs = rhs_chunks.remainder(); + + let chunks = lhs_chunks + .zip(rhs_chunks) + .map(|(left, right)| op(left, right)); + + let buffer = chunk_iter_to_vec(chunks.chain(std::iter::once(op(rem_lhs, rem_rhs)))); + + let length = lhs.len(); + + Bitmap::from_u8_vec(buffer, length) +} + +fn unary_impl(iter: I, op: F, length: usize) -> Bitmap +where + I: BitChunkIterExact, + F: Fn(u64) -> u64, +{ + let rem = op(iter.remainder()); + + let iterator = iter.map(op).chain(std::iter::once(rem)); + + let buffer = chunk_iter_to_vec(iterator); + + Bitmap::from_u8_vec(buffer, length) +} + +/// Apply a bitwise operation `op` to one input and return the result as a [`Bitmap`]. +pub fn unary(lhs: &Bitmap, op: F) -> Bitmap +where + F: Fn(u64) -> u64, +{ + let (slice, offset, length) = lhs.as_slice(); + if offset == 0 { + let iter = BitChunksExact::::new(slice, length); + unary_impl(iter, op, lhs.len()) + } else { + let iter = lhs.chunks::(); + unary_impl(iter, op, lhs.len()) + } +} + +// create a new [`Bitmap`] semantically equal to ``bitmap`` but with an offset equal to ``offset`` +pub(crate) fn align(bitmap: &Bitmap, new_offset: usize) -> Bitmap { + let length = bitmap.len(); + + let bitmap: Bitmap = std::iter::repeat(false) + .take(new_offset) + .chain(bitmap.iter()) + .collect(); + + bitmap.sliced(new_offset, length) +} + +#[inline] +/// Compute bitwise AND operation +pub fn and(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + if lhs.unset_bits() == lhs.len() || rhs.unset_bits() == rhs.len() { + assert_eq!(lhs.len(), rhs.len()); + Bitmap::new_zeroed(lhs.len()) + } else { + binary(lhs, rhs, |x, y| x & y) + } +} + +#[inline] +/// Compute bitwise OR operation +pub fn or(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + if lhs.unset_bits() == 0 || rhs.unset_bits() == 0 { + assert_eq!(lhs.len(), rhs.len()); + let mut mutable = MutableBitmap::with_capacity(lhs.len()); + mutable.extend_constant(lhs.len(), true); + mutable.into() + } else { + binary(lhs, rhs, |x, y| x | y) + } +} + +#[inline] +/// Compute bitwise XOR operation +pub fn xor(lhs: &Bitmap, rhs: &Bitmap) -> Bitmap { + let lhs_nulls = lhs.unset_bits(); + let rhs_nulls = rhs.unset_bits(); + + // all false or all true + if lhs_nulls == rhs_nulls && rhs_nulls == rhs.len() || lhs_nulls == 0 && rhs_nulls == 0 { + assert_eq!(lhs.len(), rhs.len()); + Bitmap::new_zeroed(rhs.len()) + } + // all false and all true or vice versa + else if (lhs_nulls == 0 && rhs_nulls == rhs.len()) + || (lhs_nulls == lhs.len() && rhs_nulls == 0) + { + assert_eq!(lhs.len(), rhs.len()); + let mut mutable = MutableBitmap::with_capacity(lhs.len()); + mutable.extend_constant(lhs.len(), true); + mutable.into() + } else { + binary(lhs, rhs, |x, y| x ^ y) + } +} + +fn eq(lhs: &Bitmap, rhs: &Bitmap) -> bool { + if lhs.len() != rhs.len() { + return false; + } + + let mut lhs_chunks = lhs.chunks::(); + let mut rhs_chunks = rhs.chunks::(); + + let equal_chunks = lhs_chunks + .by_ref() + .zip(rhs_chunks.by_ref()) + .all(|(left, right)| left == right); + + if !equal_chunks { + return false; + } + let lhs_remainder = lhs_chunks.remainder_iter(); + let rhs_remainder = rhs_chunks.remainder_iter(); + lhs_remainder.zip(rhs_remainder).all(|(x, y)| x == y) +} + +impl PartialEq for Bitmap { + fn eq(&self, other: &Self) -> bool { + eq(self, other) + } +} + +impl<'a, 'b> BitOr<&'b Bitmap> for &'a Bitmap { + type Output = Bitmap; + + fn bitor(self, rhs: &'b Bitmap) -> Bitmap { + or(self, rhs) + } +} + +impl<'a, 'b> BitAnd<&'b Bitmap> for &'a Bitmap { + type Output = Bitmap; + + fn bitand(self, rhs: &'b Bitmap) -> Bitmap { + and(self, rhs) + } +} + +impl<'a, 'b> BitXor<&'b Bitmap> for &'a Bitmap { + type Output = Bitmap; + + fn bitxor(self, rhs: &'b Bitmap) -> Bitmap { + xor(self, rhs) + } +} + +impl Not for &Bitmap { + type Output = Bitmap; + + fn not(self) -> Bitmap { + unary(self, |a| !a) + } +} diff --git a/crates/nano-arrow/src/bitmap/immutable.rs b/crates/nano-arrow/src/bitmap/immutable.rs new file mode 100644 index 000000000000..91f1c5942b55 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/immutable.rs @@ -0,0 +1,471 @@ +use std::iter::FromIterator; +use std::ops::Deref; +use std::sync::Arc; + +use either::Either; + +use super::utils::{count_zeros, fmt, get_bit, get_bit_unchecked, BitChunk, BitChunks, BitmapIter}; +use super::{chunk_iter_to_vec, IntoIter, MutableBitmap}; +use crate::buffer::Bytes; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +/// An immutable container semantically equivalent to `Arc>` but represented as `Arc>` where +/// each boolean is represented as a single bit. +/// +/// # Examples +/// ``` +/// use arrow2::bitmap::{Bitmap, MutableBitmap}; +/// +/// let bitmap = Bitmap::from([true, false, true]); +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true]); +/// +/// // creation directly from bytes +/// let bitmap = Bitmap::try_new(vec![0b00001101], 5).unwrap(); +/// // note: the first bit is the left-most of the first byte +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true, true, false]); +/// // we can also get the slice: +/// assert_eq!(bitmap.as_slice(), ([0b00001101u8].as_ref(), 0, 5)); +/// // debug helps :) +/// assert_eq!(format!("{:?}", bitmap), "[0b___01101]".to_string()); +/// +/// // it supports copy-on-write semantics (to a `MutableBitmap`) +/// let bitmap: MutableBitmap = bitmap.into_mut().right().unwrap(); +/// assert_eq!(bitmap, MutableBitmap::from([true, false, true, true, false])); +/// +/// // slicing is 'O(1)' (data is shared) +/// let bitmap = Bitmap::try_new(vec![0b00001101], 5).unwrap(); +/// let mut sliced = bitmap.clone(); +/// sliced.slice(1, 4); +/// assert_eq!(sliced.as_slice(), ([0b00001101u8].as_ref(), 1, 4)); // 1 here is the offset: +/// assert_eq!(format!("{:?}", sliced), "[0b___0110_]".to_string()); +/// // when sliced (or cloned), it is no longer possible to `into_mut`. +/// let same: Bitmap = sliced.into_mut().left().unwrap(); +/// ``` +#[derive(Clone)] +pub struct Bitmap { + bytes: Arc>, + // both are measured in bits. They are used to bound the bitmap to a region of Bytes. + offset: usize, + length: usize, + // this is a cache: it is computed on initialization + unset_bits: usize, +} + +impl std::fmt::Debug for Bitmap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (bytes, offset, len) = self.as_slice(); + fmt(bytes, offset, len, f) + } +} + +impl Default for Bitmap { + fn default() -> Self { + MutableBitmap::new().into() + } +} + +pub(super) fn check(bytes: &[u8], offset: usize, length: usize) -> Result<(), Error> { + if offset + length > bytes.len().saturating_mul(8) { + return Err(Error::InvalidArgumentError(format!( + "The offset + length of the bitmap ({}) must be `<=` to the number of bytes times 8 ({})", + offset + length, + bytes.len().saturating_mul(8) + ))); + } + Ok(()) +} + +impl Bitmap { + /// Initializes an empty [`Bitmap`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Initializes a new [`Bitmap`] from vector of bytes and a length. + /// # Errors + /// This function errors iff `length > bytes.len() * 8` + #[inline] + pub fn try_new(bytes: Vec, length: usize) -> Result { + check(&bytes, 0, length)?; + let unset_bits = count_zeros(&bytes, 0, length); + Ok(Self { + length, + offset: 0, + bytes: Arc::new(bytes.into()), + unset_bits, + }) + } + + /// Returns the length of the [`Bitmap`]. + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether [`Bitmap`] is empty + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns a new iterator of `bool` over this bitmap + pub fn iter(&self) -> BitmapIter { + BitmapIter::new(&self.bytes, self.offset, self.length) + } + + /// Returns an iterator over bits in bit chunks [`BitChunk`]. + /// + /// This iterator is useful to operate over multiple bits via e.g. bitwise. + pub fn chunks(&self) -> BitChunks { + BitChunks::new(&self.bytes, self.offset, self.length) + } + + /// Returns the byte slice of this [`Bitmap`]. + /// + /// The returned tuple contains: + /// * `.1`: The byte slice, truncated to the start of the first bit. So the start of the slice + /// is within the first 8 bits. + /// * `.2`: The start offset in bits on a range `0 <= offsets < 8`. + /// * `.3`: The length in number of bits. + #[inline] + pub fn as_slice(&self) -> (&[u8], usize, usize) { + let start = self.offset / 8; + let len = (self.offset % 8 + self.length).saturating_add(7) / 8; + ( + &self.bytes[start..start + len], + self.offset % 8, + self.length, + ) + } + + /// Returns the number of unset bits on this [`Bitmap`]. + /// + /// Guaranteed to be `<= self.len()`. + /// # Implementation + /// This function is `O(1)` - the number of unset bits is computed when the bitmap is + /// created + pub const fn unset_bits(&self) -> usize { + self.unset_bits + } + + /// Returns the number of unset bits on this [`Bitmap`]. + #[inline] + #[deprecated(since = "0.13.0", note = "use `unset_bits` instead")] + pub fn null_count(&self) -> usize { + self.unset_bits + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Panic + /// Panics iff `offset + length > self.length`, i.e. if the offset and `length` + /// exceeds the allocated capacity of `self`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!(offset + length <= self.length); + unsafe { self.slice_unchecked(offset, length) } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Safety + /// The caller must ensure that `self.offset + offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + // first guard a no-op slice so that we don't do a bitcount + // if there isn't any data sliced + if !(offset == 0 && length == self.length) { + // count the smallest chunk + if length < self.length / 2 { + // count the null values in the slice + self.unset_bits = count_zeros(&self.bytes, self.offset + offset, length); + } else { + // subtract the null count of the chunks we slice off + let start_end = self.offset + offset + length; + let head_count = count_zeros(&self.bytes, self.offset, offset); + let tail_count = count_zeros(&self.bytes, start_end, self.length - length - offset); + self.unset_bits -= head_count + tail_count; + } + self.offset += offset; + self.length = length; + } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Panic + /// Panics iff `offset + length > self.length`, i.e. if the offset and `length` + /// exceeds the allocated capacity of `self`. + #[inline] + #[must_use] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!(offset + length <= self.length); + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Slices `self`, offsetting by `offset` and truncating up to `length` bits. + /// # Safety + /// The caller must ensure that `self.offset + offset + length <= self.len()` + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + self.slice_unchecked(offset, length); + self + } + + /// Returns whether the bit at position `i` is set. + /// # Panics + /// Panics iff `i >= self.len()`. + #[inline] + pub fn get_bit(&self, i: usize) -> bool { + get_bit(&self.bytes, self.offset + i) + } + + /// Unsafely returns whether the bit at position `i` is set. + /// # Safety + /// Unsound iff `i >= self.len()`. + #[inline] + pub unsafe fn get_bit_unchecked(&self, i: usize) -> bool { + get_bit_unchecked(&self.bytes, self.offset + i) + } + + /// Returns a pointer to the start of this [`Bitmap`] (ignores `offsets`) + /// This pointer is allocated iff `self.len() > 0`. + pub(crate) fn as_ptr(&self) -> *const u8 { + self.bytes.deref().as_ptr() + } + + /// Returns a pointer to the start of this [`Bitmap`] (ignores `offsets`) + /// This pointer is allocated iff `self.len() > 0`. + pub(crate) fn offset(&self) -> usize { + self.offset + } + + /// Converts this [`Bitmap`] to [`MutableBitmap`], returning itself if the conversion + /// is not possible + /// + /// This operation returns a [`MutableBitmap`] iff: + /// * this [`Bitmap`] is not an offsetted slice of another [`Bitmap`] + /// * this [`Bitmap`] has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * this [`Bitmap`] was not imported from the c data interface (FFI) + pub fn into_mut(mut self) -> Either { + match ( + self.offset, + Arc::get_mut(&mut self.bytes).and_then(|b| b.get_vec()), + ) { + (0, Some(v)) => { + let data = std::mem::take(v); + Either::Right(MutableBitmap::from_vec(data, self.length)) + }, + _ => Either::Left(self), + } + } + + /// Converts this [`Bitmap`] into a [`MutableBitmap`], cloning its internal + /// buffer if required (clone-on-write). + pub fn make_mut(self) -> MutableBitmap { + match self.into_mut() { + Either::Left(data) => { + if data.offset > 0 { + // re-align the bits (remove the offset) + let chunks = data.chunks::(); + let remainder = chunks.remainder(); + let vec = chunk_iter_to_vec(chunks.chain(std::iter::once(remainder))); + MutableBitmap::from_vec(vec, data.length) + } else { + MutableBitmap::from_vec(data.bytes.as_ref().to_vec(), data.length) + } + }, + Either::Right(data) => data, + } + } + + /// Initializes an new [`Bitmap`] filled with unset values. + #[inline] + pub fn new_zeroed(length: usize) -> Self { + // don't use `MutableBitmap::from_len_zeroed().into()` + // it triggers a bitcount + let bytes = vec![0; length.saturating_add(7) / 8]; + unsafe { Bitmap::from_inner_unchecked(Arc::new(bytes.into()), 0, length, length) } + } + + /// Counts the nulls (unset bits) starting from `offset` bits and for `length` bits. + #[inline] + pub fn null_count_range(&self, offset: usize, length: usize) -> usize { + count_zeros(&self.bytes, self.offset + offset, length) + } + + /// Creates a new [`Bitmap`] from a slice and length. + /// # Panic + /// Panics iff `length <= bytes.len() * 8` + #[inline] + pub fn from_u8_slice>(slice: T, length: usize) -> Self { + Bitmap::try_new(slice.as_ref().to_vec(), length).unwrap() + } + + /// Alias for `Bitmap::try_new().unwrap()` + /// This function is `O(1)` + /// # Panic + /// This function panics iff `length <= bytes.len() * 8` + #[inline] + pub fn from_u8_vec(vec: Vec, length: usize) -> Self { + Bitmap::try_new(vec, length).unwrap() + } + + /// Returns whether the bit at position `i` is set. + #[inline] + pub fn get(&self, i: usize) -> Option { + if i < self.len() { + Some(unsafe { self.get_bit_unchecked(i) }) + } else { + None + } + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (Arc>, usize, usize, usize) { + let Self { + bytes, + offset, + length, + unset_bits, + } = self; + (bytes, offset, length, unset_bits) + } + + /// Creates a `[Bitmap]` from its internal representation. + /// This is the inverted from `[Bitmap::into_inner]` + /// + /// # Safety + /// The invariants of this struct must be upheld + pub unsafe fn from_inner( + bytes: Arc>, + offset: usize, + length: usize, + unset_bits: usize, + ) -> Result { + check(&bytes, offset, length)?; + Ok(Self { + bytes, + offset, + length, + unset_bits, + }) + } + + /// Creates a `[Bitmap]` from its internal representation. + /// This is the inverted from `[Bitmap::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked( + bytes: Arc>, + offset: usize, + length: usize, + unset_bits: usize, + ) -> Self { + Self { + bytes, + offset, + length, + unset_bits, + } + } +} + +impl> From

for Bitmap { + fn from(slice: P) -> Self { + Self::from_trusted_len_iter(slice.as_ref().iter().copied()) + } +} + +impl FromIterator for Bitmap { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + MutableBitmap::from_iter(iter).into() + } +} + +impl Bitmap { + /// Creates a new [`Bitmap`] from an iterator of booleans. + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked>(iterator: I) -> Self { + MutableBitmap::from_trusted_len_iter_unchecked(iterator).into() + } + + /// Creates a new [`Bitmap`] from an iterator of booleans. + #[inline] + pub fn from_trusted_len_iter>(iterator: I) -> Self { + MutableBitmap::from_trusted_len_iter(iterator).into() + } + + /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + #[inline] + pub fn try_from_trusted_len_iter>>( + iterator: I, + ) -> std::result::Result { + Ok(MutableBitmap::try_from_trusted_len_iter(iterator)?.into()) + } + + /// Creates a new [`Bitmap`] from a fallible iterator of booleans. + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn try_from_trusted_len_iter_unchecked< + E, + I: Iterator>, + >( + iterator: I, + ) -> std::result::Result { + Ok(MutableBitmap::try_from_trusted_len_iter_unchecked(iterator)?.into()) + } + + /// Create a new [`Bitmap`] from an arrow [`NullBuffer`] + /// + /// [`NullBuffer`]: arrow_buffer::buffer::NullBuffer + #[cfg(feature = "arrow")] + pub fn from_null_buffer(value: arrow_buffer::buffer::NullBuffer) -> Self { + let offset = value.offset(); + let length = value.len(); + let unset_bits = value.null_count(); + Self { + offset, + length, + unset_bits, + bytes: Arc::new(crate::buffer::to_bytes(value.buffer().clone())), + } + } +} + +impl<'a> IntoIterator for &'a Bitmap { + type Item = bool; + type IntoIter = BitmapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitmapIter::<'a>::new(&self.bytes, self.offset, self.length) + } +} + +impl IntoIterator for Bitmap { + type Item = bool; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +#[cfg(feature = "arrow")] +impl From for arrow_buffer::buffer::NullBuffer { + fn from(value: Bitmap) -> Self { + let null_count = value.unset_bits; + let buffer = crate::buffer::to_buffer(value.bytes); + let buffer = arrow_buffer::buffer::BooleanBuffer::new(buffer, value.offset, value.length); + // Safety: null count is accurate + unsafe { arrow_buffer::buffer::NullBuffer::new_unchecked(buffer, null_count) } + } +} diff --git a/crates/nano-arrow/src/bitmap/iterator.rs b/crates/nano-arrow/src/bitmap/iterator.rs new file mode 100644 index 000000000000..93ca7fb8576a --- /dev/null +++ b/crates/nano-arrow/src/bitmap/iterator.rs @@ -0,0 +1,68 @@ +use super::Bitmap; +use crate::trusted_len::TrustedLen; + +/// This crates' equivalent of [`std::vec::IntoIter`] for [`Bitmap`]. +#[derive(Debug, Clone)] +pub struct IntoIter { + values: Bitmap, + index: usize, + end: usize, +} + +impl IntoIter { + /// Creates a new [`IntoIter`] from a [`Bitmap`] + #[inline] + pub fn new(values: Bitmap) -> Self { + let end = values.len(); + Self { + values, + index: 0, + end, + } + } +} + +impl Iterator for IntoIter { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(unsafe { self.values.get_bit_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(unsafe { self.values.get_bit_unchecked(self.end) }) + } + } +} + +unsafe impl TrustedLen for IntoIter {} diff --git a/crates/nano-arrow/src/bitmap/mod.rs b/crates/nano-arrow/src/bitmap/mod.rs new file mode 100644 index 000000000000..dea2645c1466 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/mod.rs @@ -0,0 +1,17 @@ +//! contains [`Bitmap`] and [`MutableBitmap`], containers of `bool`. +mod immutable; +pub use immutable::*; + +mod iterator; +pub use iterator::IntoIter; + +mod mutable; +pub use mutable::MutableBitmap; + +mod bitmap_ops; +pub use bitmap_ops::*; + +mod assign_ops; +pub use assign_ops::*; + +pub mod utils; diff --git a/crates/nano-arrow/src/bitmap/mutable.rs b/crates/nano-arrow/src/bitmap/mutable.rs new file mode 100644 index 000000000000..e52e39ba3200 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/mutable.rs @@ -0,0 +1,755 @@ +use std::hint::unreachable_unchecked; +use std::iter::FromIterator; +use std::sync::Arc; + +use super::utils::{ + count_zeros, fmt, get_bit, set, set_bit, BitChunk, BitChunksExactMut, BitmapIter, +}; +use super::Bitmap; +use crate::bitmap::utils::{merge_reversed, set_bit_unchecked}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +/// A container of booleans. [`MutableBitmap`] is semantically equivalent +/// to [`Vec`]. +/// +/// The two main differences against [`Vec`] is that each element stored as a single bit, +/// thereby: +/// * it uses 8x less memory +/// * it cannot be represented as `&[bool]` (i.e. no pointer arithmetics). +/// +/// A [`MutableBitmap`] can be converted to a [`Bitmap`] at `O(1)`. +/// # Examples +/// ``` +/// use arrow2::bitmap::MutableBitmap; +/// +/// let bitmap = MutableBitmap::from([true, false, true]); +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true]); +/// +/// // creation directly from bytes +/// let mut bitmap = MutableBitmap::try_new(vec![0b00001101], 5).unwrap(); +/// // note: the first bit is the left-most of the first byte +/// assert_eq!(bitmap.iter().collect::>(), vec![true, false, true, true, false]); +/// // we can also get the slice: +/// assert_eq!(bitmap.as_slice(), [0b00001101u8].as_ref()); +/// // debug helps :) +/// assert_eq!(format!("{:?}", bitmap), "[0b___01101]".to_string()); +/// +/// // It supports mutation in place +/// bitmap.set(0, false); +/// assert_eq!(format!("{:?}", bitmap), "[0b___01100]".to_string()); +/// // and `O(1)` random access +/// assert_eq!(bitmap.get(0), false); +/// ``` +/// # Implementation +/// This container is internally a [`Vec`]. +#[derive(Clone)] +pub struct MutableBitmap { + buffer: Vec, + // invariant: length.saturating_add(7) / 8 == buffer.len(); + length: usize, +} + +impl std::fmt::Debug for MutableBitmap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt(&self.buffer, 0, self.len(), f) + } +} + +impl PartialEq for MutableBitmap { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl MutableBitmap { + /// Initializes an empty [`MutableBitmap`]. + #[inline] + pub fn new() -> Self { + Self { + buffer: Vec::new(), + length: 0, + } + } + + /// Initializes a new [`MutableBitmap`] from a [`Vec`] and a length. + /// # Errors + /// This function errors iff `length > bytes.len() * 8` + #[inline] + pub fn try_new(bytes: Vec, length: usize) -> Result { + if length > bytes.len().saturating_mul(8) { + return Err(Error::InvalidArgumentError(format!( + "The length of the bitmap ({}) must be `<=` to the number of bytes times 8 ({})", + length, + bytes.len().saturating_mul(8) + ))); + } + Ok(Self { + length, + buffer: bytes, + }) + } + + /// Initializes a [`MutableBitmap`] from a [`Vec`] and a length. + /// This function is `O(1)`. + /// # Panic + /// Panics iff the length is larger than the length of the buffer times 8. + #[inline] + pub fn from_vec(buffer: Vec, length: usize) -> Self { + Self::try_new(buffer, length).unwrap() + } + + /// Initializes a pre-allocated [`MutableBitmap`] with capacity for `capacity` bits. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + buffer: Vec::with_capacity(capacity.saturating_add(7) / 8), + length: 0, + } + } + + /// Pushes a new bit to the [`MutableBitmap`], re-sizing it if necessary. + #[inline] + pub fn push(&mut self, value: bool) { + if self.length % 8 == 0 { + self.buffer.push(0); + } + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + *byte = set(*byte, self.length % 8, value); + self.length += 1; + } + + /// Pop the last bit from the [`MutableBitmap`]. + /// Note if the [`MutableBitmap`] is empty, this method will return None. + #[inline] + pub fn pop(&mut self) -> Option { + if self.is_empty() { + return None; + } + + self.length -= 1; + let value = self.get(self.length); + if self.length % 8 == 0 { + self.buffer.pop(); + } + Some(value) + } + + /// Returns whether the position `index` is set. + /// # Panics + /// Panics iff `index >= self.len()`. + #[inline] + pub fn get(&self, index: usize) -> bool { + get_bit(&self.buffer, index) + } + + /// Sets the position `index` to `value` + /// # Panics + /// Panics iff `index >= self.len()`. + #[inline] + pub fn set(&mut self, index: usize, value: bool) { + set_bit(self.buffer.as_mut_slice(), index, value) + } + + /// constructs a new iterator over the bits of [`MutableBitmap`]. + pub fn iter(&self) -> BitmapIter { + BitmapIter::new(&self.buffer, 0, self.length) + } + + /// Empties the [`MutableBitmap`]. + #[inline] + pub fn clear(&mut self) { + self.length = 0; + self.buffer.clear(); + } + + /// Extends [`MutableBitmap`] by `additional` values of constant `value`. + /// # Implementation + /// This function is an order of magnitude faster than pushing element by element. + #[inline] + pub fn extend_constant(&mut self, additional: usize, value: bool) { + if additional == 0 { + return; + } + + if value { + self.extend_set(additional) + } else { + self.extend_unset(additional) + } + } + + /// Initializes a zeroed [`MutableBitmap`]. + #[inline] + pub fn from_len_zeroed(length: usize) -> Self { + Self { + buffer: vec![0; length.saturating_add(7) / 8], + length, + } + } + + /// Initializes a [`MutableBitmap`] with all values set to valid/ true. + #[inline] + pub fn from_len_set(length: usize) -> Self { + Self { + buffer: vec![u8::MAX; length.saturating_add(7) / 8], + length, + } + } + + /// Reserves `additional` bits in the [`MutableBitmap`], potentially re-allocating its buffer. + #[inline(always)] + pub fn reserve(&mut self, additional: usize) { + self.buffer + .reserve((self.length + additional).saturating_add(7) / 8 - self.buffer.len()) + } + + /// Returns the capacity of [`MutableBitmap`] in number of bits. + #[inline] + pub fn capacity(&self) -> usize { + self.buffer.capacity() * 8 + } + + /// Pushes a new bit to the [`MutableBitmap`] + /// # Safety + /// The caller must ensure that the [`MutableBitmap`] has sufficient capacity. + #[inline] + pub unsafe fn push_unchecked(&mut self, value: bool) { + if self.length % 8 == 0 { + self.buffer.push(0); + } + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + *byte = set(*byte, self.length % 8, value); + self.length += 1; + } + + /// Returns the number of unset bits on this [`MutableBitmap`]. + /// + /// Guaranteed to be `<= self.len()`. + /// # Implementation + /// This function is `O(N)` + pub fn unset_bits(&self) -> usize { + count_zeros(&self.buffer, 0, self.length) + } + + /// Returns the number of unset bits on this [`MutableBitmap`]. + #[deprecated(since = "0.13.0", note = "use `unset_bits` instead")] + pub fn null_count(&self) -> usize { + self.unset_bits() + } + + /// Returns the length of the [`MutableBitmap`]. + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether [`MutableBitmap`] is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// # Safety + /// The caller must ensure that the [`MutableBitmap`] was properly initialized up to `len`. + #[inline] + pub(crate) unsafe fn set_len(&mut self, len: usize) { + self.buffer.set_len(len.saturating_add(7) / 8); + self.length = len; + } + + fn extend_set(&mut self, mut additional: usize) { + let offset = self.length % 8; + let added = if offset != 0 { + // offset != 0 => at least one byte in the buffer + let last_index = self.buffer.len() - 1; + let last = &mut self.buffer[last_index]; + + let remaining = 0b11111111u8; + let remaining = remaining >> 8usize.saturating_sub(additional); + let remaining = remaining << offset; + *last |= remaining; + std::cmp::min(additional, 8 - offset) + } else { + 0 + }; + self.length += added; + additional = additional.saturating_sub(added); + if additional > 0 { + debug_assert_eq!(self.length % 8, 0); + let existing = self.length.saturating_add(7) / 8; + let required = (self.length + additional).saturating_add(7) / 8; + // add remaining as full bytes + self.buffer + .extend(std::iter::repeat(0b11111111u8).take(required - existing)); + self.length += additional; + } + } + + fn extend_unset(&mut self, mut additional: usize) { + let offset = self.length % 8; + let added = if offset != 0 { + // offset != 0 => at least one byte in the buffer + let last_index = self.buffer.len() - 1; + let last = &mut self.buffer[last_index]; + *last &= 0b11111111u8 >> (8 - offset); // unset them + std::cmp::min(additional, 8 - offset) + } else { + 0 + }; + self.length += added; + additional = additional.saturating_sub(added); + if additional > 0 { + debug_assert_eq!(self.length % 8, 0); + self.buffer + .resize((self.length + additional).saturating_add(7) / 8, 0); + self.length += additional; + } + } + + /// Sets the position `index` to `value` + /// # Safety + /// Caller must ensure that `index < self.len()` + #[inline] + pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { + set_bit_unchecked(self.buffer.as_mut_slice(), index, value) + } + + /// Shrinks the capacity of the [`MutableBitmap`] to fit its current length. + pub fn shrink_to_fit(&mut self) { + self.buffer.shrink_to_fit(); + } + + /// Returns an iterator over mutable slices, [`BitChunksExactMut`] + pub(crate) fn bitchunks_exact_mut(&mut self) -> BitChunksExactMut { + BitChunksExactMut::new(&mut self.buffer, self.length) + } +} + +impl From for Bitmap { + #[inline] + fn from(buffer: MutableBitmap) -> Self { + Bitmap::try_new(buffer.buffer, buffer.length).unwrap() + } +} + +impl From for Option { + #[inline] + fn from(buffer: MutableBitmap) -> Self { + let unset_bits = buffer.unset_bits(); + if unset_bits > 0 { + // safety: + // invariants of the `MutableBitmap` equal that of `Bitmap` + let bitmap = unsafe { + Bitmap::from_inner_unchecked( + Arc::new(buffer.buffer.into()), + 0, + buffer.length, + unset_bits, + ) + }; + Some(bitmap) + } else { + None + } + } +} + +impl> From

for MutableBitmap { + #[inline] + fn from(slice: P) -> Self { + MutableBitmap::from_trusted_len_iter(slice.as_ref().iter().copied()) + } +} + +impl FromIterator for MutableBitmap { + fn from_iter(iter: I) -> Self + where + I: IntoIterator, + { + let mut iterator = iter.into_iter(); + let mut buffer = { + let byte_capacity: usize = iterator.size_hint().0.saturating_add(7) / 8; + Vec::with_capacity(byte_capacity) + }; + + let mut length = 0; + + loop { + let mut exhausted = false; + let mut byte_accum: u8 = 0; + let mut mask: u8 = 1; + + //collect (up to) 8 bits into a byte + while mask != 0 { + if let Some(value) = iterator.next() { + length += 1; + byte_accum |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } else { + exhausted = true; + break; + } + } + + // break if the iterator was exhausted before it provided a bool for this byte + if exhausted && mask == 1 { + break; + } + + //ensure we have capacity to write the byte + if buffer.len() == buffer.capacity() { + //no capacity for new byte, allocate 1 byte more (plus however many more the iterator advertises) + let additional_byte_capacity = 1usize.saturating_add( + iterator.size_hint().0.saturating_add(7) / 8, //convert bit count to byte count, rounding up + ); + buffer.reserve(additional_byte_capacity) + } + + // Soundness: capacity was allocated above + buffer.push(byte_accum); + if exhausted { + break; + } + } + Self { buffer, length } + } +} + +// [7, 6, 5, 4, 3, 2, 1, 0], [15, 14, 13, 12, 11, 10, 9, 8] +// [00000001_00000000_00000000_00000000_...] // u64 +/// # Safety +/// The iterator must be trustedLen and its len must be least `len`. +#[inline] +unsafe fn get_chunk_unchecked(iterator: &mut impl Iterator) -> u64 { + let mut byte = 0u64; + let mut mask; + for i in 0..8 { + mask = 1u64 << (8 * i); + for _ in 0..8 { + let value = match iterator.next() { + Some(value) => value, + None => unsafe { unreachable_unchecked() }, + }; + + byte |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } + } + byte +} + +/// # Safety +/// The iterator must be trustedLen and its len must be least `len`. +#[inline] +unsafe fn get_byte_unchecked(len: usize, iterator: &mut impl Iterator) -> u8 { + let mut byte_accum: u8 = 0; + let mut mask: u8 = 1; + for _ in 0..len { + let value = match iterator.next() { + Some(value) => value, + None => unsafe { unreachable_unchecked() }, + }; + + byte_accum |= match value { + true => mask, + false => 0, + }; + mask <<= 1; + } + byte_accum +} + +/// Extends the [`Vec`] from `iterator` +/// # Safety +/// The iterator MUST be [`TrustedLen`]. +#[inline] +unsafe fn extend_aligned_trusted_iter_unchecked( + buffer: &mut Vec, + mut iterator: impl Iterator, +) -> usize { + let additional_bits = iterator.size_hint().1.unwrap(); + let chunks = additional_bits / 64; + let remainder = additional_bits % 64; + + let additional = (additional_bits + 7) / 8; + assert_eq!( + additional, + // a hint of how the following calculation will be done + chunks * 8 + remainder / 8 + (remainder % 8 > 0) as usize + ); + buffer.reserve(additional); + + // chunks of 64 bits + for _ in 0..chunks { + let chunk = get_chunk_unchecked(&mut iterator); + buffer.extend_from_slice(&chunk.to_le_bytes()); + } + + // remaining complete bytes + for _ in 0..(remainder / 8) { + let byte = unsafe { get_byte_unchecked(8, &mut iterator) }; + buffer.push(byte) + } + + // remaining bits + let remainder = remainder % 8; + if remainder > 0 { + let byte = unsafe { get_byte_unchecked(remainder, &mut iterator) }; + buffer.push(byte) + } + additional_bits +} + +impl MutableBitmap { + /// Extends `self` from a [`TrustedLen`] iterator. + #[inline] + pub fn extend_from_trusted_len_iter>(&mut self, iterator: I) { + // safety: I: TrustedLen + unsafe { self.extend_from_trusted_len_iter_unchecked(iterator) } + } + + /// Extends `self` from an iterator of trusted len. + /// # Safety + /// The caller must guarantee that the iterator has a trusted len. + #[inline] + pub unsafe fn extend_from_trusted_len_iter_unchecked>( + &mut self, + mut iterator: I, + ) { + // the length of the iterator throughout this function. + let mut length = iterator.size_hint().1.unwrap(); + + let bit_offset = self.length % 8; + + if length < 8 - bit_offset { + if bit_offset == 0 { + self.buffer.push(0); + } + // the iterator will not fill the last byte + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + let mut i = bit_offset; + for value in iterator { + *byte = set(*byte, i, value); + i += 1; + } + self.length += length; + return; + } + + // at this point we know that length will hit a byte boundary and thus + // increase the buffer. + + if bit_offset != 0 { + // we are in the middle of a byte; lets finish it + let byte = self.buffer.as_mut_slice().last_mut().unwrap(); + (bit_offset..8).for_each(|i| { + *byte = set(*byte, i, iterator.next().unwrap()); + }); + self.length += 8 - bit_offset; + length -= 8 - bit_offset; + } + + // everything is aligned; proceed with the bulk operation + debug_assert_eq!(self.length % 8, 0); + + unsafe { extend_aligned_trusted_iter_unchecked(&mut self.buffer, iterator) }; + self.length += length; + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + /// # Safety + /// The iterator must report an accurate length. + #[inline] + pub unsafe fn from_trusted_len_iter_unchecked(iterator: I) -> Self + where + I: Iterator, + { + let mut buffer = Vec::::new(); + + let length = extend_aligned_trusted_iter_unchecked(&mut buffer, iterator); + + Self { buffer, length } + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + #[inline] + pub fn from_trusted_len_iter(iterator: I) -> Self + where + I: TrustedLen, + { + // Safety: Iterator is `TrustedLen` + unsafe { Self::from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableBitmap`] from an iterator of booleans. + pub fn try_from_trusted_len_iter(iterator: I) -> std::result::Result + where + I: TrustedLen>, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iterator) } + } + + /// Creates a new [`MutableBitmap`] from an falible iterator of booleans. + /// # Safety + /// The caller must guarantee that the iterator is `TrustedLen`. + pub unsafe fn try_from_trusted_len_iter_unchecked( + mut iterator: I, + ) -> std::result::Result + where + I: Iterator>, + { + let length = iterator.size_hint().1.unwrap(); + + let mut buffer = vec![0u8; (length + 7) / 8]; + + let chunks = length / 8; + let reminder = length % 8; + + let data = buffer.as_mut_slice(); + data[..chunks].iter_mut().try_for_each(|byte| { + (0..8).try_for_each(|i| { + *byte = set(*byte, i, iterator.next().unwrap()?); + Ok(()) + }) + })?; + + if reminder != 0 { + let last = &mut data[chunks]; + iterator.enumerate().try_for_each(|(i, value)| { + *last = set(*last, i, value?); + Ok(()) + })?; + } + + Ok(Self { buffer, length }) + } + + fn extend_unaligned(&mut self, slice: &[u8], offset: usize, length: usize) { + // e.g. + // [a, b, --101010] <- to be extended + // [00111111, 11010101] <- to extend + // [a, b, 11101010, --001111] expected result + + let aligned_offset = offset / 8; + let own_offset = self.length % 8; + debug_assert_eq!(offset % 8, 0); // assumed invariant + debug_assert!(own_offset != 0); // assumed invariant + + let bytes_len = length.saturating_add(7) / 8; + let items = &slice[aligned_offset..aligned_offset + bytes_len]; + // self has some offset => we need to shift all `items`, and merge the first + let buffer = self.buffer.as_mut_slice(); + let last = &mut buffer[buffer.len() - 1]; + + // --101010 | 00111111 << 6 = 11101010 + // erase previous + *last &= 0b11111111u8 >> (8 - own_offset); // unset before setting + *last |= items[0] << own_offset; + + if length + own_offset <= 8 { + // no new bytes needed + self.length += length; + return; + } + let additional = length - (8 - own_offset); + + let remaining = [items[items.len() - 1], 0]; + let bytes = items + .windows(2) + .chain(std::iter::once(remaining.as_ref())) + .map(|w| merge_reversed(w[0], w[1], 8 - own_offset)) + .take(additional.saturating_add(7) / 8); + self.buffer.extend(bytes); + + self.length += length; + } + + fn extend_aligned(&mut self, slice: &[u8], offset: usize, length: usize) { + let aligned_offset = offset / 8; + let bytes_len = length.saturating_add(7) / 8; + let items = &slice[aligned_offset..aligned_offset + bytes_len]; + self.buffer.extend_from_slice(items); + self.length += length; + } + + /// Extends the [`MutableBitmap`] from a slice of bytes with optional offset. + /// This is the fastest way to extend a [`MutableBitmap`]. + /// # Implementation + /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, + /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + /// # Safety + /// Caller must ensure `offset + length <= slice.len() * 8` + #[inline] + pub unsafe fn extend_from_slice_unchecked( + &mut self, + slice: &[u8], + offset: usize, + length: usize, + ) { + if length == 0 { + return; + }; + let is_aligned = self.length % 8 == 0; + let other_is_aligned = offset % 8 == 0; + match (is_aligned, other_is_aligned) { + (true, true) => self.extend_aligned(slice, offset, length), + (false, true) => self.extend_unaligned(slice, offset, length), + // todo: further optimize the other branches. + _ => self.extend_from_trusted_len_iter(BitmapIter::new(slice, offset, length)), + } + // internal invariant: + debug_assert_eq!(self.length.saturating_add(7) / 8, self.buffer.len()); + } + + /// Extends the [`MutableBitmap`] from a slice of bytes with optional offset. + /// This is the fastest way to extend a [`MutableBitmap`]. + /// # Implementation + /// When both [`MutableBitmap`]'s length and `offset` are both multiples of 8, + /// this function performs a memcopy. Else, it first aligns bit by bit and then performs a memcopy. + #[inline] + pub fn extend_from_slice(&mut self, slice: &[u8], offset: usize, length: usize) { + assert!(offset + length <= slice.len() * 8); + // safety: invariant is asserted + unsafe { self.extend_from_slice_unchecked(slice, offset, length) } + } + + /// Extends the [`MutableBitmap`] from a [`Bitmap`]. + #[inline] + pub fn extend_from_bitmap(&mut self, bitmap: &Bitmap) { + let (slice, offset, length) = bitmap.as_slice(); + // safety: bitmap.as_slice adheres to the invariant + unsafe { + self.extend_from_slice_unchecked(slice, offset, length); + } + } + + /// Returns the slice of bytes of this [`MutableBitmap`]. + /// Note that the last byte may not be fully used. + #[inline] + pub fn as_slice(&self) -> &[u8] { + let len = (self.length).saturating_add(7) / 8; + &self.buffer[..len] + } +} + +impl Default for MutableBitmap { + fn default() -> Self { + Self::new() + } +} + +impl<'a> IntoIterator for &'a MutableBitmap { + type Item = bool; + type IntoIter = BitmapIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + BitmapIter::<'a>::new(&self.buffer, 0, self.length) + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs new file mode 100644 index 000000000000..4ab9d300ba02 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/chunks_exact.rs @@ -0,0 +1,101 @@ +use std::convert::TryInto; +use std::slice::ChunksExact; + +use super::{BitChunk, BitChunkIterExact}; +use crate::trusted_len::TrustedLen; + +/// An iterator over a slice of bytes in [`BitChunk`]s. +#[derive(Debug)] +pub struct BitChunksExact<'a, T: BitChunk> { + iter: ChunksExact<'a, u8>, + remainder: &'a [u8], + remainder_len: usize, + phantom: std::marker::PhantomData, +} + +impl<'a, T: BitChunk> BitChunksExact<'a, T> { + /// Creates a new [`BitChunksExact`]. + #[inline] + pub fn new(bitmap: &'a [u8], length: usize) -> Self { + assert!(length <= bitmap.len() * 8); + let size_of = std::mem::size_of::(); + + let bitmap = &bitmap[..length.saturating_add(7) / 8]; + + let split = (length / 8 / size_of) * size_of; + let (chunks, remainder) = bitmap.split_at(split); + let remainder_len = length - chunks.len() * 8; + let iter = chunks.chunks_exact(size_of); + + Self { + iter, + remainder, + remainder_len, + phantom: std::marker::PhantomData, + } + } + + /// Returns the number of chunks of this iterator + #[inline] + pub fn len(&self) -> usize { + self.iter.len() + } + + /// Returns whether there are still elements in this iterator + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the remaining [`BitChunk`]. It is zero iff `len / 8 == 0`. + #[inline] + pub fn remainder(&self) -> T { + let remainder_bytes = self.remainder; + if remainder_bytes.is_empty() { + return T::zero(); + } + let remainder = match remainder_bytes.try_into() { + Ok(a) => a, + Err(_) => { + let mut remainder = T::zero().to_ne_bytes(); + remainder_bytes + .iter() + .enumerate() + .for_each(|(index, b)| remainder[index] = *b); + remainder + }, + }; + T::from_ne_bytes(remainder) + } +} + +impl Iterator for BitChunksExact<'_, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().map(|x| match x.try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +unsafe impl TrustedLen for BitChunksExact<'_, T> {} + +impl BitChunkIterExact for BitChunksExact<'_, T> { + #[inline] + fn remainder(&self) -> T { + self.remainder() + } + + #[inline] + fn remainder_len(&self) -> usize { + self.remainder_len + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/chunk_iterator/merge.rs b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/merge.rs new file mode 100644 index 000000000000..81e08df0059e --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/merge.rs @@ -0,0 +1,61 @@ +use super::BitChunk; + +/// Merges 2 [`BitChunk`]s into a single [`BitChunk`] so that the new items represents +/// the bitmap where bits from `next` are placed in `current` according to `offset`. +/// # Panic +/// The caller must ensure that `0 < offset < size_of::() * 8` +/// # Example +/// ```rust,ignore +/// let current = 0b01011001; +/// let next = 0b01011011; +/// let result = merge_reversed(current, next, 1); +/// assert_eq!(result, 0b10101100); +/// ``` +#[inline] +pub fn merge_reversed(mut current: T, mut next: T, offset: usize) -> T +where + T: BitChunk, +{ + // 8 _bits_: + // current = [c0, c1, c2, c3, c4, c5, c6, c7] + // next = [n0, n1, n2, n3, n4, n5, n6, n7] + // offset = 3 + // expected = [n5, n6, n7, c0, c1, c2, c3, c4] + + // 1. unset most significants of `next` up to `offset` + let inverse_offset = std::mem::size_of::() * 8 - offset; + next <<= inverse_offset; + // next = [n5, n6, n7, 0 , 0 , 0 , 0 , 0 ] + + // 2. unset least significants of `current` up to `offset` + current >>= offset; + // current = [0 , 0 , 0 , c0, c1, c2, c3, c4] + + current | next +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge_reversed() { + let current = 0b00000000; + let next = 0b00000001; + let result = merge_reversed::(current, next, 1); + assert_eq!(result, 0b10000000); + + let current = 0b01011001; + let next = 0b01011011; + let result = merge_reversed::(current, next, 1); + assert_eq!(result, 0b10101100); + } + + #[test] + fn test_merge_reversed_offset2() { + let current = 0b00000000; + let next = 0b00000001; + let result = merge_reversed::(current, next, 3); + assert_eq!(result, 0b00100000); + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/chunk_iterator/mod.rs b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/mod.rs new file mode 100644 index 000000000000..71f56a284274 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunk_iterator/mod.rs @@ -0,0 +1,206 @@ +use std::convert::TryInto; + +mod chunks_exact; +mod merge; + +pub use chunks_exact::BitChunksExact; +pub(crate) use merge::merge_reversed; + +use crate::trusted_len::TrustedLen; +pub use crate::types::BitChunk; +use crate::types::BitChunkIter; + +/// Trait representing an exact iterator over bytes in [`BitChunk`]. +pub trait BitChunkIterExact: TrustedLen { + /// The remainder of the iterator. + fn remainder(&self) -> B; + + /// The number of items in the remainder + fn remainder_len(&self) -> usize; + + /// An iterator over individual items of the remainder + #[inline] + fn remainder_iter(&self) -> BitChunkIter { + BitChunkIter::new(self.remainder(), self.remainder_len()) + } +} + +/// This struct is used to efficiently iterate over bit masks by loading bytes on +/// the stack with alignments of `uX`. This allows efficient iteration over bitmaps. +#[derive(Debug)] +pub struct BitChunks<'a, T: BitChunk> { + chunk_iterator: std::slice::ChunksExact<'a, u8>, + current: T, + remainder_bytes: &'a [u8], + last_chunk: T, + remaining: usize, + /// offset inside a byte + bit_offset: usize, + len: usize, + phantom: std::marker::PhantomData, +} + +/// writes `bytes` into `dst`. +#[inline] +fn copy_with_merge(dst: &mut T::Bytes, bytes: &[u8], bit_offset: usize) { + bytes + .windows(2) + .chain(std::iter::once([bytes[bytes.len() - 1], 0].as_ref())) + .take(std::mem::size_of::()) + .enumerate() + .for_each(|(i, w)| { + let val = merge_reversed(w[0], w[1], bit_offset); + dst[i] = val; + }); +} + +impl<'a, T: BitChunk> BitChunks<'a, T> { + /// Creates a [`BitChunks`]. + pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self { + assert!(offset + len <= slice.len() * 8); + + let slice = &slice[offset / 8..]; + let bit_offset = offset % 8; + let size_of = std::mem::size_of::(); + + let bytes_len = len / 8; + let bytes_upper_len = (len + bit_offset + 7) / 8; + let mut chunks = slice[..bytes_len].chunks_exact(size_of); + + let remainder = &slice[bytes_len - chunks.remainder().len()..bytes_upper_len]; + + let remainder_bytes = if chunks.len() == 0 { slice } else { remainder }; + + let last_chunk = remainder_bytes + .first() + .map(|first| { + let mut last = T::zero().to_ne_bytes(); + last[0] = *first; + T::from_ne_bytes(last) + }) + .unwrap_or_else(T::zero); + + let remaining = chunks.size_hint().0; + + let current = chunks + .next() + .map(|x| match x.try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }) + .unwrap_or_else(T::zero); + + Self { + chunk_iterator: chunks, + len, + current, + remaining, + remainder_bytes, + last_chunk, + bit_offset, + phantom: std::marker::PhantomData, + } + } + + #[inline] + fn load_next(&mut self) { + self.current = match self.chunk_iterator.next().unwrap().try_into() { + Ok(a) => T::from_ne_bytes(a), + Err(_) => unreachable!(), + }; + } + + /// Returns the remainder [`BitChunk`]. + pub fn remainder(&self) -> T { + // remaining bytes may not fit in `size_of::()`. We complement + // them to fit by allocating T and writing to it byte by byte + let mut remainder = T::zero().to_ne_bytes(); + + let remainder = match (self.remainder_bytes.is_empty(), self.bit_offset == 0) { + (true, _) => remainder, + (false, true) => { + // all remaining bytes + self.remainder_bytes + .iter() + .take(std::mem::size_of::()) + .enumerate() + .for_each(|(i, val)| remainder[i] = *val); + + remainder + }, + (false, false) => { + // all remaining bytes + copy_with_merge::(&mut remainder, self.remainder_bytes, self.bit_offset); + remainder + }, + }; + T::from_ne_bytes(remainder) + } + + /// Returns the remainder bits in [`BitChunks::remainder`]. + pub fn remainder_len(&self) -> usize { + self.len - (std::mem::size_of::() * ((self.len / 8) / std::mem::size_of::()) * 8) + } +} + +impl Iterator for BitChunks<'_, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + + let current = self.current; + let combined = if self.bit_offset == 0 { + // fast case where there is no offset. In this case, there is bit-alignment + // at byte boundary and thus the bytes correspond exactly. + if self.remaining >= 2 { + self.load_next(); + } + current + } else { + let next = if self.remaining >= 2 { + // case where `next` is complete and thus we can take it all + self.load_next(); + self.current + } else { + // case where the `next` is incomplete and thus we take the remaining + self.last_chunk + }; + merge_reversed(current, next, self.bit_offset) + }; + + self.remaining -= 1; + Some(combined) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + // it contains always one more than the chunk_iterator, which is the last + // one where the remainder is merged into current. + (self.remaining, Some(self.remaining)) + } +} + +impl BitChunkIterExact for BitChunks<'_, T> { + #[inline] + fn remainder(&self) -> T { + self.remainder() + } + + #[inline] + fn remainder_len(&self) -> usize { + self.remainder_len() + } +} + +impl ExactSizeIterator for BitChunks<'_, T> { + #[inline] + fn len(&self) -> usize { + self.chunk_iterator.len() + } +} + +unsafe impl TrustedLen for BitChunks<'_, T> {} diff --git a/crates/nano-arrow/src/bitmap/utils/chunks_exact_mut.rs b/crates/nano-arrow/src/bitmap/utils/chunks_exact_mut.rs new file mode 100644 index 000000000000..7a5a91a12805 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/chunks_exact_mut.rs @@ -0,0 +1,63 @@ +use super::BitChunk; + +/// An iterator over mutable slices of bytes of exact size. +/// +/// # Safety +/// The slices returned by this iterator are guaranteed to have length equal to +/// `std::mem::size_of::()`. +#[derive(Debug)] +pub struct BitChunksExactMut<'a, T: BitChunk> { + chunks: std::slice::ChunksExactMut<'a, u8>, + remainder: &'a mut [u8], + remainder_len: usize, + marker: std::marker::PhantomData, +} + +impl<'a, T: BitChunk> BitChunksExactMut<'a, T> { + /// Returns a new [`BitChunksExactMut`] + #[inline] + pub fn new(bitmap: &'a mut [u8], length: usize) -> Self { + assert!(length <= bitmap.len() * 8); + let size_of = std::mem::size_of::(); + + let bitmap = &mut bitmap[..length.saturating_add(7) / 8]; + + let split = (length / 8 / size_of) * size_of; + let (chunks, remainder) = bitmap.split_at_mut(split); + let remainder_len = length - chunks.len() * 8; + + let chunks = chunks.chunks_exact_mut(size_of); + Self { + chunks, + remainder, + remainder_len, + marker: std::marker::PhantomData, + } + } + + /// The remainder slice + #[inline] + pub fn remainder(&mut self) -> &mut [u8] { + self.remainder + } + + /// The length of the remainder slice in bits. + #[inline] + pub fn remainder_len(&mut self) -> usize { + self.remainder_len + } +} + +impl<'a, T: BitChunk> Iterator for BitChunksExactMut<'a, T> { + type Item = &'a mut [u8]; + + #[inline] + fn next(&mut self) -> Option { + self.chunks.next() + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.chunks.size_hint() + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/fmt.rs b/crates/nano-arrow/src/bitmap/utils/fmt.rs new file mode 100644 index 000000000000..45fe9ec9ced3 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/fmt.rs @@ -0,0 +1,72 @@ +use std::fmt::Write; + +use super::is_set; + +/// Formats `bytes` taking into account an offset and length of the form +pub fn fmt( + bytes: &[u8], + offset: usize, + length: usize, + f: &mut std::fmt::Formatter<'_>, +) -> std::fmt::Result { + assert!(offset < 8); + + f.write_char('[')?; + let mut remaining = length; + if remaining == 0 { + f.write_char(']')?; + return Ok(()); + } + + let first = bytes[0]; + let bytes = &bytes[1..]; + let empty_before = 8usize.saturating_sub(remaining + offset); + f.write_str("0b")?; + for _ in 0..empty_before { + f.write_char('_')?; + } + let until = std::cmp::min(8, offset + remaining); + for i in offset..until { + if is_set(first, offset + until - 1 - i) { + f.write_char('1')?; + } else { + f.write_char('0')?; + } + } + for _ in 0..offset { + f.write_char('_')?; + } + remaining -= until - offset; + + if remaining == 0 { + f.write_char(']')?; + return Ok(()); + } + + let number_of_bytes = remaining / 8; + for byte in &bytes[..number_of_bytes] { + f.write_str(", ")?; + f.write_fmt(format_args!("{byte:#010b}"))?; + } + remaining -= number_of_bytes * 8; + if remaining == 0 { + f.write_char(']')?; + return Ok(()); + } + + let last = bytes[std::cmp::min((length + offset + 7) / 8, bytes.len() - 1)]; + let remaining = (length + offset) % 8; + f.write_str(", ")?; + f.write_str("0b")?; + for _ in 0..(8 - remaining) { + f.write_char('_')?; + } + for i in 0..remaining { + if is_set(last, remaining - 1 - i) { + f.write_char('1')?; + } else { + f.write_char('0')?; + } + } + f.write_char(']') +} diff --git a/crates/nano-arrow/src/bitmap/utils/iterator.rs b/crates/nano-arrow/src/bitmap/utils/iterator.rs new file mode 100644 index 000000000000..1a35ad56b562 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/iterator.rs @@ -0,0 +1,82 @@ +use super::get_bit_unchecked; +use crate::trusted_len::TrustedLen; + +/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit), +/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`. +#[derive(Debug, Clone)] +pub struct BitmapIter<'a> { + bytes: &'a [u8], + index: usize, + end: usize, +} + +impl<'a> BitmapIter<'a> { + /// Creates a new [`BitmapIter`]. + pub fn new(slice: &'a [u8], offset: usize, len: usize) -> Self { + // example: + // slice.len() = 4 + // offset = 9 + // len = 23 + // result: + let bytes = &slice[offset / 8..]; + // bytes.len() = 3 + let index = offset % 8; + // index = 9 % 8 = 1 + let end = len + index; + // end = 23 + 1 = 24 + assert!(end <= bytes.len() * 8); + // maximum read before UB in bits: bytes.len() * 8 = 24 + // the first read from the end is `end - 1`, thus, end = 24 is ok + + Self { bytes, index, end } + } +} + +impl<'a> Iterator for BitmapIter<'a> { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + // See comment in `new` + Some(unsafe { get_bit_unchecked(self.bytes, old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let exact = self.end - self.index; + (exact, Some(exact)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl<'a> DoubleEndedIterator for BitmapIter<'a> { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + // See comment in `new`; end was first decreased + Some(unsafe { get_bit_unchecked(self.bytes, self.end) }) + } + } +} + +unsafe impl TrustedLen for BitmapIter<'_> {} +impl ExactSizeIterator for BitmapIter<'_> {} diff --git a/crates/nano-arrow/src/bitmap/utils/mod.rs b/crates/nano-arrow/src/bitmap/utils/mod.rs new file mode 100644 index 000000000000..b064ffd8bed7 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/mod.rs @@ -0,0 +1,143 @@ +//! General utilities for bitmaps representing items where LSB is the first item. +mod chunk_iterator; +mod chunks_exact_mut; +mod fmt; +mod iterator; +mod slice_iterator; +mod zip_validity; + +use std::convert::TryInto; + +pub(crate) use chunk_iterator::merge_reversed; +pub use chunk_iterator::{BitChunk, BitChunkIterExact, BitChunks, BitChunksExact}; +pub use chunks_exact_mut::BitChunksExactMut; +pub use fmt::fmt; +pub use iterator::BitmapIter; +pub use slice_iterator::SlicesIterator; +pub use zip_validity::{ZipValidity, ZipValidityIter}; + +const BIT_MASK: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; +const UNSET_BIT_MASK: [u8; 8] = [ + 255 - 1, + 255 - 2, + 255 - 4, + 255 - 8, + 255 - 16, + 255 - 32, + 255 - 64, + 255 - 128, +]; + +/// Returns whether bit at position `i` in `byte` is set or not +#[inline] +pub fn is_set(byte: u8, i: usize) -> bool { + (byte & BIT_MASK[i]) != 0 +} + +/// Sets bit at position `i` in `byte` +#[inline] +pub fn set(byte: u8, i: usize, value: bool) -> u8 { + if value { + byte | BIT_MASK[i] + } else { + byte & UNSET_BIT_MASK[i] + } +} + +/// Sets bit at position `i` in `data` +/// # Panics +/// panics if `i >= data.len() / 8` +#[inline] +pub fn set_bit(data: &mut [u8], i: usize, value: bool) { + data[i / 8] = set(data[i / 8], i % 8, value); +} + +/// Sets bit at position `i` in `data` without doing bound checks +/// # Safety +/// caller must ensure that `i < data.len() / 8` +#[inline] +pub unsafe fn set_bit_unchecked(data: &mut [u8], i: usize, value: bool) { + let byte = data.get_unchecked_mut(i / 8); + *byte = set(*byte, i % 8, value); +} + +/// Returns whether bit at position `i` in `data` is set +/// # Panic +/// This function panics iff `i / 8 >= bytes.len()` +#[inline] +pub fn get_bit(bytes: &[u8], i: usize) -> bool { + is_set(bytes[i / 8], i % 8) +} + +/// Returns whether bit at position `i` in `data` is set or not. +/// +/// # Safety +/// `i >= data.len() * 8` results in undefined behavior +#[inline] +pub unsafe fn get_bit_unchecked(data: &[u8], i: usize) -> bool { + (*data.as_ptr().add(i >> 3) & BIT_MASK[i & 7]) != 0 +} + +/// Returns the number of bytes required to hold `bits` bits. +#[inline] +pub fn bytes_for(bits: usize) -> usize { + bits.saturating_add(7) / 8 +} + +/// Returns the number of zero bits in the slice offsetted by `offset` and a length of `length`. +/// # Panics +/// This function panics iff `(offset + len).saturating_add(7) / 8 >= slice.len()` +/// because it corresponds to the situation where `len` is beyond bounds. +pub fn count_zeros(slice: &[u8], offset: usize, len: usize) -> usize { + if len == 0 { + return 0; + }; + + let mut slice = &slice[offset / 8..(offset + len).saturating_add(7) / 8]; + let offset = offset % 8; + + if (offset + len) / 8 == 0 { + // all within a single byte + let byte = (slice[0] >> offset) << (8 - len); + return len - byte.count_ones() as usize; + } + + // slice: [a1,a2,a3,a4], [a5,a6,a7,a8] + // offset: 3 + // len: 4 + // [__,__,__,a4], [a5,a6,a7,__] + let mut set_count = 0; + if offset != 0 { + // count all ignoring the first `offset` bits + // i.e. [__,__,__,a4] + set_count += (slice[0] >> offset).count_ones() as usize; + slice = &slice[1..]; + } + if (offset + len) % 8 != 0 { + let end_offset = (offset + len) % 8; // i.e. 3 + 4 = 7 + let last_index = slice.len() - 1; + // count all ignoring the last `offset` bits + // i.e. [a5,a6,a7,__] + set_count += (slice[last_index] << (8 - end_offset)).count_ones() as usize; + slice = &slice[..last_index]; + } + + // finally, count any and all bytes in the middle in groups of 8 + let mut chunks = slice.chunks_exact(8); + set_count += chunks + .by_ref() + .map(|chunk| { + let a = u64::from_ne_bytes(chunk.try_into().unwrap()); + a.count_ones() as usize + }) + .sum::(); + + // and any bytes that do not fit in the group + set_count += chunks + .remainder() + .iter() + .map(|byte| byte.count_ones() as usize) + .sum::(); + + len - set_count +} diff --git a/crates/nano-arrow/src/bitmap/utils/slice_iterator.rs b/crates/nano-arrow/src/bitmap/utils/slice_iterator.rs new file mode 100644 index 000000000000..dc388f1d41b5 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/slice_iterator.rs @@ -0,0 +1,145 @@ +use crate::bitmap::Bitmap; + +/// Internal state of [`SlicesIterator`] +#[derive(Debug, Clone, PartialEq)] +enum State { + // normal iteration + Nominal, + // nothing more to iterate. + Finished, +} + +/// Iterator over a bitmap that returns slices of set regions +/// This is the most efficient method to extract slices of values from arrays +/// with a validity bitmap. +/// For example, the bitmap `00101111` returns `[(0,4), (6,1)]` +#[derive(Debug, Clone)] +pub struct SlicesIterator<'a> { + values: std::slice::Iter<'a, u8>, + count: usize, + mask: u8, + max_len: usize, + current_byte: &'a u8, + state: State, + len: usize, + start: usize, + on_region: bool, +} + +impl<'a> SlicesIterator<'a> { + /// Creates a new [`SlicesIterator`] + pub fn new(values: &'a Bitmap) -> Self { + let (buffer, offset, _) = values.as_slice(); + let mut iter = buffer.iter(); + + let (current_byte, state) = match iter.next() { + Some(b) => (b, State::Nominal), + None => (&0, State::Finished), + }; + + Self { + state, + count: values.len() - values.unset_bits(), + max_len: values.len(), + values: iter, + mask: 1u8.rotate_left(offset as u32), + current_byte, + len: 0, + start: 0, + on_region: false, + } + } + + #[inline] + fn finish(&mut self) -> Option<(usize, usize)> { + self.state = State::Finished; + if self.on_region { + Some((self.start, self.len)) + } else { + None + } + } + + #[inline] + fn current_len(&self) -> usize { + self.start + self.len + } + + /// Returns the total number of slots. + /// It corresponds to the sum of all lengths of all slices. + #[inline] + pub fn slots(&self) -> usize { + self.count + } +} + +impl<'a> Iterator for SlicesIterator<'a> { + type Item = (usize, usize); + + #[inline] + fn next(&mut self) -> Option { + loop { + if self.state == State::Finished { + return None; + } + if self.current_len() == self.max_len { + return self.finish(); + } + + if self.mask == 1 { + // at the beginning of a byte => try to skip it all together + match (self.on_region, self.current_byte) { + (true, &255u8) => { + self.len = std::cmp::min(self.max_len - self.start, self.len + 8); + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + continue; + }, + (false, &0) => { + self.len = std::cmp::min(self.max_len - self.start, self.len + 8); + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + continue; + }, + _ => (), // we need to run over all bits of this byte + } + }; + + let value = (self.current_byte & self.mask) != 0; + self.mask = self.mask.rotate_left(1); + + match (self.on_region, value) { + (true, true) => self.len += 1, + (false, false) => self.len += 1, + (true, false) => { + self.on_region = false; + let result = (self.start, self.len); + self.start += self.len; + self.len = 1; + if self.mask == 1 { + // reached a new byte => try to fetch it from the iterator + if let Some(v) = self.values.next() { + self.current_byte = v; + }; + } + return Some(result); + }, + (false, true) => { + self.start += self.len; + self.len = 1; + self.on_region = true; + }, + } + + if self.mask == 1 { + // reached a new byte => try to fetch it from the iterator + match self.values.next() { + Some(v) => self.current_byte = v, + None => return self.finish(), + }; + } + } + } +} diff --git a/crates/nano-arrow/src/bitmap/utils/zip_validity.rs b/crates/nano-arrow/src/bitmap/utils/zip_validity.rs new file mode 100644 index 000000000000..40965bab4113 --- /dev/null +++ b/crates/nano-arrow/src/bitmap/utils/zip_validity.rs @@ -0,0 +1,216 @@ +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::Bitmap; +use crate::trusted_len::TrustedLen; + +/// An [`Iterator`] over validity and values. +#[derive(Debug, Clone)] +pub struct ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + values: I, + validity: V, +} + +impl ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + /// Creates a new [`ZipValidityIter`]. + /// # Panics + /// This function panics if the size_hints of the iterators are different + pub fn new(values: I, validity: V) -> Self { + assert_eq!(values.size_hint(), validity.size_hint()); + Self { values, validity } + } +} + +impl Iterator for ZipValidityIter +where + I: Iterator, + V: Iterator, +{ + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + let value = self.values.next(); + let is_valid = self.validity.next(); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.values.size_hint() + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let value = self.values.nth(n); + let is_valid = self.validity.nth(n); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } +} + +impl DoubleEndedIterator for ZipValidityIter +where + I: DoubleEndedIterator, + V: DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + let value = self.values.next_back(); + let is_valid = self.validity.next_back(); + is_valid + .zip(value) + .map(|(is_valid, value)| is_valid.then(|| value)) + } +} + +unsafe impl TrustedLen for ZipValidityIter +where + I: TrustedLen, + V: TrustedLen, +{ +} + +impl ExactSizeIterator for ZipValidityIter +where + I: ExactSizeIterator, + V: ExactSizeIterator, +{ +} + +/// An [`Iterator`] over [`Option`] +/// This enum can be used in two distinct ways: +/// * as an iterator, via `Iterator::next` +/// * as an enum of two iterators, via `match self` +/// The latter allows specializalizing to when there are no nulls +#[derive(Debug, Clone)] +pub enum ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// There are no null values + Required(I), + /// There are null values + Optional(ZipValidityIter), +} + +impl ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// Returns a new [`ZipValidity`] + pub fn new(values: I, validity: Option) -> Self { + match validity { + Some(validity) => Self::Optional(ZipValidityIter::new(values, validity)), + _ => Self::Required(values), + } + } +} + +impl<'a, T, I> ZipValidity> +where + I: Iterator, +{ + /// Returns a new [`ZipValidity`] and drops the `validity` if all values + /// are valid. + pub fn new_with_validity(values: I, validity: Option<&'a Bitmap>) -> Self { + // only if the validity has nulls we take the optional branch. + match validity.and_then(|validity| (validity.unset_bits() > 0).then(|| validity.iter())) { + Some(validity) => Self::Optional(ZipValidityIter::new(values, validity)), + _ => Self::Required(values), + } + } +} + +impl Iterator for ZipValidity +where + I: Iterator, + V: Iterator, +{ + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + match self { + Self::Required(values) => values.next().map(Some), + Self::Optional(zipped) => zipped.next(), + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + match self { + Self::Required(values) => values.size_hint(), + Self::Optional(zipped) => zipped.size_hint(), + } + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + match self { + Self::Required(values) => values.nth(n).map(Some), + Self::Optional(zipped) => zipped.nth(n), + } + } +} + +impl DoubleEndedIterator for ZipValidity +where + I: DoubleEndedIterator, + V: DoubleEndedIterator, +{ + #[inline] + fn next_back(&mut self) -> Option { + match self { + Self::Required(values) => values.next_back().map(Some), + Self::Optional(zipped) => zipped.next_back(), + } + } +} + +impl ExactSizeIterator for ZipValidity +where + I: ExactSizeIterator, + V: ExactSizeIterator, +{ +} + +unsafe impl TrustedLen for ZipValidity +where + I: TrustedLen, + V: TrustedLen, +{ +} + +impl ZipValidity +where + I: Iterator, + V: Iterator, +{ + /// Unwrap into an iterator that has no null values. + pub fn unwrap_required(self) -> I { + match self { + ZipValidity::Required(i) => i, + _ => panic!("Could not 'unwrap_required'. 'ZipValidity' iterator has nulls."), + } + } + + /// Unwrap into an iterator that has null values. + pub fn unwrap_optional(self) -> ZipValidityIter { + match self { + ZipValidity::Optional(i) => i, + _ => panic!("Could not 'unwrap_optional'. 'ZipValidity' iterator has no nulls."), + } + } +} diff --git a/crates/nano-arrow/src/buffer/immutable.rs b/crates/nano-arrow/src/buffer/immutable.rs new file mode 100644 index 000000000000..b9b87336b359 --- /dev/null +++ b/crates/nano-arrow/src/buffer/immutable.rs @@ -0,0 +1,328 @@ +use std::iter::FromIterator; +use std::ops::Deref; +use std::sync::Arc; +use std::usize; + +use either::Either; + +use super::{Bytes, IntoIter}; + +/// [`Buffer`] is a contiguous memory region that can be shared across +/// thread boundaries. +/// +/// The easiest way to think about [`Buffer`] is being equivalent to +/// a `Arc>`, with the following differences: +/// * slicing and cloning is `O(1)`. +/// * it supports external allocated memory +/// +/// The easiest way to create one is to use its implementation of `From>`. +/// +/// # Examples +/// ``` +/// use arrow2::buffer::Buffer; +/// +/// let mut buffer: Buffer = vec![1, 2, 3].into(); +/// assert_eq!(buffer.as_ref(), [1, 2, 3].as_ref()); +/// +/// // it supports copy-on-write semantics (i.e. back to a `Vec`) +/// let vec: Vec = buffer.into_mut().right().unwrap(); +/// assert_eq!(vec, vec![1, 2, 3]); +/// +/// // cloning and slicing is `O(1)` (data is shared) +/// let mut buffer: Buffer = vec![1, 2, 3].into(); +/// let mut sliced = buffer.clone(); +/// sliced.slice(1, 1); +/// assert_eq!(sliced.as_ref(), [2].as_ref()); +/// // but cloning forbids getting mut since `slice` and `buffer` now share data +/// assert_eq!(buffer.get_mut_slice(), None); +/// ``` +#[derive(Clone)] +pub struct Buffer { + /// the internal byte buffer. + data: Arc>, + + /// The offset into the buffer. + offset: usize, + + // the length of the buffer. Given a region `data` of N bytes, [offset..offset+length] is visible + // to this buffer. + length: usize, +} + +impl PartialEq for Buffer { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.deref() == other.deref() + } +} + +impl std::fmt::Debug for Buffer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&**self, f) + } +} + +impl Default for Buffer { + #[inline] + fn default() -> Self { + Vec::new().into() + } +} + +impl Buffer { + /// Creates an empty [`Buffer`]. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Auxiliary method to create a new Buffer + pub(crate) fn from_bytes(bytes: Bytes) -> Self { + let length = bytes.len(); + Buffer { + data: Arc::new(bytes), + offset: 0, + length, + } + } + + /// Returns the number of bytes in the buffer + #[inline] + pub fn len(&self) -> usize { + self.length + } + + /// Returns whether the buffer is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns whether underlying data is sliced. + /// If sliced the [`Buffer`] is backed by + /// more data than the length of `Self`. + pub fn is_sliced(&self) -> bool { + self.data.len() != self.length + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[T] { + // Safety: + // invariant of this struct `offset + length <= data.len()` + debug_assert!(self.offset + self.length <= self.data.len()); + unsafe { + self.data + .get_unchecked(self.offset..self.offset + self.length) + } + } + + /// Returns the byte slice stored in this buffer + /// # Safety + /// `index` must be smaller than `len` + #[inline] + pub(super) unsafe fn get_unchecked(&self, index: usize) -> &T { + // Safety: + // invariant of this function + debug_assert!(index < self.length); + unsafe { self.data.get_unchecked(self.offset + index) } + } + + /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. + /// Doing so allows the same memory region to be shared between buffers. + /// # Panics + /// Panics iff `offset + length` is larger than `len`. + #[inline] + pub fn sliced(self, offset: usize, length: usize) -> Self { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + // Safety: we just checked bounds + unsafe { self.sliced_unchecked(offset, length) } + } + + /// Slices this buffer starting at `offset`. + /// # Panics + /// Panics iff `offset` is larger than `len`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!( + offset + length <= self.len(), + "the offset of the new Buffer cannot exceed the existing length" + ); + // Safety: we just checked bounds + unsafe { self.slice_unchecked(offset, length) } + } + + /// Returns a new [`Buffer`] that is a slice of this buffer starting at `offset`. + /// Doing so allows the same memory region to be shared between buffers. + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + #[must_use] + pub unsafe fn sliced_unchecked(mut self, offset: usize, length: usize) -> Self { + self.slice_unchecked(offset, length); + self + } + + /// Slices this buffer starting at `offset`. + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.offset += offset; + self.length = length; + } + + /// Returns a pointer to the start of this buffer. + #[inline] + pub(crate) fn as_ptr(&self) -> *const T { + self.data.deref().as_ptr() + } + + /// Returns the offset of this buffer. + #[inline] + pub fn offset(&self) -> usize { + self.offset + } + + /// # Safety + /// The caller must ensure that the buffer was properly initialized up to `len`. + #[inline] + pub unsafe fn set_len(&mut self, len: usize) { + self.length = len; + } + + /// Returns a mutable reference to its underlying [`Vec`], if possible. + /// + /// This operation returns [`Either::Right`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + #[inline] + pub fn into_mut(mut self) -> Either> { + match Arc::get_mut(&mut self.data) + .and_then(|b| b.get_vec()) + .map(std::mem::take) + { + Some(inner) => Either::Right(inner), + None => Either::Left(self), + } + } + + /// Returns a mutable reference to its underlying `Vec`, if possible. + /// Note that only `[self.offset(), self.offset() + self.len()[` in this vector is visible + /// by this buffer. + /// + /// This operation returns [`Some`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + /// # Safety + /// The caller must ensure that the vector in the mutable reference keeps a length of at least `self.offset() + self.len() - 1`. + #[inline] + pub unsafe fn get_mut(&mut self) -> Option<&mut Vec> { + Arc::get_mut(&mut self.data).and_then(|b| b.get_vec()) + } + + /// Returns a mutable reference to its slice, if possible. + /// + /// This operation returns [`Some`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + #[inline] + pub fn get_mut_slice(&mut self) -> Option<&mut [T]> { + Arc::get_mut(&mut self.data) + .and_then(|b| b.get_vec()) + // Safety: the invariant of this struct + .map(|x| unsafe { x.get_unchecked_mut(self.offset..self.offset + self.length) }) + } + + /// Get the strong count of underlying `Arc` data buffer. + pub fn shared_count_strong(&self) -> usize { + Arc::strong_count(&self.data) + } + + /// Get the weak count of underlying `Arc` data buffer. + pub fn shared_count_weak(&self) -> usize { + Arc::weak_count(&self.data) + } + + /// Returns its internal representation + #[must_use] + pub fn into_inner(self) -> (Arc>, usize, usize) { + let Self { + data, + offset, + length, + } = self; + (data, offset, length) + } + + /// Creates a `[Bitmap]` from its internal representation. + /// This is the inverted from `[Bitmap::into_inner]` + /// + /// # Safety + /// Callers must ensure all invariants of this struct are upheld. + pub unsafe fn from_inner_unchecked(data: Arc>, offset: usize, length: usize) -> Self { + Self { + data, + offset, + length, + } + } +} + +impl From> for Buffer { + #[inline] + fn from(p: Vec) -> Self { + let bytes: Bytes = p.into(); + Self { + offset: 0, + length: bytes.len(), + data: Arc::new(bytes), + } + } +} + +impl std::ops::Deref for Buffer { + type Target = [T]; + + #[inline] + fn deref(&self) -> &[T] { + self.as_slice() + } +} + +impl FromIterator for Buffer { + #[inline] + fn from_iter>(iter: I) -> Self { + Vec::from_iter(iter).into() + } +} + +impl IntoIterator for Buffer { + type Item = T; + + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +#[cfg(feature = "arrow")] +impl From for Buffer { + fn from(value: arrow_buffer::Buffer) -> Self { + Self::from_bytes(crate::buffer::to_bytes(value)) + } +} + +#[cfg(feature = "arrow")] +impl From> for arrow_buffer::Buffer { + fn from(value: Buffer) -> Self { + crate::buffer::to_buffer(value.data).slice_with_length( + value.offset * std::mem::size_of::(), + value.length * std::mem::size_of::(), + ) + } +} diff --git a/crates/nano-arrow/src/buffer/iterator.rs b/crates/nano-arrow/src/buffer/iterator.rs new file mode 100644 index 000000000000..93511c480284 --- /dev/null +++ b/crates/nano-arrow/src/buffer/iterator.rs @@ -0,0 +1,68 @@ +use super::Buffer; +use crate::trusted_len::TrustedLen; + +/// This crates' equivalent of [`std::vec::IntoIter`] for [`Buffer`]. +#[derive(Debug, Clone)] +pub struct IntoIter { + values: Buffer, + index: usize, + end: usize, +} + +impl IntoIter { + /// Creates a new [`Buffer`] + #[inline] + pub fn new(values: Buffer) -> Self { + let end = values.len(); + Self { + values, + index: 0, + end, + } + } +} + +impl Iterator for IntoIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + if self.index == self.end { + return None; + } + let old = self.index; + self.index += 1; + Some(*unsafe { self.values.get_unchecked(old) }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.end - self.index, Some(self.end - self.index)) + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + let new_index = self.index + n; + if new_index > self.end { + self.index = self.end; + None + } else { + self.index = new_index; + self.next() + } + } +} + +impl DoubleEndedIterator for IntoIter { + #[inline] + fn next_back(&mut self) -> Option { + if self.index == self.end { + None + } else { + self.end -= 1; + Some(*unsafe { self.values.get_unchecked(self.end) }) + } + } +} + +unsafe impl TrustedLen for IntoIter {} diff --git a/crates/nano-arrow/src/buffer/mod.rs b/crates/nano-arrow/src/buffer/mod.rs new file mode 100644 index 000000000000..b75825d0ada1 --- /dev/null +++ b/crates/nano-arrow/src/buffer/mod.rs @@ -0,0 +1,96 @@ +//! Contains [`Buffer`], an immutable container for all Arrow physical types (e.g. i32, f64). + +mod immutable; +mod iterator; + +use std::ops::Deref; + +use crate::ffi::InternalArrowArray; + +pub(crate) enum BytesAllocator { + InternalArrowArray(InternalArrowArray), + + #[cfg(feature = "arrow")] + Arrow(arrow_buffer::Buffer), +} +pub(crate) type BytesInner = foreign_vec::ForeignVec; + +/// Bytes representation. +#[repr(transparent)] +pub struct Bytes(BytesInner); + +impl Bytes { + /// Takes ownership of an allocated memory region. + /// # Panics + /// This function panics if and only if pointer is not null + /// # Safety + /// This function is safe if and only if `ptr` is valid for `length` + /// # Implementation + /// This function leaks if and only if `owner` does not deallocate + /// the region `[ptr, ptr+length[` when dropped. + #[inline] + pub(crate) unsafe fn from_foreign(ptr: *const T, length: usize, owner: BytesAllocator) -> Self { + Self(BytesInner::from_foreign(ptr, length, owner)) + } + + /// Returns a `Some` mutable reference of [`Vec`] iff this was initialized + /// from a [`Vec`] and `None` otherwise. + #[inline] + pub(crate) fn get_vec(&mut self) -> Option<&mut Vec> { + self.0.get_vec() + } +} + +impl Deref for Bytes { + type Target = [T]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From> for Bytes { + #[inline] + fn from(data: Vec) -> Self { + let inner: BytesInner = data.into(); + Bytes(inner) + } +} + +impl From> for Bytes { + #[inline] + fn from(value: BytesInner) -> Self { + Self(value) + } +} + +#[cfg(feature = "arrow")] +pub(crate) fn to_buffer( + value: std::sync::Arc>, +) -> arrow_buffer::Buffer { + // This should never panic as ForeignVec pointer must be non-null + let ptr = std::ptr::NonNull::new(value.as_ptr() as _).unwrap(); + let len = value.len() * std::mem::size_of::(); + // Safety: allocation is guaranteed to be valid for `len` bytes + unsafe { arrow_buffer::Buffer::from_custom_allocation(ptr, len, value) } +} + +#[cfg(feature = "arrow")] +pub(crate) fn to_bytes(value: arrow_buffer::Buffer) -> Bytes { + let ptr = value.as_ptr(); + let align = ptr.align_offset(std::mem::align_of::()); + assert_eq!(align, 0, "not aligned"); + let len = value.len() / std::mem::size_of::(); + + // Valid as `NativeType: Pod` and checked alignment above + let ptr = value.as_ptr() as *const T; + + let owner = crate::buffer::BytesAllocator::Arrow(value); + + // Safety: slice is valid for len elements of T + unsafe { Bytes::from_foreign(ptr, len, owner) } +} + +pub use immutable::Buffer; +pub(super) use iterator::IntoIter; diff --git a/crates/nano-arrow/src/chunk.rs b/crates/nano-arrow/src/chunk.rs new file mode 100644 index 000000000000..ffc857bcc134 --- /dev/null +++ b/crates/nano-arrow/src/chunk.rs @@ -0,0 +1,84 @@ +//! Contains [`Chunk`], a container of [`Array`] where every array has the +//! same length. + +use crate::array::Array; +use crate::error::{Error, Result}; + +/// A vector of trait objects of [`Array`] where every item has +/// the same length, [`Chunk::len`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Chunk> { + arrays: Vec, +} + +impl> Chunk { + /// Creates a new [`Chunk`]. + /// # Panic + /// Iff the arrays do not have the same length + pub fn new(arrays: Vec) -> Self { + Self::try_new(arrays).unwrap() + } + + /// Creates a new [`Chunk`]. + /// # Error + /// Iff the arrays do not have the same length + pub fn try_new(arrays: Vec) -> Result { + if !arrays.is_empty() { + let len = arrays.first().unwrap().as_ref().len(); + if arrays + .iter() + .map(|array| array.as_ref()) + .any(|array| array.len() != len) + { + return Err(Error::InvalidArgumentError( + "Chunk require all its arrays to have an equal number of rows".to_string(), + )); + } + } + Ok(Self { arrays }) + } + + /// returns the [`Array`]s in [`Chunk`] + pub fn arrays(&self) -> &[A] { + &self.arrays + } + + /// returns the [`Array`]s in [`Chunk`] + pub fn columns(&self) -> &[A] { + &self.arrays + } + + /// returns the number of rows of every array + pub fn len(&self) -> usize { + self.arrays + .first() + .map(|x| x.as_ref().len()) + .unwrap_or_default() + } + + /// returns whether the columns have any rows + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Consumes [`Chunk`] into its underlying arrays. + /// The arrays are guaranteed to have the same length + pub fn into_arrays(self) -> Vec { + self.arrays + } +} + +impl> From> for Vec { + fn from(c: Chunk) -> Self { + c.into_arrays() + } +} + +impl> std::ops::Deref for Chunk { + type Target = [A]; + + #[inline] + fn deref(&self) -> &[A] { + self.arrays() + } +} diff --git a/crates/nano-arrow/src/compute/README.md b/crates/nano-arrow/src/compute/README.md new file mode 100644 index 000000000000..6b5bec7e703e --- /dev/null +++ b/crates/nano-arrow/src/compute/README.md @@ -0,0 +1,32 @@ +# Design + +This document outlines the design guide lines of this module. + +This module is composed by independent operations common in analytics. Below are some design of its principles: + +- APIs MUST return an error when either: + - The arguments are incorrect + - The execution results in a predictable error (e.g. divide by zero) + +- APIs MAY error when an operation overflows (e.g. `i32 + i32`) + +- kernels MUST NOT have side-effects + +- kernels MUST NOT take ownership of any of its arguments (i.e. everything must be a reference). + +- APIs SHOULD error when an operation on variable sized containers can overflow the maximum size of `usize`. + +- Kernels SHOULD use the arrays' logical type to decide whether kernels + can be applied on an array. For example, `Date32 + Date32` is meaningless and SHOULD NOT be implemented. + +- Kernels SHOULD be implemented via `clone`, `slice` or the `iterator` API provided by `Buffer`, `Bitmap`, `Vec` or `MutableBitmap`. + +- Kernels MUST NOT use any API to read bits other than the ones provided by `Bitmap`. + +- Implementations SHOULD aim for auto-vectorization, which is usually accomplished via `from_trusted_len_iter`. + +- Implementations MUST feature-gate any implementation that requires external dependencies + +- When a kernel accepts dynamically-typed arrays, it MUST expect them as `&dyn Array`. + +- When an API returns `&dyn Array`, it MUST return `Box`. The rational is that a `Box` is mutable, while an `Arc` is not. As such, `Box` offers the most flexible API to consumers and the compiler. Users can cast a `Box` into `Arc` via `.into()`. diff --git a/crates/nano-arrow/src/compute/aggregate/memory.rs b/crates/nano-arrow/src/compute/aggregate/memory.rs new file mode 100644 index 000000000000..3af974a79b14 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/memory.rs @@ -0,0 +1,118 @@ +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; + +fn validity_size(validity: Option<&Bitmap>) -> usize { + validity.as_ref().map(|b| b.as_slice().0.len()).unwrap_or(0) +} + +macro_rules! dyn_binary { + ($array:expr, $ty:ty, $o:ty) => {{ + let array = $array.as_any().downcast_ref::<$ty>().unwrap(); + let offsets = array.offsets().buffer(); + + // in case of Binary/Utf8/List the offsets are sliced, + // not the values buffer + let values_start = offsets[0] as usize; + let values_end = offsets[offsets.len() - 1] as usize; + + values_end - values_start + + offsets.len() * std::mem::size_of::<$o>() + + validity_size(array.validity()) + }}; +} + +/// Returns the total (heap) allocated size of the array in bytes. +/// # Implementation +/// This estimation is the sum of the size of its buffers, validity, including nested arrays. +/// Multiple arrays may share buffers and bitmaps. Therefore, the size of 2 arrays is not the +/// sum of the sizes computed from this function. In particular, [`StructArray`]'s size is an upper bound. +/// +/// When an array is sliced, its allocated size remains constant because the buffer unchanged. +/// However, this function will yield a smaller number. This is because this function returns +/// the visible size of the buffer, not its total capacity. +/// +/// FFI buffers are included in this estimation. +pub fn estimated_bytes_size(array: &dyn Array) -> usize { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => 0, + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + array.values().as_slice().0.len() + validity_size(array.validity()) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + array.values().len() * std::mem::size_of::<$T>() + validity_size(array.validity()) + }), + Binary => dyn_binary!(array, BinaryArray, i32), + FixedSizeBinary => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + array.values().len() + validity_size(array.validity()) + }, + LargeBinary => dyn_binary!(array, BinaryArray, i64), + Utf8 => dyn_binary!(array, Utf8Array, i32), + LargeUtf8 => dyn_binary!(array, Utf8Array, i64), + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + estimated_bytes_size(array.values().as_ref()) + + array.offsets().len_proxy() * std::mem::size_of::() + + validity_size(array.validity()) + }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + estimated_bytes_size(array.values().as_ref()) + validity_size(array.validity()) + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + estimated_bytes_size(array.values().as_ref()) + + array.offsets().len_proxy() * std::mem::size_of::() + + validity_size(array.validity()) + }, + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + array + .values() + .iter() + .map(|x| x.as_ref()) + .map(estimated_bytes_size) + .sum::() + + validity_size(array.validity()) + }, + Union => { + let array = array.as_any().downcast_ref::().unwrap(); + let types = array.types().len() * std::mem::size_of::(); + let offsets = array + .offsets() + .as_ref() + .map(|x| x.len() * std::mem::size_of::()) + .unwrap_or_default(); + let fields = array + .fields() + .iter() + .map(|x| x.as_ref()) + .map(estimated_bytes_size) + .sum::(); + types + offsets + fields + }, + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + estimated_bytes_size(array.keys()) + estimated_bytes_size(array.values().as_ref()) + }), + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let offsets = array.offsets().len_proxy() * std::mem::size_of::(); + offsets + estimated_bytes_size(array.field().as_ref()) + validity_size(array.validity()) + }, + } +} diff --git a/crates/nano-arrow/src/compute/aggregate/min_max.rs b/crates/nano-arrow/src/compute/aggregate/min_max.rs new file mode 100644 index 000000000000..e733c6657ccd --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/min_max.rs @@ -0,0 +1,416 @@ +#![allow(clippy::redundant_closure_call)] +use multiversion::multiversion; + +use crate::array::{Array, BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; +use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; +use crate::error::{Error, Result}; +use crate::offset::Offset; +use crate::scalar::*; +use crate::types::simd::*; +use crate::types::NativeType; + +/// Trait describing a type describing multiple lanes with an order relationship +/// consistent with the same order of `T`. +pub trait SimdOrd { + /// The minimum value + const MIN: T; + /// The maximum value + const MAX: T; + /// reduce itself to the minimum + fn max_element(self) -> T; + /// reduce itself to the maximum + fn min_element(self) -> T; + /// lane-wise maximum between two instances + fn max_lane(self, x: Self) -> Self; + /// lane-wise minimum between two instances + fn min_lane(self, x: Self) -> Self; + /// returns a new instance with all lanes equal to `MIN` + fn new_min() -> Self; + /// returns a new instance with all lanes equal to `MAX` + fn new_max() -> Self; +} + +#[multiversion(targets = "simd")] +fn nonnull_min_primitive(values: &[T]) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let chunks = values.chunks_exact(T::Simd::LANES); + let remainder = chunks.remainder(); + + let chunk_reduced = chunks.fold(T::Simd::new_min(), |acc, chunk| { + let chunk = T::Simd::from_chunk(chunk); + acc.min_lane(chunk) + }); + + let remainder = T::Simd::from_incomplete_chunk(remainder, T::Simd::MAX); + let reduced = chunk_reduced.min_lane(remainder); + + reduced.min_element() +} + +#[multiversion(targets = "simd")] +fn null_min_primitive_impl(values: &[T], mut validity_masks: I) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, + I: BitChunkIterExact<<::Simd as NativeSimd>::Chunk>, +{ + let mut chunks = values.chunks_exact(T::Simd::LANES); + + let chunk_reduced = chunks.by_ref().zip(validity_masks.by_ref()).fold( + T::Simd::new_min(), + |acc, (chunk, validity_chunk)| { + let chunk = T::Simd::from_chunk(chunk); + let mask = ::Mask::from_chunk(validity_chunk); + let chunk = chunk.select(mask, T::Simd::new_min()); + acc.min_lane(chunk) + }, + ); + + let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::Simd::MAX); + let mask = ::Mask::from_chunk(validity_masks.remainder()); + let remainder = remainder.select(mask, T::Simd::new_min()); + let reduced = chunk_reduced.min_lane(remainder); + + reduced.min_element() +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +fn null_min_primitive(values: &[T], bitmap: &Bitmap) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let (slice, offset, length) = bitmap.as_slice(); + if offset == 0 { + let validity_masks = BitChunksExact::<::Chunk>::new(slice, length); + null_min_primitive_impl(values, validity_masks) + } else { + let validity_masks = bitmap.chunks::<::Chunk>(); + null_min_primitive_impl(values, validity_masks) + } +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +fn null_max_primitive(values: &[T], bitmap: &Bitmap) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let (slice, offset, length) = bitmap.as_slice(); + if offset == 0 { + let validity_masks = BitChunksExact::<::Chunk>::new(slice, length); + null_max_primitive_impl(values, validity_masks) + } else { + let validity_masks = bitmap.chunks::<::Chunk>(); + null_max_primitive_impl(values, validity_masks) + } +} + +#[multiversion(targets = "simd")] +fn nonnull_max_primitive(values: &[T]) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let chunks = values.chunks_exact(T::Simd::LANES); + let remainder = chunks.remainder(); + + let chunk_reduced = chunks.fold(T::Simd::new_max(), |acc, chunk| { + let chunk = T::Simd::from_chunk(chunk); + acc.max_lane(chunk) + }); + + let remainder = T::Simd::from_incomplete_chunk(remainder, T::Simd::MIN); + let reduced = chunk_reduced.max_lane(remainder); + + reduced.max_element() +} + +#[multiversion(targets = "simd")] +fn null_max_primitive_impl(values: &[T], mut validity_masks: I) -> T +where + T: NativeType + Simd, + T::Simd: SimdOrd, + I: BitChunkIterExact<<::Simd as NativeSimd>::Chunk>, +{ + let mut chunks = values.chunks_exact(T::Simd::LANES); + + let chunk_reduced = chunks.by_ref().zip(validity_masks.by_ref()).fold( + T::Simd::new_max(), + |acc, (chunk, validity_chunk)| { + let chunk = T::Simd::from_chunk(chunk); + let mask = ::Mask::from_chunk(validity_chunk); + let chunk = chunk.select(mask, T::Simd::new_max()); + acc.max_lane(chunk) + }, + ); + + let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::Simd::MIN); + let mask = ::Mask::from_chunk(validity_masks.remainder()); + let remainder = remainder.select(mask, T::Simd::new_max()); + let reduced = chunk_reduced.max_lane(remainder); + + reduced.max_element() +} + +/// Returns the minimum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn min_primitive(array: &PrimitiveArray) -> Option +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let null_count = array.null_count(); + + // Includes case array.len() == 0 + if null_count == array.len() { + return None; + } + let values = array.values(); + + Some(if let Some(validity) = array.validity() { + null_min_primitive(values, validity) + } else { + nonnull_min_primitive(values) + }) +} + +/// Returns the maximum value in the array, according to the natural order. +/// For floating point arrays any NaN values are considered to be greater than any other non-null value +pub fn max_primitive(array: &PrimitiveArray) -> Option +where + T: NativeType + Simd, + T::Simd: SimdOrd, +{ + let null_count = array.null_count(); + + // Includes case array.len() == 0 + if null_count == array.len() { + return None; + } + let values = array.values(); + + Some(if let Some(validity) = array.validity() { + null_max_primitive(values, validity) + } else { + nonnull_max_primitive(values) + }) +} + +/// Helper to compute min/max of [`BinaryArray`] and [`Utf8Array`] +macro_rules! min_max_binary_utf8 { + ($array: expr, $cmp: expr) => { + if $array.null_count() == $array.len() { + None + } else if $array.validity().is_some() { + $array + .iter() + .reduce(|v1, v2| match (v1, v2) { + (None, v2) => v2, + (v1, None) => v1, + (Some(v1), Some(v2)) => { + if $cmp(v1, v2) { + Some(v2) + } else { + Some(v1) + } + }, + }) + .unwrap_or(None) + } else { + $array + .values_iter() + .reduce(|v1, v2| if $cmp(v1, v2) { v2 } else { v1 }) + } + }; +} + +/// Returns the maximum value in the binary array, according to the natural order. +pub fn max_binary(array: &BinaryArray) -> Option<&[u8]> { + min_max_binary_utf8!(array, |a, b| a < b) +} + +/// Returns the minimum value in the binary array, according to the natural order. +pub fn min_binary(array: &BinaryArray) -> Option<&[u8]> { + min_max_binary_utf8!(array, |a, b| a > b) +} + +/// Returns the maximum value in the string array, according to the natural order. +pub fn max_string(array: &Utf8Array) -> Option<&str> { + min_max_binary_utf8!(array, |a, b| a < b) +} + +/// Returns the minimum value in the string array, according to the natural order. +pub fn min_string(array: &Utf8Array) -> Option<&str> { + min_max_binary_utf8!(array, |a, b| a > b) +} + +/// Returns the minimum value in the boolean array. +/// +/// ``` +/// use arrow2::{ +/// array::BooleanArray, +/// compute::aggregate::min_boolean, +/// }; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(min_boolean(&a), Some(false)) +/// ``` +pub fn min_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + Some(array.values().unset_bits() == 0) + } else { + // Note the min bool is false (0), so short circuit as soon as we see it + array + .iter() + .find(|&b| b == Some(false)) + .flatten() + .or(Some(true)) + } +} + +/// Returns the maximum value in the boolean array +/// +/// ``` +/// use arrow2::{ +/// array::BooleanArray, +/// compute::aggregate::max_boolean, +/// }; +/// +/// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); +/// assert_eq!(max_boolean(&a), Some(true)) +/// ``` +pub fn max_boolean(array: &BooleanArray) -> Option { + // short circuit if all nulls / zero length array + let null_count = array.null_count(); + if null_count == array.len() { + None + } else if null_count == 0 { + Some(array.values().unset_bits() < array.len()) + } else { + // Note the max bool is true (1), so short circuit as soon as we see it + array + .iter() + .find(|&b| b == Some(true)) + .flatten() + .or(Some(false)) + } +} + +macro_rules! dyn_generic { + ($array_ty:ty, $scalar_ty:ty, $array:expr, $f:ident) => {{ + let array = $array.as_any().downcast_ref::<$array_ty>().unwrap(); + Box::new(<$scalar_ty>::new($f(array))) + }}; +} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => return Err(Error::InvalidArgumentError(format!( + "`min` and `max` operator do not support primitive `{:?}`", + $key_type, + ))), + } +})} + +/// Returns the maximum of [`Array`]. The scalar is null when all elements are null. +/// # Error +/// Errors iff the type does not support this operation. +pub fn max(array: &dyn Array) -> Result> { + Ok(match array.data_type().to_physical_type() { + PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, max_boolean), + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::<$T>::new(data_type, max_primitive::<$T>(array))) + }), + PhysicalType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, max_string), + PhysicalType::Binary => { + dyn_generic!(BinaryArray, BinaryScalar, array, max_binary) + }, + PhysicalType::LargeBinary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) + }, + _ => { + return Err(Error::InvalidArgumentError(format!( + "The `max` operator does not support type `{:?}`", + array.data_type(), + ))) + }, + }) +} + +/// Returns the minimum of [`Array`]. The scalar is null when all elements are null. +/// # Error +/// Errors iff the type does not support this operation. +pub fn min(array: &dyn Array) -> Result> { + Ok(match array.data_type().to_physical_type() { + PhysicalType::Boolean => dyn_generic!(BooleanArray, BooleanScalar, array, min_boolean), + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::<$T>::new(data_type, min_primitive::<$T>(array))) + }), + PhysicalType::Utf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + PhysicalType::LargeUtf8 => dyn_generic!(Utf8Array, Utf8Scalar, array, min_string), + PhysicalType::Binary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) + }, + PhysicalType::LargeBinary => { + dyn_generic!(BinaryArray, BinaryScalar, array, min_binary) + }, + _ => { + return Err(Error::InvalidArgumentError(format!( + "The `max` operator does not support type `{:?}`", + array.data_type(), + ))) + }, + }) +} + +/// Whether [`min`] supports `data_type` +pub fn can_min(data_type: &DataType) -> bool { + let physical = data_type.to_physical_type(); + if let PhysicalType::Primitive(primitive) = physical { + use PrimitiveType::*; + matches!( + primitive, + Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 + ) + } else { + use PhysicalType::*; + matches!(physical, Boolean | Utf8 | LargeUtf8 | Binary | LargeBinary) + } +} + +/// Whether [`max`] supports `data_type` +pub fn can_max(data_type: &DataType) -> bool { + can_min(data_type) +} diff --git a/crates/nano-arrow/src/compute/aggregate/mod.rs b/crates/nano-arrow/src/compute/aggregate/mod.rs new file mode 100644 index 000000000000..b513238f9fd9 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/mod.rs @@ -0,0 +1,15 @@ +//! Contains different aggregation functions +#[cfg(feature = "compute_aggregate")] +mod sum; +#[cfg(feature = "compute_aggregate")] +pub use sum::*; + +#[cfg(feature = "compute_aggregate")] +mod min_max; +#[cfg(feature = "compute_aggregate")] +pub use min_max::*; + +mod memory; +pub use memory::*; +#[cfg(feature = "compute_aggregate")] +mod simd; diff --git a/crates/nano-arrow/src/compute/aggregate/simd/mod.rs b/crates/nano-arrow/src/compute/aggregate/simd/mod.rs new file mode 100644 index 000000000000..25558e9a9e19 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/simd/mod.rs @@ -0,0 +1,109 @@ +use std::ops::Add; + +use super::{SimdOrd, Sum}; +use crate::types::simd::{i128x8, NativeSimd}; + +macro_rules! simd_add { + ($simd:tt, $type:ty, $lanes:expr, $add:tt) => { + impl std::ops::AddAssign for $simd { + #[inline] + fn add_assign(&mut self, rhs: Self) { + for i in 0..$lanes { + self[i] = <$type>::$add(self[i], rhs[i]); + } + } + } + + impl std::ops::Add for $simd { + type Output = Self; + + #[inline] + fn add(self, rhs: Self) -> Self::Output { + let mut result = Self::default(); + for i in 0..$lanes { + result[i] = <$type>::$add(self[i], rhs[i]); + } + result + } + } + + impl Sum<$type> for $simd { + #[inline] + fn simd_sum(self) -> $type { + let mut reduced = <$type>::default(); + (0..<$simd>::LANES).for_each(|i| { + reduced += self[i]; + }); + reduced + } + } + }; +} + +macro_rules! simd_ord_int { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::MIN; + const MAX: $type = <$type>::MAX; + + #[inline] + fn max_element(self) -> $type { + self.0.iter().copied().fold(Self::MIN, <$type>::max) + } + + #[inline] + fn min_element(self) -> $type { + self.0.iter().copied().fold(Self::MAX, <$type>::min) + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).max(*c)); + result + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).min(*c)); + result + } + + #[inline] + fn new_min() -> Self { + Self([Self::MAX; <$simd>::LANES]) + } + + #[inline] + fn new_max() -> Self { + Self([Self::MIN; <$simd>::LANES]) + } + } + }; +} + +pub(super) use {simd_add, simd_ord_int}; + +simd_add!(i128x8, i128, 8, add); +simd_ord_int!(i128x8, i128); + +#[cfg(not(feature = "simd"))] +mod native; +#[cfg(not(feature = "simd"))] +pub use native::*; +#[cfg(feature = "simd")] +mod packed; +#[cfg(feature = "simd")] +#[cfg_attr(docsrs, doc(cfg(feature = "simd")))] +pub use packed::*; diff --git a/crates/nano-arrow/src/compute/aggregate/simd/native.rs b/crates/nano-arrow/src/compute/aggregate/simd/native.rs new file mode 100644 index 000000000000..d6a0275f35e9 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/simd/native.rs @@ -0,0 +1,81 @@ +use std::ops::Add; + +use super::super::min_max::SimdOrd; +use super::super::sum::Sum; +use super::{simd_add, simd_ord_int}; +use crate::types::simd::*; + +simd_add!(u8x64, u8, 64, wrapping_add); +simd_add!(u16x32, u16, 32, wrapping_add); +simd_add!(u32x16, u32, 16, wrapping_add); +simd_add!(u64x8, u64, 8, wrapping_add); +simd_add!(i8x64, i8, 64, wrapping_add); +simd_add!(i16x32, i16, 32, wrapping_add); +simd_add!(i32x16, i32, 16, wrapping_add); +simd_add!(i64x8, i64, 8, wrapping_add); +simd_add!(f32x16, f32, 16, add); +simd_add!(f64x8, f64, 8, add); + +macro_rules! simd_ord_float { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::NAN; + const MAX: $type = <$type>::NAN; + + #[inline] + fn max_element(self) -> $type { + self.0.iter().copied().fold(Self::MIN, <$type>::max) + } + + #[inline] + fn min_element(self) -> $type { + self.0.iter().copied().fold(Self::MAX, <$type>::min) + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).max(*c)); + result + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + let mut result = <$simd>::default(); + result + .0 + .iter_mut() + .zip(self.0.iter()) + .zip(x.0.iter()) + .for_each(|((a, b), c)| *a = (*b).min(*c)); + result + } + + #[inline] + fn new_min() -> Self { + Self([Self::MAX; <$simd>::LANES]) + } + + #[inline] + fn new_max() -> Self { + Self([Self::MIN; <$simd>::LANES]) + } + } + }; +} + +simd_ord_int!(u8x64, u8); +simd_ord_int!(u16x32, u16); +simd_ord_int!(u32x16, u32); +simd_ord_int!(u64x8, u64); +simd_ord_int!(i8x64, i8); +simd_ord_int!(i16x32, i16); +simd_ord_int!(i32x16, i32); +simd_ord_int!(i64x8, i64); +simd_ord_float!(f32x16, f32); +simd_ord_float!(f64x8, f64); diff --git a/crates/nano-arrow/src/compute/aggregate/simd/packed.rs b/crates/nano-arrow/src/compute/aggregate/simd/packed.rs new file mode 100644 index 000000000000..40094d31e239 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/simd/packed.rs @@ -0,0 +1,116 @@ +use std::simd::{SimdFloat as _, SimdInt as _, SimdOrd as _, SimdUint as _}; + +use super::super::min_max::SimdOrd; +use super::super::sum::Sum; +use crate::types::simd::*; + +macro_rules! simd_sum { + ($simd:tt, $type:ty, $sum:tt) => { + impl Sum<$type> for $simd { + #[inline] + fn simd_sum(self) -> $type { + self.$sum() + } + } + }; +} + +simd_sum!(f32x16, f32, reduce_sum); +simd_sum!(f64x8, f64, reduce_sum); +simd_sum!(u8x64, u8, reduce_sum); +simd_sum!(u16x32, u16, reduce_sum); +simd_sum!(u32x16, u32, reduce_sum); +simd_sum!(u64x8, u64, reduce_sum); +simd_sum!(i8x64, i8, reduce_sum); +simd_sum!(i16x32, i16, reduce_sum); +simd_sum!(i32x16, i32, reduce_sum); +simd_sum!(i64x8, i64, reduce_sum); + +macro_rules! simd_ord_int { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::MIN; + const MAX: $type = <$type>::MAX; + + #[inline] + fn max_element(self) -> $type { + self.reduce_max() + } + + #[inline] + fn min_element(self) -> $type { + self.reduce_min() + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + self.simd_max(x) + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + self.simd_min(x) + } + + #[inline] + fn new_min() -> Self { + Self::splat(Self::MAX) + } + + #[inline] + fn new_max() -> Self { + Self::splat(Self::MIN) + } + } + }; +} + +macro_rules! simd_ord_float { + ($simd:tt, $type:ty) => { + impl SimdOrd<$type> for $simd { + const MIN: $type = <$type>::NAN; + const MAX: $type = <$type>::NAN; + + #[inline] + fn max_element(self) -> $type { + self.reduce_max() + } + + #[inline] + fn min_element(self) -> $type { + self.reduce_min() + } + + #[inline] + fn max_lane(self, x: Self) -> Self { + self.simd_max(x) + } + + #[inline] + fn min_lane(self, x: Self) -> Self { + self.simd_min(x) + } + + #[inline] + fn new_min() -> Self { + Self::splat(<$type>::NAN) + } + + #[inline] + fn new_max() -> Self { + Self::splat(<$type>::NAN) + } + } + }; +} + +simd_ord_int!(u8x64, u8); +simd_ord_int!(u16x32, u16); +simd_ord_int!(u32x16, u32); +simd_ord_int!(u64x8, u64); +simd_ord_int!(i8x64, i8); +simd_ord_int!(i16x32, i16); +simd_ord_int!(i32x16, i32); +simd_ord_int!(i64x8, i64); +simd_ord_float!(f32x16, f32); +simd_ord_float!(f64x8, f64); diff --git a/crates/nano-arrow/src/compute/aggregate/sum.rs b/crates/nano-arrow/src/compute/aggregate/sum.rs new file mode 100644 index 000000000000..738440c9f0d2 --- /dev/null +++ b/crates/nano-arrow/src/compute/aggregate/sum.rs @@ -0,0 +1,159 @@ +use std::ops::Add; + +use multiversion::multiversion; + +use crate::array::{Array, PrimitiveArray}; +use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; +use crate::error::{Error, Result}; +use crate::scalar::*; +use crate::types::simd::*; +use crate::types::NativeType; + +/// Object that can reduce itself to a number. This is used in the context of SIMD to reduce +/// a MD (e.g. `[f32; 16]`) into a single number (`f32`). +pub trait Sum { + /// Reduces this element to a single value. + fn simd_sum(self) -> T; +} + +#[multiversion(targets = "simd")] +/// Compute the sum of a slice +pub fn sum_slice(values: &[T]) -> T +where + T: NativeType + Simd + Add + std::iter::Sum, + T::Simd: Sum + Add, +{ + let (head, simd_vals, tail) = T::Simd::align(values); + + let mut reduced = T::Simd::from_incomplete_chunk(&[], T::default()); + for chunk in simd_vals { + reduced = reduced + *chunk; + } + + reduced.simd_sum() + head.iter().copied().sum() + tail.iter().copied().sum() +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +#[multiversion(targets = "simd")] +fn null_sum_impl(values: &[T], mut validity_masks: I) -> T +where + T: NativeType + Simd, + T::Simd: Add + Sum, + I: BitChunkIterExact<<::Simd as NativeSimd>::Chunk>, +{ + let mut chunks = values.chunks_exact(T::Simd::LANES); + + let sum = chunks.by_ref().zip(validity_masks.by_ref()).fold( + T::Simd::default(), + |acc, (chunk, validity_chunk)| { + let chunk = T::Simd::from_chunk(chunk); + let mask = ::Mask::from_chunk(validity_chunk); + let selected = chunk.select(mask, T::Simd::default()); + acc + selected + }, + ); + + let remainder = T::Simd::from_incomplete_chunk(chunks.remainder(), T::default()); + let mask = ::Mask::from_chunk(validity_masks.remainder()); + let remainder = remainder.select(mask, T::Simd::default()); + let reduced = sum + remainder; + + reduced.simd_sum() +} + +/// # Panics +/// iff `values.len() != bitmap.len()` or the operation overflows. +fn null_sum(values: &[T], bitmap: &Bitmap) -> T +where + T: NativeType + Simd, + T::Simd: Add + Sum, +{ + let (slice, offset, length) = bitmap.as_slice(); + if offset == 0 { + let validity_masks = BitChunksExact::<::Chunk>::new(slice, length); + null_sum_impl(values, validity_masks) + } else { + let validity_masks = bitmap.chunks::<::Chunk>(); + null_sum_impl(values, validity_masks) + } +} + +/// Returns the sum of values in the array. +/// +/// Returns `None` if the array is empty or only contains null values. +pub fn sum_primitive(array: &PrimitiveArray) -> Option +where + T: NativeType + Simd + Add + std::iter::Sum, + T::Simd: Add + Sum, +{ + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + match array.validity() { + None => Some(sum_slice(array.values())), + Some(bitmap) => Some(null_sum(array.values(), bitmap)), + } +} + +/// Whether [`sum`] supports `data_type` +pub fn can_sum(data_type: &DataType) -> bool { + if let PhysicalType::Primitive(primitive) = data_type.to_physical_type() { + use PrimitiveType::*; + matches!( + primitive, + Int8 | Int16 | Int64 | Int128 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 + ) + } else { + false + } +} + +macro_rules! with_match_primitive_type {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + _ => return Err(Error::InvalidArgumentError(format!( + "`sum` operator do not support primitive `{:?}`", + $key_type, + ))), + } +})} + +/// Returns the sum of all elements in `array` as a [`Scalar`] of the same physical +/// and logical types as `array`. +/// # Error +/// Errors iff the operation is not supported. +pub fn sum(array: &dyn Array) -> Result> { + Ok(match array.data_type().to_physical_type() { + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let data_type = array.data_type().clone(); + let array = array.as_any().downcast_ref().unwrap(); + Box::new(PrimitiveScalar::new(data_type, sum_primitive::<$T>(array))) + }), + _ => { + return Err(Error::InvalidArgumentError(format!( + "The `sum` operator does not support type `{:?}`", + array.data_type(), + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/add.rs b/crates/nano-arrow/src/compute/arithmetics/basic/add.rs new file mode 100644 index 000000000000..5919b65fdbd5 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/add.rs @@ -0,0 +1,337 @@ +//! Definition of basic add operations with primitive arrays +use std::ops::Add; + +use num_traits::ops::overflowing::OverflowingAdd; +use num_traits::{CheckedAdd, SaturatingAdd, WrappingAdd}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayAdd, ArrayCheckedAdd, ArrayOverflowingAdd, ArraySaturatingAdd, ArrayWrappingAdd, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; + +/// Adds two primitive arrays with the same type. +/// Panics if the sum of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); +/// let b = PrimitiveArray::from([Some(5), None, None, Some(6)]); +/// let result = add(&a, &b); +/// let expected = PrimitiveArray::from([None, None, None, Some(12)]); +/// assert_eq!(result, expected) +/// ``` +pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Add, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a + b) +} + +/// Wrapping addition of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(-100i8), Some(100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = wrapping_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(-100i8), Some(-56i8), Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + let op = move |a: T, b: T| a.wrapping_add(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked addition of two primitive arrays. If the result from the sum +/// overflows, the validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8), Some(100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = checked_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(100i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + let op = move |a: T, b: T| a.checked_add(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating addition of two primitive arrays. If the result from the sum is +/// larger than the possible number for this type, the result for the operation +/// will be the saturated value. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8)]); +/// let b = PrimitiveArray::from([Some(100i8)]); +/// let result = saturating_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(127)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + let op = move |a: T, b: T| a.saturating_add(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing addition of two primitive arrays. If the result from the sum is +/// larger than the possible number for this type, the result for the operation +/// will be an array with overflowed values and a validity array indicating +/// the overflowing elements from the array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_add; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(2i8), Some(-56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingAdd, +{ + let op = move |a: T, b: T| a.overflowing_add(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays +impl ArrayAdd> for PrimitiveArray +where + T: NativeArithmetics + Add, +{ + fn add(&self, rhs: &PrimitiveArray) -> Self { + add(self, rhs) + } +} + +impl ArrayWrappingAdd> for PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + fn wrapping_add(&self, rhs: &PrimitiveArray) -> Self { + wrapping_add(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays +impl ArrayCheckedAdd> for PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + fn checked_add(&self, rhs: &PrimitiveArray) -> Self { + checked_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArraySaturatingAdd> for PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { + saturating_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArrayOverflowingAdd> for PrimitiveArray +where + T: NativeArithmetics + OverflowingAdd, +{ + fn overflowing_add(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_add(self, rhs) + } +} + +/// Adds a scalar T to a primitive array of type T. +/// Panics if the sum of the values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::add_scalar; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(6)]); +/// let result = add_scalar(&a, &1i32); +/// let expected = PrimitiveArray::from([None, Some(7), None, Some(7)]); +/// assert_eq!(result, expected) +/// ``` +pub fn add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Add, +{ + let rhs = *rhs; + unary(lhs, |a| a + rhs, lhs.data_type().clone()) +} + +/// Wrapping addition of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_add_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100)]); +/// let result = wrapping_add_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, Some(-56)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingAdd, +{ + unary(lhs, |a| a.wrapping_add(rhs), lhs.data_type().clone()) +} + +/// Checked addition of a scalar T to a primitive array of type T. If the +/// result from the sum overflows then the validity index for that value is +/// changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_add_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); +/// let result = checked_add_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_add(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated addition of a scalar T to a primitive array of type T. If the +/// result from the sum is larger than the possible number for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_add_scalar; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8)]); +/// let result = saturating_add_scalar(&a, &100i8); +/// let expected = PrimitiveArray::from([Some(127)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_add(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing addition of a scalar T to a primitive array of type T. If the +/// result from the sum is larger than the possible number for this type, then +/// the result will be an array with overflowed values and a validity array +/// indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_add_scalar; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_add_scalar(&a, &100i8); +/// let expected = PrimitiveArray::from([Some(101i8), Some(-56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_add_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingAdd, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_add(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays with a scalar +impl ArrayAdd for PrimitiveArray +where + T: NativeArithmetics + Add, +{ + fn add(&self, rhs: &T) -> Self { + add_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays with a scalar +impl ArrayCheckedAdd for PrimitiveArray +where + T: NativeArithmetics + CheckedAdd, +{ + fn checked_add(&self, rhs: &T) -> Self { + checked_add_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar +impl ArraySaturatingAdd for PrimitiveArray +where + T: NativeArithmetics + SaturatingAdd, +{ + fn saturating_add(&self, rhs: &T) -> Self { + saturating_add_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays with a scalar +impl ArrayOverflowingAdd for PrimitiveArray +where + T: NativeArithmetics + OverflowingAdd, +{ + fn overflowing_add(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_add_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/div.rs b/crates/nano-arrow/src/compute/arithmetics/basic/div.rs new file mode 100644 index 000000000000..eb8f2ae0ac7c --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/div.rs @@ -0,0 +1,204 @@ +//! Definition of basic div operations with primitive arrays +use std::ops::Div; + +use num_traits::{CheckedDiv, NumCast}; +use strength_reduce::{ + StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, +}; + +use super::NativeArithmetics; +use crate::array::{Array, PrimitiveArray}; +use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; +use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; +use crate::compute::utils::check_same_len; +use crate::datatypes::PrimitiveType; + +/// Divides two primitive arrays with the same type. +/// Panics if the divisor is zero of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::div; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[Some(10), Some(1), Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, Some(6)]); +/// let result = div(&a, &b); +/// let expected = Int32Array::from(&[Some(2), None, Some(1)]); +/// assert_eq!(result, expected) +/// ``` +pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Div, +{ + if rhs.null_count() == 0 { + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a / b) + } else { + check_same_len(lhs, rhs).unwrap(); + let values = lhs.iter().zip(rhs.iter()).map(|(l, r)| match (l, r) { + (Some(l), Some(r)) => Some(*l / *r), + _ => None, + }); + + PrimitiveArray::from_trusted_len_iter(values).to(lhs.data_type().clone()) + } +} + +/// Checked division of two primitive arrays. If the result from the division +/// overflows, the result for the operation will change the validity array +/// making this operation None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_div; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8), Some(10i8)]); +/// let b = Int8Array::from(&[Some(100i8), Some(0i8)]); +/// let result = checked_div(&a, &b); +/// let expected = Int8Array::from(&[Some(-1i8), None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + let op = move |a: T, b: T| a.checked_div(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayDiv trait for PrimitiveArrays +impl ArrayDiv> for PrimitiveArray +where + T: NativeArithmetics + Div, +{ + fn div(&self, rhs: &PrimitiveArray) -> Self { + div(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays +impl ArrayCheckedDiv> for PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + fn checked_div(&self, rhs: &PrimitiveArray) -> Self { + checked_div(self, rhs) + } +} + +/// Divide a primitive array of type T by a scalar T. +/// Panics if the divisor is zero. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::div_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = div_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(3), None, Some(3)]); +/// assert_eq!(result, expected) +/// ``` +pub fn div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Div + NumCast, +{ + let rhs = *rhs; + match T::PRIMITIVE { + PrimitiveType::UInt64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u64().unwrap(); + + let reduced_div = StrengthReducedU64::new(rhs); + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u32().unwrap(); + + let reduced_div = StrengthReducedU32::new(rhs); + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u16().unwrap(); + + let reduced_div = StrengthReducedU16::new(rhs); + + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u8().unwrap(); + + let reduced_div = StrengthReducedU8::new(rhs); + let r = unary(lhs, |a| a / reduced_div, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + _ => unary(lhs, |a| a / rhs, lhs.data_type().clone()), + } +} + +/// Checked division of a primitive array of type T by a scalar T. If the +/// divisor is zero then the validity array is changed to None. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_div_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = checked_div_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-1i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_div_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_div(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayDiv trait for PrimitiveArrays with a scalar +impl ArrayDiv for PrimitiveArray +where + T: NativeArithmetics + Div + NumCast, +{ + fn div(&self, rhs: &T) -> Self { + div_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays with a scalar +impl ArrayCheckedDiv for PrimitiveArray +where + T: NativeArithmetics + CheckedDiv, +{ + fn checked_div(&self, rhs: &T) -> Self { + checked_div_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/mod.rs b/crates/nano-arrow/src/compute/arithmetics/basic/mod.rs new file mode 100644 index 000000000000..898a69f59536 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/mod.rs @@ -0,0 +1,100 @@ +//! Contains arithmetic functions for [`PrimitiveArray`]s. +//! +//! Each operation has four variants, like the rest of Rust's ecosystem: +//! * usual, that [`panic!`]s on overflow +//! * `checked_*` that turns overflowings to `None` +//! * `overflowing_*` returning a [`Bitmap`](crate::bitmap::Bitmap) with items that overflow. +//! * `saturating_*` that saturates the result. +mod add; +pub use add::*; +mod div; +pub use div::*; +mod mul; +pub use mul::*; +mod pow; +pub use pow::*; +mod rem; +pub use rem::*; +mod sub; +use std::ops::Neg; + +use num_traits::{CheckedNeg, WrappingNeg}; +pub use sub::*; + +use super::super::arity::{unary, unary_checked}; +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +/// Trait describing a [`NativeType`] whose semantics of arithmetic in Arrow equals +/// the semantics in Rust. +/// A counter example is `i128`, that in arrow represents a decimal while in rust represents +/// a signed integer. +pub trait NativeArithmetics: NativeType {} +impl NativeArithmetics for u8 {} +impl NativeArithmetics for u16 {} +impl NativeArithmetics for u32 {} +impl NativeArithmetics for u64 {} +impl NativeArithmetics for i8 {} +impl NativeArithmetics for i16 {} +impl NativeArithmetics for i32 {} +impl NativeArithmetics for i64 {} +impl NativeArithmetics for f32 {} +impl NativeArithmetics for f64 {} + +/// Negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::negate; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([None, Some(6), None, Some(7)]); +/// let result = negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); +/// assert_eq!(result, expected) +/// ``` +pub fn negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Neg, +{ + unary(array, |a| -a, array.data_type().clone()) +} + +/// Checked negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_negate; +/// use arrow2::array::{Array, PrimitiveArray}; +/// +/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); +/// let result = checked_negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), None, Some(-7)]); +/// assert_eq!(result, expected); +/// assert!(!result.is_valid(2)) +/// ``` +pub fn checked_negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + CheckedNeg, +{ + unary_checked(array, |a| a.checked_neg(), array.data_type().clone()) +} + +/// Wrapping negates values from array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_negate; +/// use arrow2::array::{Array, PrimitiveArray}; +/// +/// let a = PrimitiveArray::from([None, Some(6), Some(i8::MIN), Some(7)]); +/// let result = wrapping_negate(&a); +/// let expected = PrimitiveArray::from([None, Some(-6), Some(i8::MIN), Some(-7)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_negate(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + WrappingNeg, +{ + unary(array, |a| a.wrapping_neg(), array.data_type().clone()) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/mul.rs b/crates/nano-arrow/src/compute/arithmetics/basic/mul.rs new file mode 100644 index 000000000000..e006abe186e5 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/mul.rs @@ -0,0 +1,338 @@ +//! Definition of basic mul operations with primitive arrays +use std::ops::Mul; + +use num_traits::ops::overflowing::OverflowingMul; +use num_traits::{CheckedMul, SaturatingMul, WrappingMul}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayCheckedMul, ArrayMul, ArrayOverflowingMul, ArraySaturatingMul, ArrayWrappingMul, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; + +/// Multiplies two primitive arrays with the same type. +/// Panics if the multiplication of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::mul; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); +/// let result = mul(&a, &b); +/// let expected = Int32Array::from(&[None, None, None, Some(36)]); +/// assert_eq!(result, expected) +/// ``` +pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a * b) +} + +/// Wrapping multiplication of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_mul; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(100i8), Some(0x10i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(0x10i8), Some(0i8)]); +/// let result = wrapping_mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(0), Some(0), Some(0)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + let op = move |a: T, b: T| a.wrapping_mul(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked multiplication of two primitive arrays. If the result from the +/// multiplications overflows, the validity for that index is changed +/// returned. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_mul; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(100i8), Some(100i8), Some(100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(1i8)]); +/// let result = checked_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(100i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + let op = move |a: T, b: T| a.checked_mul(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating multiplication of two primitive arrays. If the result from the +/// multiplication overflows, the result for the +/// operation will be the saturated value. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_mul; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let b = Int8Array::from(&[Some(100i8)]); +/// let result = saturating_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(-128)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + let op = move |a: T, b: T| a.saturating_mul(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing multiplication of two primitive arrays. If the result from the +/// mul overflows, the result for the operation will be an array with +/// overflowed values and a validity array indicating the overflowing elements +/// from the array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_mul; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_mul(&a, &b); +/// let expected = Int8Array::from(&[Some(1i8), Some(-16i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingMul, +{ + let op = move |a: T, b: T| a.overflowing_mul(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayMul trait for PrimitiveArrays +impl ArrayMul> for PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + fn mul(&self, rhs: &PrimitiveArray) -> Self { + mul(self, rhs) + } +} + +impl ArrayWrappingMul> for PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + fn wrapping_mul(&self, rhs: &PrimitiveArray) -> Self { + wrapping_mul(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays +impl ArrayCheckedMul> for PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { + checked_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArraySaturatingMul> for PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { + saturating_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArrayOverflowingMul> for PrimitiveArray +where + T: NativeArithmetics + OverflowingMul, +{ + fn overflowing_mul(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_mul(self, rhs) + } +} + +/// Multiply a scalar T to a primitive array of type T. +/// Panics if the multiplication of the values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::mul_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = mul_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(12), None, Some(12)]); +/// assert_eq!(result, expected) +/// ``` +pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + let rhs = *rhs; + unary(lhs, |a| a * rhs, lhs.data_type().clone()) +} + +/// Wrapping multiplication of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(0x10)]); +/// let result = wrapping_mul_scalar(&a, &0x10); +/// let expected = Int8Array::from(&[None, Some(0)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingMul, +{ + unary(lhs, |a| a.wrapping_mul(rhs), lhs.data_type().clone()) +} + +/// Checked multiplication of a scalar T to a primitive array of type T. If the +/// result from the multiplication overflows, then the validity for that index is +/// changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(100), None, Some(100)]); +/// let result = checked_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_mul(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated multiplication of a scalar T to a primitive array of type T. If the +/// result from the mul overflows for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = saturating_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-128i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_mul(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing multiplication of a scalar T to a primitive array of type T. If +/// the result from the mul overflows for this type, +/// then the result will be an array with overflowed values and a validity +/// array indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_mul_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_mul_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(100i8), Some(16i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_mul_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingMul, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_mul(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArrayMul trait for PrimitiveArrays with a scalar +impl ArrayMul for PrimitiveArray +where + T: NativeArithmetics + Mul, +{ + fn mul(&self, rhs: &T) -> Self { + mul_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays with a scalar +impl ArrayCheckedMul for PrimitiveArray +where + T: NativeArithmetics + CheckedMul, +{ + fn checked_mul(&self, rhs: &T) -> Self { + checked_mul_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar +impl ArraySaturatingMul for PrimitiveArray +where + T: NativeArithmetics + SaturatingMul, +{ + fn saturating_mul(&self, rhs: &T) -> Self { + saturating_mul_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays with a scalar +impl ArrayOverflowingMul for PrimitiveArray +where + T: NativeArithmetics + OverflowingMul, +{ + fn overflowing_mul(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_mul_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/pow.rs b/crates/nano-arrow/src/compute/arithmetics/basic/pow.rs new file mode 100644 index 000000000000..ea8908db6a51 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/pow.rs @@ -0,0 +1,49 @@ +//! Definition of basic pow operations with primitive arrays +use num_traits::{checked_pow, CheckedMul, One, Pow}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::compute::arity::{unary, unary_checked}; + +/// Raises an array of primitives to the power of exponent. Panics if one of +/// the values values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::powf_scalar; +/// use arrow2::array::Float32Array; +/// +/// let a = Float32Array::from(&[Some(2f32), None]); +/// let actual = powf_scalar(&a, 2.0); +/// let expected = Float32Array::from(&[Some(4f32), None]); +/// assert_eq!(expected, actual); +/// ``` +pub fn powf_scalar(array: &PrimitiveArray, exponent: T) -> PrimitiveArray +where + T: NativeArithmetics + Pow, +{ + unary(array, |x| x.pow(exponent), array.data_type().clone()) +} + +/// Checked operation of raising an array of primitives to the power of +/// exponent. If the result from the multiplications overflows, the validity +/// for that index is changed returned. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_powf_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), None, Some(7i8)]); +/// let actual = checked_powf_scalar(&a, 8usize); +/// let expected = Int8Array::from(&[Some(1i8), None, None]); +/// assert_eq!(expected, actual); +/// ``` +pub fn checked_powf_scalar(array: &PrimitiveArray, exponent: usize) -> PrimitiveArray +where + T: NativeArithmetics + CheckedMul + One, +{ + let op = move |a: T| checked_pow(a, exponent); + + unary_checked(array, op, array.data_type().clone()) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/rem.rs b/crates/nano-arrow/src/compute/arithmetics/basic/rem.rs new file mode 100644 index 000000000000..6c400fce2b07 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/rem.rs @@ -0,0 +1,196 @@ +use std::ops::Rem; + +use num_traits::{CheckedRem, NumCast}; +use strength_reduce::{ + StrengthReducedU16, StrengthReducedU32, StrengthReducedU64, StrengthReducedU8, +}; + +use super::NativeArithmetics; +use crate::array::{Array, PrimitiveArray}; +use crate::compute::arithmetics::{ArrayCheckedRem, ArrayRem}; +use crate::compute::arity::{binary, binary_checked, unary, unary_checked}; +use crate::datatypes::PrimitiveType; + +/// Remainder of two primitive arrays with the same type. +/// Panics if the divisor is zero of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::rem; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[Some(10), Some(7)]); +/// let b = Int32Array::from(&[Some(5), Some(6)]); +/// let result = rem(&a, &b); +/// let expected = Int32Array::from(&[Some(0), Some(1)]); +/// assert_eq!(result, expected) +/// ``` +pub fn rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Rem, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a % b) +} + +/// Checked remainder of two primitive arrays. If the result from the remainder +/// overflows, the result for the operation will change the validity array +/// making this operation None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_rem; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8), Some(10i8)]); +/// let b = Int8Array::from(&[Some(100i8), Some(0i8)]); +/// let result = checked_rem(&a, &b); +/// let expected = Int8Array::from(&[Some(-0i8), None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_rem(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + let op = move |a: T, b: T| a.checked_rem(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +impl ArrayRem> for PrimitiveArray +where + T: NativeArithmetics + Rem, +{ + fn rem(&self, rhs: &PrimitiveArray) -> Self { + rem(self, rhs) + } +} + +impl ArrayCheckedRem> for PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + fn checked_rem(&self, rhs: &PrimitiveArray) -> Self { + checked_rem(self, rhs) + } +} + +/// Remainder a primitive array of type T by a scalar T. +/// Panics if the divisor is zero. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::rem_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(7)]); +/// let result = rem_scalar(&a, &2i32); +/// let expected = Int32Array::from(&[None, Some(0), None, Some(1)]); +/// assert_eq!(result, expected) +/// ``` +pub fn rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Rem + NumCast, +{ + let rhs = *rhs; + + match T::PRIMITIVE { + PrimitiveType::UInt64 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u64().unwrap(); + + let reduced_rem = StrengthReducedU64::new(rhs); + + // small hack to avoid a transmute of `PrimitiveArray` to `PrimitiveArray` + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt32 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u32().unwrap(); + + let reduced_rem = StrengthReducedU32::new(rhs); + + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt16 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u16().unwrap(); + + let reduced_rem = StrengthReducedU16::new(rhs); + + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + PrimitiveType::UInt8 => { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let rhs = rhs.to_u8().unwrap(); + + let reduced_rem = StrengthReducedU8::new(rhs); + + let r = unary(lhs, |a| a % reduced_rem, lhs.data_type().clone()); + // small hack to avoid an unsafe transmute of `PrimitiveArray` to `PrimitiveArray` + (&r as &dyn Array) + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + }, + _ => unary(lhs, |a| a % rhs, lhs.data_type().clone()), + } +} + +/// Checked remainder of a primitive array of type T by a scalar T. If the +/// divisor is zero then the validity array is changed to None. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_rem_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = checked_rem_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(0i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_rem_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_rem(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +impl ArrayRem for PrimitiveArray +where + T: NativeArithmetics + Rem + NumCast, +{ + fn rem(&self, rhs: &T) -> Self { + rem_scalar(self, rhs) + } +} + +impl ArrayCheckedRem for PrimitiveArray +where + T: NativeArithmetics + CheckedRem, +{ + fn checked_rem(&self, rhs: &T) -> Self { + checked_rem_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/basic/sub.rs b/crates/nano-arrow/src/compute/arithmetics/basic/sub.rs new file mode 100644 index 000000000000..5b2dcd36cb25 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/basic/sub.rs @@ -0,0 +1,337 @@ +//! Definition of basic sub operations with primitive arrays +use std::ops::Sub; + +use num_traits::ops::overflowing::OverflowingSub; +use num_traits::{CheckedSub, SaturatingSub, WrappingSub}; + +use super::NativeArithmetics; +use crate::array::PrimitiveArray; +use crate::bitmap::Bitmap; +use crate::compute::arithmetics::{ + ArrayCheckedSub, ArrayOverflowingSub, ArraySaturatingSub, ArraySub, ArrayWrappingSub, +}; +use crate::compute::arity::{ + binary, binary_checked, binary_with_bitmap, unary, unary_checked, unary_with_bitmap, +}; + +/// Subtracts two primitive arrays with the same type. +/// Panics if the subtraction of one pair of values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::sub; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let b = Int32Array::from(&[Some(5), None, None, Some(6)]); +/// let result = sub(&a, &b); +/// let expected = Int32Array::from(&[None, None, None, Some(0)]); +/// assert_eq!(result, expected) +/// ``` +pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a - b) +} + +/// Wrapping subtraction of two [`PrimitiveArray`]s. +/// It wraps around at the boundary of the type if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_sub; +/// use arrow2::array::PrimitiveArray; +/// +/// let a = PrimitiveArray::from([Some(-100i8), Some(-100i8), Some(100i8)]); +/// let b = PrimitiveArray::from([Some(0i8), Some(100i8), Some(0i8)]); +/// let result = wrapping_sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(-100i8), Some(56i8), Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + let op = move |a: T, b: T| a.wrapping_sub(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked subtraction of two primitive arrays. If the result from the +/// subtraction overflow, the validity for that index is changed +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_sub; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(100i8), Some(-100i8), Some(100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8), Some(0i8)]); +/// let result = checked_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(99i8), None, Some(100i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + let op = move |a: T, b: T| a.checked_sub(&b); + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturating subtraction of two primitive arrays. If the result from the sub +/// is smaller than the possible number for this type, the result for the +/// operation will be the saturated value. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_sub; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let b = Int8Array::from(&[Some(100i8)]); +/// let result = saturating_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(-128)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + let op = move |a: T, b: T| a.saturating_sub(&b); + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Overflowing subtraction of two primitive arrays. If the result from the sub +/// is smaller than the possible number for this type, the result for the +/// operation will be an array with overflowed values and a validity array +/// indicating the overflowing elements from the array. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_sub; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let b = Int8Array::from(&[Some(1i8), Some(100i8)]); +/// let (result, overflow) = overflowing_sub(&a, &b); +/// let expected = Int8Array::from(&[Some(0i8), Some(56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingSub, +{ + let op = move |a: T, b: T| a.overflowing_sub(&b); + + binary_with_bitmap(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArraySub trait for PrimitiveArrays +impl ArraySub> for PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + fn sub(&self, rhs: &PrimitiveArray) -> Self { + sub(self, rhs) + } +} + +impl ArrayWrappingSub> for PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + fn wrapping_sub(&self, rhs: &PrimitiveArray) -> Self { + wrapping_sub(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays +impl ArrayCheckedSub> for PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { + checked_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArraySaturatingSub> for PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { + saturating_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArrayOverflowingSub> for PrimitiveArray +where + T: NativeArithmetics + OverflowingSub, +{ + fn overflowing_sub(&self, rhs: &PrimitiveArray) -> (Self, Bitmap) { + overflowing_sub(self, rhs) + } +} + +/// Subtract a scalar T to a primitive array of type T. +/// Panics if the subtraction of the values overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::sub_scalar; +/// use arrow2::array::Int32Array; +/// +/// let a = Int32Array::from(&[None, Some(6), None, Some(6)]); +/// let result = sub_scalar(&a, &1i32); +/// let expected = Int32Array::from(&[None, Some(5), None, Some(5)]); +/// assert_eq!(result, expected) +/// ``` +pub fn sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + let rhs = *rhs; + unary(lhs, |a| a - rhs, lhs.data_type().clone()) +} + +/// Wrapping subtraction of a scalar T to a [`PrimitiveArray`] of type T. +/// It do nothing if the result overflows. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::wrapping_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(-100)]); +/// let result = wrapping_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, Some(56)]); +/// assert_eq!(result, expected); +/// ``` +pub fn wrapping_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + WrappingSub, +{ + unary(lhs, |a| a.wrapping_sub(rhs), lhs.data_type().clone()) +} + +/// Checked subtraction of a scalar T to a primitive array of type T. If the +/// result from the subtraction overflows, then the validity for that index +/// is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::checked_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[None, Some(-100), None, Some(-100)]); +/// let result = checked_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[None, None, None, None]); +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + let rhs = *rhs; + let op = move |a: T| a.checked_sub(&rhs); + + unary_checked(lhs, op, lhs.data_type().clone()) +} + +/// Saturated subtraction of a scalar T to a primitive array of type T. If the +/// result from the sub is smaller than the possible number for this type, then +/// the result will be saturated +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::saturating_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(-100i8)]); +/// let result = saturating_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-128i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + let rhs = *rhs; + let op = move |a: T| a.saturating_sub(&rhs); + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Overflowing subtraction of a scalar T to a primitive array of type T. If +/// the result from the sub is smaller than the possible number for this type, +/// then the result will be an array with overflowed values and a validity +/// array indicating the overflowing elements from the array +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::basic::overflowing_sub_scalar; +/// use arrow2::array::Int8Array; +/// +/// let a = Int8Array::from(&[Some(1i8), Some(-100i8)]); +/// let (result, overflow) = overflowing_sub_scalar(&a, &100i8); +/// let expected = Int8Array::from(&[Some(-99i8), Some(56i8)]); +/// assert_eq!(result, expected); +/// ``` +pub fn overflowing_sub_scalar(lhs: &PrimitiveArray, rhs: &T) -> (PrimitiveArray, Bitmap) +where + T: NativeArithmetics + OverflowingSub, +{ + let rhs = *rhs; + let op = move |a: T| a.overflowing_sub(&rhs); + + unary_with_bitmap(lhs, op, lhs.data_type().clone()) +} + +// Implementation of ArraySub trait for PrimitiveArrays with a scalar +impl ArraySub for PrimitiveArray +where + T: NativeArithmetics + Sub, +{ + fn sub(&self, rhs: &T) -> Self { + sub_scalar(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays with a scalar +impl ArrayCheckedSub for PrimitiveArray +where + T: NativeArithmetics + CheckedSub, +{ + fn checked_sub(&self, rhs: &T) -> Self { + checked_sub_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar +impl ArraySaturatingSub for PrimitiveArray +where + T: NativeArithmetics + SaturatingSub, +{ + fn saturating_sub(&self, rhs: &T) -> Self { + saturating_sub_scalar(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays with a scalar +impl ArrayOverflowingSub for PrimitiveArray +where + T: NativeArithmetics + OverflowingSub, +{ + fn overflowing_sub(&self, rhs: &T) -> (Self, Bitmap) { + overflowing_sub_scalar(self, rhs) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/add.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/add.rs new file mode 100644 index 000000000000..dccdb6b144c1 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/add.rs @@ -0,0 +1,236 @@ +//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals. +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd}; +use crate::compute::arity::{binary, binary_checked}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Adds two decimal [`PrimitiveArray`] with the same precision and scale. +/// # Error +/// Errors if the precision and scale are different. +/// # Panic +/// This function panics iff the added numbers result in a number larger than +/// the possible number for the precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = add(&a, &b); +/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; + + assert!( + res.abs() <= max, + "Overflow in addition presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturated addition of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is larger than +/// the possible number with the selected precision then the resulted number in +/// the arrow array is the maximum number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_add(&a, &b); +/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let res: i128 = a + b; + + if res.abs() > max { + if res > 0 { + max + } else { + -max + } + } else { + res + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked addition of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_add(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_add(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + let op = move |a, b| { + let result: i128 = a + b; + + if result.abs() > max { + None + } else { + Some(result) + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayAdd trait for PrimitiveArrays +impl ArrayAdd> for PrimitiveArray { + fn add(&self, rhs: &PrimitiveArray) -> Self { + add(self, rhs) + } +} + +// Implementation of ArrayCheckedAdd trait for PrimitiveArrays +impl ArrayCheckedAdd> for PrimitiveArray { + fn checked_add(&self, rhs: &PrimitiveArray) -> Self { + checked_add(self, rhs) + } +} + +// Implementation of ArraySaturatingAdd trait for PrimitiveArrays +impl ArraySaturatingAdd> for PrimitiveArray { + fn saturating_add(&self, rhs: &PrimitiveArray) -> Self { + saturating_add(self, rhs) + } +} + +/// Adaptive addition of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// addition one of the results is larger than the max possible value, the +/// result precision is changed to the precision of the max value +/// +/// ```nocode +/// 11111.11 -> 7, 2 +/// 11111.111 -> 8, 3 +/// ------------------ +/// 22222.221 -> 8, 3 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_add; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2)); +/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3)); +/// let result = adaptive_add(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_add( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + l + r * shift + } else { + l * shift + r + }; + + // The precision of the resulting array will change if one of the + // sums during the iteration produces a value bigger than the + // possible value for the initial precision + + // 99.9999 -> 6, 4 + // 00.0001 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/div.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/div.rs new file mode 100644 index 000000000000..1576fc061947 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/div.rs @@ -0,0 +1,302 @@ +//! Defines the division arithmetic kernels for Decimal +//! `PrimitiveArrays`. + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedDiv, ArrayDiv}; +use crate::compute::arity::{binary, binary_checked, unary}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::scalar::{PrimitiveScalar, Scalar}; + +/// Divide two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the dividend is divided by 0 or None. +/// This function also panics if the division produces a number larger +/// than the possible number for the array precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = div(&a, &b); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + let op = move |a: i128, b: i128| { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + let numeral: i128 = a * scale; + + // The division can overflow if the dividend is divided + // by zero. + let res: i128 = numeral.checked_div(b).expect("Found division by zero"); + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +pub fn div_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128| { + // The division is done using the numbers without scale. + // The dividend is scaled up to maintain precision after the + // division + + // 222.222 --> 222222000 + // 123.456 --> 123456 + // -------- --------- + // 1.800 <-- 1800 + let numeral: i128 = a * scale; + + // The division can overflow if the dividend is divided + // by zero. + let res: i128 = numeral.checked_div(rhs).expect("Found division by zero"); + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Saturated division of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the division is +/// larger than the possible number with the selected precision then the +/// resulted number in the arrow array is the maximum number for the selected +/// precision. The function panics if divided by zero. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(000_01i128), Some(2_00i128), Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_div(&a, &b); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; + + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + }, + None => 0, + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked division of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the divisor is zero, then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(4_00i128), Some(6_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(000_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_div(&a, &b); +/// let expected = PrimitiveArray::from([None, None, Some(3_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_div(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + let numeral: i128 = a * scale; + + match numeral.checked_div(b) { + Some(res) => match res { + res if res.abs() > max => None, + _ => Some(res), + }, + None => None, + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayDiv trait for PrimitiveArrays +impl ArrayDiv> for PrimitiveArray { + fn div(&self, rhs: &PrimitiveArray) -> Self { + div(self, rhs) + } +} + +// Implementation of ArrayCheckedDiv trait for PrimitiveArrays +impl ArrayCheckedDiv> for PrimitiveArray { + fn checked_div(&self, rhs: &PrimitiveArray) -> Self { + checked_div(self, rhs) + } +} + +/// Adaptive division of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// division one of the results is larger than the max possible value, the +/// result precision is changed to the precision of the max value. The function +/// panics when divided by zero. +/// +/// ```nocode +/// 1000.00 -> 7, 2 +/// 10.0000 -> 6, 4 +/// ----------------- +/// 100.0000 -> 9, 4 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_div; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1000_00i128)]).to(DataType::Decimal(7, 2)); +/// let b = PrimitiveArray::from([Some(10_0000i128)]).to(DataType::Decimal(6, 4)); +/// let result = adaptive_div(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(9, 4)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_div( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + let numeral: i128 = l * shift_1; + + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + numeral.checked_div(r * shift) + } else { + (numeral * shift).checked_div(*r) + } + .expect("Found division by zero"); + + // The precision of the resulting array will change if one of the + // multiplications during the iteration produces a value bigger + // than the possible value for the initial precision + + // 10.0000 -> 6, 4 + // 00.1000 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/mod.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/mod.rs new file mode 100644 index 000000000000..4b412ef13c6e --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/mod.rs @@ -0,0 +1,119 @@ +//! Defines the arithmetic kernels for Decimal `PrimitiveArrays`. The +//! [`Decimal`](crate::datatypes::DataType::Decimal) type specifies the +//! precision and scale parameters. These affect the arithmetic operations and +//! need to be considered while doing operations with Decimal numbers. + +mod add; +pub use add::*; +mod div; +pub use div::*; +mod mul; +pub use mul::*; +mod sub; +pub use sub::*; + +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Maximum value that can exist with a selected precision +#[inline] +fn max_value(precision: usize) -> i128 { + 10i128.pow(precision as u32) - 1 +} + +// Calculates the number of digits in a i128 number +fn number_digits(num: i128) -> usize { + let mut num = num.abs(); + let mut digit: i128 = 0; + let base = 10i128; + + while num != 0 { + num /= base; + digit += 1; + } + + digit as usize +} + +fn get_parameters(lhs: &DataType, rhs: &DataType) -> Result<(usize, usize)> { + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.to_logical_type(), rhs.to_logical_type()) + { + if lhs_p == rhs_p && lhs_s == rhs_s { + Ok((*lhs_p, *lhs_s)) + } else { + Err(Error::InvalidArgumentError( + "Arrays must have the same precision and scale".to_string(), + )) + } + } else { + unreachable!() + } +} + +/// Returns the adjusted precision and scale for the lhs and rhs precision and +/// scale +fn adjusted_precision_scale( + lhs_p: usize, + lhs_s: usize, + rhs_p: usize, + rhs_s: usize, +) -> (usize, usize, usize) { + // The initial new precision and scale is based on the number of digits + // that lhs and rhs number has before and after the point. The max + // number of digits before and after the point will make the last + // precision and scale of the result + + // Digits before/after point + // before after + // 11.1111 -> 5, 4 -> 2 4 + // 11111.01 -> 7, 2 -> 5 2 + // ----------------- + // 11122.1211 -> 9, 4 -> 5 4 + let lhs_digits_before = lhs_p - lhs_s; + let rhs_digits_before = rhs_p - rhs_s; + + let res_digits_before = std::cmp::max(lhs_digits_before, rhs_digits_before); + + let (res_s, diff) = if lhs_s > rhs_s { + (lhs_s, lhs_s - rhs_s) + } else { + (rhs_s, rhs_s - lhs_s) + }; + + let res_p = res_digits_before + res_s; + + (res_p, res_s, diff) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_max_value() { + assert_eq!(999, max_value(3)); + assert_eq!(99999, max_value(5)); + assert_eq!(999999, max_value(6)); + } + + #[test] + fn test_number_digits() { + assert_eq!(2, number_digits(12i128)); + assert_eq!(3, number_digits(123i128)); + assert_eq!(4, number_digits(1234i128)); + assert_eq!(6, number_digits(123456i128)); + assert_eq!(7, number_digits(1234567i128)); + assert_eq!(7, number_digits(-1234567i128)); + assert_eq!(3, number_digits(-123i128)); + } + + #[test] + fn test_adjusted_precision_scale() { + // 11.1111 -> 5, 4 -> 2 4 + // 11111.01 -> 7, 2 -> 5 2 + // ----------------- + // 11122.1211 -> 9, 4 -> 5 4 + assert_eq!((9, 4, 2), adjusted_precision_scale(5, 4, 7, 2)) + } +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/mul.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/mul.rs new file mode 100644 index 000000000000..a944279a133e --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/mul.rs @@ -0,0 +1,314 @@ +//! Defines the multiplication arithmetic kernels for Decimal +//! `PrimitiveArrays`. + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedMul, ArrayMul, ArraySaturatingMul}; +use crate::compute::arity::{binary, binary_checked, unary}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::scalar::{PrimitiveScalar, Scalar}; + +/// Multiply two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1_00i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(1_00i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| { + // The multiplication between i128 can overflow if they are + // very large numbers. For that reason a checked + // multiplication is used. + let res: i128 = a.checked_mul(b).expect("Mayor overflow for multiplication"); + + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + let res = res / scale; + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Multiply a decimal [`PrimitiveArray`] with a [`PrimitiveScalar`] with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the multiplied numbers result in a number +/// larger than the possible number for the selected precision. +pub fn mul_scalar(lhs: &PrimitiveArray, rhs: &PrimitiveScalar) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128| { + // The multiplication between i128 can overflow if they are + // very large numbers. For that reason a checked + // multiplication is used. + let res: i128 = a + .checked_mul(rhs) + .expect("Mayor overflow for multiplication"); + + // The multiplication is done using the numbers without scale. + // The resulting scale of the value has to be corrected by + // dividing by (10^scale) + + // 111.111 --> 111111 + // 222.222 --> 222222 + // -------- ------- + // 24691.308 <-- 24691308642 + let res = res / scale; + + assert!( + res.abs() <= max, + "Overflow in multiplication presented for precision {precision}" + ); + + res + }; + + unary(lhs, op, lhs.data_type().clone()) +} + +/// Saturated multiplication of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the multiplication is +/// larger than the possible number with the selected precision then the +/// resulted number in the arrow array is the maximum number for the selected +/// precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_mul(&a, &b); +/// let expected = PrimitiveArray::from([Some(999_99i128), Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + } + }, + None => max, + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Checked multiplication of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the mul is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(999_99i128), Some(1_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(10_00i128), Some(2_00i128), None, Some(2_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_mul(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(2_00i128), None, Some(4_00i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let scale = 10i128.pow(scale as u32); + let max = max_value(precision); + + let op = move |a: i128, b: i128| match a.checked_mul(b) { + Some(res) => { + let res = res / scale; + + match res { + res if res.abs() > max => None, + _ => Some(res), + } + }, + None => None, + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArrayMul trait for PrimitiveArrays +impl ArrayMul> for PrimitiveArray { + fn mul(&self, rhs: &PrimitiveArray) -> Self { + mul(self, rhs) + } +} + +// Implementation of ArrayCheckedMul trait for PrimitiveArrays +impl ArrayCheckedMul> for PrimitiveArray { + fn checked_mul(&self, rhs: &PrimitiveArray) -> Self { + checked_mul(self, rhs) + } +} + +// Implementation of ArraySaturatingMul trait for PrimitiveArrays +impl ArraySaturatingMul> for PrimitiveArray { + fn saturating_mul(&self, rhs: &PrimitiveArray) -> Self { + saturating_mul(self, rhs) + } +} + +/// Adaptive multiplication of two decimal primitive arrays with different +/// precision and scale. If the precision and scale is different, then the +/// smallest scale and precision is adjusted to the largest precision and +/// scale. If during the multiplication one of the results is larger than the +/// max possible value, the result precision is changed to the precision of the +/// max value +/// +/// ```nocode +/// 11111.0 -> 6, 1 +/// 10.002 -> 5, 3 +/// ----------------- +/// 111132.222 -> 9, 3 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_mul; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(11111_0i128), Some(1_0i128)]).to(DataType::Decimal(6, 1)); +/// let b = PrimitiveArray::from([Some(10_002i128), Some(2_000i128)]).to(DataType::Decimal(5, 3)); +/// let result = adaptive_mul(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(111132_222i128), Some(2_000i128)]).to(DataType::Decimal(9, 3)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_mul( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let shift_1 = 10i128.pow(res_s as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res = if lhs_s > rhs_s { + l.checked_mul(r * shift) + } else { + (l * shift).checked_mul(*r) + } + .expect("Mayor overflow for multiplication"); + + let res = res / shift_1; + + // The precision of the resulting array will change if one of the + // multiplications during the iteration produces a value bigger + // than the possible value for the initial precision + + // 10.0000 -> 6, 4 + // 10.0000 -> 6, 4 + // ----------------- + // 100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/decimal/sub.rs b/crates/nano-arrow/src/compute/arithmetics/decimal/sub.rs new file mode 100644 index 000000000000..2a0f7a72da17 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/decimal/sub.rs @@ -0,0 +1,238 @@ +//! Defines the subtract arithmetic kernels for Decimal `PrimitiveArrays`. + +use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; +use crate::array::PrimitiveArray; +use crate::compute::arithmetics::{ArrayCheckedSub, ArraySaturatingSub, ArraySub}; +use crate::compute::arity::{binary, binary_checked}; +use crate::compute::utils::{check_same_len, combine_validities}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +/// Subtract two decimal primitive arrays with the same precision and scale. If +/// the precision and scale is different, then an InvalidArgumentError is +/// returned. This function panics if the subtracted numbers result in a number +/// smaller than the possible number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(0i128), Some(-1i128), None, Some(0i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + assert!( + res.abs() <= max, + "Overflow in subtract presented for precision {precision}" + ); + + res + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Saturated subtraction of two decimal primitive arrays with the same +/// precision and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sum is smaller +/// than the possible number with the selected precision then the resulted +/// number in the arrow array is the minimum number for the selected precision. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::saturating_sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = saturating_sub(&a, &b); +/// let expected = PrimitiveArray::from([Some(-99999i128), Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn saturating_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => { + if res > 0 { + max + } else { + -max + } + }, + _ => res, + } + }; + + binary(lhs, rhs, lhs.data_type().clone(), op) +} + +// Implementation of ArraySub trait for PrimitiveArrays +impl ArraySub> for PrimitiveArray { + fn sub(&self, rhs: &PrimitiveArray) -> Self { + sub(self, rhs) + } +} + +// Implementation of ArrayCheckedSub trait for PrimitiveArrays +impl ArrayCheckedSub> for PrimitiveArray { + fn checked_sub(&self, rhs: &PrimitiveArray) -> Self { + checked_sub(self, rhs) + } +} + +// Implementation of ArraySaturatingSub trait for PrimitiveArrays +impl ArraySaturatingSub> for PrimitiveArray { + fn saturating_sub(&self, rhs: &PrimitiveArray) -> Self { + saturating_sub(self, rhs) + } +} +/// Checked subtract of two decimal primitive arrays with the same precision +/// and scale. If the precision and scale is different, then an +/// InvalidArgumentError is returned. If the result from the sub is larger than +/// the possible number with the selected precision (overflowing), then the +/// validity for that index is changed to None +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::checked_sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(-99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2)); +/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// let result = checked_sub(&a, &b); +/// let expected = PrimitiveArray::from([None, Some(-11100i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn checked_sub(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { + let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); + + let max = max_value(precision); + + let op = move |a, b| { + let res: i128 = a - b; + + match res { + res if res.abs() > max => None, + _ => Some(res), + } + }; + + binary_checked(lhs, rhs, lhs.data_type().clone(), op) +} + +/// Adaptive subtract of two decimal primitive arrays with different precision +/// and scale. If the precision and scale is different, then the smallest scale +/// and precision is adjusted to the largest precision and scale. If during the +/// addition one of the results is smaller than the min possible value, the +/// result precision is changed to the precision of the min value +/// +/// ```nocode +/// 99.9999 -> 6, 4 +/// -00.0001 -> 6, 4 +/// ----------------- +/// 100.0000 -> 7, 4 +/// ``` +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::decimal::adaptive_sub; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::DataType; +/// +/// let a = PrimitiveArray::from([Some(99_9999i128)]).to(DataType::Decimal(6, 4)); +/// let b = PrimitiveArray::from([Some(-00_0001i128)]).to(DataType::Decimal(6, 4)); +/// let result = adaptive_sub(&a, &b).unwrap(); +/// let expected = PrimitiveArray::from([Some(100_0000i128)]).to(DataType::Decimal(7, 4)); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn adaptive_sub( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + check_same_len(lhs, rhs)?; + + let (lhs_p, lhs_s, rhs_p, rhs_s) = + if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) = + (lhs.data_type(), rhs.data_type()) + { + (*lhs_p, *lhs_s, *rhs_p, *rhs_s) + } else { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the array".to_string(), + )); + }; + + // The resulting precision is mutable because it could change while + // looping through the iterator + let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s); + + let shift = 10i128.pow(diff as u32); + let mut max = max_value(res_p); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + // Based on the array's scales one of the arguments in the sum has to be shifted + // to the left to match the final scale + let res: i128 = if lhs_s > rhs_s { + l - r * shift + } else { + l * shift - r + }; + + // The precision of the resulting array will change if one of the + // subtraction during the iteration produces a value bigger than the + // possible value for the initial precision + + // -99.9999 -> 6, 4 + // 00.0001 -> 6, 4 + // ----------------- + // -100.0000 -> 7, 4 + if res.abs() > max { + res_p = number_digits(res); + max = max_value(res_p); + } + + res + }) + .collect::>(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + Ok(PrimitiveArray::::new( + DataType::Decimal(res_p, res_s), + values.into(), + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/arithmetics/mod.rs b/crates/nano-arrow/src/compute/arithmetics/mod.rs new file mode 100644 index 000000000000..1d520e9ad644 --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/mod.rs @@ -0,0 +1,581 @@ +//! Defines basic arithmetic kernels for [`PrimitiveArray`](crate::array::PrimitiveArray)s. +//! +//! The Arithmetics module is composed by basic arithmetics operations that can +//! be performed on [`PrimitiveArray`](crate::array::PrimitiveArray). +//! +//! Whenever possible, each operation declares variations +//! of the basic operation that offers different guarantees: +//! * plain: panics on overflowing and underflowing. +//! * checked: turns an overflowing to a null. +//! * saturating: turns the overflowing to the MAX or MIN value respectively. +//! * overflowing: returns an extra [`Bitmap`] denoting whether the operation overflowed. +//! * adaptive: for [`Decimal`](crate::datatypes::DataType::Decimal) only, +//! adjusts the precision and scale to make the resulting value fit. +#[forbid(unsafe_code)] +pub mod basic; +#[cfg(feature = "compute_arithmetics_decimal")] +pub mod decimal; +pub mod time; + +use crate::array::{Array, DictionaryArray, PrimitiveArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::{DataType, IntervalUnit, TimeUnit}; +use crate::scalar::{PrimitiveScalar, Scalar}; +use crate::types::NativeType; + +fn binary_dyn, &PrimitiveArray) -> PrimitiveArray>( + lhs: &dyn Array, + rhs: &dyn Array, + op: F, +) -> Box { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + op(lhs, rhs).boxed() +} + +// Macro to create a `match` statement with dynamic dispatch to functions based on +// the array's logical types +macro_rules! arith { + ($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{ + let lhs = $lhs; + let rhs = $rhs; + use DataType::*; + match (lhs.data_type(), rhs.data_type()) { + (Int8, Int8) => binary_dyn::(lhs, rhs, basic::$op), + (Int16, Int16) => binary_dyn::(lhs, rhs, basic::$op), + (Int32, Int32) => binary_dyn::(lhs, rhs, basic::$op), + (Int64, Int64) | (Duration(_), Duration(_)) => { + binary_dyn::(lhs, rhs, basic::$op) + } + (UInt8, UInt8) => binary_dyn::(lhs, rhs, basic::$op), + (UInt16, UInt16) => binary_dyn::(lhs, rhs, basic::$op), + (UInt32, UInt32) => binary_dyn::(lhs, rhs, basic::$op), + (UInt64, UInt64) => binary_dyn::(lhs, rhs, basic::$op), + (Float32, Float32) => binary_dyn::(lhs, rhs, basic::$op), + (Float64, Float64) => binary_dyn::(lhs, rhs, basic::$op), + $ ( + (Decimal(_, _), Decimal(_, _)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(decimal::$op_decimal(lhs, rhs)) as Box + } + )? + $ ( + (Time32(TimeUnit::Second), Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Date32, Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Date64, Duration(_)) + | (Timestamp(_, _), Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + Box::new(time::$op_duration::(lhs, rhs)) as Box + } + )? + $ ( + (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + $ ( + (Timestamp(_, None), Timestamp(_, None)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + } + )? + _ => todo!( + "Addition of {:?} with {:?} is not supported", + lhs.data_type(), + rhs.data_type() + ), + } + }}; +} + +fn binary_scalar, &T) -> PrimitiveArray>( + lhs: &PrimitiveArray, + rhs: &PrimitiveScalar, + op: F, +) -> PrimitiveArray { + let rhs = if let Some(rhs) = *rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + op(lhs, &rhs) +} + +fn binary_scalar_dyn, &T) -> PrimitiveArray>( + lhs: &dyn Array, + rhs: &dyn Scalar, + op: F, +) -> Box { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary_scalar(lhs, rhs, op).boxed() +} + +// Macro to create a `match` statement with dynamic dispatch to functions based on +// the array's logical types +macro_rules! arith_scalar { + ($lhs:expr, $rhs:expr, $op:tt $(, decimal = $op_decimal:tt )? $(, duration = $op_duration:tt )? $(, interval = $op_interval:tt )? $(, timestamp = $op_timestamp:tt )?) => {{ + let lhs = $lhs; + let rhs = $rhs; + use DataType::*; + match (lhs.data_type(), rhs.data_type()) { + (Int8, Int8) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int16, Int16) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int32, Int32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int64, Int64) | (Duration(_), Duration(_)) => { + binary_scalar_dyn::(lhs, rhs, basic::$op) + } + (UInt8, UInt8) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt16, UInt16) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt32, UInt32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt64, UInt64) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Float32, Float32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Float64, Float64) => binary_scalar_dyn::(lhs, rhs, basic::$op), + $ ( + (Decimal(_, _), Decimal(_, _)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + decimal::$op_decimal(lhs, rhs).boxed() + } + )? + $ ( + (Time32(TimeUnit::Second), Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Date32, Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_duration::(lhs, rhs).boxed() + } + (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Date64, Duration(_)) + | (Timestamp(_, _), Duration(_)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_duration::(lhs, rhs).boxed() + } + )? + $ ( + (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_interval(lhs, rhs).unwrap().boxed() + } + )? + $ ( + (Timestamp(_, None), Timestamp(_, None)) => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + time::$op_timestamp(lhs, rhs).unwrap().boxed() + } + )? + _ => todo!( + "Addition of {:?} with {:?} is not supported", + lhs.data_type(), + rhs.data_type() + ), + } + }}; +} + +/// Adds two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_add`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!( + lhs, + rhs, + add, + duration = add_duration, + interval = add_interval + ) +} + +/// Adds an [`Array`] and a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_add`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn add_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!( + lhs, + rhs, + add_scalar, + duration = add_duration_scalar, + interval = add_interval_scalar + ) +} + +/// Returns whether two [`DataType`]s can be added by [`add`]. +pub fn can_add(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Duration(_), Duration(_)) + | (Decimal(_, _), Decimal(_, _)) + | (Date32, Duration(_)) + | (Date64, Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Time32(TimeUnit::Second), Duration(_)) + | (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Timestamp(_, _), Duration(_)) + | (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) + ) +} + +/// Subtracts two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_sub`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn sub(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!( + lhs, + rhs, + sub, + decimal = sub, + duration = subtract_duration, + timestamp = subtract_timestamps + ) +} + +/// Adds an [`Array`] and a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_sub`] to check) +/// * the arrays have a different length +/// * one of the arrays is a timestamp with timezone and the timezone is not valid. +pub fn sub_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!( + lhs, + rhs, + sub_scalar, + duration = sub_duration_scalar, + timestamp = sub_timestamps_scalar + ) +} + +/// Returns whether two [`DataType`]s can be subtracted by [`sub`]. +pub fn can_sub(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Duration(_), Duration(_)) + | (Decimal(_, _), Decimal(_, _)) + | (Date32, Duration(_)) + | (Date64, Duration(_)) + | (Time32(TimeUnit::Millisecond), Duration(_)) + | (Time32(TimeUnit::Second), Duration(_)) + | (Time64(TimeUnit::Microsecond), Duration(_)) + | (Time64(TimeUnit::Nanosecond), Duration(_)) + | (Timestamp(_, _), Duration(_)) + | (Timestamp(_, None), Timestamp(_, None)) + ) +} + +/// Multiply two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_mul`] to check) +/// * the arrays have a different length +pub fn mul(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, mul, decimal = mul) +} + +/// Multiply an [`Array`] with a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_mul`] to check) +pub fn mul_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!(lhs, rhs, mul_scalar, decimal = mul_scalar) +} + +/// Returns whether two [`DataType`]s can be multiplied by [`mul`]. +pub fn can_mul(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + | (Decimal(_, _), Decimal(_, _)) + ) +} + +/// Divide of two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_div`] to check) +/// * the arrays have a different length +pub fn div(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, div, decimal = div) +} + +/// Divide an [`Array`] with a [`Scalar`]. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_div`] to check) +pub fn div_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> Box { + arith_scalar!(lhs, rhs, div_scalar, decimal = div_scalar) +} + +/// Returns whether two [`DataType`]s can be divided by [`div`]. +pub fn can_div(lhs: &DataType, rhs: &DataType) -> bool { + can_mul(lhs, rhs) +} + +/// Remainder of two [`Array`]s. +/// # Panic +/// This function panics iff +/// * the operation is not supported for the logical types (use [`can_rem`] to check) +/// * the arrays have a different length +pub fn rem(lhs: &dyn Array, rhs: &dyn Array) -> Box { + arith!(lhs, rhs, rem) +} + +/// Returns whether two [`DataType`]s "can be remainder" by [`rem`]. +pub fn can_rem(lhs: &DataType, rhs: &DataType) -> bool { + use DataType::*; + matches!( + (lhs, rhs), + (Int8, Int8) + | (Int16, Int16) + | (Int32, Int32) + | (Int64, Int64) + | (UInt8, UInt8) + | (UInt16, UInt16) + | (UInt32, UInt32) + | (UInt64, UInt64) + | (Float64, Float64) + | (Float32, Float32) + ) +} + +macro_rules! with_match_negatable {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns, i256}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 | UInt16 | UInt32 | UInt64 | Float16 => todo!(), + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +/// Negates an [`Array`]. +/// # Panic +/// This function panics iff either +/// * the operation is not supported for the logical type (use [`can_neg`] to check) +/// * the operation overflows +pub fn neg(array: &dyn Array) -> Box { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Primitive(primitive) => with_match_negatable!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + + let result = basic::negate::<$T>(array); + Box::new(result) as Box + }), + Dictionary(key) => match_integer_type!(key, |$T| { + let array = array.as_any().downcast_ref::>().unwrap(); + + let values = neg(array.values().as_ref()); + + // safety - this operation only applies to values and thus preserves the dictionary's invariant + unsafe{ + DictionaryArray::<$T>::try_new_unchecked(array.data_type().clone(), array.keys().clone(), values).unwrap().boxed() + } + }), + _ => todo!(), + } +} + +/// Whether [`neg`] is supported for a given [`DataType`] +pub fn can_neg(data_type: &DataType) -> bool { + if let DataType::Dictionary(_, values, _) = data_type.to_logical_type() { + return can_neg(values.as_ref()); + } + + use crate::datatypes::PhysicalType::*; + use crate::datatypes::PrimitiveType::*; + matches!( + data_type.to_physical_type(), + Primitive(Int8) + | Primitive(Int16) + | Primitive(Int32) + | Primitive(Int64) + | Primitive(Float64) + | Primitive(Float32) + | Primitive(DaysMs) + | Primitive(MonthDayNano) + ) +} + +/// Defines basic addition operation for primitive arrays +pub trait ArrayAdd: Sized { + /// Adds itself to `rhs` + fn add(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping addition operation for primitive arrays +pub trait ArrayWrappingAdd: Sized { + /// Adds itself to `rhs` using wrapping addition + fn wrapping_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked addition operation for primitive arrays +pub trait ArrayCheckedAdd: Sized { + /// Checked add + fn checked_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating addition operation for primitive arrays +pub trait ArraySaturatingAdd: Sized { + /// Saturating add + fn saturating_add(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing addition operation for primitive arrays +pub trait ArrayOverflowingAdd: Sized { + /// Overflowing add + fn overflowing_add(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic subtraction operation for primitive arrays +pub trait ArraySub: Sized { + /// subtraction + fn sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping subtraction operation for primitive arrays +pub trait ArrayWrappingSub: Sized { + /// wrapping subtraction + fn wrapping_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked subtraction operation for primitive arrays +pub trait ArrayCheckedSub: Sized { + /// checked subtraction + fn checked_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating subtraction operation for primitive arrays +pub trait ArraySaturatingSub: Sized { + /// saturarting subtraction + fn saturating_sub(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing subtraction operation for primitive arrays +pub trait ArrayOverflowingSub: Sized { + /// overflowing subtraction + fn overflowing_sub(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic multiplication operation for primitive arrays +pub trait ArrayMul: Sized { + /// multiplication + fn mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines wrapping multiplication operation for primitive arrays +pub trait ArrayWrappingMul: Sized { + /// wrapping multiplication + fn wrapping_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked multiplication operation for primitive arrays +pub trait ArrayCheckedMul: Sized { + /// checked multiplication + fn checked_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines saturating multiplication operation for primitive arrays +pub trait ArraySaturatingMul: Sized { + /// saturating multiplication + fn saturating_mul(&self, rhs: &Rhs) -> Self; +} + +/// Defines Overflowing multiplication operation for primitive arrays +pub trait ArrayOverflowingMul: Sized { + /// overflowing multiplication + fn overflowing_mul(&self, rhs: &Rhs) -> (Self, Bitmap); +} + +/// Defines basic division operation for primitive arrays +pub trait ArrayDiv: Sized { + /// division + fn div(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked division operation for primitive arrays +pub trait ArrayCheckedDiv: Sized { + /// checked division + fn checked_div(&self, rhs: &Rhs) -> Self; +} + +/// Defines basic reminder operation for primitive arrays +pub trait ArrayRem: Sized { + /// remainder + fn rem(&self, rhs: &Rhs) -> Self; +} + +/// Defines checked reminder operation for primitive arrays +pub trait ArrayCheckedRem: Sized { + /// checked remainder + fn checked_rem(&self, rhs: &Rhs) -> Self; +} diff --git a/crates/nano-arrow/src/compute/arithmetics/time.rs b/crates/nano-arrow/src/compute/arithmetics/time.rs new file mode 100644 index 000000000000..aa2e25e3ab0f --- /dev/null +++ b/crates/nano-arrow/src/compute/arithmetics/time.rs @@ -0,0 +1,432 @@ +//! Defines the arithmetic kernels for adding a Duration to a Timestamp, +//! Time32, Time64, Date32 and Date64. +//! +//! For the purposes of Arrow Implementations, adding this value to a Timestamp +//! ("t1") naively (i.e. simply summing the two number) is acceptable even +//! though in some cases the resulting Timestamp (t2) would not account for +//! leap-seconds during the elapsed time between "t1" and "t2". Similarly, +//! representing the difference between two Unix timestamp is acceptable, but +//! would yield a value that is possibly a few seconds off from the true +//! elapsed time. + +use std::ops::{Add, Sub}; + +use num_traits::AsPrimitive; + +use crate::array::PrimitiveArray; +use crate::compute::arity::{binary, unary}; +use crate::datatypes::{DataType, TimeUnit}; +use crate::error::{Error, Result}; +use crate::scalar::{PrimitiveScalar, Scalar}; +use crate::temporal_conversions; +use crate::types::{months_days_ns, NativeType}; + +/// Creates the scale required to add or subtract a Duration to a time array +/// (Timestamp, Time, or Date). The resulting scale always multiplies the rhs +/// number (Duration) so it can be added to the lhs number (time array). +fn create_scale(lhs: &DataType, rhs: &DataType) -> Result { + // Matching on both data types from both numbers to calculate the correct + // scale for the operation. The timestamp, Time and duration have a + // Timeunit enum in its data type. This enum is used to describe the + // addition of the duration. The Date32 and Date64 have different rules for + // the scaling. + let scale = match (lhs, rhs) { + (DataType::Timestamp(timeunit_a, _), DataType::Duration(timeunit_b)) + | (DataType::Time32(timeunit_a), DataType::Duration(timeunit_b)) + | (DataType::Time64(timeunit_a), DataType::Duration(timeunit_b)) => { + // The scale is based on the TimeUnit that each of the numbers have. + temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b) + }, + (DataType::Date32, DataType::Duration(timeunit)) => { + // Date32 represents the time elapsed time since UNIX epoch + // (1970-01-01) in days (32 bits). The duration value has to be + // scaled to days to be able to add the value to the Date. + temporal_conversions::timeunit_scale(TimeUnit::Second, *timeunit) + / temporal_conversions::SECONDS_IN_DAY as f64 + }, + (DataType::Date64, DataType::Duration(timeunit)) => { + // Date64 represents the time elapsed time since UNIX epoch + // (1970-01-01) in milliseconds (64 bits). The duration value has + // to be scaled to milliseconds to be able to add the value to the + // Date. + temporal_conversions::timeunit_scale(TimeUnit::Millisecond, *timeunit) + }, + _ => { + return Err(Error::InvalidArgumentError( + "Incorrect data type for the arguments".to_string(), + )); + }, + }; + + Ok(scale) +} + +/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::time::add_duration; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// +/// let timestamp = PrimitiveArray::from([ +/// Some(100000i64), +/// Some(200000i64), +/// None, +/// Some(300000i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = add_duration(×tamp, &duration); +/// let expected = PrimitiveArray::from([ +/// Some(100010i64), +/// Some(200020i64), +/// None, +/// Some(300030i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// assert_eq!(result, expected); +/// ``` +pub fn add_duration( + time: &PrimitiveArray, + duration: &PrimitiveArray, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Add, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T, b: i64| a + (b as f64 * scale).as_(); + + binary(time, duration, time.data_type().clone(), op) +} + +/// Adds a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +pub fn add_duration_scalar( + time: &PrimitiveArray, + duration: &PrimitiveScalar, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Add, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + let duration = if let Some(duration) = *duration.value() { + duration + } else { + return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); + }; + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T| a + (duration as f64 * scale).as_(); + + unary(time, op, time.data_type().clone()) +} + +/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::time::subtract_duration; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// +/// let timestamp = PrimitiveArray::from([ +/// Some(100000i64), +/// Some(200000i64), +/// None, +/// Some(300000i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// let duration = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = subtract_duration(×tamp, &duration); +/// let expected = PrimitiveArray::from([ +/// Some(99990i64), +/// Some(199980i64), +/// None, +/// Some(299970i64), +/// ]) +/// .to(DataType::Timestamp( +/// TimeUnit::Second, +/// Some("America/New_York".to_string()), +/// )); +/// +/// assert_eq!(result, expected); +/// +/// ``` +pub fn subtract_duration( + time: &PrimitiveArray, + duration: &PrimitiveArray, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Sub, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + + // Closure for the binary operation. The closure contains the scale + // required to add a duration to the timestamp array. + let op = move |a: T, b: i64| a - (b as f64 * scale).as_(); + + binary(time, duration, time.data_type().clone(), op) +} + +/// Subtract a duration to a time array (Timestamp, Time and Date). The timeunit +/// enum is used to scale correctly both arrays; adding seconds with seconds, +/// or milliseconds with milliseconds. +pub fn sub_duration_scalar( + time: &PrimitiveArray, + duration: &PrimitiveScalar, +) -> PrimitiveArray +where + f64: AsPrimitive, + T: NativeType + Sub, +{ + let scale = create_scale(time.data_type(), duration.data_type()).unwrap(); + let duration = if let Some(duration) = *duration.value() { + duration + } else { + return PrimitiveArray::::new_null(time.data_type().clone(), time.len()); + }; + + let op = move |a: T| a - (duration as f64 * scale).as_(); + + unary(time, op, time.data_type().clone()) +} + +/// Calculates the difference between two timestamps returning an array of type +/// Duration. The timeunit enum is used to scale correctly both arrays; +/// subtracting seconds with seconds, or milliseconds with milliseconds. +/// +/// # Examples +/// ``` +/// use arrow2::compute::arithmetics::time::subtract_timestamps; +/// use arrow2::array::PrimitiveArray; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// let timestamp_a = PrimitiveArray::from([ +/// Some(100_010i64), +/// Some(200_020i64), +/// None, +/// Some(300_030i64), +/// ]) +/// .to(DataType::Timestamp(TimeUnit::Second, None)); +/// +/// let timestamp_b = PrimitiveArray::from([ +/// Some(100_000i64), +/// Some(200_000i64), +/// None, +/// Some(300_000i64), +/// ]) +/// .to(DataType::Timestamp(TimeUnit::Second, None)); +/// +/// let expected = PrimitiveArray::from([Some(10i64), Some(20i64), None, Some(30i64)]) +/// .to(DataType::Duration(TimeUnit::Second)); +/// +/// let result = subtract_timestamps(×tamp_a, &×tamp_b).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn subtract_timestamps( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, +) -> Result> { + // Matching on both data types from both arrays. + // Both timestamps have a Timeunit enum in its data type. + // This enum is used to adjust the scale between the timestamps. + match (lhs.data_type(), rhs.data_type()) { + // Naive timestamp comparison. It doesn't take into account timezones + // from the Timestamp timeunit. + (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) => { + // Closure for the binary operation. The closure contains the scale + // required to calculate the difference between the timestamps. + let scale = temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b); + let op = move |a, b| a - (b as f64 * scale) as i64; + + Ok(binary(lhs, rhs, DataType::Duration(*timeunit_a), op)) + }, + _ => Err(Error::InvalidArgumentError( + "Incorrect data type for the arguments".to_string(), + )), + } +} + +/// Calculates the difference between two timestamps as [`DataType::Duration`] with the same time scale. +pub fn sub_timestamps_scalar( + lhs: &PrimitiveArray, + rhs: &PrimitiveScalar, +) -> Result> { + let (scale, timeunit_a) = + if let (DataType::Timestamp(timeunit_a, None), DataType::Timestamp(timeunit_b, None)) = + (lhs.data_type(), rhs.data_type()) + { + ( + temporal_conversions::timeunit_scale(*timeunit_a, *timeunit_b), + timeunit_a, + ) + } else { + return Err(Error::InvalidArgumentError( + "sub_timestamps_scalar requires both arguments to be timestamps without timezone" + .to_string(), + )); + }; + + let rhs = if let Some(value) = *rhs.value() { + value + } else { + return Ok(PrimitiveArray::::new_null( + lhs.data_type().clone(), + lhs.len(), + )); + }; + + let op = move |a| a - (rhs as f64 * scale) as i64; + + Ok(unary(lhs, op, DataType::Duration(*timeunit_a))) +} + +/// Adds an interval to a [`DataType::Timestamp`]. +pub fn add_interval( + timestamp: &PrimitiveArray, + interval: &PrimitiveArray, +) -> Result> { + match timestamp.data_type().to_logical_type() { + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let time_unit = *time_unit; + let timezone = temporal_conversions::parse_offset(timezone_str); + match timezone { + Ok(timezone) => Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + )), + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; + Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + )) + }, + #[cfg(not(feature = "chrono-tz"))] + _ => Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))), + } + }, + DataType::Timestamp(time_unit, None) => { + let time_unit = *time_unit; + Ok(binary( + timestamp, + interval, + timestamp.data_type().clone(), + |timestamp, interval| { + temporal_conversions::add_naive_interval(timestamp, time_unit, interval) + }, + )) + }, + _ => Err(Error::InvalidArgumentError( + "Adding an interval is only supported for `DataType::Timestamp`".to_string(), + )), + } +} + +/// Adds an interval to a [`DataType::Timestamp`]. +pub fn add_interval_scalar( + timestamp: &PrimitiveArray, + interval: &PrimitiveScalar, +) -> Result> { + let interval = if let Some(interval) = *interval.value() { + interval + } else { + return Ok(PrimitiveArray::::new_null( + timestamp.data_type().clone(), + timestamp.len(), + )); + }; + + match timestamp.data_type().to_logical_type() { + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let time_unit = *time_unit; + let timezone = temporal_conversions::parse_offset(timezone_str); + match timezone { + Ok(timezone) => Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + timestamp.data_type().clone(), + )), + #[cfg(feature = "chrono-tz")] + Err(_) => { + let timezone = temporal_conversions::parse_offset_tz(timezone_str)?; + Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_interval( + timestamp, time_unit, interval, &timezone, + ) + }, + timestamp.data_type().clone(), + )) + }, + #[cfg(not(feature = "chrono-tz"))] + _ => Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))), + } + }, + DataType::Timestamp(time_unit, None) => { + let time_unit = *time_unit; + Ok(unary( + timestamp, + |timestamp| { + temporal_conversions::add_naive_interval(timestamp, time_unit, interval) + }, + timestamp.data_type().clone(), + )) + }, + _ => Err(Error::InvalidArgumentError( + "Adding an interval is only supported for `DataType::Timestamp`".to_string(), + )), + } +} diff --git a/crates/nano-arrow/src/compute/arity.rs b/crates/nano-arrow/src/compute/arity.rs new file mode 100644 index 000000000000..935970ccdf75 --- /dev/null +++ b/crates/nano-arrow/src/compute/arity.rs @@ -0,0 +1,279 @@ +//! Defines kernels suitable to perform operations to primitive arrays. + +use super::utils::{check_same_len, combine_validities}; +use crate::array::PrimitiveArray; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +/// Applies an unary and infallible function to a [`PrimitiveArray`]. This is the +/// fastest way to perform an operation on a [`PrimitiveArray`] when the benefits +/// of a vectorized operation outweighs the cost of branching nulls and +/// non-nulls. +/// +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infallible for any value of the +/// corresponding type or this function may panic. +#[inline] +pub fn unary(array: &PrimitiveArray, op: F, data_type: DataType) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> O, +{ + let values = array.values().iter().map(|v| op(*v)).collect::>(); + + PrimitiveArray::::new(data_type, values.into(), array.validity().cloned()) +} + +/// Version of unary that checks for errors in the closure used to create the +/// buffer +pub fn try_unary( + array: &PrimitiveArray, + op: F, + data_type: DataType, +) -> Result> +where + I: NativeType, + O: NativeType, + F: Fn(I) -> Result, +{ + let values = array + .values() + .iter() + .map(|v| op(*v)) + .collect::>>()? + .into(); + + Ok(PrimitiveArray::::new( + data_type, + values, + array.validity().cloned(), + )) +} + +/// Version of unary that returns an array and bitmap. Used when working with +/// overflowing operations +pub fn unary_with_bitmap( + array: &PrimitiveArray, + op: F, + data_type: DataType, +) -> (PrimitiveArray, Bitmap) +where + I: NativeType, + O: NativeType, + F: Fn(I) -> (O, bool), +{ + let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); + + let values = array + .values() + .iter() + .map(|v| { + let (res, over) = op(*v); + mut_bitmap.push(over); + res + }) + .collect::>() + .into(); + + ( + PrimitiveArray::::new(data_type, values, array.validity().cloned()), + mut_bitmap.into(), + ) +} + +/// Version of unary that creates a mutable bitmap that is used to keep track +/// of checked operations. The resulting bitmap is compared with the array +/// bitmap to create the final validity array. +pub fn unary_checked( + array: &PrimitiveArray, + op: F, + data_type: DataType, +) -> PrimitiveArray +where + I: NativeType, + O: NativeType, + F: Fn(I) -> Option, +{ + let mut mut_bitmap = MutableBitmap::with_capacity(array.len()); + + let values = array + .values() + .iter() + .map(|v| match op(*v) { + Some(val) => { + mut_bitmap.push(true); + val + }, + None => { + mut_bitmap.push(false); + O::default() + }, + }) + .collect::>() + .into(); + + // The validity has to be checked against the bitmap created during the + // creation of the values with the iterator. If an error was found during + // the iteration, then the validity is changed to None to mark the value + // as Null + let bitmap: Bitmap = mut_bitmap.into(); + let validity = combine_validities(array.validity(), Some(&bitmap)); + + PrimitiveArray::::new(data_type, values, validity) +} + +/// Applies a binary operations to two primitive arrays. This is the fastest +/// way to perform an operation on two primitive array when the benefits of a +/// vectorized operation outweighs the cost of branching nulls and non-nulls. +/// # Errors +/// This function errors iff the arrays have a different length. +/// # Implementation +/// This will apply the function for all values, including those on null slots. +/// This implies that the operation must be infallible for any value of the +/// corresponding type. +/// The types of the arrays are not checked with this operation. The closure +/// "op" needs to handle the different types in the arrays. The datatype for the +/// resulting array has to be selected by the implementer of the function as +/// an argument for the function. +#[inline] +pub fn binary( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> PrimitiveArray +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> T, +{ + check_same_len(lhs, rhs).unwrap(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>() + .into(); + + PrimitiveArray::::new(data_type, values, validity) +} + +/// Version of binary that checks for errors in the closure used to create the +/// buffer +pub fn try_binary( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> Result> +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> Result, +{ + check_same_len(lhs, rhs)?; + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>>()? + .into(); + + Ok(PrimitiveArray::::new(data_type, values, validity)) +} + +/// Version of binary that returns an array and bitmap. Used when working with +/// overflowing operations +pub fn binary_with_bitmap( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> (PrimitiveArray, Bitmap) +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> (T, bool), +{ + check_same_len(lhs, rhs).unwrap(); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| { + let (res, over) = op(*l, *r); + mut_bitmap.push(over); + res + }) + .collect::>() + .into(); + + ( + PrimitiveArray::::new(data_type, values, validity), + mut_bitmap.into(), + ) +} + +/// Version of binary that creates a mutable bitmap that is used to keep track +/// of checked operations. The resulting bitmap is compared with the array +/// bitmap to create the final validity array. +pub fn binary_checked( + lhs: &PrimitiveArray, + rhs: &PrimitiveArray, + data_type: DataType, + op: F, +) -> PrimitiveArray +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> Option, +{ + check_same_len(lhs, rhs).unwrap(); + + let mut mut_bitmap = MutableBitmap::with_capacity(lhs.len()); + + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| match op(*l, *r) { + Some(val) => { + mut_bitmap.push(true); + val + }, + None => { + mut_bitmap.push(false); + T::default() + }, + }) + .collect::>() + .into(); + + let bitmap: Bitmap = mut_bitmap.into(); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + // The validity has to be checked against the bitmap created during the + // creation of the values with the iterator. If an error was found during + // the iteration, then the validity is changed to None to mark the value + // as Null + let validity = combine_validities(validity.as_ref(), Some(&bitmap)); + + PrimitiveArray::::new(data_type, values, validity) +} diff --git a/crates/nano-arrow/src/compute/arity_assign.rs b/crates/nano-arrow/src/compute/arity_assign.rs new file mode 100644 index 000000000000..e1b358d8aebb --- /dev/null +++ b/crates/nano-arrow/src/compute/arity_assign.rs @@ -0,0 +1,96 @@ +//! Defines generics suitable to perform operations to [`PrimitiveArray`] in-place. + +use either::Either; + +use super::utils::check_same_len; +use crate::array::PrimitiveArray; +use crate::types::NativeType; + +/// Applies an unary function to a [`PrimitiveArray`], optionally in-place. +/// +/// # Implementation +/// This function tries to apply the function directly to the values of the array. +/// If that region is shared, this function creates a new region and writes to it. +/// +/// # Panics +/// This function panics iff +/// * the arrays have a different length. +/// * the function itself panics. +#[inline] +pub fn unary(array: &mut PrimitiveArray, op: F) +where + I: NativeType, + F: Fn(I) -> I, +{ + if let Some(values) = array.get_mut_values() { + // mutate in place + values.iter_mut().for_each(|l| *l = op(*l)); + } else { + // alloc and write to new region + let values = array.values().iter().map(|l| op(*l)).collect::>(); + array.set_values(values.into()); + } +} + +/// Applies a binary function to two [`PrimitiveArray`]s, optionally in-place, returning +/// a new [`PrimitiveArray`]. +/// +/// # Implementation +/// This function tries to apply the function directly to the values of the array. +/// If that region is shared, this function creates a new region and writes to it. +/// # Panics +/// This function panics iff +/// * the arrays have a different length. +/// * the function itself panics. +#[inline] +pub fn binary(lhs: &mut PrimitiveArray, rhs: &PrimitiveArray, op: F) +where + T: NativeType, + D: NativeType, + F: Fn(T, D) -> T, +{ + check_same_len(lhs, rhs).unwrap(); + + // both for the validity and for the values + // we branch to check if we can mutate in place + // if we can, great that is fastest. + // if we cannot, we allocate a new buffer and assign values to that + // new buffer, that is benchmarked to be ~2x faster than first memcpy and assign in place + // for the validity bits it can be much faster as we might need to iterate all bits if the + // bitmap has an offset. + if let Some(rhs) = rhs.validity() { + if lhs.validity().is_none() { + lhs.set_validity(Some(rhs.clone())); + } else { + lhs.apply_validity(|bitmap| { + match bitmap.into_mut() { + Either::Left(immutable) => { + // alloc new region + &immutable & rhs + }, + Either::Right(mutable) => { + // mutate in place + (mutable & rhs).into() + }, + } + }); + } + }; + + if let Some(values) = lhs.get_mut_values() { + // mutate values in place + values + .iter_mut() + .zip(rhs.values().iter()) + .for_each(|(l, r)| *l = op(*l, *r)); + } else { + // alloc new region + let values = lhs + .values() + .iter() + .zip(rhs.values().iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>(); + lhs.set_values(values.into()); + } +} diff --git a/crates/nano-arrow/src/compute/bitwise.rs b/crates/nano-arrow/src/compute/bitwise.rs new file mode 100644 index 000000000000..37c26542b848 --- /dev/null +++ b/crates/nano-arrow/src/compute/bitwise.rs @@ -0,0 +1,75 @@ +//! Contains bitwise operators: [`or`], [`and`], [`xor`] and [`not`]. +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use crate::array::PrimitiveArray; +use crate::compute::arity::{binary, unary}; +use crate::types::NativeType; + +/// Performs `OR` operation on two [`PrimitiveArray`]s. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn or(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitOr, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a | b) +} + +/// Performs `XOR` operation between two [`PrimitiveArray`]s. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn xor(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitXor, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a ^ b) +} + +/// Performs `AND` operation on two [`PrimitiveArray`]s. +/// # Panic +/// This function panics when the arrays have different lengths. +pub fn and(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + BitAnd, +{ + binary(lhs, rhs, lhs.data_type().clone(), |a, b| a & b) +} + +/// Returns a new [`PrimitiveArray`] with the bitwise `not`. +pub fn not(array: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Not, +{ + let op = move |a: T| !a; + unary(array, op, array.data_type().clone()) +} + +/// Performs `OR` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn or_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitOr, +{ + unary(lhs, |a| a | *rhs, lhs.data_type().clone()) +} + +/// Performs `XOR` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function errors when the arrays have different lengths. +pub fn xor_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitXor, +{ + unary(lhs, |a| a ^ *rhs, lhs.data_type().clone()) +} + +/// Performs `AND` operation between a [`PrimitiveArray`] and scalar. +/// # Panic +/// This function panics when the arrays have different lengths. +pub fn and_scalar(lhs: &PrimitiveArray, rhs: &T) -> PrimitiveArray +where + T: NativeType + BitAnd, +{ + unary(lhs, |a| a & *rhs, lhs.data_type().clone()) +} diff --git a/crates/nano-arrow/src/compute/boolean.rs b/crates/nano-arrow/src/compute/boolean.rs new file mode 100644 index 000000000000..daf6853c3c29 --- /dev/null +++ b/crates/nano-arrow/src/compute/boolean.rs @@ -0,0 +1,288 @@ +//! null-preserving operators such as [`and`], [`or`] and [`not`]. +use super::utils::combine_validities; +use crate::array::{Array, BooleanArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::scalar::BooleanScalar; + +fn assert_lengths(lhs: &BooleanArray, rhs: &BooleanArray) { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); +} + +/// Helper function to implement binary kernels +pub(crate) fn binary_boolean_kernel( + lhs: &BooleanArray, + rhs: &BooleanArray, + op: F, +) -> BooleanArray +where + F: Fn(&Bitmap, &Bitmap) -> Bitmap, +{ + assert_lengths(lhs, rhs); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + let values = op(left_buffer, right_buffer); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Performs `&&` operation on two [`BooleanArray`], combining the validities. +/// # Panics +/// This function panics iff the arrays have different lengths. +/// # Examples +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::and; +/// +/// let a = BooleanArray::from(&[Some(false), Some(true), None]); +/// let b = BooleanArray::from(&[Some(true), Some(true), Some(false)]); +/// let and_ab = and(&a, &b); +/// assert_eq!(and_ab, BooleanArray::from(&[Some(false), Some(true), None])); +/// ``` +pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on both sides + (0, 0) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `false` on left side + (l, _) if l == lhs.len() => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `false` on right side + (_, r) if r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // ignore the rest + _ => {}, + } + } + + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs & rhs) +} + +/// Performs `||` operation on two [`BooleanArray`], combining the validities. +/// # Panics +/// This function panics iff the arrays have different lengths. +/// # Examples +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::or; +/// +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let b = BooleanArray::from(vec![Some(true), Some(true), Some(false)]); +/// let or_ab = or(&a, &b); +/// assert_eq!(or_ab, BooleanArray::from(vec![Some(true), Some(true), None])); +/// ``` +pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + if lhs.null_count() == 0 && rhs.null_count() == 0 { + let left_buffer = lhs.values(); + let right_buffer = rhs.values(); + + match (left_buffer.unset_bits(), right_buffer.unset_bits()) { + // all values are `true` on left side + (0, _) => { + assert_lengths(lhs, rhs); + return lhs.clone(); + }, + // all values are `true` on right side + (_, 0) => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // all values on lhs and rhs are `false` + (l, r) if l == lhs.len() && r == rhs.len() => { + assert_lengths(lhs, rhs); + return rhs.clone(); + }, + // ignore the rest + _ => {}, + } + } + + binary_boolean_kernel(lhs, rhs, |lhs, rhs| lhs | rhs) +} + +/// Performs unary `NOT` operation on an arrays. If value is null then the result is also +/// null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::not; +/// +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let not_a = not(&a); +/// assert_eq!(not_a, BooleanArray::from(vec![Some(true), Some(false), None])); +/// ``` +pub fn not(array: &BooleanArray) -> BooleanArray { + let values = !array.values(); + let validity = array.validity().cloned(); + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Returns a non-null [`BooleanArray`] with whether each value of the array is null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::is_null; +/// # fn main() { +/// let a = BooleanArray::from(vec![Some(false), Some(true), None]); +/// let a_is_null = is_null(&a); +/// assert_eq!(a_is_null, BooleanArray::from_slice(vec![false, false, true])); +/// # } +/// ``` +pub fn is_null(input: &dyn Array) -> BooleanArray { + let len = input.len(); + + let values = match input.validity() { + None => MutableBitmap::from_len_zeroed(len).into(), + Some(buffer) => !buffer, + }; + + BooleanArray::new(DataType::Boolean, values, None) +} + +/// Returns a non-null [`BooleanArray`] with whether each value of the array is not null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::is_not_null; +/// +/// let a = BooleanArray::from(&vec![Some(false), Some(true), None]); +/// let a_is_not_null = is_not_null(&a); +/// assert_eq!(a_is_not_null, BooleanArray::from_slice(&vec![true, true, false])); +/// ``` +pub fn is_not_null(input: &dyn Array) -> BooleanArray { + let values = match input.validity() { + None => { + let mut mutable = MutableBitmap::new(); + mutable.extend_constant(input.len(), true); + mutable.into() + }, + Some(buffer) => buffer.clone(), + }; + BooleanArray::new(DataType::Boolean, values, None) +} + +/// Performs `AND` operation on an array and a scalar value. If either left or right value +/// is null then the result is also null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::and_scalar; +/// use arrow2::scalar::BooleanScalar; +/// +/// let array = BooleanArray::from_slice(&[false, false, true, true]); +/// let scalar = BooleanScalar::new(Some(true)); +/// let result = and_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from_slice(&[false, false, true, true])); +/// +/// ``` +pub fn and_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => array.clone(), + Some(false) => { + let values = Bitmap::new_zeroed(array.len()); + BooleanArray::new(DataType::Boolean, values, array.validity().cloned()) + }, + None => BooleanArray::new_null(DataType::Boolean, array.len()), + } +} + +/// Performs `OR` operation on an array and a scalar value. If either left or right value +/// is null then the result is also null. +/// # Example +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::or_scalar; +/// use arrow2::scalar::BooleanScalar; +/// # fn main() { +/// let array = BooleanArray::from_slice(&[false, false, true, true]); +/// let scalar = BooleanScalar::new(Some(true)); +/// let result = or_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from_slice(&[true, true, true, true])); +/// # } +/// ``` +pub fn or_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => { + let mut values = MutableBitmap::new(); + values.extend_constant(array.len(), true); + BooleanArray::new(DataType::Boolean, values.into(), array.validity().cloned()) + }, + Some(false) => array.clone(), + None => BooleanArray::new_null(DataType::Boolean, array.len()), + } +} + +/// Returns whether any of the values in the array are `true`. +/// +/// Null values are ignored. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::any; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false)]); +/// let b = BooleanArray::from(&[Some(false), Some(false)]); +/// let c = BooleanArray::from(&[None, Some(false)]); +/// +/// assert_eq!(any(&a), true); +/// assert_eq!(any(&b), false); +/// assert_eq!(any(&c), false); +/// ``` +pub fn any(array: &BooleanArray) -> bool { + if array.is_empty() { + false + } else if array.null_count() > 0 { + array.into_iter().any(|v| v == Some(true)) + } else { + let vals = array.values(); + vals.unset_bits() != vals.len() + } +} + +/// Returns whether all values in the array are `true`. +/// +/// Null values are ignored. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean::all; +/// +/// let a = BooleanArray::from(&[Some(true), Some(true)]); +/// let b = BooleanArray::from(&[Some(false), Some(true)]); +/// let c = BooleanArray::from(&[None, Some(true)]); +/// +/// assert_eq!(all(&a), true); +/// assert_eq!(all(&b), false); +/// assert_eq!(all(&c), true); +/// ``` +pub fn all(array: &BooleanArray) -> bool { + if array.is_empty() { + true + } else if array.null_count() > 0 { + !array.into_iter().any(|v| v == Some(false)) + } else { + let vals = array.values(); + vals.unset_bits() == 0 + } +} diff --git a/crates/nano-arrow/src/compute/boolean_kleene.rs b/crates/nano-arrow/src/compute/boolean_kleene.rs new file mode 100644 index 000000000000..2983c2e31ded --- /dev/null +++ b/crates/nano-arrow/src/compute/boolean_kleene.rs @@ -0,0 +1,301 @@ +//! Boolean operators of [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics). +use crate::array::{Array, BooleanArray}; +use crate::bitmap::{binary, quaternary, ternary, unary, Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::scalar::BooleanScalar; + +/// Logical 'or' operation on two arrays with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Panics +/// This function panics iff the arrays have a different length +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::or; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false), None]); +/// let b = BooleanArray::from(&[None, None, None]); +/// let or_ab = or(&a, &b); +/// assert_eq!(or_ab, BooleanArray::from(&[Some(true), None, None])); +/// ``` +pub fn or(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); + + let lhs_values = lhs.values(); + let rhs_values = rhs.values(); + + let lhs_validity = lhs.validity(); + let rhs_validity = rhs.validity(); + + let validity = match (lhs_validity, rhs_validity) { + (Some(lhs_validity), Some(rhs_validity)) => { + Some(quaternary( + lhs_values, + rhs_values, + lhs_validity, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v, rhs_v| { + // A = T + (lhs & lhs_v) | + // B = T + (rhs & rhs_v) | + // A = F & B = F + (!lhs & lhs_v) & (!rhs & rhs_v) + }, + )) + }, + (Some(lhs_validity), None) => { + // B != U + Some(ternary( + lhs_values, + rhs_values, + lhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v| { + // A = T + (lhs & lhs_v) | + // B = T + rhs | + // A = F & B = F + (!lhs & lhs_v) & !rhs + }, + )) + }, + (None, Some(rhs_validity)) => { + Some(ternary( + lhs_values, + rhs_values, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, rhs_v| { + // A = T + lhs | + // B = T + (rhs & rhs_v) | + // A = F & B = F + !lhs & (!rhs & rhs_v) + }, + )) + }, + (None, None) => None, + }; + BooleanArray::new(DataType::Boolean, lhs_values | rhs_values, validity) +} + +/// Logical 'and' operation on two arrays with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Panics +/// This function panics iff the arrays have a different length +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::and; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false), None]); +/// let b = BooleanArray::from(&[None, None, None]); +/// let and_ab = and(&a, &b); +/// assert_eq!(and_ab, BooleanArray::from(&[None, Some(false), None])); +/// ``` +pub fn and(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + assert_eq!( + lhs.len(), + rhs.len(), + "lhs and rhs must have the same length" + ); + + let lhs_values = lhs.values(); + let rhs_values = rhs.values(); + + let lhs_validity = lhs.validity(); + let rhs_validity = rhs.validity(); + + let validity = match (lhs_validity, rhs_validity) { + (Some(lhs_validity), Some(rhs_validity)) => { + Some(quaternary( + lhs_values, + rhs_values, + lhs_validity, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v, rhs_v| { + // B = F + (!rhs & rhs_v) | + // A = F + (!lhs & lhs_v) | + // A = T & B = T + (lhs & lhs_v) & (rhs & rhs_v) + }, + )) + }, + (Some(lhs_validity), None) => { + Some(ternary( + lhs_values, + rhs_values, + lhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, lhs_v| { + // B = F + !rhs | + // A = F + (!lhs & lhs_v) | + // A = T & B = T + (lhs & lhs_v) & rhs + }, + )) + }, + (None, Some(rhs_validity)) => { + Some(ternary( + lhs_values, + rhs_values, + rhs_validity, + // see https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics + |lhs, rhs, rhs_v| { + // B = F + (!rhs & rhs_v) | + // A = F + !lhs | + // A = T & B = T + lhs & (rhs & rhs_v) + }, + )) + }, + (None, None) => None, + }; + BooleanArray::new(DataType::Boolean, lhs_values & rhs_values, validity) +} + +/// Logical 'or' operation on an array and a scalar value with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::scalar::BooleanScalar; +/// use arrow2::compute::boolean_kleene::or_scalar; +/// +/// let array = BooleanArray::from(&[Some(true), Some(false), None]); +/// let scalar = BooleanScalar::new(Some(false)); +/// let result = or_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from(&[Some(true), Some(false), None])); +/// ``` +pub fn or_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => { + let mut values = MutableBitmap::new(); + values.extend_constant(array.len(), true); + BooleanArray::new(DataType::Boolean, values.into(), None) + }, + Some(false) => array.clone(), + None => { + let values = array.values(); + let validity = match array.validity() { + Some(validity) => binary(values, validity, |value, validity| validity & value), + None => unary(values, |value| value), + }; + BooleanArray::new(DataType::Boolean, values.clone(), Some(validity)) + }, + } +} + +/// Logical 'and' operation on an array and a scalar value with [Kleene logic](https://en.wikipedia.org/wiki/Three-valued_logic#Kleene_and_Priest_logics) +/// # Example +/// +/// ```rust +/// use arrow2::array::BooleanArray; +/// use arrow2::scalar::BooleanScalar; +/// use arrow2::compute::boolean_kleene::and_scalar; +/// +/// let array = BooleanArray::from(&[Some(true), Some(false), None]); +/// let scalar = BooleanScalar::new(None); +/// let result = and_scalar(&array, &scalar); +/// assert_eq!(result, BooleanArray::from(&[None, Some(false), None])); +/// ``` +pub fn and_scalar(array: &BooleanArray, scalar: &BooleanScalar) -> BooleanArray { + match scalar.value() { + Some(true) => array.clone(), + Some(false) => { + let values = Bitmap::new_zeroed(array.len()); + BooleanArray::new(DataType::Boolean, values, None) + }, + None => { + let values = array.values(); + let validity = match array.validity() { + Some(validity) => binary(values, validity, |value, validity| validity & !value), + None => unary(values, |value| !value), + }; + BooleanArray::new(DataType::Boolean, array.values().clone(), Some(validity)) + }, + } +} + +/// Returns whether any of the values in the array are `true`. +/// +/// The output is unknown (`None`) if the array contains any null values and +/// no `true` values. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::any; +/// +/// let a = BooleanArray::from(&[Some(true), Some(false)]); +/// let b = BooleanArray::from(&[Some(false), Some(false)]); +/// let c = BooleanArray::from(&[None, Some(false)]); +/// +/// assert_eq!(any(&a), Some(true)); +/// assert_eq!(any(&b), Some(false)); +/// assert_eq!(any(&c), None); +/// ``` +pub fn any(array: &BooleanArray) -> Option { + if array.is_empty() { + Some(false) + } else if array.null_count() > 0 { + if array.into_iter().any(|v| v == Some(true)) { + Some(true) + } else { + None + } + } else { + let vals = array.values(); + Some(vals.unset_bits() != vals.len()) + } +} + +/// Returns whether all values in the array are `true`. +/// +/// The output is unknown (`None`) if the array contains any null values and +/// no `false` values. +/// +/// # Example +/// +/// ``` +/// use arrow2::array::BooleanArray; +/// use arrow2::compute::boolean_kleene::all; +/// +/// let a = BooleanArray::from(&[Some(true), Some(true)]); +/// let b = BooleanArray::from(&[Some(false), Some(true)]); +/// let c = BooleanArray::from(&[None, Some(true)]); +/// +/// assert_eq!(all(&a), Some(true)); +/// assert_eq!(all(&b), Some(false)); +/// assert_eq!(all(&c), None); +/// ``` +pub fn all(array: &BooleanArray) -> Option { + if array.is_empty() { + Some(true) + } else if array.null_count() > 0 { + if array.into_iter().any(|v| v == Some(false)) { + Some(false) + } else { + None + } + } else { + let vals = array.values(); + Some(vals.unset_bits() == 0) + } +} diff --git a/crates/nano-arrow/src/compute/cast/binary_to.rs b/crates/nano-arrow/src/compute/cast/binary_to.rs new file mode 100644 index 000000000000..52038f9caefa --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/binary_to.rs @@ -0,0 +1,159 @@ +use super::CastOptions; +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::offset::{Offset, Offsets}; +use crate::types::NativeType; + +/// Conversion of binary +pub fn binary_to_large_binary(from: &BinaryArray, to_data_type: DataType) -> BinaryArray { + let values = from.values().clone(); + BinaryArray::::new( + to_data_type, + from.offsets().into(), + values, + from.validity().cloned(), + ) +} + +/// Conversion of binary +pub fn binary_large_to_binary( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + let values = from.values().clone(); + let offsets = from.offsets().try_into()?; + Ok(BinaryArray::::new( + to_data_type, + offsets, + values, + from.validity().cloned(), + )) +} + +/// Conversion to utf8 +pub fn binary_to_utf8( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + Utf8Array::::try_new( + to_data_type, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +/// Conversion to utf8 +/// # Errors +/// This function errors if the values are not valid utf8 +pub fn binary_to_large_utf8( + from: &BinaryArray, + to_data_type: DataType, +) -> Result> { + let values = from.values().clone(); + let offsets = from.offsets().into(); + + Utf8Array::::try_new(to_data_type, offsets, values, from.validity().cloned()) +} + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`] at best-effort using `lexical_core::parse_partial`, making any uncastable value as zero. +pub fn partial_binary_to_primitive( + from: &BinaryArray, + to: &DataType, +) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse_partial(x).ok().map(|x| x.0))); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +/// Casts a [`BinaryArray`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub fn binary_to_primitive(from: &BinaryArray, to: &DataType) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse(x).ok())); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +pub(super) fn binary_to_primitive_dyn( + from: &dyn Array, + to: &DataType, + options: CastOptions, +) -> Result> +where + T: NativeType + lexical_core::FromLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + if options.partial { + Ok(Box::new(partial_binary_to_primitive::(from, to))) + } else { + Ok(Box::new(binary_to_primitive::(from, to))) + } +} + +/// Cast [`BinaryArray`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn binary_to_dictionary( + from: &BinaryArray, +) -> Result> { + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn binary_to_dictionary_dyn( + from: &dyn Array, +) -> Result> { + let values = from.as_any().downcast_ref().unwrap(); + binary_to_dictionary::(values).map(|x| Box::new(x) as Box) +} + +fn fixed_size_to_offsets(values_len: usize, fixed_size: usize) -> Offsets { + let offsets = (0..(values_len + 1)) + .step_by(fixed_size) + .map(|v| O::from_usize(v).unwrap()) + .collect(); + // Safety + // * every element is `>= 0` + // * element at position `i` is >= than element at position `i-1`. + unsafe { Offsets::new_unchecked(offsets) } +} + +/// Conversion of `FixedSizeBinary` to `Binary`. +pub fn fixed_size_binary_binary( + from: &FixedSizeBinaryArray, + to_data_type: DataType, +) -> BinaryArray { + let values = from.values().clone(); + let offsets = fixed_size_to_offsets(values.len(), from.size()); + BinaryArray::::new( + to_data_type, + offsets.into(), + values, + from.validity().cloned(), + ) +} + +/// Conversion of binary +pub fn binary_to_list(from: &BinaryArray, to_data_type: DataType) -> ListArray { + let values = from.values().clone(); + let values = PrimitiveArray::new(DataType::UInt8, values, None); + ListArray::::new( + to_data_type, + from.offsets().clone(), + values.boxed(), + from.validity().cloned(), + ) +} diff --git a/crates/nano-arrow/src/compute/cast/boolean_to.rs b/crates/nano-arrow/src/compute/cast/boolean_to.rs new file mode 100644 index 000000000000..8a8cf7089d8f --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/boolean_to.rs @@ -0,0 +1,48 @@ +use crate::array::{Array, BinaryArray, BooleanArray, PrimitiveArray, Utf8Array}; +use crate::error::Result; +use crate::offset::Offset; +use crate::types::NativeType; + +pub(super) fn boolean_to_primitive_dyn(array: &dyn Array) -> Result> +where + T: NativeType + num_traits::One, +{ + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_primitive::(array))) +} + +/// Casts the [`BooleanArray`] to a [`PrimitiveArray`]. +pub fn boolean_to_primitive(from: &BooleanArray) -> PrimitiveArray +where + T: NativeType + num_traits::One, +{ + let values = from + .values() + .iter() + .map(|x| if x { T::one() } else { T::default() }) + .collect::>(); + + PrimitiveArray::::new(T::PRIMITIVE.into(), values.into(), from.validity().cloned()) +} + +/// Casts the [`BooleanArray`] to a [`Utf8Array`], casting trues to `"1"` and falses to `"0"` +pub fn boolean_to_utf8(from: &BooleanArray) -> Utf8Array { + let iter = from.values().iter().map(|x| if x { "1" } else { "0" }); + Utf8Array::from_trusted_len_values_iter(iter) +} + +pub(super) fn boolean_to_utf8_dyn(array: &dyn Array) -> Result> { + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_utf8::(array))) +} + +/// Casts the [`BooleanArray`] to a [`BinaryArray`], casting trues to `"1"` and falses to `"0"` +pub fn boolean_to_binary(from: &BooleanArray) -> BinaryArray { + let iter = from.values().iter().map(|x| if x { b"1" } else { b"0" }); + BinaryArray::from_trusted_len_values_iter(iter) +} + +pub(super) fn boolean_to_binary_dyn(array: &dyn Array) -> Result> { + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean_to_binary::(array))) +} diff --git a/crates/nano-arrow/src/compute/cast/decimal_to.rs b/crates/nano-arrow/src/compute/cast/decimal_to.rs new file mode 100644 index 000000000000..ba9995c86c12 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/decimal_to.rs @@ -0,0 +1,137 @@ +use num_traits::{AsPrimitive, Float, NumCast}; + +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +#[inline] +fn decimal_to_decimal_impl Option>( + from: &PrimitiveArray, + op: F, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + op(*x).and_then(|x| { + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let (from_precision, from_scale) = + if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + if to_scale == from_scale && to_precision >= from_precision { + // fast path + return from.clone().to(DataType::Decimal(to_precision, to_scale)); + } + // todo: other fast paths include increasing scale and precision by so that + // a number will never overflow (validity is preserved) + + if from_scale > to_scale { + let factor = 10_i128.pow((from_scale - to_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_div(factor), + to_precision, + to_scale, + ) + } else { + let factor = 10_i128.pow((to_scale - from_scale) as u32); + decimal_to_decimal_impl( + from, + |x: i128| x.checked_mul(factor), + to_precision, + to_scale, + ) + } +} + +pub(super) fn decimal_to_decimal_dyn( + from: &dyn Array, + to_precision: usize, + to_scale: usize, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_decimal(from, to_precision, to_scale))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_float(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let div = 10_f64.powi(from_scale as i32); + let values = from + .values() + .iter() + .map(|x| (*x as f64 / div).as_()) + .collect(); + + PrimitiveArray::::new(T::PRIMITIVE.into(), values, from.validity().cloned()) +} + +pub(super) fn decimal_to_float_dyn(from: &dyn Array) -> Result> +where + T: NativeType + Float, + f64: AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_float::(from))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn decimal_to_integer(from: &PrimitiveArray) -> PrimitiveArray +where + T: NativeType + NumCast, +{ + let (_, from_scale) = if let DataType::Decimal(p, s) = from.data_type().to_logical_type() { + (*p, *s) + } else { + panic!("internal error: i128 is always a decimal") + }; + + let factor = 10_i128.pow(from_scale as u32); + let values = from.iter().map(|x| x.and_then(|x| T::from(*x / factor))); + + PrimitiveArray::from_trusted_len_iter(values) +} + +pub(super) fn decimal_to_integer_dyn(from: &dyn Array) -> Result> +where + T: NativeType + NumCast, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(decimal_to_integer::(from))) +} diff --git a/crates/nano-arrow/src/compute/cast/dictionary_to.rs b/crates/nano-arrow/src/compute/cast/dictionary_to.rs new file mode 100644 index 000000000000..4126e4a3d589 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/dictionary_to.rs @@ -0,0 +1,183 @@ +use super::{primitive_as_primitive, primitive_to_primitive, CastOptions}; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::compute::cast::cast; +use crate::compute::take::take; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +macro_rules! key_cast { + ($keys:expr, $values:expr, $array:expr, $to_keys_type:expr, $to_type:ty, $to_datatype:expr) => {{ + let cast_keys = primitive_to_primitive::<_, $to_type>($keys, $to_keys_type); + + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > $keys.null_count() { + return Err(Error::Overflow); + } + // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + unsafe { + DictionaryArray::try_new_unchecked($to_datatype, cast_keys, $values.clone()) + } + .map(|x| x.boxed()) + }}; +} + +/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] by keeping the +/// keys and casting the values to `values_type`. +/// # Errors +/// This function errors if the values are not castable to `values_type` +pub fn dictionary_to_dictionary_values( + from: &DictionaryArray, + values_type: &DataType, +) -> Result> { + let keys = from.keys(); + let values = from.values(); + let length = values.len(); + + let values = cast(values.as_ref(), values_type, CastOptions::default())?; + + assert_eq!(values.len(), length); // this is guaranteed by `cast` + unsafe { + DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + } +} + +/// Similar to dictionary_to_dictionary_values, but overflowing cast is wrapped +pub fn wrapping_dictionary_to_dictionary_values( + from: &DictionaryArray, + values_type: &DataType, +) -> Result> { + let keys = from.keys(); + let values = from.values(); + let length = values.len(); + + let values = cast( + values.as_ref(), + values_type, + CastOptions { + wrapped: true, + partial: false, + }, + )?; + assert_eq!(values.len(), length); // this is guaranteed by `cast` + unsafe { + DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + } +} + +/// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] backed by a +/// different physical type of the keys, while keeping the values equal. +/// # Errors +/// Errors if any of the old keys' values is larger than the maximum value +/// supported by the new physical type. +pub fn dictionary_to_dictionary_keys( + from: &DictionaryArray, +) -> Result> +where + K1: DictionaryKey + num_traits::NumCast, + K2: DictionaryKey + num_traits::NumCast, +{ + let keys = from.keys(); + let values = from.values(); + let is_ordered = from.is_ordered(); + + let casted_keys = primitive_to_primitive::(keys, &K2::PRIMITIVE.into()); + + if casted_keys.null_count() > keys.null_count() { + Err(Error::Overflow) + } else { + let data_type = DataType::Dictionary( + K2::KEY_TYPE, + Box::new(values.data_type().clone()), + is_ordered, + ); + // Safety: this is safe because given a type `T` that fits in a `usize`, casting it to type `P` either overflows or also fits in a `usize` + unsafe { DictionaryArray::try_new_unchecked(data_type, casted_keys, values.clone()) } + } +} + +/// Similar to dictionary_to_dictionary_keys, but overflowing cast is wrapped +pub fn wrapping_dictionary_to_dictionary_keys( + from: &DictionaryArray, +) -> Result> +where + K1: DictionaryKey + num_traits::AsPrimitive, + K2: DictionaryKey, +{ + let keys = from.keys(); + let values = from.values(); + let is_ordered = from.is_ordered(); + + let casted_keys = primitive_as_primitive::(keys, &K2::PRIMITIVE.into()); + + if casted_keys.null_count() > keys.null_count() { + Err(Error::Overflow) + } else { + let data_type = DataType::Dictionary( + K2::KEY_TYPE, + Box::new(values.data_type().clone()), + is_ordered, + ); + // some of the values may not fit in `usize` and thus this needs to be checked + DictionaryArray::try_new(data_type, casted_keys, values.clone()) + } +} + +pub(super) fn dictionary_cast_dyn( + array: &dyn Array, + to_type: &DataType, + options: CastOptions, +) -> Result> { + let array = array.as_any().downcast_ref::>().unwrap(); + let keys = array.keys(); + let values = array.values(); + + match to_type { + DataType::Dictionary(to_keys_type, to_values_type, _) => { + let values = cast(values.as_ref(), to_values_type, options)?; + + // create the appropriate array type + let to_key_type = (*to_keys_type).into(); + + // Safety: + // we return an error on overflow so the integers remain within bounds + match_integer_type!(to_keys_type, |$T| { + key_cast!(keys, values, array, &to_key_type, $T, to_type.clone()) + }) + }, + _ => unpack_dictionary::(keys, values.as_ref(), to_type, options), + } +} + +// Unpack the dictionary +fn unpack_dictionary( + keys: &PrimitiveArray, + values: &dyn Array, + to_type: &DataType, + options: CastOptions, +) -> Result> +where + K: DictionaryKey + num_traits::NumCast, +{ + // attempt to cast the dict values to the target type + // use the take kernel to expand out the dictionary + let values = cast(values, to_type, options)?; + + // take requires first casting i32 + let indices = primitive_to_primitive::<_, i32>(keys, &DataType::Int32); + + take(values.as_ref(), &indices) +} + +/// Casts a [`DictionaryArray`] to its values' [`DataType`], also known as unpacking. +/// The resulting array has the same length. +pub fn dictionary_to_values(from: &DictionaryArray) -> Box +where + K: DictionaryKey + num_traits::NumCast, +{ + // take requires first casting i64 + let indices = primitive_to_primitive::<_, i64>(from.keys(), &DataType::Int64); + + // unwrap: The dictionary guarantees that the keys are not out-of-bounds. + take(from.values().as_ref(), &indices).unwrap() +} diff --git a/crates/nano-arrow/src/compute/cast/mod.rs b/crates/nano-arrow/src/compute/cast/mod.rs new file mode 100644 index 000000000000..f13a638a9c0d --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/mod.rs @@ -0,0 +1,989 @@ +//! Defines different casting operators such as [`cast`] or [`primitive_to_binary`]. + +mod binary_to; +mod boolean_to; +mod decimal_to; +mod dictionary_to; +mod primitive_to; +mod utf8_to; + +pub use binary_to::*; +pub use boolean_to::*; +pub use decimal_to::*; +pub use dictionary_to::*; +pub use primitive_to::*; +pub use utf8_to::*; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::offset::{Offset, Offsets}; + +/// options defining how Cast kernels behave +#[derive(Clone, Copy, Debug, Default)] +pub struct CastOptions { + /// default to false + /// whether an overflowing cast should be converted to `None` (default), or be wrapped (i.e. `256i16 as u8 = 0` vectorized). + /// Settings this to `true` is 5-6x faster for numeric types. + pub wrapped: bool, + /// default to false + /// whether to cast to an integer at the best-effort + pub partial: bool, +} + +impl CastOptions { + fn with_wrapped(&self, v: bool) -> Self { + let mut option = *self; + option.wrapped = v; + option + } +} + +/// Returns true if this type is numeric: (UInt*, Unit*, or Float*). +fn is_numeric(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 + ) +} + +macro_rules! primitive_dyn { + ($from:expr, $expr:tt) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from))) + }}; + ($from:expr, $expr:tt, $to:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $to))) + }}; + ($from:expr, $expr:tt, $from_t:expr, $to:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $from_t, $to))) + }}; + ($from:expr, $expr:tt, $arg1:expr, $arg2:expr, $arg3:expr) => {{ + let from = $from.as_any().downcast_ref().unwrap(); + Ok(Box::new($expr(from, $arg1, $arg2, $arg3))) + }}; +} + +/// Return true if a value of type `from_type` can be cast into a +/// value of `to_type`. Note that such as cast may be lossy. +/// +/// If this function returns true to stay consistent with the `cast` kernel below. +pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { + use self::DataType::*; + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + (Null, _) | (_, Null) => true, + (Struct(_), _) => false, + (_, Struct(_)) => false, + (FixedSizeList(list_from, _), List(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (FixedSizeList(list_from, _), LargeList(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (List(list_from), FixedSizeList(list_to, _)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (LargeList(list_from), FixedSizeList(list_to, _)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (List(list_from), List(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (LargeList(list_from), LargeList(list_to)) => { + can_cast_types(&list_from.data_type, &list_to.data_type) + }, + (List(list_from), LargeList(list_to)) if list_from == list_to => true, + (LargeList(list_from), List(list_to)) if list_from == list_to => true, + (_, List(list_to)) => can_cast_types(from_type, &list_to.data_type), + (_, LargeList(list_to)) if from_type != &LargeBinary => { + can_cast_types(from_type, &list_to.data_type) + }, + (Dictionary(_, from_value_type, _), Dictionary(_, to_value_type, _)) => { + can_cast_types(from_value_type, to_value_type) + }, + (Dictionary(_, value_type, _), _) => can_cast_types(value_type, to_type), + (_, Dictionary(_, value_type, _)) => can_cast_types(from_type, value_type), + + (_, Boolean) => is_numeric(from_type), + (Boolean, _) => { + is_numeric(to_type) + || to_type == &Utf8 + || to_type == &LargeUtf8 + || to_type == &Binary + || to_type == &LargeBinary + }, + + (Utf8, to_type) => { + is_numeric(to_type) + || matches!( + to_type, + LargeUtf8 | Binary | Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, _) + ) + }, + (LargeUtf8, to_type) => { + is_numeric(to_type) + || matches!( + to_type, + Utf8 | LargeBinary | Date32 | Date64 | Timestamp(TimeUnit::Nanosecond, _) + ) + }, + + (Binary, to_type) => { + is_numeric(to_type) || matches!(to_type, LargeBinary | Utf8 | LargeUtf8) + }, + (LargeBinary, to_type) => { + is_numeric(to_type) + || match to_type { + Binary | LargeUtf8 => true, + LargeList(field) => matches!(field.data_type, UInt8), + _ => false, + } + }, + (FixedSizeBinary(_), to_type) => matches!(to_type, Binary | LargeBinary), + (Timestamp(_, _), Utf8) => true, + (Timestamp(_, _), LargeUtf8) => true, + (_, Utf8) => is_numeric(from_type) || from_type == &Binary, + (_, LargeUtf8) => is_numeric(from_type) || from_type == &LargeBinary, + + (_, Binary) => is_numeric(from_type), + (_, LargeBinary) => is_numeric(from_type), + + // start numeric casts + (UInt8, UInt16) => true, + (UInt8, UInt32) => true, + (UInt8, UInt64) => true, + (UInt8, Int8) => true, + (UInt8, Int16) => true, + (UInt8, Int32) => true, + (UInt8, Int64) => true, + (UInt8, Float32) => true, + (UInt8, Float64) => true, + (UInt8, Decimal(_, _)) => true, + + (UInt16, UInt8) => true, + (UInt16, UInt32) => true, + (UInt16, UInt64) => true, + (UInt16, Int8) => true, + (UInt16, Int16) => true, + (UInt16, Int32) => true, + (UInt16, Int64) => true, + (UInt16, Float32) => true, + (UInt16, Float64) => true, + (UInt16, Decimal(_, _)) => true, + + (UInt32, UInt8) => true, + (UInt32, UInt16) => true, + (UInt32, UInt64) => true, + (UInt32, Int8) => true, + (UInt32, Int16) => true, + (UInt32, Int32) => true, + (UInt32, Int64) => true, + (UInt32, Float32) => true, + (UInt32, Float64) => true, + (UInt32, Decimal(_, _)) => true, + + (UInt64, UInt8) => true, + (UInt64, UInt16) => true, + (UInt64, UInt32) => true, + (UInt64, Int8) => true, + (UInt64, Int16) => true, + (UInt64, Int32) => true, + (UInt64, Int64) => true, + (UInt64, Float32) => true, + (UInt64, Float64) => true, + (UInt64, Decimal(_, _)) => true, + + (Int8, UInt8) => true, + (Int8, UInt16) => true, + (Int8, UInt32) => true, + (Int8, UInt64) => true, + (Int8, Int16) => true, + (Int8, Int32) => true, + (Int8, Int64) => true, + (Int8, Float32) => true, + (Int8, Float64) => true, + (Int8, Decimal(_, _)) => true, + + (Int16, UInt8) => true, + (Int16, UInt16) => true, + (Int16, UInt32) => true, + (Int16, UInt64) => true, + (Int16, Int8) => true, + (Int16, Int32) => true, + (Int16, Int64) => true, + (Int16, Float32) => true, + (Int16, Float64) => true, + (Int16, Decimal(_, _)) => true, + + (Int32, UInt8) => true, + (Int32, UInt16) => true, + (Int32, UInt32) => true, + (Int32, UInt64) => true, + (Int32, Int8) => true, + (Int32, Int16) => true, + (Int32, Int64) => true, + (Int32, Float32) => true, + (Int32, Float64) => true, + (Int32, Decimal(_, _)) => true, + + (Int64, UInt8) => true, + (Int64, UInt16) => true, + (Int64, UInt32) => true, + (Int64, UInt64) => true, + (Int64, Int8) => true, + (Int64, Int16) => true, + (Int64, Int32) => true, + (Int64, Float32) => true, + (Int64, Float64) => true, + (Int64, Decimal(_, _)) => true, + + (Float16, Float32) => true, + + (Float32, UInt8) => true, + (Float32, UInt16) => true, + (Float32, UInt32) => true, + (Float32, UInt64) => true, + (Float32, Int8) => true, + (Float32, Int16) => true, + (Float32, Int32) => true, + (Float32, Int64) => true, + (Float32, Float64) => true, + (Float32, Decimal(_, _)) => true, + + (Float64, UInt8) => true, + (Float64, UInt16) => true, + (Float64, UInt32) => true, + (Float64, UInt64) => true, + (Float64, Int8) => true, + (Float64, Int16) => true, + (Float64, Int32) => true, + (Float64, Int64) => true, + (Float64, Float32) => true, + (Float64, Decimal(_, _)) => true, + + ( + Decimal(_, _), + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float32 + | Float64 + | Decimal(_, _), + ) => true, + // end numeric casts + + // temporal casts + (Int32, Date32) => true, + (Int32, Time32(_)) => true, + (Date32, Int32) => true, + (Date32, Int64) => true, + (Time32(_), Int32) => true, + (Int64, Date64) => true, + (Int64, Time64(_)) => true, + (Date64, Int32) => true, + (Date64, Int64) => true, + (Time64(_), Int64) => true, + (Date32, Date64) => true, + (Date64, Date32) => true, + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => true, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => true, + (Time32(_), Time64(_)) => true, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => true, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => true, + (Time64(_), Time32(to_unit)) => { + matches!(to_unit, TimeUnit::Second | TimeUnit::Millisecond) + }, + (Timestamp(_, _), Int64) => true, + (Int64, Timestamp(_, _)) => true, + (Timestamp(_, _), Timestamp(_, _)) => true, + (Timestamp(_, _), Date32) => true, + (Timestamp(_, _), Date64) => true, + (Int64, Duration(_)) => true, + (Duration(_), Int64) => true, + (Interval(_), Interval(IntervalUnit::MonthDayNano)) => true, + (_, _) => false, + } +} + +fn cast_list( + array: &ListArray, + to_type: &DataType, + options: CastOptions, +) -> Result> { + let values = array.values(); + let new_values = cast( + values.as_ref(), + ListArray::::get_child_type(to_type), + options, + )?; + + Ok(ListArray::::new( + to_type.clone(), + array.offsets().clone(), + new_values, + array.validity().cloned(), + )) +} + +fn cast_list_to_large_list(array: &ListArray, to_type: &DataType) -> ListArray { + let offsets = array.offsets().into(); + + ListArray::::new( + to_type.clone(), + offsets, + array.values().clone(), + array.validity().cloned(), + ) +} + +fn cast_large_to_list(array: &ListArray, to_type: &DataType) -> ListArray { + let offsets = array.offsets().try_into().expect("Convertme to error"); + + ListArray::::new( + to_type.clone(), + offsets, + array.values().clone(), + array.validity().cloned(), + ) +} + +fn cast_fixed_size_list_to_list( + fixed: &FixedSizeListArray, + to_type: &DataType, + options: CastOptions, +) -> Result> { + let new_values = cast( + fixed.values().as_ref(), + ListArray::::get_child_type(to_type), + options, + )?; + + let offsets = (0..=fixed.len()) + .map(|ix| O::from_as_usize(ix * fixed.size())) + .collect::>(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + Ok(ListArray::::new( + to_type.clone(), + offsets.into(), + new_values, + fixed.validity().cloned(), + )) +} + +fn cast_list_to_fixed_size_list( + list: &ListArray, + inner: &Field, + size: usize, + options: CastOptions, +) -> Result { + let offsets = list.offsets().buffer().iter(); + let expected = (0..list.len()).map(|ix| O::from_as_usize(ix * size)); + + match offsets + .zip(expected) + .find(|(actual, expected)| *actual != expected) + { + Some(_) => Err(Error::InvalidArgumentError( + "incompatible offsets in source list".to_string(), + )), + None => { + let sliced_values = list.values().sliced( + list.offsets().first().to_usize(), + list.offsets().range().to_usize(), + ); + let new_values = cast(sliced_values.as_ref(), inner.data_type(), options)?; + Ok(FixedSizeListArray::new( + DataType::FixedSizeList(Box::new(inner.clone()), size), + new_values, + list.validity().cloned(), + )) + }, + } +} + +/// Cast `array` to the provided data type and return a new [`Array`] with +/// type `to_type`, if possible. +/// +/// Behavior: +/// * PrimitiveArray to PrimitiveArray: overflowing cast will be None +/// * Boolean to Utf8: `true` => '1', `false` => `0` +/// * Utf8 to numeric: strings that can't be parsed to numbers return null, float strings +/// in integer casts return null +/// * Numeric to boolean: 0 returns `false`, any other value returns `true` +/// * List to List: the underlying data type is cast +/// * Fixed Size List to List: the underlying data type is cast +/// * List to Fixed Size List: the offsets are checked for valid order, then the +/// underlying type is cast. +/// * PrimitiveArray to List: a list array with 1 value per slot is created +/// * Date32 and Date64: precision lost when going to higher interval +/// * Time32 and Time64: precision lost when going to higher interval +/// * Timestamp and Date{32|64}: precision lost when going to higher interval +/// * Temporal to/from backing primitive: zero-copy with data type change +/// Unsupported Casts +/// * To or from `StructArray` +/// * List to primitive +/// * Utf8 to boolean +/// * Interval and duration +pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Result> { + use DataType::*; + let from_type = array.data_type(); + + // clone array if types are the same + if from_type == to_type { + return Ok(clone(array)); + } + + let as_options = options.with_wrapped(true); + match (from_type, to_type) { + (Null, _) | (_, Null) => Ok(new_null_array(to_type.clone(), array.len())), + (Struct(_), _) => Err(Error::NotYetImplemented( + "Cannot cast from struct to other types".to_string(), + )), + (_, Struct(_)) => Err(Error::NotYetImplemented( + "Cannot cast to struct from other types".to_string(), + )), + (List(_), FixedSizeList(inner, size)) => cast_list_to_fixed_size_list::( + array.as_any().downcast_ref().unwrap(), + inner.as_ref(), + *size, + options, + ) + .map(|x| x.boxed()), + (LargeList(_), FixedSizeList(inner, size)) => cast_list_to_fixed_size_list::( + array.as_any().downcast_ref().unwrap(), + inner.as_ref(), + *size, + options, + ) + .map(|x| x.boxed()), + (FixedSizeList(_, _), List(_)) => cast_fixed_size_list_to_list::( + array.as_any().downcast_ref().unwrap(), + to_type, + options, + ) + .map(|x| x.boxed()), + (FixedSizeList(_, _), LargeList(_)) => cast_fixed_size_list_to_list::( + array.as_any().downcast_ref().unwrap(), + to_type, + options, + ) + .map(|x| x.boxed()), + (List(_), List(_)) => { + cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) + .map(|x| x.boxed()) + }, + (LargeList(_), LargeList(_)) => { + cast_list::(array.as_any().downcast_ref().unwrap(), to_type, options) + .map(|x| x.boxed()) + }, + (List(lhs), LargeList(rhs)) if lhs == rhs => { + Ok(cast_list_to_large_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) + }, + (LargeList(lhs), List(rhs)) if lhs == rhs => { + Ok(cast_large_to_list(array.as_any().downcast_ref().unwrap(), to_type).boxed()) + }, + + (_, List(to)) => { + // cast primitive to list's primitive + let values = cast(array, &to.data_type, options)?; + // create offsets, where if array.len() = 2, we have [0,1,2] + let offsets = (0..=array.len() as i32).collect::>(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); + + Ok(Box::new(list_array)) + }, + + (_, LargeList(to)) if from_type != &LargeBinary => { + // cast primitive to list's primitive + let values = cast(array, &to.data_type, options)?; + // create offsets, where if array.len() = 2, we have [0,1,2] + let offsets = (0..=array.len() as i64).collect::>(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + + let list_array = ListArray::::new(to_type.clone(), offsets.into(), values, None); + + Ok(Box::new(list_array)) + }, + + (Dictionary(index_type, ..), _) => match_integer_type!(index_type, |$T| { + dictionary_cast_dyn::<$T>(array, to_type, options) + }), + (_, Dictionary(index_type, value_type, _)) => match_integer_type!(index_type, |$T| { + cast_to_dictionary::<$T>(array, value_type, options) + }), + (_, Boolean) => match from_type { + UInt8 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt16 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt32 => primitive_to_boolean_dyn::(array, to_type.clone()), + UInt64 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int8 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int16 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int32 => primitive_to_boolean_dyn::(array, to_type.clone()), + Int64 => primitive_to_boolean_dyn::(array, to_type.clone()), + Float32 => primitive_to_boolean_dyn::(array, to_type.clone()), + Float64 => primitive_to_boolean_dyn::(array, to_type.clone()), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (Boolean, _) => match to_type { + UInt8 => boolean_to_primitive_dyn::(array), + UInt16 => boolean_to_primitive_dyn::(array), + UInt32 => boolean_to_primitive_dyn::(array), + UInt64 => boolean_to_primitive_dyn::(array), + Int8 => boolean_to_primitive_dyn::(array), + Int16 => boolean_to_primitive_dyn::(array), + Int32 => boolean_to_primitive_dyn::(array), + Int64 => boolean_to_primitive_dyn::(array), + Float32 => boolean_to_primitive_dyn::(array), + Float64 => boolean_to_primitive_dyn::(array), + LargeUtf8 => boolean_to_utf8_dyn::(array), + LargeBinary => boolean_to_binary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (Utf8, _) => match to_type { + UInt8 => utf8_to_primitive_dyn::(array, to_type, options), + UInt16 => utf8_to_primitive_dyn::(array, to_type, options), + UInt32 => utf8_to_primitive_dyn::(array, to_type, options), + UInt64 => utf8_to_primitive_dyn::(array, to_type, options), + Int8 => utf8_to_primitive_dyn::(array, to_type, options), + Int16 => utf8_to_primitive_dyn::(array, to_type, options), + Int32 => utf8_to_primitive_dyn::(array, to_type, options), + Int64 => utf8_to_primitive_dyn::(array, to_type, options), + Float32 => utf8_to_primitive_dyn::(array, to_type, options), + Float64 => utf8_to_primitive_dyn::(array, to_type, options), + Date32 => utf8_to_date32_dyn::(array), + Date64 => utf8_to_date64_dyn::(array), + LargeUtf8 => Ok(Box::new(utf8_to_large_utf8( + array.as_any().downcast_ref().unwrap(), + ))), + Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, Some(tz)) => { + utf8_to_timestamp_ns_dyn::(array, tz.clone()) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + (LargeUtf8, _) => match to_type { + UInt8 => utf8_to_primitive_dyn::(array, to_type, options), + UInt16 => utf8_to_primitive_dyn::(array, to_type, options), + UInt32 => utf8_to_primitive_dyn::(array, to_type, options), + UInt64 => utf8_to_primitive_dyn::(array, to_type, options), + Int8 => utf8_to_primitive_dyn::(array, to_type, options), + Int16 => utf8_to_primitive_dyn::(array, to_type, options), + Int32 => utf8_to_primitive_dyn::(array, to_type, options), + Int64 => utf8_to_primitive_dyn::(array, to_type, options), + Float32 => utf8_to_primitive_dyn::(array, to_type, options), + Float64 => utf8_to_primitive_dyn::(array, to_type, options), + Date32 => utf8_to_date32_dyn::(array), + Date64 => utf8_to_date64_dyn::(array), + Utf8 => utf8_large_to_utf8(array.as_any().downcast_ref().unwrap()).map(|x| x.boxed()), + LargeBinary => Ok(utf8_to_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + Timestamp(TimeUnit::Nanosecond, None) => utf8_to_naive_timestamp_ns_dyn::(array), + Timestamp(TimeUnit::Nanosecond, Some(tz)) => { + utf8_to_timestamp_ns_dyn::(array, tz.clone()) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, Utf8) => match from_type { + UInt8 => primitive_to_utf8_dyn::(array), + UInt16 => primitive_to_utf8_dyn::(array), + UInt32 => primitive_to_utf8_dyn::(array), + UInt64 => primitive_to_utf8_dyn::(array), + Int8 => primitive_to_utf8_dyn::(array), + Int16 => primitive_to_utf8_dyn::(array), + Int32 => primitive_to_utf8_dyn::(array), + Int64 => primitive_to_utf8_dyn::(array), + Float32 => primitive_to_utf8_dyn::(array), + Float64 => primitive_to_utf8_dyn::(array), + Timestamp(from_unit, Some(tz)) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(timestamp_to_utf8::(from, *from_unit, tz)?)) + }, + Timestamp(from_unit, None) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(naive_timestamp_to_utf8::(from, *from_unit))) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, LargeUtf8) => match from_type { + UInt8 => primitive_to_utf8_dyn::(array), + UInt16 => primitive_to_utf8_dyn::(array), + UInt32 => primitive_to_utf8_dyn::(array), + UInt64 => primitive_to_utf8_dyn::(array), + Int8 => primitive_to_utf8_dyn::(array), + Int16 => primitive_to_utf8_dyn::(array), + Int32 => primitive_to_utf8_dyn::(array), + Int64 => primitive_to_utf8_dyn::(array), + Float32 => primitive_to_utf8_dyn::(array), + Float64 => primitive_to_utf8_dyn::(array), + LargeBinary => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + Timestamp(from_unit, Some(tz)) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(timestamp_to_utf8::(from, *from_unit, tz)?)) + }, + Timestamp(from_unit, None) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(naive_timestamp_to_utf8::(from, *from_unit))) + }, + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (Binary, _) => match to_type { + UInt8 => binary_to_primitive_dyn::(array, to_type, options), + UInt16 => binary_to_primitive_dyn::(array, to_type, options), + UInt32 => binary_to_primitive_dyn::(array, to_type, options), + UInt64 => binary_to_primitive_dyn::(array, to_type, options), + Int8 => binary_to_primitive_dyn::(array, to_type, options), + Int16 => binary_to_primitive_dyn::(array, to_type, options), + Int32 => binary_to_primitive_dyn::(array, to_type, options), + Int64 => binary_to_primitive_dyn::(array, to_type, options), + Float32 => binary_to_primitive_dyn::(array, to_type, options), + Float64 => binary_to_primitive_dyn::(array, to_type, options), + LargeBinary => Ok(Box::new(binary_to_large_binary( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ))), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (LargeBinary, _) => { + match to_type { + UInt8 => binary_to_primitive_dyn::(array, to_type, options), + UInt16 => binary_to_primitive_dyn::(array, to_type, options), + UInt32 => binary_to_primitive_dyn::(array, to_type, options), + UInt64 => binary_to_primitive_dyn::(array, to_type, options), + Int8 => binary_to_primitive_dyn::(array, to_type, options), + Int16 => binary_to_primitive_dyn::(array, to_type, options), + Int32 => binary_to_primitive_dyn::(array, to_type, options), + Int64 => binary_to_primitive_dyn::(array, to_type, options), + Float32 => binary_to_primitive_dyn::(array, to_type, options), + Float64 => binary_to_primitive_dyn::(array, to_type, options), + Binary => { + binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + LargeUtf8 => { + binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .map(|x| x.boxed()) + }, + LargeList(inner) if matches!(inner.data_type, DataType::UInt8) => Ok( + binary_to_list::(array.as_any().downcast_ref().unwrap(), to_type.clone()) + .boxed(), + ), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + } + }, + (FixedSizeBinary(_), _) => match to_type { + Binary => Ok(fixed_size_binary_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + LargeBinary => Ok(fixed_size_binary_binary::( + array.as_any().downcast_ref().unwrap(), + to_type.clone(), + ) + .boxed()), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, Binary) => match from_type { + UInt8 => primitive_to_binary_dyn::(array), + UInt16 => primitive_to_binary_dyn::(array), + UInt32 => primitive_to_binary_dyn::(array), + UInt64 => primitive_to_binary_dyn::(array), + Int8 => primitive_to_binary_dyn::(array), + Int16 => primitive_to_binary_dyn::(array), + Int32 => primitive_to_binary_dyn::(array), + Int64 => primitive_to_binary_dyn::(array), + Float32 => primitive_to_binary_dyn::(array), + Float64 => primitive_to_binary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + (_, LargeBinary) => match from_type { + UInt8 => primitive_to_binary_dyn::(array), + UInt16 => primitive_to_binary_dyn::(array), + UInt32 => primitive_to_binary_dyn::(array), + UInt64 => primitive_to_binary_dyn::(array), + Int8 => primitive_to_binary_dyn::(array), + Int16 => primitive_to_binary_dyn::(array), + Int32 => primitive_to_binary_dyn::(array), + Int64 => primitive_to_binary_dyn::(array), + Float32 => primitive_to_binary_dyn::(array), + Float64 => primitive_to_binary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + }, + + // start numeric casts + (UInt8, UInt16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, UInt32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, UInt64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (UInt64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (UInt64, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (UInt64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int8, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int8, Int16) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int8, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int16, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int16, Int32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int16, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int32, Int64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float32) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int32, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Int64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Int64, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Int64, Decimal(p, s)) => integer_to_decimal_dyn::(array, *p, *s), + + (Float16, Float32) => { + let from = array.as_any().downcast_ref().unwrap(); + Ok(f16_to_f32(from).boxed()) + }, + + (Float32, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float32, Float64) => primitive_to_primitive_dyn::(array, to_type, as_options), + (Float32, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Float64, UInt8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, UInt64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int8) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int16) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Float32) => primitive_to_primitive_dyn::(array, to_type, options), + (Float64, Decimal(p, s)) => float_to_decimal_dyn::(array, *p, *s), + + (Decimal(_, _), UInt8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), UInt64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int8) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int16) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int32) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Int64) => decimal_to_integer_dyn::(array), + (Decimal(_, _), Float32) => decimal_to_float_dyn::(array), + (Decimal(_, _), Float64) => decimal_to_float_dyn::(array), + (Decimal(_, _), Decimal(to_p, to_s)) => decimal_to_decimal_dyn(array, *to_p, *to_s), + // end numeric casts + + // temporal casts + (Int32, Date32) => primitive_to_same_primitive_dyn::(array, to_type), + (Int32, Time32(TimeUnit::Second)) => primitive_to_same_primitive_dyn::(array, to_type), + (Int32, Time32(TimeUnit::Millisecond)) => { + primitive_to_same_primitive_dyn::(array, to_type) + }, + // No support for microsecond/nanosecond with i32 + (Date32, Int32) => primitive_to_same_primitive_dyn::(array, to_type), + (Date32, Int64) => primitive_to_primitive_dyn::(array, to_type, options), + (Time32(_), Int32) => primitive_to_same_primitive_dyn::(array, to_type), + (Int64, Date64) => primitive_to_same_primitive_dyn::(array, to_type), + // No support for second/milliseconds with i64 + (Int64, Time64(TimeUnit::Microsecond)) => { + primitive_to_same_primitive_dyn::(array, to_type) + }, + (Int64, Time64(TimeUnit::Nanosecond)) => { + primitive_to_same_primitive_dyn::(array, to_type) + }, + + (Date64, Int32) => primitive_to_primitive_dyn::(array, to_type, options), + (Date64, Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Time64(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Date32, Date64) => primitive_dyn!(array, date32_to_date64), + (Date64, Date32) => primitive_dyn!(array, date64_to_date32), + (Time32(TimeUnit::Second), Time32(TimeUnit::Millisecond)) => { + primitive_dyn!(array, time32s_to_time32ms) + }, + (Time32(TimeUnit::Millisecond), Time32(TimeUnit::Second)) => { + primitive_dyn!(array, time32ms_to_time32s) + }, + (Time32(from_unit), Time64(to_unit)) => { + primitive_dyn!(array, time32_to_time64, *from_unit, *to_unit) + }, + (Time64(TimeUnit::Microsecond), Time64(TimeUnit::Nanosecond)) => { + primitive_dyn!(array, time64us_to_time64ns) + }, + (Time64(TimeUnit::Nanosecond), Time64(TimeUnit::Microsecond)) => { + primitive_dyn!(array, time64ns_to_time64us) + }, + (Time64(from_unit), Time32(to_unit)) => { + primitive_dyn!(array, time64_to_time32, *from_unit, *to_unit) + }, + (Timestamp(_, _), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + (Int64, Timestamp(_, _)) => primitive_to_same_primitive_dyn::(array, to_type), + (Timestamp(from_unit, _), Timestamp(to_unit, tz)) => { + primitive_dyn!(array, timestamp_to_timestamp, *from_unit, *to_unit, tz) + }, + (Timestamp(from_unit, _), Date32) => primitive_dyn!(array, timestamp_to_date32, *from_unit), + (Timestamp(from_unit, _), Date64) => primitive_dyn!(array, timestamp_to_date64, *from_unit), + + (Int64, Duration(_)) => primitive_to_same_primitive_dyn::(array, to_type), + (Duration(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), + + (Interval(IntervalUnit::DayTime), Interval(IntervalUnit::MonthDayNano)) => { + primitive_dyn!(array, days_ms_to_months_days_ns) + }, + (Interval(IntervalUnit::YearMonth), Interval(IntervalUnit::MonthDayNano)) => { + primitive_dyn!(array, months_to_months_days_ns) + }, + + (_, _) => Err(Error::NotYetImplemented(format!( + "Casting from {from_type:?} to {to_type:?} not supported", + ))), + } +} + +/// Attempts to encode an array into an `ArrayDictionary` with index +/// type K and value (dictionary) type value_type +/// +/// K is the key type +fn cast_to_dictionary( + array: &dyn Array, + dict_value_type: &DataType, + options: CastOptions, +) -> Result> { + let array = cast(array, dict_value_type, options)?; + let array = array.as_ref(); + match *dict_value_type { + DataType::Int8 => primitive_to_dictionary_dyn::(array), + DataType::Int16 => primitive_to_dictionary_dyn::(array), + DataType::Int32 => primitive_to_dictionary_dyn::(array), + DataType::Int64 => primitive_to_dictionary_dyn::(array), + DataType::UInt8 => primitive_to_dictionary_dyn::(array), + DataType::UInt16 => primitive_to_dictionary_dyn::(array), + DataType::UInt32 => primitive_to_dictionary_dyn::(array), + DataType::UInt64 => primitive_to_dictionary_dyn::(array), + DataType::Utf8 => utf8_to_dictionary_dyn::(array), + DataType::LargeUtf8 => utf8_to_dictionary_dyn::(array), + DataType::Binary => binary_to_dictionary_dyn::(array), + DataType::LargeBinary => binary_to_dictionary_dyn::(array), + _ => Err(Error::NotYetImplemented(format!( + "Unsupported output type for dictionary packing: {dict_value_type:?}" + ))), + } +} diff --git a/crates/nano-arrow/src/compute/cast/primitive_to.rs b/crates/nano-arrow/src/compute/cast/primitive_to.rs new file mode 100644 index 000000000000..a83569ee165c --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/primitive_to.rs @@ -0,0 +1,584 @@ +use std::hash::Hash; + +use num_traits::{AsPrimitive, Float, ToPrimitive}; + +use super::CastOptions; +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::compute::arity::unary; +use crate::datatypes::{DataType, IntervalUnit, TimeUnit}; +use crate::error::Result; +use crate::offset::{Offset, Offsets}; +use crate::temporal_conversions::*; +use crate::types::{days_ms, f16, months_days_ns, NativeType}; + +/// Returns a [`BinaryArray`] where every element is the binary representation of the number. +pub fn primitive_to_binary( + from: &PrimitiveArray, +) -> BinaryArray { + let mut values: Vec = Vec::with_capacity(from.len()); + let mut offsets: Vec = Vec::with_capacity(from.len() + 1); + offsets.push(O::default()); + + let mut offset: usize = 0; + + unsafe { + for x in from.values().iter() { + values.reserve(offset + T::FORMATTED_SIZE_DECIMAL); + + let bytes = std::slice::from_raw_parts_mut( + values.as_mut_ptr().add(offset), + values.capacity() - offset, + ); + let len = lexical_core::write_unchecked(*x, bytes).len(); + + offset += len; + offsets.push(O::from_usize(offset).unwrap()); + } + values.set_len(offset); + values.shrink_to_fit(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + BinaryArray::::new( + BinaryArray::::default_data_type(), + offsets.into(), + values.into(), + from.validity().cloned(), + ) + } +} + +pub(super) fn primitive_to_binary_dyn(from: &dyn Array) -> Result> +where + O: Offset, + T: NativeType + lexical_core::ToLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_binary::(from))) +} + +/// Returns a [`BooleanArray`] where every element is different from zero. +/// Validity is preserved. +pub fn primitive_to_boolean( + from: &PrimitiveArray, + to_type: DataType, +) -> BooleanArray { + let iter = from.values().iter().map(|v| *v != T::default()); + let values = Bitmap::from_trusted_len_iter(iter); + + BooleanArray::new(to_type, values, from.validity().cloned()) +} + +pub(super) fn primitive_to_boolean_dyn( + from: &dyn Array, + to_type: DataType, +) -> Result> +where + T: NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_boolean::(from, to_type))) +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the number. +pub fn primitive_to_utf8( + from: &PrimitiveArray, +) -> Utf8Array { + let mut values: Vec = Vec::with_capacity(from.len()); + let mut offsets: Vec = Vec::with_capacity(from.len() + 1); + offsets.push(O::default()); + + let mut offset: usize = 0; + + unsafe { + for x in from.values().iter() { + values.reserve(offset + T::FORMATTED_SIZE_DECIMAL); + + let bytes = std::slice::from_raw_parts_mut( + values.as_mut_ptr().add(offset), + values.capacity() - offset, + ); + let len = lexical_core::write_unchecked(*x, bytes).len(); + + offset += len; + offsets.push(O::from_usize(offset).unwrap()); + } + values.set_len(offset); + values.shrink_to_fit(); + // Safety: offsets _are_ monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }; + Utf8Array::::new_unchecked( + Utf8Array::::default_data_type(), + offsets.into(), + values.into(), + from.validity().cloned(), + ) + } +} + +pub(super) fn primitive_to_utf8_dyn(from: &dyn Array) -> Result> +where + O: Offset, + T: NativeType + lexical_core::ToLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_utf8::(from))) +} + +pub(super) fn primitive_to_primitive_dyn( + from: &dyn Array, + to_type: &DataType, + options: CastOptions, +) -> Result> +where + I: NativeType + num_traits::NumCast + num_traits::AsPrimitive, + O: NativeType + num_traits::NumCast, +{ + let from = from.as_any().downcast_ref::>().unwrap(); + if options.wrapped { + Ok(Box::new(primitive_as_primitive::(from, to_type))) + } else { + Ok(Box::new(primitive_to_primitive::(from, to_type))) + } +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of another physical type via numeric conversion. +pub fn primitive_to_primitive( + from: &PrimitiveArray, + to_type: &DataType, +) -> PrimitiveArray +where + I: NativeType + num_traits::NumCast, + O: NativeType + num_traits::NumCast, +{ + let iter = from + .iter() + .map(|v| v.and_then(|x| num_traits::cast::cast::(*x))); + PrimitiveArray::::from_trusted_len_iter(iter).to(to_type.clone()) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn integer_to_decimal>( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray { + let multiplier = 10_i128.pow(to_scale as u32); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + x.as_().checked_mul(multiplier).and_then(|x| { + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn integer_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> Result> +where + T: NativeType + AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(integer_to_decimal::(from, precision, scale))) +} + +/// Returns a [`PrimitiveArray`] with the casted values. Values are `None` on overflow +pub fn float_to_decimal( + from: &PrimitiveArray, + to_precision: usize, + to_scale: usize, +) -> PrimitiveArray +where + T: NativeType + Float + ToPrimitive, + f64: AsPrimitive, +{ + // 1.2 => 12 + let multiplier: T = (10_f64).powi(to_scale as i32).as_(); + + let min_for_precision = 9_i128 + .saturating_pow(1 + to_precision as u32) + .saturating_neg(); + let max_for_precision = 9_i128.saturating_pow(1 + to_precision as u32); + + let values = from.iter().map(|x| { + x.and_then(|x| { + let x = (*x * multiplier).to_i128().unwrap(); + if x > max_for_precision || x < min_for_precision { + None + } else { + Some(x) + } + }) + }); + + PrimitiveArray::::from_trusted_len_iter(values) + .to(DataType::Decimal(to_precision, to_scale)) +} + +pub(super) fn float_to_decimal_dyn( + from: &dyn Array, + precision: usize, + scale: usize, +) -> Result> +where + T: NativeType + Float + ToPrimitive, + f64: AsPrimitive, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(float_to_decimal::(from, precision, scale))) +} + +/// Cast [`PrimitiveArray`] as a [`PrimitiveArray`] +/// Same as `number as to_number_type` in rust +pub fn primitive_as_primitive( + from: &PrimitiveArray, + to_type: &DataType, +) -> PrimitiveArray +where + I: NativeType + num_traits::AsPrimitive, + O: NativeType, +{ + unary(from, num_traits::AsPrimitive::::as_, to_type.clone()) +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type. +/// This is O(1). +pub fn primitive_to_same_primitive( + from: &PrimitiveArray, + to_type: &DataType, +) -> PrimitiveArray +where + T: NativeType, +{ + PrimitiveArray::::new( + to_type.clone(), + from.values().clone(), + from.validity().cloned(), + ) +} + +/// Cast [`PrimitiveArray`] to a [`PrimitiveArray`] of the same physical type. +/// This is O(1). +pub(super) fn primitive_to_same_primitive_dyn( + from: &dyn Array, + to_type: &DataType, +) -> Result> +where + T: NativeType, +{ + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive_to_same_primitive::(from, to_type))) +} + +pub(super) fn primitive_to_dictionary_dyn( + from: &dyn Array, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + primitive_to_dictionary::(from).map(|x| Box::new(x) as Box) +} + +/// Cast [`PrimitiveArray`] to [`DictionaryArray`]. Also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn primitive_to_dictionary( + from: &PrimitiveArray, +) -> Result> { + let iter = from.iter().map(|x| x.copied()); + let mut array = MutableDictionaryArray::::try_empty(MutablePrimitiveArray::::from( + from.data_type().clone(), + ))?; + array.try_extend(iter)?; + + Ok(array.into()) +} + +/// Get the time unit as a multiple of a second +const fn time_unit_multiple(unit: TimeUnit) -> i64 { + match unit { + TimeUnit::Second => 1, + TimeUnit::Millisecond => MILLISECONDS, + TimeUnit::Microsecond => MICROSECONDS, + TimeUnit::Nanosecond => NANOSECONDS, + } +} + +/// Conversion of dates +pub fn date32_to_date64(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x as i64 * MILLISECONDS_IN_DAY, DataType::Date64) +} + +/// Conversion of dates +pub fn date64_to_date32(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| (x / MILLISECONDS_IN_DAY) as i32, DataType::Date32) +} + +/// Conversion of times +pub fn time32s_to_time32ms(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x * 1000, DataType::Time32(TimeUnit::Millisecond)) +} + +/// Conversion of times +pub fn time32ms_to_time32s(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x / 1000, DataType::Time32(TimeUnit::Second)) +} + +/// Conversion of times +pub fn time64us_to_time64ns(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x * 1000, DataType::Time64(TimeUnit::Nanosecond)) +} + +/// Conversion of times +pub fn time64ns_to_time64us(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x / 1000, DataType::Time64(TimeUnit::Microsecond)) +} + +/// Conversion of timestamp +pub fn timestamp_to_date64(from: &PrimitiveArray, from_unit: TimeUnit) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = MILLISECONDS; + let to_type = DataType::Date64; + + // Scale time_array by (to_size / from_size) using a + // single integer operation, but need to avoid integer + // math rounding down to zero + + match to_size.cmp(&from_size) { + std::cmp::Ordering::Less => unary(from, |x| (x / (from_size / to_size)), to_type), + std::cmp::Ordering::Equal => primitive_to_same_primitive(from, &to_type), + std::cmp::Ordering::Greater => unary(from, |x| (x * (to_size / from_size)), to_type), + } +} + +/// Conversion of timestamp +pub fn timestamp_to_date32(from: &PrimitiveArray, from_unit: TimeUnit) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY; + unary(from, |x| (x / from_size) as i32, DataType::Date32) +} + +/// Conversion of time +pub fn time32_to_time64( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let divisor = to_size / from_size; + unary(from, |x| (x as i64 * divisor), DataType::Time64(to_unit)) +} + +/// Conversion of time +pub fn time64_to_time32( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let divisor = from_size / to_size; + unary(from, |x| (x / divisor) as i32, DataType::Time32(to_unit)) +} + +/// Conversion of timestamp +pub fn timestamp_to_timestamp( + from: &PrimitiveArray, + from_unit: TimeUnit, + to_unit: TimeUnit, + tz: &Option, +) -> PrimitiveArray { + let from_size = time_unit_multiple(from_unit); + let to_size = time_unit_multiple(to_unit); + let to_type = DataType::Timestamp(to_unit, tz.clone()); + // we either divide or multiply, depending on size of each unit + if from_size >= to_size { + unary(from, |x| (x / (from_size / to_size)), to_type) + } else { + unary(from, |x| (x * (to_size / from_size)), to_type) + } +} + +fn timestamp_to_utf8_impl( + from: &PrimitiveArray, + time_unit: TimeUnit, + timezone: T, +) -> Utf8Array +where + T::Offset: std::fmt::Display, +{ + match time_unit { + TimeUnit::Nanosecond => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_ns_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Microsecond => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_us_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Millisecond => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_ms_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Second => { + let iter = from.iter().map(|x| { + x.map(|x| { + let datetime = timestamp_s_to_datetime(*x); + let offset = timezone.offset_from_utc_datetime(&datetime); + chrono::DateTime::::from_naive_utc_and_offset(datetime, offset).to_rfc3339() + }) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + } +} + +#[cfg(feature = "chrono-tz")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] +fn chrono_tz_timestamp_to_utf8( + from: &PrimitiveArray, + time_unit: TimeUnit, + timezone_str: &str, +) -> Result> { + let timezone = parse_offset_tz(timezone_str)?; + Ok(timestamp_to_utf8_impl::( + from, time_unit, timezone, + )) +} + +#[cfg(not(feature = "chrono-tz"))] +fn chrono_tz_timestamp_to_utf8( + _: &PrimitiveArray, + _: TimeUnit, + timezone_str: &str, +) -> Result> { + use crate::error::Error; + Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))) +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. +pub fn timestamp_to_utf8( + from: &PrimitiveArray, + time_unit: TimeUnit, + timezone_str: &str, +) -> Result> { + let timezone = parse_offset(timezone_str); + + if let Ok(timezone) = timezone { + Ok(timestamp_to_utf8_impl::( + from, time_unit, timezone, + )) + } else { + chrono_tz_timestamp_to_utf8(from, time_unit, timezone_str) + } +} + +/// Returns a [`Utf8Array`] where every element is the utf8 representation of the timestamp in the rfc3339 format. +pub fn naive_timestamp_to_utf8( + from: &PrimitiveArray, + time_unit: TimeUnit, +) -> Utf8Array { + match time_unit { + TimeUnit::Nanosecond => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_ns_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Microsecond => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_us_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Millisecond => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_ms_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + TimeUnit::Second => { + let iter = from.iter().map(|x| { + x.copied() + .map(timestamp_s_to_datetime) + .map(|x| x.to_string()) + }); + Utf8Array::from_trusted_len_iter(iter) + }, + } +} + +#[inline] +fn days_ms_to_months_days_ns_scalar(from: days_ms) -> months_days_ns { + months_days_ns::new(0, from.days(), from.milliseconds() as i64 * 1000) +} + +/// Casts [`days_ms`]s to [`months_days_ns`]. This operation is infalible and lossless. +pub fn days_ms_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + days_ms_to_months_days_ns_scalar, + DataType::Interval(IntervalUnit::MonthDayNano), + ) +} + +#[inline] +fn months_to_months_days_ns_scalar(from: i32) -> months_days_ns { + months_days_ns::new(from, 0, 0) +} + +/// Casts months represented as [`i32`]s to [`months_days_ns`]. This operation is infalible and lossless. +pub fn months_to_months_days_ns(from: &PrimitiveArray) -> PrimitiveArray { + unary( + from, + months_to_months_days_ns_scalar, + DataType::Interval(IntervalUnit::MonthDayNano), + ) +} + +/// Casts f16 into f32 +pub fn f16_to_f32(from: &PrimitiveArray) -> PrimitiveArray { + unary(from, |x| x.to_f32(), DataType::Float32) +} diff --git a/crates/nano-arrow/src/compute/cast/utf8_to.rs b/crates/nano-arrow/src/compute/cast/utf8_to.rs new file mode 100644 index 000000000000..9c86ff85da54 --- /dev/null +++ b/crates/nano-arrow/src/compute/cast/utf8_to.rs @@ -0,0 +1,176 @@ +use chrono::Datelike; + +use super::CastOptions; +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::offset::Offset; +use crate::temporal_conversions::{ + utf8_to_naive_timestamp_ns as utf8_to_naive_timestamp_ns_, + utf8_to_timestamp_ns as utf8_to_timestamp_ns_, EPOCH_DAYS_FROM_CE, +}; +use crate::types::NativeType; + +const RFC3339: &str = "%Y-%m-%dT%H:%M:%S%.f%:z"; + +/// Casts a [`Utf8Array`] to a [`PrimitiveArray`], making any uncastable value a Null. +pub fn utf8_to_primitive(from: &Utf8Array, to: &DataType) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from + .iter() + .map(|x| x.and_then::(|x| lexical_core::parse(x.as_bytes()).ok())); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +/// Casts a [`Utf8Array`] to a [`PrimitiveArray`] at best-effort using `lexical_core::parse_partial`, making any uncastable value as zero. +pub fn partial_utf8_to_primitive( + from: &Utf8Array, + to: &DataType, +) -> PrimitiveArray +where + T: NativeType + lexical_core::FromLexical, +{ + let iter = from.iter().map(|x| { + x.and_then::(|x| lexical_core::parse_partial(x.as_bytes()).ok().map(|x| x.0)) + }); + + PrimitiveArray::::from_trusted_len_iter(iter).to(to.clone()) +} + +pub(super) fn utf8_to_primitive_dyn( + from: &dyn Array, + to: &DataType, + options: CastOptions, +) -> Result> +where + T: NativeType + lexical_core::FromLexical, +{ + let from = from.as_any().downcast_ref().unwrap(); + if options.partial { + Ok(Box::new(partial_utf8_to_primitive::(from, to))) + } else { + Ok(Box::new(utf8_to_primitive::(from, to))) + } +} + +/// Casts a [`Utf8Array`] to a Date32 primitive, making any uncastable value a Null. +pub fn utf8_to_date32(from: &Utf8Array) -> PrimitiveArray { + let iter = from.iter().map(|x| { + x.and_then(|x| { + x.parse::() + .ok() + .map(|x| x.num_days_from_ce() - EPOCH_DAYS_FROM_CE) + }) + }); + PrimitiveArray::::from_trusted_len_iter(iter).to(DataType::Date32) +} + +pub(super) fn utf8_to_date32_dyn(from: &dyn Array) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8_to_date32::(from))) +} + +/// Casts a [`Utf8Array`] to a Date64 primitive, making any uncastable value a Null. +pub fn utf8_to_date64(from: &Utf8Array) -> PrimitiveArray { + let iter = from.iter().map(|x| { + x.and_then(|x| { + x.parse::() + .ok() + .map(|x| (x.num_days_from_ce() - EPOCH_DAYS_FROM_CE) as i64 * 86400000) + }) + }); + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Date64) +} + +pub(super) fn utf8_to_date64_dyn(from: &dyn Array) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8_to_date64::(from))) +} + +pub(super) fn utf8_to_dictionary_dyn( + from: &dyn Array, +) -> Result> { + let values = from.as_any().downcast_ref().unwrap(); + utf8_to_dictionary::(values).map(|x| Box::new(x) as Box) +} + +/// Cast [`Utf8Array`] to [`DictionaryArray`], also known as packing. +/// # Errors +/// This function errors if the maximum key is smaller than the number of distinct elements +/// in the array. +pub fn utf8_to_dictionary( + from: &Utf8Array, +) -> Result> { + let mut array = MutableDictionaryArray::>::new(); + array.try_extend(from.iter())?; + + Ok(array.into()) +} + +pub(super) fn utf8_to_naive_timestamp_ns_dyn( + from: &dyn Array, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8_to_naive_timestamp_ns::(from))) +} + +/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting +pub fn utf8_to_naive_timestamp_ns(from: &Utf8Array) -> PrimitiveArray { + utf8_to_naive_timestamp_ns_(from, RFC3339) +} + +pub(super) fn utf8_to_timestamp_ns_dyn( + from: &dyn Array, + timezone: String, +) -> Result> { + let from = from.as_any().downcast_ref().unwrap(); + utf8_to_timestamp_ns::(from, timezone) + .map(Box::new) + .map(|x| x as Box) +} + +/// [`crate::temporal_conversions::utf8_to_timestamp_ns`] applied for RFC3339 formatting +pub fn utf8_to_timestamp_ns( + from: &Utf8Array, + timezone: String, +) -> Result> { + utf8_to_timestamp_ns_(from, RFC3339, timezone) +} + +/// Conversion of utf8 +pub fn utf8_to_large_utf8(from: &Utf8Array) -> Utf8Array { + let data_type = Utf8Array::::default_data_type(); + let validity = from.validity().cloned(); + let values = from.values().clone(); + + let offsets = from.offsets().into(); + // Safety: sound because `values` fulfills the same invariants as `from.values()` + unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) } +} + +/// Conversion of utf8 +pub fn utf8_large_to_utf8(from: &Utf8Array) -> Result> { + let data_type = Utf8Array::::default_data_type(); + let validity = from.validity().cloned(); + let values = from.values().clone(); + let offsets = from.offsets().try_into()?; + + // Safety: sound because `values` fulfills the same invariants as `from.values()` + Ok(unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) }) +} + +/// Conversion to binary +pub fn utf8_to_binary(from: &Utf8Array, to_data_type: DataType) -> BinaryArray { + // Safety: erasure of an invariant is always safe + unsafe { + BinaryArray::::new( + to_data_type, + from.offsets().clone(), + from.values().clone(), + from.validity().cloned(), + ) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/binary.rs b/crates/nano-arrow/src/compute/comparison/binary.rs new file mode 100644 index 000000000000..af87362a7841 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/binary.rs @@ -0,0 +1,238 @@ +//! Comparison functions for [`BinaryArray`] +use super::super::utils::combine_validities; +use crate::array::{BinaryArray, BooleanArray}; +use crate::bitmap::Bitmap; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &BinaryArray, rhs: &BinaryArray, op: F) -> BooleanArray +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + assert_eq!(lhs.len(), rhs.len()); + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values_iter() + .zip(rhs.values_iter()) + .map(|(lhs, rhs)| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`] and scalar using +/// a specified comparison function. +fn compare_op_scalar(lhs: &BinaryArray, rhs: &[u8], op: F) -> BooleanArray +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + let validity = lhs.validity().cloned(); + + let values = lhs.values_iter().map(|lhs| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`]. +/// # Panic +/// iff the arrays do not have the same length. +pub fn eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and include validities in comparison. +/// # Panic +/// iff the arrays do not have the same length. +pub fn eq_and_validity(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a == b); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar. +pub fn eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar and include validities in comparison. +pub fn eq_scalar_and_validity(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a == b); + + finish_eq_validities(out, validity, None) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`]. +/// # Panic +/// iff the arrays do not have the same length. +pub fn neq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`]. +/// # Panic +/// iff the arrays do not have the same length and include validities in comparison. +pub fn neq_and_validity(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + + let out = compare_op(&lhs, &rhs, |a, b| a != b); + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar. +pub fn neq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar and include validities in comparison. +pub fn neq_scalar_and_validity(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a != b); + + finish_neq_validities(out, validity, None) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`]. +pub fn lt(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`] and a scalar. +pub fn lt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`]. +pub fn lt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`] and a scalar. +pub fn lt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`]. +pub fn gt(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`] and a scalar. +pub fn gt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`]. +pub fn gt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a >= b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`] and a scalar. +pub fn gt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a >= b) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_generic, &BinaryArray) -> BooleanArray>( + lhs: Vec<&[u8]>, + rhs: Vec<&[u8]>, + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let rhs = BinaryArray::::from_slice(rhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, &rhs), expected); + } + + fn test_generic_scalar, &[u8]) -> BooleanArray>( + lhs: Vec<&[u8]>, + rhs: &[u8], + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, rhs), expected); + } + + #[test] + fn test_gt_eq() { + test_generic::( + vec![b"arrow", b"datafusion", b"flight", b"parquet"], + vec![b"flight", b"flight", b"flight", b"flight"], + gt_eq, + vec![false, false, true, true], + ) + } + + #[test] + fn test_gt_eq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"datafusion", b"flight", b"parquet"], + b"flight", + gt_eq_scalar, + vec![false, false, true, true], + ) + } + + #[test] + fn test_eq() { + test_generic::( + vec![b"arrow", b"arrow", b"arrow", b"arrow"], + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + eq, + vec![true, false, false, false], + ) + } + + #[test] + fn test_eq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + b"arrow", + eq_scalar, + vec![true, false, false, false], + ) + } + + #[test] + fn test_neq() { + test_generic::( + vec![b"arrow", b"arrow", b"arrow", b"arrow"], + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + neq, + vec![false, true, true, true], + ) + } + + #[test] + fn test_neq_scalar() { + test_generic_scalar::( + vec![b"arrow", b"parquet", b"datafusion", b"flight"], + b"arrow", + neq_scalar, + vec![false, true, true, true], + ) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/boolean.rs b/crates/nano-arrow/src/compute/comparison/boolean.rs new file mode 100644 index 000000000000..6b62f7fc6b00 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/boolean.rs @@ -0,0 +1,172 @@ +//! Comparison functions for [`BooleanArray`] +use super::super::utils::combine_validities; +use crate::array::BooleanArray; +use crate::bitmap::{binary, unary, Bitmap}; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; + +/// Evaluate `op(lhs, rhs)` for [`BooleanArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &BooleanArray, rhs: &BooleanArray, op: F) -> BooleanArray +where + F: Fn(u64, u64) -> u64, +{ + assert_eq!(lhs.len(), rhs.len()); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = binary(lhs.values(), rhs.values(), op); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Evaluate `op(left, right)` for [`BooleanArray`] and scalar using +/// a specified comparison function. +pub fn compare_op_scalar(lhs: &BooleanArray, rhs: bool, op: F) -> BooleanArray +where + F: Fn(u64, u64) -> u64, +{ + let rhs = if rhs { !0 } else { 0 }; + + let values = unary(lhs.values(), |x| op(x, rhs)); + BooleanArray::new(DataType::Boolean, values, lhs.validity().cloned()) +} + +/// Perform `lhs == rhs` operation on two [`BooleanArray`]s. +pub fn eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| !(a ^ b)) +} + +/// Perform `lhs == rhs` operation on two [`BooleanArray`]s and include validities in comparison. +pub fn eq_and_validity(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| !(a ^ b)); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs == rhs` operation on a [`BooleanArray`] and a scalar value. +pub fn eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + lhs.clone() + } else { + compare_op_scalar(lhs, rhs, |a, _| !a) + } +} + +/// Perform `lhs == rhs` operation on a [`BooleanArray`] and a scalar value and include validities in comparison. +pub fn eq_scalar_and_validity(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + if rhs { + finish_eq_validities(lhs, validity, None) + } else { + let lhs = lhs.with_validity(None); + + let out = compare_op_scalar(&lhs, rhs, |a, _| !a); + + finish_eq_validities(out, validity, None) + } +} + +/// `lhs != rhs` for [`BooleanArray`] +pub fn neq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a ^ b) +} + +/// `lhs != rhs` for [`BooleanArray`] and include validities in comparison. +pub fn neq_and_validity(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a ^ b); + + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `left != right` operation on an array and a scalar value. +pub fn neq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + eq_scalar(lhs, !rhs) +} + +/// Perform `left != right` operation on an array and a scalar value. +pub fn neq_scalar_and_validity(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = eq_scalar(&lhs, !rhs); + finish_neq_validities(out, validity, None) +} + +/// Perform `left < right` operation on two arrays. +pub fn lt(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| !a & b) +} + +/// Perform `left < right` operation on an array and a scalar value. +pub fn lt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + compare_op_scalar(lhs, rhs, |a, _| !a) + } else { + BooleanArray::new( + DataType::Boolean, + Bitmap::new_zeroed(lhs.len()), + lhs.validity().cloned(), + ) + } +} + +/// Perform `left <= right` operation on two arrays. +pub fn lt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| !a | b) +} + +/// Perform `left <= right` operation on an array and a scalar value. +/// Null values are less than non-null values. +pub fn lt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + let all_ones = !0; + compare_op_scalar(lhs, rhs, |_, _| all_ones) + } else { + compare_op_scalar(lhs, rhs, |a, _| !a) + } +} + +/// Perform `left > right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a & !b) +} + +/// Perform `left > right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + BooleanArray::new( + DataType::Boolean, + Bitmap::new_zeroed(lhs.len()), + lhs.validity().cloned(), + ) + } else { + lhs.clone() + } +} + +/// Perform `left >= right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt_eq(lhs: &BooleanArray, rhs: &BooleanArray) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a | !b) +} + +/// Perform `left >= right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_eq_scalar(lhs: &BooleanArray, rhs: bool) -> BooleanArray { + if rhs { + lhs.clone() + } else { + let all_ones = !0; + compare_op_scalar(lhs, rhs, |_, _| all_ones) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/mod.rs b/crates/nano-arrow/src/compute/comparison/mod.rs new file mode 100644 index 000000000000..96627ef2a5e1 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/mod.rs @@ -0,0 +1,613 @@ +//! Contains comparison operators +//! +//! The module contains functions that compare either an [`Array`] and a [`Scalar`] +//! or two [`Array`]s (of the same [`DataType`]). The scalar-oriented functions are +//! suffixed with `_scalar`. +//! +//! The functions are organized in two variants: +//! * statically typed +//! * dynamically typed +//! The statically typed are available under each module of this module (e.g. [`primitive::eq`], [`primitive::lt_scalar`]) +//! The dynamically typed are available in this module (e.g. [`eq`] or [`lt_scalar`]). +//! +//! # Examples +//! +//! Compare two [`PrimitiveArray`]s: +//! ``` +//! use arrow2::array::{BooleanArray, PrimitiveArray}; +//! use arrow2::compute::comparison::primitive::gt; +//! +//! let array1 = PrimitiveArray::::from([Some(1), None, Some(2)]); +//! let array2 = PrimitiveArray::::from([Some(1), Some(3), Some(1)]); +//! let result = gt(&array1, &array2); +//! assert_eq!(result, BooleanArray::from([Some(false), None, Some(true)])); +//! ``` +//! +//! Compare two dynamically-typed [`Array`]s (trait objects): +//! ``` +//! use arrow2::array::{Array, BooleanArray, PrimitiveArray}; +//! use arrow2::compute::comparison::eq; +//! +//! let array1: &dyn Array = &PrimitiveArray::::from(&[Some(10.0), None, Some(20.0)]); +//! let array2: &dyn Array = &PrimitiveArray::::from(&[Some(10.0), None, Some(10.0)]); +//! let result = eq(array1, array2); +//! assert_eq!(result, BooleanArray::from([Some(true), None, Some(false)])); +//! ``` +//! +//! Compare (not equal) a [`Utf8Array`] to a word: +//! ``` +//! use arrow2::array::{BooleanArray, Utf8Array}; +//! use arrow2::compute::comparison::utf8::neq_scalar; +//! +//! let array = Utf8Array::::from([Some("compute"), None, Some("compare")]); +//! let result = neq_scalar(&array, "compare"); +//! assert_eq!(result, BooleanArray::from([Some(true), None, Some(false)])); +//! ``` + +use crate::array::*; +use crate::datatypes::{DataType, IntervalUnit}; +use crate::scalar::*; + +pub mod binary; +pub mod boolean; +pub mod primitive; +pub mod utf8; + +mod simd; +pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; + +use super::take::take_boolean; +use crate::bitmap::{binary, Bitmap}; +use crate::compute; + +macro_rules! match_eq_ord {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::i256; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => todo!(), + MonthDayNano => todo!(), + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float16 => todo!(), + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +macro_rules! match_eq {( + $key_type:expr, | $_:tt $T:ident | $($body:tt)* +) => ({ + macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )} + use crate::datatypes::PrimitiveType::*; + use crate::types::{days_ms, months_days_ns, f16, i256}; + match $key_type { + Int8 => __with_ty__! { i8 }, + Int16 => __with_ty__! { i16 }, + Int32 => __with_ty__! { i32 }, + Int64 => __with_ty__! { i64 }, + Int128 => __with_ty__! { i128 }, + Int256 => __with_ty__! { i256 }, + DaysMs => __with_ty__! { days_ms }, + MonthDayNano => __with_ty__! { months_days_ns }, + UInt8 => __with_ty__! { u8 }, + UInt16 => __with_ty__! { u16 }, + UInt32 => __with_ty__! { u32 }, + UInt64 => __with_ty__! { u64 }, + Float16 => __with_ty__! { f16 }, + Float32 => __with_ty__! { f32 }, + Float64 => __with_ty__! { f64 }, + } +})} + +macro_rules! compare { + ($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{ + let lhs = $lhs; + let rhs = $rhs; + assert_eq!( + lhs.data_type().to_logical_type(), + rhs.data_type().to_logical_type() + ); + + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + boolean::$op(lhs, rhs) + }, + Primitive(primitive) => $p!(primitive, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + primitive::$op::<$T>(lhs, rhs) + }), + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + utf8::$op::(lhs, rhs) + }, + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::$op::(lhs, rhs) + }, + _ => todo!( + "Comparison between {:?} are not yet supported", + lhs.data_type() + ), + } + }}; +} + +/// `==` between two [`Array`]s. +/// Use [`can_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, eq, match_eq) +} + +/// `==` between two [`Array`]s and includes validities in comparison. +/// Use [`can_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn eq_and_validity(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, eq_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`eq`]. +pub fn can_eq(data_type: &DataType) -> bool { + can_partial_eq(data_type) +} + +/// `!=` between two [`Array`]s. +/// Use [`can_neq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, neq, match_eq) +} + +/// `!=` between two [`Array`]s and includes validities in comparison. +/// Use [`can_neq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn neq_and_validity(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, neq_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`neq`]. +pub fn can_neq(data_type: &DataType) -> bool { + can_partial_eq(data_type) +} + +/// `<` between two [`Array`]s. +/// Use [`can_lt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, lt, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`lt`]. +pub fn can_lt(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +/// `<=` between two [`Array`]s. +/// Use [`can_lt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, lt_eq, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`lt`]. +pub fn can_lt_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +/// `>` between two [`Array`]s. +/// Use [`can_gt`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, gt, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`gt`]. +pub fn can_gt(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +/// `>=` between two [`Array`]s. +/// Use [`can_gt_eq`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * the arrays do not have have the same logical type +/// * the arrays do not have the same length +/// * the operation is not supported for the logical type +pub fn gt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray { + compare!(lhs, rhs, gt_eq, match_eq_ord) +} + +/// Returns whether a [`DataType`] is comparable is supported by [`gt_eq`]. +pub fn can_gt_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) +} + +macro_rules! compare_scalar { + ($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{ + let lhs = $lhs; + let rhs = $rhs; + assert_eq!( + lhs.data_type().to_logical_type(), + rhs.data_type().to_logical_type() + ); + if !rhs.is_valid() { + return BooleanArray::new_null(DataType::Boolean, lhs.len()); + } + + use crate::datatypes::PhysicalType::*; + match lhs.data_type().to_physical_type() { + Boolean => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::().unwrap(); + // validity checked above + boolean::$op(lhs, rhs.value().unwrap()) + }, + Primitive(primitive) => $p!(primitive, |$T| { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + primitive::$op::<$T>(lhs, rhs.value().unwrap()) + }), + LargeUtf8 => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + utf8::$op::(lhs, rhs.value().unwrap()) + }, + LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref::>().unwrap(); + binary::$op::(lhs, rhs.value().unwrap()) + }, + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let lhs = lhs.as_any().downcast_ref::>().unwrap(); + let values = $op(lhs.values().as_ref(), rhs); + + take_boolean(&values, lhs.keys()) + }) + }, + _ => todo!("Comparisons of {:?} are not yet supported", lhs.data_type()), + } + }}; +} + +/// `==` between an [`Array`] and a [`Scalar`]. +/// Use [`can_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, eq_scalar, match_eq) +} + +/// `==` between an [`Array`] and a [`Scalar`] and includes validities in comparison. +/// Use [`can_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn eq_scalar_and_validity(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, eq_scalar_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is supported by [`eq_scalar`]. +pub fn can_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_scalar(data_type) +} + +/// `!=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_neq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, neq_scalar, match_eq) +} + +/// `!=` between an [`Array`] and a [`Scalar`] and includes validities in comparison. +/// Use [`can_neq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn neq_scalar_and_validity(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, neq_scalar_and_validity, match_eq) +} + +/// Returns whether a [`DataType`] is supported by [`neq_scalar`]. +pub fn can_neq_scalar(data_type: &DataType) -> bool { + can_partial_eq_scalar(data_type) +} + +/// `<` between an [`Array`] and a [`Scalar`]. +/// Use [`can_lt_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, lt_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`lt_scalar`]. +pub fn can_lt_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +/// `<=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_lt_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, lt_eq_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`lt_eq_scalar`]. +pub fn can_lt_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +/// `>` between an [`Array`] and a [`Scalar`]. +/// Use [`can_gt_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, gt_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`gt_scalar`]. +pub fn can_gt_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +/// `>=` between an [`Array`] and a [`Scalar`]. +/// Use [`can_gt_eq_scalar`] to check whether the operation is valid +/// # Panic +/// Panics iff either: +/// * they do not have have the same logical type +/// * the operation is not supported for the logical type +pub fn gt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray { + compare_scalar!(lhs, rhs, gt_eq_scalar, match_eq_ord) +} + +/// Returns whether a [`DataType`] is supported by [`gt_eq_scalar`]. +pub fn can_gt_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) +} + +// The list of operations currently supported. +fn can_partial_eq_and_ord_scalar(data_type: &DataType) -> bool { + if let DataType::Dictionary(_, values, _) = data_type.to_logical_type() { + return can_partial_eq_and_ord_scalar(values.as_ref()); + } + can_partial_eq_and_ord(data_type) +} + +// The list of operations currently supported. +fn can_partial_eq_and_ord(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) + | DataType::Int64 + | DataType::Timestamp(_, _) + | DataType::Date64 + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Decimal(_, _) + | DataType::Binary + | DataType::LargeBinary + ) +} + +// The list of operations currently supported. +fn can_partial_eq(data_type: &DataType) -> bool { + can_partial_eq_and_ord(data_type) + || matches!( + data_type.to_logical_type(), + DataType::Float16 + | DataType::Interval(IntervalUnit::DayTime) + | DataType::Interval(IntervalUnit::MonthDayNano) + ) +} + +// The list of operations currently supported. +fn can_partial_eq_scalar(data_type: &DataType) -> bool { + can_partial_eq_and_ord_scalar(data_type) + || matches!( + data_type.to_logical_type(), + DataType::Interval(IntervalUnit::DayTime) + | DataType::Interval(IntervalUnit::MonthDayNano) + ) +} + +/// Utility for low level end users that implement their own comparison functions +/// A comparison on the data column can be applied on masked out values +/// This function will correct equality for the validities. +pub fn finish_eq_validities( + output_without_validities: BooleanArray, + validity_lhs: Option, + validity_rhs: Option, +) -> BooleanArray { + match (validity_lhs, validity_rhs) { + (None, None) => output_without_validities, + (Some(lhs), None) => compute::boolean::and( + &BooleanArray::new(DataType::Boolean, lhs, None), + &output_without_validities, + ), + (None, Some(rhs)) => compute::boolean::and( + &output_without_validities, + &BooleanArray::new(DataType::Boolean, rhs, None), + ), + (Some(lhs), Some(rhs)) => { + let lhs_validity_unset_bits = lhs.unset_bits(); + let rhs_validity_unset_bits = rhs.unset_bits(); + + // this branch is a bit more complicated as both arrays can have masked out values + // these masked out values might differ and lead to a `eq == false` that has to + // be corrected as both should be `null == null = true` + + let lhs = BooleanArray::new(DataType::Boolean, lhs, None); + let rhs = BooleanArray::new(DataType::Boolean, rhs, None); + let eq_validities = compute::comparison::boolean::eq(&lhs, &rhs); + + // validity_bits are equal AND values are equal + let equal = compute::boolean::and(&output_without_validities, &eq_validities); + + match (lhs_validity_unset_bits, rhs_validity_unset_bits) { + // there is at least one side with all values valid + // so we don't have to correct. + (0, _) | (_, 0) => equal, + _ => { + // we use the binary kernel here to save allocations + // and apply `!(lhs | rhs)` in one step + let both_sides_invalid = + compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { + binary(lhs, rhs, |lhs, rhs| !(lhs | rhs)) + }); + // this still might include incorrect masked out values + // under the validity bits, so we must correct for that + + // if not all true, e.g. at least one is set. + // then we propagate that null as `true` in equality + if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() { + compute::boolean::or(&equal, &both_sides_invalid) + } else { + equal + } + }, + } + }, + } +} + +/// Utility for low level end users that implement their own comparison functions +/// A comparison on the data column can be applied on masked out values +/// This function will correct non-equality for the validities. +pub fn finish_neq_validities( + output_without_validities: BooleanArray, + validity_lhs: Option, + validity_rhs: Option, +) -> BooleanArray { + match (validity_lhs, validity_rhs) { + (None, None) => output_without_validities, + (Some(lhs), None) => { + let lhs_negated = + compute::boolean::not(&BooleanArray::new(DataType::Boolean, lhs, None)); + compute::boolean::or(&lhs_negated, &output_without_validities) + }, + (None, Some(rhs)) => { + let rhs_negated = + compute::boolean::not(&BooleanArray::new(DataType::Boolean, rhs, None)); + compute::boolean::or(&output_without_validities, &rhs_negated) + }, + (Some(lhs), Some(rhs)) => { + let lhs_validity_unset_bits = lhs.unset_bits(); + let rhs_validity_unset_bits = rhs.unset_bits(); + + // this branch is a bit more complicated as both arrays can have masked out values + // these masked out values might differ and lead to a `neq == true` that has to + // be corrected as both should be `null != null = false` + let lhs = BooleanArray::new(DataType::Boolean, lhs, None); + let rhs = BooleanArray::new(DataType::Boolean, rhs, None); + let neq_validities = compute::comparison::boolean::neq(&lhs, &rhs); + + // validity_bits are not equal OR values not equal + let or = compute::boolean::or(&output_without_validities, &neq_validities); + + match (lhs_validity_unset_bits, rhs_validity_unset_bits) { + // there is at least one side with all values valid + // so we don't have to correct. + (0, _) | (_, 0) => or, + _ => { + // we use the binary kernel here to save allocations + // and apply `!(lhs | rhs)` in one step + let both_sides_invalid = + compute::boolean::binary_boolean_kernel(&lhs, &rhs, |lhs, rhs| { + binary(lhs, rhs, |lhs, rhs| !(lhs | rhs)) + }); + // this still might include incorrect masked out values + // under the validity bits, so we must correct for that + + // if not all true, e.g. at least one is set. + // then we propagate that null as `false` as the nulls are equal + if both_sides_invalid.values().unset_bits() != both_sides_invalid.len() { + // we use the `binary` kernel directly to save allocations + // and apply `lhs & !rhs)` in one shot. + + compute::boolean::binary_boolean_kernel( + &or, + &both_sides_invalid, + |lhs, rhs| binary(lhs, rhs, |lhs, rhs| (lhs & !rhs)), + ) + } else { + or + } + }, + } + }, + } +} diff --git a/crates/nano-arrow/src/compute/comparison/primitive.rs b/crates/nano-arrow/src/compute/comparison/primitive.rs new file mode 100644 index 000000000000..5ecda063cd22 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/primitive.rs @@ -0,0 +1,590 @@ +//! Comparison functions for [`PrimitiveArray`] +use super::super::utils::combine_validities; +use super::simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; +use crate::array::{BooleanArray, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; +use crate::types::NativeType; + +pub(crate) fn compare_values_op(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + assert_eq!(lhs.len(), rhs.len()); + + let lhs_chunks_iter = lhs.chunks_exact(8); + let lhs_remainder = lhs_chunks_iter.remainder(); + let rhs_chunks_iter = rhs.chunks_exact(8); + let rhs_remainder = rhs_chunks_iter.remainder(); + + let mut values = Vec::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.zip(rhs_chunks_iter).map(|(lhs, rhs)| { + let lhs = T::Simd::from_chunk(lhs); + let rhs = T::Simd::from_chunk(rhs); + op(lhs, rhs) + }); + values.extend(iterator); + + if !lhs_remainder.is_empty() { + let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); + let rhs = T::Simd::from_incomplete_chunk(rhs_remainder, T::default()); + values.push(op(lhs, rhs)) + }; + MutableBitmap::from_vec(values, lhs.len()) +} + +pub(crate) fn compare_values_op_scalar(lhs: &[T], rhs: T, op: F) -> MutableBitmap +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + let rhs = T::Simd::from_chunk(&[rhs; 8]); + + let lhs_chunks_iter = lhs.chunks_exact(8); + let lhs_remainder = lhs_chunks_iter.remainder(); + + let mut values = Vec::with_capacity((lhs.len() + 7) / 8); + let iterator = lhs_chunks_iter.map(|lhs| { + let lhs = T::Simd::from_chunk(lhs); + op(lhs, rhs) + }); + values.extend(iterator); + + if !lhs_remainder.is_empty() { + let lhs = T::Simd::from_incomplete_chunk(lhs_remainder, T::default()); + values.push(op(lhs, rhs)) + }; + + MutableBitmap::from_vec(values, lhs.len()) +} + +/// Evaluate `op(lhs, rhs)` for [`PrimitiveArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &PrimitiveArray, rhs: &PrimitiveArray, op: F) -> BooleanArray +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = compare_values_op(lhs.values(), rhs.values(), op); + + BooleanArray::new(DataType::Boolean, values.into(), validity) +} + +/// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using +/// a specified comparison function. +pub fn compare_op_scalar(lhs: &PrimitiveArray, rhs: T, op: F) -> BooleanArray +where + T: NativeType + Simd8, + F: Fn(T::Simd, T::Simd) -> u8, +{ + let validity = lhs.validity().cloned(); + + let values = compare_values_op_scalar(lhs.values(), rhs, op); + + BooleanArray::new(DataType::Boolean, values.into(), validity) +} + +/// Perform `lhs == rhs` operation on two arrays. +pub fn eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op(lhs, rhs, |a, b| a.eq(b)) +} + +/// Perform `lhs == rhs` operation on two arrays and include validities in comparison. +pub fn eq_and_validity(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a.eq(b)); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `left == right` operation on an array and a scalar value. +pub fn eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op_scalar(lhs, rhs, |a, b| a.eq(b)) +} + +/// Perform `left == right` operation on an array and a scalar value and include validities in comparison. +pub fn eq_scalar_and_validity(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a.eq(b)); + + finish_eq_validities(out, validity, None) +} + +/// Perform `left != right` operation on two arrays. +pub fn neq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op(lhs, rhs, |a, b| a.neq(b)) +} + +/// Perform `left != right` operation on two arrays and include validities in comparison. +pub fn neq_and_validity(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a.neq(b)); + + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `left != right` operation on an array and a scalar value. +pub fn neq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + compare_op_scalar(lhs, rhs, |a, b| a.neq(b)) +} + +/// Perform `left != right` operation on an array and a scalar value and include validities in comparison. +pub fn neq_scalar_and_validity(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialEq, +{ + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a.neq(b)); + + finish_neq_validities(out, validity, None) +} + +/// Perform `left < right` operation on two arrays. +pub fn lt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.lt(b)) +} + +/// Perform `left < right` operation on an array and a scalar value. +pub fn lt_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.lt(b)) +} + +/// Perform `left <= right` operation on two arrays. +pub fn lt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.lt_eq(b)) +} + +/// Perform `left <= right` operation on an array and a scalar value. +/// Null values are less than non-null values. +pub fn lt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b)) +} + +/// Perform `left > right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.gt(b)) +} + +/// Perform `left > right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.gt(b)) +} + +/// Perform `left >= right` operation on two arrays. Non-null values are greater than null +/// values. +pub fn gt_eq(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op(lhs, rhs, |a, b| a.gt_eq(b)) +} + +/// Perform `left >= right` operation on an array and a scalar value. +/// Non-null values are greater than null values. +pub fn gt_eq_scalar(lhs: &PrimitiveArray, rhs: T) -> BooleanArray +where + T: NativeType + Simd8, + T::Simd: Simd8PartialOrd, +{ + compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b)) +} + +// disable wrapping inside literal vectors used for test data and assertions +#[rustfmt::skip::macros(vec)] +#[cfg(test)] +mod tests { + use super::*; + use crate::array::{Int64Array, Int8Array}; + + /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. + /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>`. + /// `EXPECTED` can be either `Vec` or `Vec>`. + /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`. + macro_rules! cmp_i64 { + ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { + let a = Int64Array::from_slice($A_VEC); + let b = Int64Array::from_slice($B_VEC); + let c = $KERNEL(&a, &b); + assert_eq!(BooleanArray::from_slice($EXPECTED), c); + }; + } + + macro_rules! cmp_i64_options { + ($KERNEL:ident, $A_VEC:expr, $B_VEC:expr, $EXPECTED:expr) => { + let a = Int64Array::from($A_VEC); + let b = Int64Array::from($B_VEC); + let c = $KERNEL(&a, &b); + assert_eq!(BooleanArray::from($EXPECTED), c); + }; + } + + /// Evaluate `KERNEL` with one vectors and one scalar as inputs and assert against the expected output. + /// `A_VEC` can be of type `Vec` or `Vec>`. + /// `EXPECTED` can be either `Vec` or `Vec>`. + /// The main reason for this macro is that inputs and outputs align nicely after `cargo fmt`. + macro_rules! cmp_i64_scalar_options { + ($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => { + let a = Int64Array::from($A_VEC); + let c = $KERNEL(&a, $B); + assert_eq!(BooleanArray::from($EXPECTED), c); + }; + } + + macro_rules! cmp_i64_scalar { + ($KERNEL:ident, $A_VEC:expr, $B:literal, $EXPECTED:expr) => { + let a = Int64Array::from_slice($A_VEC); + let c = $KERNEL(&a, $B); + assert_eq!(BooleanArray::from_slice($EXPECTED), c); + }; + } + + #[test] + fn test_primitive_array_eq() { + cmp_i64!( + eq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, false, false, false, false, true, false, false] + ); + } + + #[test] + fn test_primitive_array_eq_scalar() { + cmp_i64_scalar!( + eq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![false, false, true, false, false, false, false, true, false, false] + ); + } + + #[test] + fn test_primitive_array_eq_with_slice() { + let a = Int64Array::from_slice([6, 7, 8, 8, 10]); + let mut b = Int64Array::from_slice([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + b.slice(5, 5); + let d = eq(&b, &a); + assert_eq!(d, BooleanArray::from_slice([true, true, true, false, true])); + } + + #[test] + fn test_primitive_array_neq() { + cmp_i64!( + neq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![true, true, false, true, true, true, true, false, true, true] + ); + } + + #[test] + fn test_primitive_array_neq_scalar() { + cmp_i64_scalar!( + neq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![true, true, false, true, true, true, true, false, true, true] + ); + } + + #[test] + fn test_primitive_array_lt() { + cmp_i64!( + lt, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, false, true, true, false, false, false, true, true] + ); + } + + #[test] + fn test_primitive_array_lt_scalar() { + cmp_i64_scalar!( + lt_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![true, true, false, false, false, true, true, false, false, false] + ); + } + + #[test] + fn test_primitive_array_lt_nulls() { + cmp_i64_options!( + lt, + &[None, None, Some(1), Some(1), None, None, Some(2), Some(2),], + &[None, Some(1), None, Some(1), None, Some(3), None, Some(3),], + vec![None, None, None, Some(false), None, None, None, Some(true)] + ); + } + + #[test] + fn test_primitive_array_lt_scalar_nulls() { + cmp_i64_scalar_options!( + lt_scalar, + &[None, Some(1), Some(2), Some(3), None, Some(1), Some(2), Some(3), Some(2), None], + 2, + vec![None, Some(true), Some(false), Some(false), None, Some(true), Some(false), Some(false), Some(false), None] + ); + } + + #[test] + fn test_primitive_array_lt_eq() { + cmp_i64!( + lt_eq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![false, false, true, true, true, false, false, true, true, true] + ); + } + + #[test] + fn test_primitive_array_lt_eq_scalar() { + cmp_i64_scalar!( + lt_eq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![true, true, true, false, false, true, true, true, false, false] + ); + } + + #[test] + fn test_primitive_array_lt_eq_nulls() { + cmp_i64_options!( + lt_eq, + &[ + None, + None, + Some(1), + None, + None, + Some(1), + None, + None, + Some(1) + ], + &[ + None, + Some(1), + Some(0), + None, + Some(1), + Some(2), + None, + None, + Some(3) + ], + vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)] + ); + } + + #[test] + fn test_primitive_array_lt_eq_scalar_nulls() { + cmp_i64_scalar_options!( + lt_eq_scalar, + &[None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)], + 1, + vec![None, Some(true), Some(false), None, Some(true), Some(false), None, Some(true), Some(false)] + ); + } + + #[test] + fn test_primitive_array_gt() { + cmp_i64!( + gt, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![true, true, false, false, false, true, true, false, false, false] + ); + } + + #[test] + fn test_primitive_array_gt_scalar() { + cmp_i64_scalar!( + gt_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![false, false, false, true, true, false, false, false, true, true] + ); + } + + #[test] + fn test_primitive_array_gt_nulls() { + cmp_i64_options!( + gt, + &[ + None, + None, + Some(1), + None, + None, + Some(2), + None, + None, + Some(3) + ], + &[ + None, + Some(1), + Some(1), + None, + Some(1), + Some(1), + None, + Some(1), + Some(1) + ], + vec![None, None, Some(false), None, None, Some(true), None, None, Some(true)] + ); + } + + #[test] + fn test_primitive_array_gt_scalar_nulls() { + cmp_i64_scalar_options!( + gt_scalar, + &[None, Some(1), Some(2), None, Some(1), Some(2), None, Some(1), Some(2)], + 1, + vec![None, Some(false), Some(true), None, Some(false), Some(true), None, Some(false), Some(true)] + ); + } + + #[test] + fn test_primitive_array_gt_eq() { + cmp_i64!( + gt_eq, + &[8, 8, 8, 8, 8, 8, 8, 8, 8, 8], + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + vec![true, true, true, false, false, true, true, true, false, false] + ); + } + + #[test] + fn test_primitive_array_gt_eq_scalar() { + cmp_i64_scalar!( + gt_eq_scalar, + &[6, 7, 8, 9, 10, 6, 7, 8, 9, 10], + 8, + vec![false, false, true, true, true, false, false, true, true, true] + ); + } + + #[test] + fn test_primitive_array_gt_eq_nulls() { + cmp_i64_options!( + gt_eq, + vec![None, None, Some(1), None, Some(1), Some(2), None, None, Some(1)], + vec![None, Some(1), None, None, Some(1), Some(1), None, Some(2), Some(2)], + vec![None, None, None, None, Some(true), Some(true), None, None, Some(false)] + ); + } + + #[test] + fn test_primitive_array_gt_eq_scalar_nulls() { + cmp_i64_scalar_options!( + gt_eq_scalar, + vec![None, Some(1), Some(2), None, Some(2), Some(3), None, Some(3), Some(4)], + 2, + vec![None, Some(false), Some(true), None, Some(true), Some(true), None, Some(true), Some(true)] + ); + } + + #[test] + fn test_primitive_array_compare_slice() { + let mut a = (0..100).map(Some).collect::>(); + a.slice(50, 50); + let mut b = (100..200).map(Some).collect::>(); + b.slice(50, 50); + let actual = lt(&a, &b); + let expected: BooleanArray = (0..50).map(|_| Some(true)).collect(); + assert_eq!(expected, actual); + } + + #[test] + fn test_primitive_array_compare_scalar_slice() { + let mut a = (0..100).map(Some).collect::>(); + a.slice(50, 50); + let actual = lt_scalar(&a, 200); + let expected: BooleanArray = (0..50).map(|_| Some(true)).collect(); + assert_eq!(expected, actual); + } + + #[test] + fn test_length_of_result_buffer() { + // `item_count` is chosen to not be a multiple of 64. + const ITEM_COUNT: usize = 130; + + let array_a = Int8Array::from_slice([1; ITEM_COUNT]); + let array_b = Int8Array::from_slice([2; ITEM_COUNT]); + let expected = BooleanArray::from_slice([false; ITEM_COUNT]); + let result = gt_eq(&array_a, &array_b); + + assert_eq!(result, expected) + } +} diff --git a/crates/nano-arrow/src/compute/comparison/simd/mod.rs b/crates/nano-arrow/src/compute/comparison/simd/mod.rs new file mode 100644 index 000000000000..30d9773cd4c9 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/simd/mod.rs @@ -0,0 +1,133 @@ +use crate::types::NativeType; + +/// [`NativeType`] that supports a representation of 8 lanes +pub trait Simd8: NativeType { + /// The 8 lane representation of `Self` + type Simd: Simd8Lanes; +} + +/// Trait declaring an 8-lane multi-data. +pub trait Simd8Lanes: Copy { + /// loads a complete chunk + fn from_chunk(v: &[T]) -> Self; + /// loads an incomplete chunk, filling the remaining items with `remaining`. + fn from_incomplete_chunk(v: &[T], remaining: T) -> Self; +} + +/// Trait implemented by implementors of [`Simd8Lanes`] whose [`Simd8`] implements [PartialEq]. +pub trait Simd8PartialEq: Copy { + /// Equal + fn eq(self, other: Self) -> u8; + /// Not equal + fn neq(self, other: Self) -> u8; +} + +/// Trait implemented by implementors of [`Simd8Lanes`] whose [`Simd8`] implements [PartialOrd]. +pub trait Simd8PartialOrd: Copy { + /// Less than or equal to + fn lt_eq(self, other: Self) -> u8; + /// Less than + fn lt(self, other: Self) -> u8; + /// Greater than + fn gt(self, other: Self) -> u8; + /// Greater than or equal to + fn gt_eq(self, other: Self) -> u8; +} + +#[inline] +pub(super) fn set bool>(lhs: [T; 8], rhs: [T; 8], op: F) -> u8 { + let mut byte = 0u8; + lhs.iter() + .zip(rhs.iter()) + .enumerate() + .for_each(|(i, (lhs, rhs))| { + byte |= if op(*lhs, *rhs) { 1 << i } else { 0 }; + }); + byte +} + +/// Types that implement Simd8 +macro_rules! simd8_native { + ($type:ty) => { + impl Simd8 for $type { + type Simd = [$type; 8]; + } + + impl Simd8Lanes<$type> for [$type; 8] { + #[inline] + fn from_chunk(v: &[$type]) -> Self { + v.try_into().unwrap() + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; 8]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + a + } + } + }; +} + +/// Types that implement PartialEq +macro_rules! simd8_native_partial_eq { + ($type:ty) => { + impl Simd8PartialEq for [$type; 8] { + #[inline] + fn eq(self, other: Self) -> u8 { + set(self, other, |x, y| x == y) + } + + #[inline] + fn neq(self, other: Self) -> u8 { + #[allow(clippy::float_cmp)] + set(self, other, |x, y| x != y) + } + } + }; +} + +/// Types that implement PartialOrd +macro_rules! simd8_native_partial_ord { + ($type:ty) => { + impl Simd8PartialOrd for [$type; 8] { + #[inline] + fn lt_eq(self, other: Self) -> u8 { + set(self, other, |x, y| x <= y) + } + + #[inline] + fn lt(self, other: Self) -> u8 { + set(self, other, |x, y| x < y) + } + + #[inline] + fn gt_eq(self, other: Self) -> u8 { + set(self, other, |x, y| x >= y) + } + + #[inline] + fn gt(self, other: Self) -> u8 { + set(self, other, |x, y| x > y) + } + } + }; +} + +/// Types that implement simd8, PartialEq and PartialOrd +macro_rules! simd8_native_all { + ($type:ty) => { + simd8_native! {$type} + simd8_native_partial_eq! {$type} + simd8_native_partial_ord! {$type} + }; +} + +#[cfg(not(feature = "simd"))] +mod native; +#[cfg(not(feature = "simd"))] +pub use native::*; +#[cfg(feature = "simd")] +mod packed; +#[cfg(feature = "simd")] +pub use packed::*; diff --git a/crates/nano-arrow/src/compute/comparison/simd/native.rs b/crates/nano-arrow/src/compute/comparison/simd/native.rs new file mode 100644 index 000000000000..b8bbf9b17d66 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/simd/native.rs @@ -0,0 +1,23 @@ +use std::convert::TryInto; + +use super::{set, Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd}; +use crate::types::{days_ms, f16, i256, months_days_ns}; + +simd8_native_all!(u8); +simd8_native_all!(u16); +simd8_native_all!(u32); +simd8_native_all!(u64); +simd8_native_all!(i8); +simd8_native_all!(i16); +simd8_native_all!(i32); +simd8_native_all!(i128); +simd8_native_all!(i256); +simd8_native_all!(i64); +simd8_native!(f16); +simd8_native_partial_eq!(f16); +simd8_native_all!(f32); +simd8_native_all!(f64); +simd8_native!(days_ms); +simd8_native_partial_eq!(days_ms); +simd8_native!(months_days_ns); +simd8_native_partial_eq!(months_days_ns); diff --git a/crates/nano-arrow/src/compute/comparison/simd/packed.rs b/crates/nano-arrow/src/compute/comparison/simd/packed.rs new file mode 100644 index 000000000000..707d875deef0 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/simd/packed.rs @@ -0,0 +1,81 @@ +use std::convert::TryInto; +use std::simd::{SimdPartialEq, SimdPartialOrd, ToBitMask}; + +use super::*; +use crate::types::simd::*; +use crate::types::{days_ms, f16, i256, months_days_ns}; + +macro_rules! simd8 { + ($type:ty, $md:ty) => { + impl Simd8 for $type { + type Simd = $md; + } + + impl Simd8Lanes<$type> for $md { + #[inline] + fn from_chunk(v: &[$type]) -> Self { + <$md>::from_slice(v) + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; 8]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + Self::from_array(a) + } + } + + impl Simd8PartialEq for $md { + #[inline] + fn eq(self, other: Self) -> u8 { + self.simd_eq(other).to_bitmask() + } + + #[inline] + fn neq(self, other: Self) -> u8 { + self.simd_ne(other).to_bitmask() + } + } + + impl Simd8PartialOrd for $md { + #[inline] + fn lt_eq(self, other: Self) -> u8 { + self.simd_le(other).to_bitmask() + } + + #[inline] + fn lt(self, other: Self) -> u8 { + self.simd_lt(other).to_bitmask() + } + + #[inline] + fn gt_eq(self, other: Self) -> u8 { + self.simd_ge(other).to_bitmask() + } + + #[inline] + fn gt(self, other: Self) -> u8 { + self.simd_gt(other).to_bitmask() + } + } + }; +} + +simd8!(u8, u8x8); +simd8!(u16, u16x8); +simd8!(u32, u32x8); +simd8!(u64, u64x8); +simd8!(i8, i8x8); +simd8!(i16, i16x8); +simd8!(i32, i32x8); +simd8!(i64, i64x8); +simd8_native_all!(i128); +simd8_native_all!(i256); +simd8_native!(f16); +simd8_native_partial_eq!(f16); +simd8!(f32, f32x8); +simd8!(f64, f64x8); +simd8_native!(days_ms); +simd8_native_partial_eq!(days_ms); +simd8_native!(months_days_ns); +simd8_native_partial_eq!(months_days_ns); diff --git a/crates/nano-arrow/src/compute/comparison/utf8.rs b/crates/nano-arrow/src/compute/comparison/utf8.rs new file mode 100644 index 000000000000..cba683c7b869 --- /dev/null +++ b/crates/nano-arrow/src/compute/comparison/utf8.rs @@ -0,0 +1,291 @@ +//! Comparison functions for [`Utf8Array`] +use super::super::utils::combine_validities; +use crate::array::{BooleanArray, Utf8Array}; +use crate::bitmap::Bitmap; +use crate::compute::comparison::{finish_eq_validities, finish_neq_validities}; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// Evaluate `op(lhs, rhs)` for [`Utf8Array`]s using a specified +/// comparison function. +fn compare_op(lhs: &Utf8Array, rhs: &Utf8Array, op: F) -> BooleanArray +where + O: Offset, + F: Fn(&str, &str) -> bool, +{ + assert_eq!(lhs.len(), rhs.len()); + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values_iter() + .zip(rhs.values_iter()) + .map(|(lhs, rhs)| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Evaluate `op(lhs, rhs)` for [`Utf8Array`] and scalar using +/// a specified comparison function. +fn compare_op_scalar(lhs: &Utf8Array, rhs: &str, op: F) -> BooleanArray +where + O: Offset, + F: Fn(&str, &str) -> bool, +{ + let validity = lhs.validity().cloned(); + + let values = lhs.values_iter().map(|lhs| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::new(DataType::Boolean, values, validity) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`]. +pub fn eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`] and include validities in comparison. +pub fn eq_and_validity(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a == b); + + finish_eq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`] and include validities in comparison. +pub fn neq_and_validity(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + let validity_lhs = lhs.validity().cloned(); + let validity_rhs = rhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let rhs = rhs.clone().with_validity(None); + let out = compare_op(&lhs, &rhs, |a, b| a != b); + + finish_neq_validities(out, validity_lhs, validity_rhs) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`] and a scalar. +pub fn eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`Utf8Array`] and a scalar. Also includes null values in comparison. +pub fn eq_scalar_and_validity(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a == b); + + finish_eq_validities(out, validity, None) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`] and a scalar. Also includes null values in comparison. +pub fn neq_scalar_and_validity(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + let validity = lhs.validity().cloned(); + let lhs = lhs.clone().with_validity(None); + let out = compare_op_scalar(&lhs, rhs, |a, b| a != b); + + finish_neq_validities(out, validity, None) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`]. +pub fn neq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`Utf8Array`] and a scalar. +pub fn neq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs < rhs` operation on [`Utf8Array`]. +pub fn lt(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs < rhs` operation on [`Utf8Array`] and a scalar. +pub fn lt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs <= rhs` operation on [`Utf8Array`]. +pub fn lt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs <= rhs` operation on [`Utf8Array`] and a scalar. +pub fn lt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs > rhs` operation on [`Utf8Array`]. +pub fn gt(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs > rhs` operation on [`Utf8Array`] and a scalar. +pub fn gt_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs >= rhs` operation on [`Utf8Array`]. +pub fn gt_eq(lhs: &Utf8Array, rhs: &Utf8Array) -> BooleanArray { + compare_op(lhs, rhs, |a, b| a >= b) +} + +/// Perform `lhs >= rhs` operation on [`Utf8Array`] and a scalar. +pub fn gt_eq_scalar(lhs: &Utf8Array, rhs: &str) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a >= b) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_generic, &Utf8Array) -> BooleanArray>( + lhs: Vec<&str>, + rhs: Vec<&str>, + op: F, + expected: Vec, + ) { + let lhs = Utf8Array::::from_slice(lhs); + let rhs = Utf8Array::::from_slice(rhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, &rhs), expected); + } + + fn test_generic_scalar, &str) -> BooleanArray>( + lhs: Vec<&str>, + rhs: &str, + op: F, + expected: Vec, + ) { + let lhs = Utf8Array::::from_slice(lhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, rhs), expected); + } + + #[test] + fn test_gt_eq() { + test_generic::( + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + gt_eq, + vec![false, false, true, true], + ) + } + + #[test] + fn test_gt_eq_scalar() { + test_generic_scalar::( + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_eq_scalar, + vec![false, false, true, true], + ) + } + + #[test] + fn test_eq() { + test_generic::( + vec!["arrow", "arrow", "arrow", "arrow"], + vec!["arrow", "parquet", "datafusion", "flight"], + eq, + vec![true, false, false, false], + ) + } + + #[test] + fn test_eq_scalar() { + test_generic_scalar::( + vec!["arrow", "parquet", "datafusion", "flight"], + "arrow", + eq_scalar, + vec![true, false, false, false], + ) + } + + #[test] + fn test_neq() { + test_generic::( + vec!["arrow", "arrow", "arrow", "arrow"], + vec!["arrow", "parquet", "datafusion", "flight"], + neq, + vec![false, true, true, true], + ) + } + + #[test] + fn test_neq_scalar() { + test_generic_scalar::( + vec!["arrow", "parquet", "datafusion", "flight"], + "arrow", + neq_scalar, + vec![false, true, true, true], + ) + } + + /* + test_utf8!( + test_utf8_array_lt, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + lt_utf8, + vec![true, true, false, false] + ); + test_utf8_scalar!( + test_utf8_array_lt_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + lt_utf8_scalar, + vec![true, true, false, false] + ); + + test_utf8!( + test_utf8_array_lt_eq, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + lt_eq_utf8, + vec![true, true, true, false] + ); + test_utf8_scalar!( + test_utf8_array_lt_eq_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + lt_eq_utf8_scalar, + vec![true, true, true, false] + ); + + test_utf8!( + test_utf8_array_gt, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + gt_utf8, + vec![false, false, false, true] + ); + test_utf8_scalar!( + test_utf8_array_gt_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_utf8_scalar, + vec![false, false, false, true] + ); + + test_utf8!( + test_utf8_array_gt_eq, + vec!["arrow", "datafusion", "flight", "parquet"], + vec!["flight", "flight", "flight", "flight"], + gt_eq_utf8, + vec![false, false, true, true] + ); + test_utf8_scalar!( + test_utf8_array_gt_eq_scalar, + vec!["arrow", "datafusion", "flight", "parquet"], + "flight", + gt_eq_utf8_scalar, + vec![false, false, true, true] + ); + */ +} diff --git a/crates/nano-arrow/src/compute/concatenate.rs b/crates/nano-arrow/src/compute/concatenate.rs new file mode 100644 index 000000000000..5e38731a2fe9 --- /dev/null +++ b/crates/nano-arrow/src/compute/concatenate.rs @@ -0,0 +1,47 @@ +//! Contains the concatenate kernel +//! +//! Example: +//! +//! ``` +//! use arrow2::array::Utf8Array; +//! use arrow2::compute::concatenate::concatenate; +//! +//! let arr = concatenate(&[ +//! &Utf8Array::::from_slice(["hello", "world"]), +//! &Utf8Array::::from_slice(["!"]), +//! ]).unwrap(); +//! assert_eq!(arr.len(), 3); +//! ``` + +use crate::array::growable::make_growable; +use crate::array::Array; +use crate::error::{Error, Result}; + +/// Concatenate multiple [Array] of the same type into a single [`Array`]. +pub fn concatenate(arrays: &[&dyn Array]) -> Result> { + if arrays.is_empty() { + return Err(Error::InvalidArgumentError( + "concat requires input of at least one array".to_string(), + )); + } + + if arrays + .iter() + .any(|array| array.data_type() != arrays[0].data_type()) + { + return Err(Error::InvalidArgumentError( + "It is not possible to concatenate arrays of different data types.".to_string(), + )); + } + + let lengths = arrays.iter().map(|array| array.len()).collect::>(); + let capacity = lengths.iter().sum(); + + let mut mutable = make_growable(arrays, false, capacity); + + for (i, len) in lengths.iter().enumerate() { + mutable.extend(i, 0, *len) + } + + Ok(mutable.as_box()) +} diff --git a/crates/nano-arrow/src/compute/filter.rs b/crates/nano-arrow/src/compute/filter.rs new file mode 100644 index 000000000000..90ddf4b4d158 --- /dev/null +++ b/crates/nano-arrow/src/compute/filter.rs @@ -0,0 +1,321 @@ +//! Contains operators to filter arrays such as [`filter`]. +use crate::array::growable::{make_growable, Growable}; +use crate::array::*; +use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact, SlicesIterator}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::chunk::Chunk; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::simd::Simd; +use crate::types::{BitChunkOnes, NativeType}; + +/// Function that can filter arbitrary arrays +pub type Filter<'a> = Box Box + 'a + Send + Sync>; + +#[inline] +fn get_leading_ones(chunk: u64) -> u32 { + if cfg!(target_endian = "little") { + chunk.trailing_ones() + } else { + chunk.leading_ones() + } +} + +/// # Safety +/// This assumes that the `mask_chunks` contains a number of set/true items equal +/// to `filter_count` +unsafe fn nonnull_filter_impl(values: &[T], mut mask_chunks: I, filter_count: usize) -> Vec +where + T: NativeType + Simd, + I: BitChunkIterExact, +{ + let mut chunks = values.chunks_exact(64); + let mut new = Vec::::with_capacity(filter_count); + let mut dst = new.as_mut_ptr(); + + chunks + .by_ref() + .zip(mask_chunks.by_ref()) + .for_each(|(chunk, mask_chunk)| { + let ones = mask_chunk.count_ones(); + let leading_ones = get_leading_ones(mask_chunk); + + if ones == leading_ones { + let size = leading_ones as usize; + unsafe { + std::ptr::copy(chunk.as_ptr(), dst, size); + dst = dst.add(size); + } + return; + } + + let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); + for pos in ones_iter { + dst.write(*chunk.get_unchecked(pos)); + dst = dst.add(1); + } + }); + + chunks + .remainder() + .iter() + .zip(mask_chunks.remainder_iter()) + .for_each(|(value, b)| { + if b { + unsafe { + dst.write(*value); + dst = dst.add(1); + }; + } + }); + + unsafe { new.set_len(filter_count) }; + new +} + +/// # Safety +/// This assumes that the `mask_chunks` contains a number of set/true items equal +/// to `filter_count` +unsafe fn null_filter_impl( + values: &[T], + validity: &Bitmap, + mut mask_chunks: I, + filter_count: usize, +) -> (Vec, MutableBitmap) +where + T: NativeType + Simd, + I: BitChunkIterExact, +{ + let mut chunks = values.chunks_exact(64); + + let mut validity_chunks = validity.chunks::(); + + let mut new = Vec::::with_capacity(filter_count); + let mut dst = new.as_mut_ptr(); + let mut new_validity = MutableBitmap::with_capacity(filter_count); + + chunks + .by_ref() + .zip(validity_chunks.by_ref()) + .zip(mask_chunks.by_ref()) + .for_each(|((chunk, validity_chunk), mask_chunk)| { + let ones = mask_chunk.count_ones(); + let leading_ones = get_leading_ones(mask_chunk); + + if ones == leading_ones { + let size = leading_ones as usize; + unsafe { + std::ptr::copy(chunk.as_ptr(), dst, size); + dst = dst.add(size); + + // safety: invariant offset + length <= slice.len() + new_validity.extend_from_slice_unchecked( + validity_chunk.to_ne_bytes().as_ref(), + 0, + size, + ); + } + return; + } + + // this triggers a bitcount + let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize); + for pos in ones_iter { + dst.write(*chunk.get_unchecked(pos)); + dst = dst.add(1); + new_validity.push_unchecked(validity_chunk & (1 << pos) > 0); + } + }); + + chunks + .remainder() + .iter() + .zip(validity_chunks.remainder_iter()) + .zip(mask_chunks.remainder_iter()) + .for_each(|((value, is_valid), is_selected)| { + if is_selected { + unsafe { + dst.write(*value); + dst = dst.add(1); + new_validity.push_unchecked(is_valid); + }; + } + }); + + unsafe { new.set_len(filter_count) }; + (new, new_validity) +} + +fn null_filter_simd( + values: &[T], + validity: &Bitmap, + mask: &Bitmap, +) -> (Vec, MutableBitmap) { + assert_eq!(values.len(), mask.len()); + let filter_count = mask.len() - mask.unset_bits(); + + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } + } else { + let mask_chunks = mask.chunks::(); + unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) } + } +} + +fn nonnull_filter_simd(values: &[T], mask: &Bitmap) -> Vec { + assert_eq!(values.len(), mask.len()); + let filter_count = mask.len() - mask.unset_bits(); + + let (slice, offset, length) = mask.as_slice(); + if offset == 0 { + let mask_chunks = BitChunksExact::::new(slice, length); + unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } + } else { + let mask_chunks = mask.chunks::(); + unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) } + } +} + +fn filter_nonnull_primitive( + array: &PrimitiveArray, + mask: &Bitmap, +) -> PrimitiveArray { + assert_eq!(array.len(), mask.len()); + + if let Some(validity) = array.validity() { + let (values, validity) = null_filter_simd(array.values(), validity, mask); + PrimitiveArray::::new(array.data_type().clone(), values.into(), validity.into()) + } else { + let values = nonnull_filter_simd(array.values(), mask); + PrimitiveArray::::new(array.data_type().clone(), values.into(), None) + } +} + +fn filter_primitive( + array: &PrimitiveArray, + mask: &BooleanArray, +) -> PrimitiveArray { + // todo: branch on mask.validity() + filter_nonnull_primitive(array, mask.values()) +} + +fn filter_growable<'a>(growable: &mut impl Growable<'a>, chunks: &[(usize, usize)]) { + chunks + .iter() + .for_each(|(start, len)| growable.extend(0, *start, *len)); +} + +/// Returns a prepared function optimized to filter multiple arrays. +/// Creating this function requires time, but using it is faster than [filter] when the +/// same filter needs to be applied to multiple arrays (e.g. a multiple columns). +pub fn build_filter(filter: &BooleanArray) -> Result { + let iter = SlicesIterator::new(filter.values()); + let filter_count = iter.slots(); + let chunks = iter.collect::>(); + + use crate::datatypes::PhysicalType::*; + Ok(Box::new(move |array: &dyn Array| { + match array.data_type().to_physical_type() { + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + let mut growable = + growable::GrowablePrimitive::<$T>::new(vec![array], false, filter_count); + filter_growable(&mut growable, &chunks); + let array: PrimitiveArray<$T> = growable.into(); + Box::new(array) + }), + LargeUtf8 => { + let array = array.as_any().downcast_ref::>().unwrap(); + let mut growable = growable::GrowableUtf8::new(vec![array], false, filter_count); + filter_growable(&mut growable, &chunks); + let array: Utf8Array = growable.into(); + Box::new(array) + }, + _ => { + let mut mutable = make_growable(&[array], false, filter_count); + chunks + .iter() + .for_each(|(start, len)| mutable.extend(0, *start, *len)); + mutable.as_box() + }, + } + })) +} + +/// Filters an [Array], returning elements matching the filter (i.e. where the values are true). +/// +/// Note that the nulls of `filter` are interpreted as `false` will lead to these elements being +/// masked out. +/// +/// # Example +/// ```rust +/// # use arrow2::array::{Int32Array, PrimitiveArray, BooleanArray}; +/// # use arrow2::error::Result; +/// # use arrow2::compute::filter::filter; +/// # fn main() -> Result<()> { +/// let array = PrimitiveArray::from_slice([5, 6, 7, 8, 9]); +/// let filter_array = BooleanArray::from_slice(&vec![true, false, false, true, false]); +/// let c = filter(&array, &filter_array)?; +/// let c = c.as_any().downcast_ref::().unwrap(); +/// assert_eq!(c, &PrimitiveArray::from_slice(vec![5, 8])); +/// # Ok(()) +/// # } +/// ``` +pub fn filter(array: &dyn Array, filter: &BooleanArray) -> Result> { + // The validities may be masking out `true` bits, making the filter operation + // based on the values incorrect + if let Some(validities) = filter.validity() { + let values = filter.values(); + let new_values = values & validities; + let filter = BooleanArray::new(DataType::Boolean, new_values, None); + return crate::compute::filter::filter(array, &filter); + } + + let false_count = filter.values().unset_bits(); + if false_count == filter.len() { + assert_eq!(array.len(), filter.len()); + return Ok(new_empty_array(array.data_type().clone())); + } + if false_count == 0 { + assert_eq!(array.len(), filter.len()); + return Ok(array.to_boxed()); + } + + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + Ok(Box::new(filter_primitive::<$T>(array, filter))) + }), + _ => { + let iter = SlicesIterator::new(filter.values()); + let mut mutable = make_growable(&[array], false, iter.slots()); + iter.for_each(|(start, len)| mutable.extend(0, start, len)); + Ok(mutable.as_box()) + }, + } +} + +/// Returns a new [Chunk] with arrays containing only values matching the filter. +/// This is a convenience function: filter multiple columns is embarrassingly parallel. +pub fn filter_chunk>( + columns: &Chunk, + filter_values: &BooleanArray, +) -> Result>> { + let arrays = columns.arrays(); + + let num_columns = arrays.len(); + + let filtered_arrays = match num_columns { + 1 => { + vec![filter(columns.arrays()[0].as_ref(), filter_values)?] + }, + _ => { + let filter = build_filter(filter_values)?; + arrays.iter().map(|a| filter(a.as_ref())).collect() + }, + }; + Chunk::try_new(filtered_arrays) +} diff --git a/crates/nano-arrow/src/compute/if_then_else.rs b/crates/nano-arrow/src/compute/if_then_else.rs new file mode 100644 index 000000000000..86c46b29d040 --- /dev/null +++ b/crates/nano-arrow/src/compute/if_then_else.rs @@ -0,0 +1,75 @@ +//! Contains the operator [`if_then_else`]. +use crate::array::{growable, Array, BooleanArray}; +use crate::bitmap::utils::SlicesIterator; +use crate::error::{Error, Result}; + +/// Returns the values from `lhs` if the predicate is `true` or from the `rhs` if the predicate is false +/// Returns `None` if the predicate is `None`. +/// # Example +/// ```rust +/// # use arrow2::error::Result; +/// use arrow2::compute::if_then_else::if_then_else; +/// use arrow2::array::{Int32Array, BooleanArray}; +/// +/// # fn main() -> Result<()> { +/// let lhs = Int32Array::from_slice(&[1, 2, 3]); +/// let rhs = Int32Array::from_slice(&[4, 5, 6]); +/// let predicate = BooleanArray::from(&[Some(true), None, Some(false)]); +/// let result = if_then_else(&predicate, &lhs, &rhs)?; +/// +/// let expected = Int32Array::from(&[Some(1), None, Some(6)]); +/// +/// assert_eq!(expected, result.as_ref()); +/// # Ok(()) +/// # } +/// ``` +pub fn if_then_else( + predicate: &BooleanArray, + lhs: &dyn Array, + rhs: &dyn Array, +) -> Result> { + if lhs.data_type() != rhs.data_type() { + return Err(Error::InvalidArgumentError(format!( + "If then else requires the arguments to have the same datatypes ({:?} != {:?})", + lhs.data_type(), + rhs.data_type() + ))); + } + if (lhs.len() != rhs.len()) | (lhs.len() != predicate.len()) { + return Err(Error::InvalidArgumentError(format!( + "If then else requires all arguments to have the same length (predicate = {}, lhs = {}, rhs = {})", + predicate.len(), + lhs.len(), + rhs.len() + ))); + } + + let result = if predicate.null_count() > 0 { + let mut growable = growable::make_growable(&[lhs, rhs], true, lhs.len()); + for (i, v) in predicate.iter().enumerate() { + match v { + Some(v) => growable.extend(!v as usize, i, 1), + None => growable.extend_validity(1), + } + } + growable.as_box() + } else { + let mut growable = growable::make_growable(&[lhs, rhs], false, lhs.len()); + let mut start_falsy = 0; + let mut total_len = 0; + for (start, len) in SlicesIterator::new(predicate.values()) { + if start != start_falsy { + growable.extend(1, start_falsy, start - start_falsy); + total_len += start - start_falsy; + }; + growable.extend(0, start, len); + total_len += len; + start_falsy = start + len; + } + if total_len != lhs.len() { + growable.extend(1, total_len, lhs.len() - total_len); + } + growable.as_box() + }; + Ok(result) +} diff --git a/crates/nano-arrow/src/compute/mod.rs b/crates/nano-arrow/src/compute/mod.rs new file mode 100644 index 000000000000..a40e4dcbb558 --- /dev/null +++ b/crates/nano-arrow/src/compute/mod.rs @@ -0,0 +1,52 @@ +//! contains a wide range of compute operations (e.g. +//! [`arithmetics`], [`aggregate`], +//! [`filter`], [`comparison`], and [`sort`]) +//! +//! This module's general design is +//! that each operator has two interfaces, a statically-typed version and a dynamically-typed +//! version. +//! The statically-typed version expects concrete arrays (such as [`PrimitiveArray`](crate::array::PrimitiveArray)); +//! the dynamically-typed version expects `&dyn Array` and errors if the the type is not +//! supported. +//! Some dynamically-typed operators have an auxiliary function, `can_*`, that returns +//! true if the operator can be applied to the particular `DataType`. + +#[cfg(any(feature = "compute_aggregate", feature = "io_parquet"))] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_aggregate")))] +pub mod aggregate; +#[cfg(feature = "compute_arithmetics")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_arithmetics")))] +pub mod arithmetics; +pub mod arity; +pub mod arity_assign; +#[cfg(feature = "compute_bitwise")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_bitwise")))] +pub mod bitwise; +#[cfg(feature = "compute_boolean")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_boolean")))] +pub mod boolean; +#[cfg(feature = "compute_boolean_kleene")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_boolean_kleene")))] +pub mod boolean_kleene; +#[cfg(feature = "compute_cast")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_cast")))] +pub mod cast; +#[cfg(feature = "compute_comparison")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_comparison")))] +pub mod comparison; +#[cfg(feature = "compute_concatenate")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_concatenate")))] +pub mod concatenate; +#[cfg(feature = "compute_filter")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_filter")))] +pub mod filter; +#[cfg(feature = "compute_if_then_else")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_if_then_else")))] +pub mod if_then_else; +#[cfg(feature = "compute_take")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_take")))] +pub mod take; +#[cfg(feature = "compute_temporal")] +#[cfg_attr(docsrs, doc(cfg(feature = "compute_temporal")))] +pub mod temporal; +mod utils; diff --git a/crates/nano-arrow/src/compute/take/binary.rs b/crates/nano-arrow/src/compute/take/binary.rs new file mode 100644 index 000000000000..0e6460206f0e --- /dev/null +++ b/crates/nano-arrow/src/compute/take/binary.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::generic_binary::*; +use super::Index; +use crate::array::{Array, BinaryArray, PrimitiveArray}; +use crate::offset::Offset; + +/// `take` implementation for utf8 arrays +pub fn take( + values: &BinaryArray, + indices: &PrimitiveArray, +) -> BinaryArray { + let data_type = values.data_type().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (offsets, values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => { + take_no_validity::(values.offsets(), values.values(), indices.values()) + }, + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.offsets(), values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + BinaryArray::::new(data_type, offsets, values, validity) +} diff --git a/crates/nano-arrow/src/compute/take/boolean.rs b/crates/nano-arrow/src/compute/take/boolean.rs new file mode 100644 index 000000000000..62be88e46226 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/boolean.rs @@ -0,0 +1,138 @@ +use super::Index; +use crate::array::{Array, BooleanArray, PrimitiveArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; + +// take implementation when neither values nor indices contain nulls +fn take_no_validity(values: &Bitmap, indices: &[I]) -> (Bitmap, Option) { + let values = indices.iter().map(|index| values.get_bit(index.to_usize())); + let buffer = Bitmap::from_trusted_len_iter(values); + + (buffer, None) +} + +// take implementation when only values contain nulls +fn take_values_validity( + values: &BooleanArray, + indices: &[I], +) -> (Bitmap, Option) { + let validity_values = values.validity().unwrap(); + let validity = indices + .iter() + .map(|index| validity_values.get_bit(index.to_usize())); + let validity = Bitmap::from_trusted_len_iter(validity); + + let values_values = values.values(); + let values = indices + .iter() + .map(|index| values_values.get_bit(index.to_usize())); + let buffer = Bitmap::from_trusted_len_iter(values); + + (buffer, validity.into()) +} + +// take implementation when only indices contain nulls +fn take_indices_validity( + values: &Bitmap, + indices: &PrimitiveArray, +) -> (Bitmap, Option) { + let validity = indices.validity().unwrap(); + + let values = indices.values().iter().enumerate().map(|(i, index)| { + let index = index.to_usize(); + match values.get(index) { + Some(value) => value, + None => { + if !validity.get_bit(i) { + false + } else { + panic!("Out-of-bounds index {index}") + } + }, + } + }); + + let buffer = Bitmap::from_trusted_len_iter(values); + + (buffer, indices.validity().cloned()) +} + +// take implementation when both values and indices contain nulls +fn take_values_indices_validity( + values: &BooleanArray, + indices: &PrimitiveArray, +) -> (Bitmap, Option) { + let mut validity = MutableBitmap::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + + let values_values = values.values(); + let values = indices.iter().map(|index| match index { + Some(index) => { + let index = index.to_usize(); + validity.push(values_validity.get_bit(index)); + values_values.get_bit(index) + }, + None => { + validity.push(false); + false + }, + }); + let values = Bitmap::from_trusted_len_iter(values); + (values, validity.into()) +} + +/// `take` implementation for boolean arrays +pub fn take(values: &BooleanArray, indices: &PrimitiveArray) -> BooleanArray { + let data_type = values.data_type().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => take_no_validity(values.values(), indices.values()), + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + + BooleanArray::new(data_type, values, validity) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::Int32Array; + + fn _all_cases() -> Vec<(Int32Array, BooleanArray, BooleanArray)> { + vec![ + ( + Int32Array::from(&[Some(1), Some(0)]), + BooleanArray::from(vec![Some(true), Some(false)]), + BooleanArray::from(vec![Some(false), Some(true)]), + ), + ( + Int32Array::from(&[Some(1), None]), + BooleanArray::from(vec![Some(true), Some(false)]), + BooleanArray::from(vec![Some(false), None]), + ), + ( + Int32Array::from(&[Some(1), Some(0)]), + BooleanArray::from(vec![None, Some(false)]), + BooleanArray::from(vec![Some(false), None]), + ), + ( + Int32Array::from(&[Some(1), None, Some(0)]), + BooleanArray::from(vec![None, Some(false)]), + BooleanArray::from(vec![Some(false), None, None]), + ), + ] + } + + #[test] + fn all_cases() { + let cases = _all_cases(); + for (indices, input, expected) in cases { + let output = take(&input, &indices); + assert_eq!(expected, output); + } + } +} diff --git a/crates/nano-arrow/src/compute/take/dict.rs b/crates/nano-arrow/src/compute/take/dict.rs new file mode 100644 index 000000000000..bb60c09193f7 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/dict.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::primitive::take as take_primitive; +use super::Index; +use crate::array::{DictionaryArray, DictionaryKey, PrimitiveArray}; + +/// `take` implementation for dictionary arrays +/// +/// applies `take` to the keys of the dictionary array and returns a new dictionary array +/// with the same dictionary values and reordered keys +pub fn take(values: &DictionaryArray, indices: &PrimitiveArray) -> DictionaryArray +where + K: DictionaryKey, + I: Index, +{ + let keys = take_primitive::(values.keys(), indices); + // safety - this operation takes a subset of keys and thus preserves the dictionary's invariant + unsafe { + DictionaryArray::::try_new_unchecked( + values.data_type().clone(), + keys, + values.values().clone(), + ) + .unwrap() + } +} diff --git a/crates/nano-arrow/src/compute/take/fixed_size_list.rs b/crates/nano-arrow/src/compute/take/fixed_size_list.rs new file mode 100644 index 000000000000..6e7e74b91720 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/fixed_size_list.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Index; +use crate::array::growable::{Growable, GrowableFixedSizeList}; +use crate::array::{FixedSizeListArray, PrimitiveArray}; + +/// `take` implementation for FixedSizeListArrays +pub fn take( + values: &FixedSizeListArray, + indices: &PrimitiveArray, +) -> FixedSizeListArray { + let mut capacity = 0; + let arrays = indices + .values() + .iter() + .map(|index| { + let index = index.to_usize(); + let slice = values.clone().sliced(index, 1); + capacity += slice.len(); + slice + }) + .collect::>(); + + let arrays = arrays.iter().collect(); + + if let Some(validity) = indices.validity() { + let mut growable: GrowableFixedSizeList = + GrowableFixedSizeList::new(arrays, true, capacity); + + for index in 0..indices.len() { + if validity.get_bit(index) { + growable.extend(index, 0, 1); + } else { + growable.extend_validity(1) + } + } + + growable.into() + } else { + let mut growable: GrowableFixedSizeList = + GrowableFixedSizeList::new(arrays, false, capacity); + for index in 0..indices.len() { + growable.extend(index, 0, 1); + } + + growable.into() + } +} diff --git a/crates/nano-arrow/src/compute/take/generic_binary.rs b/crates/nano-arrow/src/compute/take/generic_binary.rs new file mode 100644 index 000000000000..9f6658c7d5a0 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/generic_binary.rs @@ -0,0 +1,155 @@ +use super::Index; +use crate::array::{GenericBinaryArray, PrimitiveArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::buffer::Buffer; +use crate::offset::{Offset, Offsets, OffsetsBuffer}; + +pub fn take_values( + length: O, + starts: &[O], + offsets: &OffsetsBuffer, + values: &[u8], +) -> Buffer { + let new_len = length.to_usize(); + let mut buffer = Vec::with_capacity(new_len); + starts + .iter() + .map(|start| start.to_usize()) + .zip(offsets.lengths()) + .for_each(|(start, length)| { + let end = start + length; + buffer.extend_from_slice(&values[start..end]); + }); + buffer.into() +} + +// take implementation when neither values nor indices contain nulls +pub fn take_no_validity( + offsets: &OffsetsBuffer, + values: &[u8], + indices: &[I], +) -> (OffsetsBuffer, Buffer, Option) { + let mut buffer = Vec::::new(); + let lengths = indices.iter().map(|index| index.to_usize()).map(|index| { + let (start, end) = offsets.start_end(index); + // todo: remove this bound check + buffer.extend_from_slice(&values[start..end]); + end - start + }); + let offsets = Offsets::try_from_lengths(lengths).expect(""); + + (offsets.into(), buffer.into(), None) +} + +// take implementation when only values contain nulls +pub fn take_values_validity>( + values: &A, + indices: &[I], +) -> (OffsetsBuffer, Buffer, Option) { + let validity_values = values.validity().unwrap(); + let validity = indices + .iter() + .map(|index| validity_values.get_bit(index.to_usize())); + let validity = Bitmap::from_trusted_len_iter(validity); + + let mut length = O::default(); + + let offsets = values.offsets(); + let values_values = values.values(); + + let mut starts = Vec::::with_capacity(indices.len()); + let offsets = indices.iter().map(|index| { + let index = index.to_usize(); + let start = offsets[index]; + length += offsets[index + 1] - start; + starts.push(start); + length + }); + let offsets = std::iter::once(O::default()) + .chain(offsets) + .collect::>(); + // Safety: by construction offsets are monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + + let buffer = take_values(length, starts.as_slice(), &offsets, values_values); + + (offsets, buffer, validity.into()) +} + +// take implementation when only indices contain nulls +pub fn take_indices_validity( + offsets: &OffsetsBuffer, + values: &[u8], + indices: &PrimitiveArray, +) -> (OffsetsBuffer, Buffer, Option) { + let mut length = O::default(); + + let offsets = offsets.buffer(); + + let mut starts = Vec::::with_capacity(indices.len()); + let offsets = indices.values().iter().map(|index| { + let index = index.to_usize(); + match offsets.get(index + 1) { + Some(&next) => { + let start = offsets[index]; + length += next - start; + starts.push(start); + }, + None => starts.push(O::default()), + }; + length + }); + let offsets = std::iter::once(O::default()) + .chain(offsets) + .collect::>(); + // Safety: by construction offsets are monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + + let buffer = take_values(length, &starts, &offsets, values); + + (offsets, buffer, indices.validity().cloned()) +} + +// take implementation when both indices and values contain nulls +pub fn take_values_indices_validity>( + values: &A, + indices: &PrimitiveArray, +) -> (OffsetsBuffer, Buffer, Option) { + let mut length = O::default(); + let mut validity = MutableBitmap::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + let offsets = values.offsets(); + let values_values = values.values(); + + let mut starts = Vec::::with_capacity(indices.len()); + let offsets = indices.iter().map(|index| { + match index { + Some(index) => { + let index = index.to_usize(); + if values_validity.get_bit(index) { + validity.push(true); + length += offsets[index + 1] - offsets[index]; + starts.push(offsets[index]); + } else { + validity.push(false); + starts.push(O::default()); + } + }, + None => { + validity.push(false); + starts.push(O::default()); + }, + }; + length + }); + let offsets = std::iter::once(O::default()) + .chain(offsets) + .collect::>(); + // Safety: by construction offsets are monotonically increasing + let offsets = unsafe { Offsets::new_unchecked(offsets) }.into(); + + let buffer = take_values(length, &starts, &offsets, values_values); + + (offsets, buffer, validity.into()) +} diff --git a/crates/nano-arrow/src/compute/take/list.rs b/crates/nano-arrow/src/compute/take/list.rs new file mode 100644 index 000000000000..58fb9d6fd788 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/list.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Index; +use crate::array::growable::{Growable, GrowableList}; +use crate::array::{ListArray, PrimitiveArray}; +use crate::offset::Offset; + +/// `take` implementation for ListArrays +pub fn take( + values: &ListArray, + indices: &PrimitiveArray, +) -> ListArray { + let mut capacity = 0; + let arrays = indices + .values() + .iter() + .map(|index| { + let index = index.to_usize(); + let slice = values.clone().sliced(index, 1); + capacity += slice.len(); + slice + }) + .collect::>>(); + + let arrays = arrays.iter().collect(); + + if let Some(validity) = indices.validity() { + let mut growable: GrowableList = GrowableList::new(arrays, true, capacity); + + for index in 0..indices.len() { + if validity.get_bit(index) { + growable.extend(index, 0, 1); + } else { + growable.extend_validity(1) + } + } + + growable.into() + } else { + let mut growable: GrowableList = GrowableList::new(arrays, false, capacity); + for index in 0..indices.len() { + growable.extend(index, 0, 1); + } + + growable.into() + } +} diff --git a/crates/nano-arrow/src/compute/take/mod.rs b/crates/nano-arrow/src/compute/take/mod.rs new file mode 100644 index 000000000000..d526713a4327 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/mod.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines take kernel for [`Array`] + +use crate::array::{new_empty_array, Array, NullArray, PrimitiveArray}; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::Index; + +mod binary; +mod boolean; +mod dict; +mod fixed_size_list; +mod generic_binary; +mod list; +mod primitive; +mod structure; +mod utf8; + +pub(crate) use boolean::take as take_boolean; + +/// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls. +/// The returned array has a length equal to `indices.len()`. +pub fn take(values: &dyn Array, indices: &PrimitiveArray) -> Result> { + if indices.len() == 0 { + return Ok(new_empty_array(values.data_type().clone())); + } + + use crate::datatypes::PhysicalType::*; + match values.data_type().to_physical_type() { + Null => Ok(Box::new(NullArray::new( + values.data_type().clone(), + indices.len(), + ))), + Boolean => { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(boolean::take::(values, indices))) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(primitive::take::<$T, _>(&values, indices))) + }), + LargeUtf8 => { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(utf8::take::(values, indices))) + }, + LargeBinary => { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(binary::take::(values, indices))) + }, + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + let values = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(dict::take::<$T, _>(&values, indices))) + }) + }, + Struct => { + let array = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(structure::take::<_>(array, indices)?)) + }, + LargeList => { + let array = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(list::take::(array, indices))) + }, + FixedSizeList => { + let array = values.as_any().downcast_ref().unwrap(); + Ok(Box::new(fixed_size_list::take::(array, indices))) + }, + t => unimplemented!("Take not supported for data type {:?}", t), + } +} + +/// Checks if an array of type `datatype` can perform take operation +/// +/// # Examples +/// ``` +/// use arrow2::compute::take::can_take; +/// use arrow2::datatypes::{DataType}; +/// +/// let data_type = DataType::Int8; +/// assert_eq!(can_take(&data_type), true); +/// ``` +pub fn can_take(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Interval(_) + | DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Timestamp(_, _) + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Struct(_) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) + | DataType::Dictionary(..) + ) +} diff --git a/crates/nano-arrow/src/compute/take/primitive.rs b/crates/nano-arrow/src/compute/take/primitive.rs new file mode 100644 index 000000000000..5ce53ba7cc20 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/primitive.rs @@ -0,0 +1,112 @@ +use super::Index; +use crate::array::{Array, PrimitiveArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::buffer::Buffer; +use crate::types::NativeType; + +// take implementation when neither values nor indices contain nulls +fn take_no_validity( + values: &[T], + indices: &[I], +) -> (Buffer, Option) { + let values = indices + .iter() + .map(|index| values[index.to_usize()]) + .collect::>(); + + (values.into(), None) +} + +// take implementation when only values contain nulls +fn take_values_validity( + values: &PrimitiveArray, + indices: &[I], +) -> (Buffer, Option) { + let values_validity = values.validity().unwrap(); + + let validity = indices + .iter() + .map(|index| values_validity.get_bit(index.to_usize())); + let validity = MutableBitmap::from_trusted_len_iter(validity); + + let values_values = values.values(); + + let values = indices + .iter() + .map(|index| values_values[index.to_usize()]) + .collect::>(); + + (values.into(), validity.into()) +} + +// take implementation when only indices contain nulls +fn take_indices_validity( + values: &[T], + indices: &PrimitiveArray, +) -> (Buffer, Option) { + let validity = indices.validity().unwrap(); + let values = indices + .values() + .iter() + .enumerate() + .map(|(i, index)| { + let index = index.to_usize(); + match values.get(index) { + Some(value) => *value, + None => { + if !validity.get_bit(i) { + T::default() + } else { + panic!("Out-of-bounds index {index}") + } + }, + } + }) + .collect::>(); + + (values.into(), indices.validity().cloned()) +} + +// take implementation when both values and indices contain nulls +fn take_values_indices_validity( + values: &PrimitiveArray, + indices: &PrimitiveArray, +) -> (Buffer, Option) { + let mut bitmap = MutableBitmap::with_capacity(indices.len()); + + let values_validity = values.validity().unwrap(); + + let values_values = values.values(); + let values = indices + .iter() + .map(|index| match index { + Some(index) => { + let index = index.to_usize(); + bitmap.push(values_validity.get_bit(index)); + values_values[index] + }, + None => { + bitmap.push(false); + T::default() + }, + }) + .collect::>(); + (values.into(), bitmap.into()) +} + +/// `take` implementation for primitive arrays +pub fn take( + values: &PrimitiveArray, + indices: &PrimitiveArray, +) -> PrimitiveArray { + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + let (buffer, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => take_no_validity::(values.values(), indices.values()), + (true, false) => take_values_validity::(values, indices.values()), + (false, true) => take_indices_validity::(values.values(), indices), + (true, true) => take_values_indices_validity::(values, indices), + }; + + PrimitiveArray::::new(values.data_type().clone(), buffer, validity) +} diff --git a/crates/nano-arrow/src/compute/take/structure.rs b/crates/nano-arrow/src/compute/take/structure.rs new file mode 100644 index 000000000000..e0a2717f5746 --- /dev/null +++ b/crates/nano-arrow/src/compute/take/structure.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Index; +use crate::array::{Array, PrimitiveArray, StructArray}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::error::Result; + +#[inline] +fn take_validity( + validity: Option<&Bitmap>, + indices: &PrimitiveArray, +) -> Result> { + let indices_validity = indices.validity(); + match (validity, indices_validity) { + (None, _) => Ok(indices_validity.cloned()), + (Some(validity), None) => { + let iter = indices.values().iter().map(|index| { + let index = index.to_usize(); + validity.get_bit(index) + }); + Ok(MutableBitmap::from_trusted_len_iter(iter).into()) + }, + (Some(validity), _) => { + let iter = indices.iter().map(|x| match x { + Some(index) => { + let index = index.to_usize(); + validity.get_bit(index) + }, + None => false, + }); + Ok(MutableBitmap::from_trusted_len_iter(iter).into()) + }, + } +} + +pub fn take(array: &StructArray, indices: &PrimitiveArray) -> Result { + let values: Vec> = array + .values() + .iter() + .map(|a| super::take(a.as_ref(), indices)) + .collect::>()?; + let validity = take_validity(array.validity(), indices)?; + Ok(StructArray::new( + array.data_type().clone(), + values, + validity, + )) +} diff --git a/crates/nano-arrow/src/compute/take/utf8.rs b/crates/nano-arrow/src/compute/take/utf8.rs new file mode 100644 index 000000000000..3f5f5877c12f --- /dev/null +++ b/crates/nano-arrow/src/compute/take/utf8.rs @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::generic_binary::*; +use super::Index; +use crate::array::{Array, PrimitiveArray, Utf8Array}; +use crate::offset::Offset; + +/// `take` implementation for utf8 arrays +pub fn take( + values: &Utf8Array, + indices: &PrimitiveArray, +) -> Utf8Array { + let data_type = values.data_type().clone(); + let indices_has_validity = indices.null_count() > 0; + let values_has_validity = values.null_count() > 0; + + let (offsets, values, validity) = match (values_has_validity, indices_has_validity) { + (false, false) => { + take_no_validity::(values.offsets(), values.values(), indices.values()) + }, + (true, false) => take_values_validity(values, indices.values()), + (false, true) => take_indices_validity(values.offsets(), values.values(), indices), + (true, true) => take_values_indices_validity(values, indices), + }; + unsafe { Utf8Array::::new_unchecked(data_type, offsets, values, validity) } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::Int32Array; + + fn _all_cases() -> Vec<(Int32Array, Utf8Array, Utf8Array)> { + vec![ + ( + Int32Array::from(&[Some(1), Some(0)]), + Utf8Array::::from(vec![Some("one"), Some("two")]), + Utf8Array::::from(vec![Some("two"), Some("one")]), + ), + ( + Int32Array::from(&[Some(1), None]), + Utf8Array::::from(vec![Some("one"), Some("two")]), + Utf8Array::::from(vec![Some("two"), None]), + ), + ( + Int32Array::from(&[Some(1), Some(0)]), + Utf8Array::::from(vec![None, Some("two")]), + Utf8Array::::from(vec![Some("two"), None]), + ), + ( + Int32Array::from(&[Some(1), None, Some(0)]), + Utf8Array::::from(vec![None, Some("two")]), + Utf8Array::::from(vec![Some("two"), None, None]), + ), + ] + } + + #[test] + fn all_cases() { + let cases = _all_cases::(); + for (indices, input, expected) in cases { + let output = take(&input, &indices); + assert_eq!(expected, output); + } + let cases = _all_cases::(); + for (indices, input, expected) in cases { + let output = take(&input, &indices); + assert_eq!(expected, output); + } + } +} diff --git a/crates/nano-arrow/src/compute/temporal.rs b/crates/nano-arrow/src/compute/temporal.rs new file mode 100644 index 000000000000..132492f58b6e --- /dev/null +++ b/crates/nano-arrow/src/compute/temporal.rs @@ -0,0 +1,410 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines temporal kernels for time and date related functions. + +use chrono::{Datelike, Timelike}; + +use super::arity::unary; +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::temporal_conversions::*; +use crate::types::NativeType; + +// Create and implement a trait that converts chrono's `Weekday` +// type into `u32` +trait U32Weekday: Datelike { + fn u32_weekday(&self) -> u32 { + self.weekday().number_from_monday() + } +} + +impl U32Weekday for chrono::NaiveDateTime {} +impl U32Weekday for chrono::DateTime {} + +// Create and implement a trait that converts chrono's `IsoWeek` +// type into `u32` +trait U32IsoWeek: Datelike { + fn u32_iso_week(&self) -> u32 { + self.iso_week().week() + } +} + +impl U32IsoWeek for chrono::NaiveDateTime {} +impl U32IsoWeek for chrono::DateTime {} + +// Macro to avoid repetition in functions, that apply +// `chrono::Datelike` methods on Arrays +macro_rules! date_like { + ($extract:ident, $array:ident, $data_type:path) => { + match $array.data_type().to_logical_type() { + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, None) => { + date_variants($array, $data_type, |x| x.$extract()) + }, + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let array = $array.as_any().downcast_ref().unwrap(); + + if let Ok(timezone) = parse_offset(timezone_str) { + Ok(extract_impl(array, *time_unit, timezone, |x| x.$extract())) + } else { + chrono_tz(array, *time_unit, timezone_str, |x| x.$extract()) + } + }, + dt => Err(Error::NotYetImplemented(format!( + "\"{}\" does not support type {:?}", + stringify!($extract), + dt + ))), + } + }; +} + +/// Extracts the years of a temporal array as [`PrimitiveArray`]. +/// Use [`can_year`] to check if this operation is supported for the target [`DataType`]. +pub fn year(array: &dyn Array) -> Result> { + date_like!(year, array, DataType::Int32) +} + +/// Extracts the months of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 1 to 12. +/// Use [`can_month`] to check if this operation is supported for the target [`DataType`]. +pub fn month(array: &dyn Array) -> Result> { + date_like!(month, array, DataType::UInt32) +} + +/// Extracts the days of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 1 to 32 (Last day depends on month). +/// Use [`can_day`] to check if this operation is supported for the target [`DataType`]. +pub fn day(array: &dyn Array) -> Result> { + date_like!(day, array, DataType::UInt32) +} + +/// Extracts weekday of a temporal array as [`PrimitiveArray`]. +/// Monday is 1, Tuesday is 2, ..., Sunday is 7. +/// Use [`can_weekday`] to check if this operation is supported for the target [`DataType`] +pub fn weekday(array: &dyn Array) -> Result> { + date_like!(u32_weekday, array, DataType::UInt32) +} + +/// Extracts ISO week of a temporal array as [`PrimitiveArray`] +/// Value ranges from 1 to 53 (Last week depends on the year). +/// Use [`can_iso_week`] to check if this operation is supported for the target [`DataType`] +pub fn iso_week(array: &dyn Array) -> Result> { + date_like!(u32_iso_week, array, DataType::UInt32) +} + +// Macro to avoid repetition in functions, that apply +// `chrono::Timelike` methods on Arrays +macro_rules! time_like { + ($extract:ident, $array:ident, $data_type:path) => { + match $array.data_type().to_logical_type() { + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, None) => { + date_variants($array, $data_type, |x| x.$extract()) + }, + DataType::Time32(_) | DataType::Time64(_) => { + time_variants($array, DataType::UInt32, |x| x.$extract()) + }, + DataType::Timestamp(time_unit, Some(timezone_str)) => { + let array = $array.as_any().downcast_ref().unwrap(); + + if let Ok(timezone) = parse_offset(timezone_str) { + Ok(extract_impl(array, *time_unit, timezone, |x| x.$extract())) + } else { + chrono_tz(array, *time_unit, timezone_str, |x| x.$extract()) + } + }, + dt => Err(Error::NotYetImplemented(format!( + "\"{}\" does not support type {:?}", + stringify!($extract), + dt + ))), + } + }; +} + +/// Extracts the hours of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 23. +/// Use [`can_hour`] to check if this operation is supported for the target [`DataType`]. +pub fn hour(array: &dyn Array) -> Result> { + time_like!(hour, array, DataType::UInt32) +} + +/// Extracts the minutes of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 59. +/// Use [`can_minute`] to check if this operation is supported for the target [`DataType`]. +pub fn minute(array: &dyn Array) -> Result> { + time_like!(minute, array, DataType::UInt32) +} + +/// Extracts the seconds of a temporal array as [`PrimitiveArray`]. +/// Value ranges from 0 to 59. +/// Use [`can_second`] to check if this operation is supported for the target [`DataType`]. +pub fn second(array: &dyn Array) -> Result> { + time_like!(second, array, DataType::UInt32) +} + +/// Extracts the nanoseconds of a temporal array as [`PrimitiveArray`]. +/// Use [`can_nanosecond`] to check if this operation is supported for the target [`DataType`]. +pub fn nanosecond(array: &dyn Array) -> Result> { + time_like!(nanosecond, array, DataType::UInt32) +} + +fn date_variants(array: &dyn Array, data_type: DataType, op: F) -> Result> +where + O: NativeType, + F: Fn(chrono::NaiveDateTime) -> O, +{ + match array.data_type().to_logical_type() { + DataType::Date32 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(date32_to_datetime(x)), data_type)) + }, + DataType::Date64 => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(date64_to_datetime(x)), data_type)) + }, + DataType::Timestamp(time_unit, None) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let func = match time_unit { + TimeUnit::Second => timestamp_s_to_datetime, + TimeUnit::Millisecond => timestamp_ms_to_datetime, + TimeUnit::Microsecond => timestamp_us_to_datetime, + TimeUnit::Nanosecond => timestamp_ns_to_datetime, + }; + Ok(unary(array, |x| op(func(x)), data_type)) + }, + _ => unreachable!(), + } +} + +fn time_variants(array: &dyn Array, data_type: DataType, op: F) -> Result> +where + O: NativeType, + F: Fn(chrono::NaiveTime) -> O, +{ + match array.data_type().to_logical_type() { + DataType::Time32(TimeUnit::Second) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time32s_to_time(x)), data_type)) + }, + DataType::Time32(TimeUnit::Millisecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time32ms_to_time(x)), data_type)) + }, + DataType::Time64(TimeUnit::Microsecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time64us_to_time(x)), data_type)) + }, + DataType::Time64(TimeUnit::Nanosecond) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(unary(array, |x| op(time64ns_to_time(x)), data_type)) + }, + _ => unreachable!(), + } +} + +#[cfg(feature = "chrono-tz")] +fn chrono_tz( + array: &PrimitiveArray, + time_unit: TimeUnit, + timezone_str: &str, + op: F, +) -> Result> +where + O: NativeType, + F: Fn(chrono::DateTime) -> O, +{ + let timezone = parse_offset_tz(timezone_str)?; + Ok(extract_impl(array, time_unit, timezone, op)) +} + +#[cfg(not(feature = "chrono-tz"))] +fn chrono_tz( + _: &PrimitiveArray, + _: TimeUnit, + timezone_str: &str, + _: F, +) -> Result> +where + O: NativeType, + F: Fn(chrono::DateTime) -> O, +{ + Err(Error::InvalidArgumentError(format!( + "timezone \"{}\" cannot be parsed (feature chrono-tz is not active)", + timezone_str + ))) +} + +fn extract_impl( + array: &PrimitiveArray, + time_unit: TimeUnit, + timezone: T, + extract: F, +) -> PrimitiveArray +where + T: chrono::TimeZone, + A: NativeType, + F: Fn(chrono::DateTime) -> A, +{ + match time_unit { + TimeUnit::Second => { + let op = |x| { + let datetime = timestamp_s_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Millisecond => { + let op = |x| { + let datetime = timestamp_ms_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Microsecond => { + let op = |x| { + let datetime = timestamp_us_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + TimeUnit::Nanosecond => { + let op = |x| { + let datetime = timestamp_ns_to_datetime(x); + let offset = timezone.offset_from_utc_datetime(&datetime); + extract(chrono::DateTime::::from_naive_utc_and_offset( + datetime, offset, + )) + }; + unary(array, op, A::PRIMITIVE.into()) + }, + } +} + +/// Checks if an array of type `datatype` can perform year operation +/// +/// # Examples +/// ``` +/// use arrow2::compute::temporal::can_year; +/// use arrow2::datatypes::{DataType}; +/// +/// assert_eq!(can_year(&DataType::Date32), true); +/// assert_eq!(can_year(&DataType::Int8), false); +/// ``` +pub fn can_year(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `datatype` can perform month operation +pub fn can_month(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `datatype` can perform day operation +pub fn can_day(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `data_type` can perform weekday operation +pub fn can_weekday(data_type: &DataType) -> bool { + can_date(data_type) +} + +/// Checks if an array of type `data_type` can perform ISO week operation +pub fn can_iso_week(data_type: &DataType) -> bool { + can_date(data_type) +} + +fn can_date(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) + ) +} + +/// Checks if an array of type `datatype` can perform hour operation +/// +/// # Examples +/// ``` +/// use arrow2::compute::temporal::can_hour; +/// use arrow2::datatypes::{DataType, TimeUnit}; +/// +/// assert_eq!(can_hour(&DataType::Time32(TimeUnit::Second)), true); +/// assert_eq!(can_hour(&DataType::Int8), false); +/// ``` +pub fn can_hour(data_type: &DataType) -> bool { + can_time(data_type) +} + +/// Checks if an array of type `datatype` can perform minute operation +pub fn can_minute(data_type: &DataType) -> bool { + can_time(data_type) +} + +/// Checks if an array of type `datatype` can perform second operation +pub fn can_second(data_type: &DataType) -> bool { + can_time(data_type) +} + +/// Checks if an array of type `datatype` can perform nanosecond operation +pub fn can_nanosecond(data_type: &DataType) -> bool { + can_time(data_type) +} + +fn can_time(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Date32 + | DataType::Date64 + | DataType::Timestamp(_, _) + ) +} diff --git a/crates/nano-arrow/src/compute/utils.rs b/crates/nano-arrow/src/compute/utils.rs new file mode 100644 index 000000000000..e06acdcd470c --- /dev/null +++ b/crates/nano-arrow/src/compute/utils.rs @@ -0,0 +1,23 @@ +use crate::array::Array; +use crate::bitmap::Bitmap; +use crate::error::{Error, Result}; + +pub fn combine_validities(lhs: Option<&Bitmap>, rhs: Option<&Bitmap>) -> Option { + match (lhs, rhs) { + (Some(lhs), None) => Some(lhs.clone()), + (None, Some(rhs)) => Some(rhs.clone()), + (None, None) => None, + (Some(lhs), Some(rhs)) => Some(lhs & rhs), + } +} + +// Errors iff the two arrays have a different length. +#[inline] +pub fn check_same_len(lhs: &dyn Array, rhs: &dyn Array) -> Result<()> { + if lhs.len() != rhs.len() { + return Err(Error::InvalidArgumentError( + "Arrays must have the same length".to_string(), + )); + } + Ok(()) +} diff --git a/crates/nano-arrow/src/datatypes/field.rs b/crates/nano-arrow/src/datatypes/field.rs new file mode 100644 index 000000000000..489cacb0b5b5 --- /dev/null +++ b/crates/nano-arrow/src/datatypes/field.rs @@ -0,0 +1,96 @@ +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +use super::{DataType, Metadata}; + +/// Represents Arrow's metadata of a "column". +/// +/// A [`Field`] is the closest representation of the traditional "column": a logical type +/// ([`DataType`]) with a name and nullability. +/// A Field has optional [`Metadata`] that can be used to annotate the field with custom metadata. +/// +/// Almost all IO in this crate uses [`Field`] to represent logical information about the data +/// to be serialized. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub struct Field { + /// Its name + pub name: String, + /// Its logical [`DataType`] + pub data_type: DataType, + /// Its nullability + pub is_nullable: bool, + /// Additional custom (opaque) metadata. + pub metadata: Metadata, +} + +impl Field { + /// Creates a new [`Field`]. + pub fn new>(name: T, data_type: DataType, is_nullable: bool) -> Self { + Field { + name: name.into(), + data_type, + is_nullable, + metadata: Default::default(), + } + } + + /// Creates a new [`Field`] with metadata. + #[inline] + pub fn with_metadata(self, metadata: Metadata) -> Self { + Self { + name: self.name, + data_type: self.data_type, + is_nullable: self.is_nullable, + metadata, + } + } + + /// Returns the [`Field`]'s [`DataType`]. + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[cfg(feature = "arrow")] +impl From for arrow_schema::Field { + fn from(value: Field) -> Self { + Self::new(value.name, value.data_type.into(), value.is_nullable) + .with_metadata(value.metadata.into_iter().collect()) + } +} + +#[cfg(feature = "arrow")] +impl From for Field { + fn from(value: arrow_schema::Field) -> Self { + (&value).into() + } +} + +#[cfg(feature = "arrow")] +impl From<&arrow_schema::Field> for Field { + fn from(value: &arrow_schema::Field) -> Self { + let data_type = value.data_type().clone().into(); + let metadata = value + .metadata() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + Self::new(value.name(), data_type, value.is_nullable()).with_metadata(metadata) + } +} + +#[cfg(feature = "arrow")] +impl From for Field { + fn from(value: arrow_schema::FieldRef) -> Self { + value.as_ref().into() + } +} + +#[cfg(feature = "arrow")] +impl From<&arrow_schema::FieldRef> for Field { + fn from(value: &arrow_schema::FieldRef) -> Self { + value.as_ref().into() + } +} diff --git a/crates/nano-arrow/src/datatypes/mod.rs b/crates/nano-arrow/src/datatypes/mod.rs new file mode 100644 index 000000000000..7487af3a0a9a --- /dev/null +++ b/crates/nano-arrow/src/datatypes/mod.rs @@ -0,0 +1,513 @@ +#![forbid(unsafe_code)] +//! Contains all metadata, such as [`PhysicalType`], [`DataType`], [`Field`] and [`Schema`]. + +mod field; +mod physical_type; +mod schema; + +use std::collections::BTreeMap; +use std::sync::Arc; + +pub use field::Field; +pub use physical_type::*; +pub use schema::Schema; +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +/// typedef for [BTreeMap] denoting [`Field`]'s and [`Schema`]'s metadata. +pub type Metadata = BTreeMap; +/// typedef for [Option<(String, Option)>] descr +pub(crate) type Extension = Option<(String, Option)>; + +/// The set of supported logical types in this crate. +/// +/// Each variant uniquely identifies a logical type, which define specific semantics to the data +/// (e.g. how it should be represented). +/// Each variant has a corresponding [`PhysicalType`], obtained via [`DataType::to_physical_type`], +/// which declares the in-memory representation of data. +/// The [`DataType::Extension`] is special in that it augments a [`DataType`] with metadata to support custom types. +/// Use `to_logical_type` to desugar such type and return its corresponding logical type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum DataType { + /// Null type + Null, + /// `true` and `false`. + Boolean, + /// An [`i8`] + Int8, + /// An [`i16`] + Int16, + /// An [`i32`] + Int32, + /// An [`i64`] + Int64, + /// An [`u8`] + UInt8, + /// An [`u16`] + UInt16, + /// An [`u32`] + UInt32, + /// An [`u64`] + UInt64, + /// An 16-bit float + Float16, + /// A [`f32`] + Float32, + /// A [`f64`] + Float64, + /// A [`i64`] representing a timestamp measured in [`TimeUnit`] with an optional timezone. + /// + /// Time is measured as a Unix epoch, counting the seconds from + /// 00:00:00.000 on 1 January 1970, excluding leap seconds, + /// as a 64-bit signed integer. + /// + /// The time zone is a string indicating the name of a time zone, one of: + /// + /// * As used in the Olson time zone database (the "tz database" or + /// "tzdata"), such as "America/New_York" + /// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + /// When the timezone is not specified, the timestamp is considered to have no timezone + /// and is represented _as is_ + Timestamp(TimeUnit, Option), + /// An [`i32`] representing the elapsed time since UNIX epoch (1970-01-01) + /// in days. + Date32, + /// An [`i64`] representing the elapsed time since UNIX epoch (1970-01-01) + /// in milliseconds. Values are evenly divisible by 86400000. + Date64, + /// A 32-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + /// Only [`TimeUnit::Second`] and [`TimeUnit::Millisecond`] are supported on this variant. + Time32(TimeUnit), + /// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + /// Only [`TimeUnit::Microsecond`] and [`TimeUnit::Nanosecond`] are supported on this variant. + Time64(TimeUnit), + /// Measure of elapsed time. This elapsed time is a physical duration (i.e. 1s as defined in S.I.) + Duration(TimeUnit), + /// A "calendar" interval modeling elapsed time that takes into account calendar shifts. + /// For example an interval of 1 day may represent more than 24 hours. + Interval(IntervalUnit), + /// Opaque binary data of variable length whose offsets are represented as [`i32`]. + Binary, + /// Opaque binary data of fixed size. + /// Enum parameter specifies the number of bytes per value. + FixedSizeBinary(usize), + /// Opaque binary data of variable length whose offsets are represented as [`i64`]. + LargeBinary, + /// A variable-length UTF-8 encoded string whose offsets are represented as [`i32`]. + Utf8, + /// A variable-length UTF-8 encoded string whose offsets are represented as [`i64`]. + LargeUtf8, + /// A list of some logical data type whose offsets are represented as [`i32`]. + List(Box), + /// A list of some logical data type with a fixed number of elements. + FixedSizeList(Box, usize), + /// A list of some logical data type whose offsets are represented as [`i64`]. + LargeList(Box), + /// A nested [`DataType`] with a given number of [`Field`]s. + Struct(Vec), + /// A nested datatype that can represent slots of differing types. + /// Third argument represents mode + Union(Vec, Option>, UnionMode), + /// A nested type that is represented as + /// + /// List> + /// + /// In this layout, the keys and values are each respectively contiguous. We do + /// not constrain the key and value types, so the application is responsible + /// for ensuring that the keys are hashable and unique. Whether the keys are sorted + /// may be set in the metadata for this field. + /// + /// In a field with Map type, the field has a child Struct field, which then + /// has two children: key type and the second the value type. The names of the + /// child fields may be respectively "entries", "key", and "value", but this is + /// not enforced. + /// + /// Map + /// ```text + /// - child[0] entries: Struct + /// - child[0] key: K + /// - child[1] value: V + /// ``` + /// Neither the "entries" field nor the "key" field may be nullable. + /// + /// The metadata is structured so that Arrow systems without special handling + /// for Map can make Map an alias for List. The "layout" attribute for the Map + /// field must have the same contents as a List. + Map(Box, bool), + /// A dictionary encoded array (`key_type`, `value_type`), where + /// each array element is an index of `key_type` into an + /// associated dictionary of `value_type`. + /// + /// Dictionary arrays are used to store columns of `value_type` + /// that contain many repeated values using less memory, but with + /// a higher CPU overhead for some operations. + /// + /// This type mostly used to represent low cardinality string + /// arrays or a limited set of primitive types as integers. + /// + /// The `bool` value indicates the `Dictionary` is sorted if set to `true`. + Dictionary(IntegerType, Box, bool), + /// Decimal value with precision and scale + /// precision is the number of digits in the number and + /// scale is the number of decimal places. + /// The number 999.99 has a precision of 5 and scale of 2. + Decimal(usize, usize), + /// Decimal backed by 256 bits + Decimal256(usize, usize), + /// Extension type. + Extension(String, Box, Option), +} + +#[cfg(feature = "arrow")] +impl From for arrow_schema::DataType { + fn from(value: DataType) -> Self { + use arrow_schema::{Field as ArrowField, UnionFields}; + + match value { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 => Self::Int8, + DataType::Int16 => Self::Int16, + DataType::Int32 => Self::Int32, + DataType::Int64 => Self::Int64, + DataType::UInt8 => Self::UInt8, + DataType::UInt16 => Self::UInt16, + DataType::UInt32 => Self::UInt32, + DataType::UInt64 => Self::UInt64, + DataType::Float16 => Self::Float16, + DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float64, + DataType::Timestamp(unit, tz) => Self::Timestamp(unit.into(), tz.map(Into::into)), + DataType::Date32 => Self::Date32, + DataType::Date64 => Self::Date64, + DataType::Time32(unit) => Self::Time32(unit.into()), + DataType::Time64(unit) => Self::Time64(unit.into()), + DataType::Duration(unit) => Self::Duration(unit.into()), + DataType::Interval(unit) => Self::Interval(unit.into()), + DataType::Binary => Self::Binary, + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), + DataType::LargeBinary => Self::LargeBinary, + DataType::Utf8 => Self::Utf8, + DataType::LargeUtf8 => Self::LargeUtf8, + DataType::List(f) => Self::List(Arc::new((*f).into())), + DataType::FixedSizeList(f, size) => { + Self::FixedSizeList(Arc::new((*f).into()), size as _) + }, + DataType::LargeList(f) => Self::LargeList(Arc::new((*f).into())), + DataType::Struct(f) => Self::Struct(f.into_iter().map(ArrowField::from).collect()), + DataType::Union(fields, Some(ids), mode) => { + let ids = ids.into_iter().map(|x| x as _); + let fields = fields.into_iter().map(ArrowField::from); + Self::Union(UnionFields::new(ids, fields), mode.into()) + }, + DataType::Union(fields, None, mode) => { + let ids = 0..fields.len() as i8; + let fields = fields.into_iter().map(ArrowField::from); + Self::Union(UnionFields::new(ids, fields), mode.into()) + }, + DataType::Map(f, ordered) => Self::Map(Arc::new((*f).into()), ordered), + DataType::Dictionary(key, value, _) => Self::Dictionary( + Box::new(DataType::from(key).into()), + Box::new((*value).into()), + ), + DataType::Decimal(precision, scale) => Self::Decimal128(precision as _, scale as _), + DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), + DataType::Extension(_, d, _) => (*d).into(), + } + } +} + +#[cfg(feature = "arrow")] +impl From for DataType { + fn from(value: arrow_schema::DataType) -> Self { + use arrow_schema::DataType; + match value { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 => Self::Int8, + DataType::Int16 => Self::Int16, + DataType::Int32 => Self::Int32, + DataType::Int64 => Self::Int64, + DataType::UInt8 => Self::UInt8, + DataType::UInt16 => Self::UInt16, + DataType::UInt32 => Self::UInt32, + DataType::UInt64 => Self::UInt64, + DataType::Float16 => Self::Float16, + DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float64, + DataType::Timestamp(unit, tz) => { + Self::Timestamp(unit.into(), tz.map(|x| x.to_string())) + }, + DataType::Date32 => Self::Date32, + DataType::Date64 => Self::Date64, + DataType::Time32(unit) => Self::Time32(unit.into()), + DataType::Time64(unit) => Self::Time64(unit.into()), + DataType::Duration(unit) => Self::Duration(unit.into()), + DataType::Interval(unit) => Self::Interval(unit.into()), + DataType::Binary => Self::Binary, + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(size as _), + DataType::LargeBinary => Self::LargeBinary, + DataType::Utf8 => Self::Utf8, + DataType::LargeUtf8 => Self::LargeUtf8, + DataType::List(f) => Self::List(Box::new(f.into())), + DataType::FixedSizeList(f, size) => Self::FixedSizeList(Box::new(f.into()), size as _), + DataType::LargeList(f) => Self::LargeList(Box::new(f.into())), + DataType::Struct(f) => Self::Struct(f.into_iter().map(Into::into).collect()), + DataType::Union(fields, mode) => { + let ids = fields.iter().map(|(x, _)| x as _).collect(); + let fields = fields.iter().map(|(_, f)| f.into()).collect(); + Self::Union(fields, Some(ids), mode.into()) + }, + DataType::Map(f, ordered) => Self::Map(Box::new(f.into()), ordered), + DataType::Dictionary(key, value) => { + let key = match *key { + DataType::Int8 => IntegerType::Int8, + DataType::Int16 => IntegerType::Int16, + DataType::Int32 => IntegerType::Int32, + DataType::Int64 => IntegerType::Int64, + DataType::UInt8 => IntegerType::UInt8, + DataType::UInt16 => IntegerType::UInt16, + DataType::UInt32 => IntegerType::UInt32, + DataType::UInt64 => IntegerType::UInt64, + d => panic!("illegal dictionary key type: {d}"), + }; + Self::Dictionary(key, Box::new((*value).into()), false) + }, + DataType::Decimal128(precision, scale) => Self::Decimal(precision as _, scale as _), + DataType::Decimal256(precision, scale) => Self::Decimal256(precision as _, scale as _), + DataType::RunEndEncoded(_, _) => panic!("Run-end encoding not supported by arrow2"), + } + } +} + +/// Mode of [`DataType::Union`] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum UnionMode { + /// Dense union + Dense, + /// Sparse union + Sparse, +} + +#[cfg(feature = "arrow")] +impl From for arrow_schema::UnionMode { + fn from(value: UnionMode) -> Self { + match value { + UnionMode::Dense => Self::Dense, + UnionMode::Sparse => Self::Sparse, + } + } +} + +#[cfg(feature = "arrow")] +impl From for UnionMode { + fn from(value: arrow_schema::UnionMode) -> Self { + match value { + arrow_schema::UnionMode::Dense => Self::Dense, + arrow_schema::UnionMode::Sparse => Self::Sparse, + } + } +} + +impl UnionMode { + /// Constructs a [`UnionMode::Sparse`] if the input bool is true, + /// or otherwise constructs a [`UnionMode::Dense`] + pub fn sparse(is_sparse: bool) -> Self { + if is_sparse { + Self::Sparse + } else { + Self::Dense + } + } + + /// Returns whether the mode is sparse + pub fn is_sparse(&self) -> bool { + matches!(self, Self::Sparse) + } + + /// Returns whether the mode is dense + pub fn is_dense(&self) -> bool { + matches!(self, Self::Dense) + } +} + +/// The time units defined in Arrow. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum TimeUnit { + /// Time in seconds. + Second, + /// Time in milliseconds. + Millisecond, + /// Time in microseconds. + Microsecond, + /// Time in nanoseconds. + Nanosecond, +} + +#[cfg(feature = "arrow")] +impl From for arrow_schema::TimeUnit { + fn from(value: TimeUnit) -> Self { + match value { + TimeUnit::Nanosecond => Self::Nanosecond, + TimeUnit::Millisecond => Self::Millisecond, + TimeUnit::Microsecond => Self::Microsecond, + TimeUnit::Second => Self::Second, + } + } +} + +#[cfg(feature = "arrow")] +impl From for TimeUnit { + fn from(value: arrow_schema::TimeUnit) -> Self { + match value { + arrow_schema::TimeUnit::Nanosecond => Self::Nanosecond, + arrow_schema::TimeUnit::Millisecond => Self::Millisecond, + arrow_schema::TimeUnit::Microsecond => Self::Microsecond, + arrow_schema::TimeUnit::Second => Self::Second, + } + } +} + +/// Interval units defined in Arrow +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum IntervalUnit { + /// The number of elapsed whole months. + YearMonth, + /// The number of elapsed days and milliseconds, + /// stored as 2 contiguous `i32` + DayTime, + /// The number of elapsed months (i32), days (i32) and nanoseconds (i64). + MonthDayNano, +} + +#[cfg(feature = "arrow")] +impl From for arrow_schema::IntervalUnit { + fn from(value: IntervalUnit) -> Self { + match value { + IntervalUnit::YearMonth => Self::YearMonth, + IntervalUnit::DayTime => Self::DayTime, + IntervalUnit::MonthDayNano => Self::MonthDayNano, + } + } +} + +#[cfg(feature = "arrow")] +impl From for IntervalUnit { + fn from(value: arrow_schema::IntervalUnit) -> Self { + match value { + arrow_schema::IntervalUnit::YearMonth => Self::YearMonth, + arrow_schema::IntervalUnit::DayTime => Self::DayTime, + arrow_schema::IntervalUnit::MonthDayNano => Self::MonthDayNano, + } + } +} + +impl DataType { + /// the [`PhysicalType`] of this [`DataType`]. + pub fn to_physical_type(&self) -> PhysicalType { + use DataType::*; + match self { + Null => PhysicalType::Null, + Boolean => PhysicalType::Boolean, + Int8 => PhysicalType::Primitive(PrimitiveType::Int8), + Int16 => PhysicalType::Primitive(PrimitiveType::Int16), + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + PhysicalType::Primitive(PrimitiveType::Int32) + }, + Int64 | Date64 | Timestamp(_, _) | Time64(_) | Duration(_) => { + PhysicalType::Primitive(PrimitiveType::Int64) + }, + Decimal(_, _) => PhysicalType::Primitive(PrimitiveType::Int128), + Decimal256(_, _) => PhysicalType::Primitive(PrimitiveType::Int256), + UInt8 => PhysicalType::Primitive(PrimitiveType::UInt8), + UInt16 => PhysicalType::Primitive(PrimitiveType::UInt16), + UInt32 => PhysicalType::Primitive(PrimitiveType::UInt32), + UInt64 => PhysicalType::Primitive(PrimitiveType::UInt64), + Float16 => PhysicalType::Primitive(PrimitiveType::Float16), + Float32 => PhysicalType::Primitive(PrimitiveType::Float32), + Float64 => PhysicalType::Primitive(PrimitiveType::Float64), + Interval(IntervalUnit::DayTime) => PhysicalType::Primitive(PrimitiveType::DaysMs), + Interval(IntervalUnit::MonthDayNano) => { + PhysicalType::Primitive(PrimitiveType::MonthDayNano) + }, + Binary => PhysicalType::Binary, + FixedSizeBinary(_) => PhysicalType::FixedSizeBinary, + LargeBinary => PhysicalType::LargeBinary, + Utf8 => PhysicalType::Utf8, + LargeUtf8 => PhysicalType::LargeUtf8, + List(_) => PhysicalType::List, + FixedSizeList(_, _) => PhysicalType::FixedSizeList, + LargeList(_) => PhysicalType::LargeList, + Struct(_) => PhysicalType::Struct, + Union(_, _, _) => PhysicalType::Union, + Map(_, _) => PhysicalType::Map, + Dictionary(key, _, _) => PhysicalType::Dictionary(*key), + Extension(_, key, _) => key.to_physical_type(), + } + } + + /// Returns `&self` for all but [`DataType::Extension`]. For [`DataType::Extension`], + /// (recursively) returns the inner [`DataType`]. + /// Never returns the variant [`DataType::Extension`]. + pub fn to_logical_type(&self) -> &DataType { + use DataType::*; + match self { + Extension(_, key, _) => key.to_logical_type(), + _ => self, + } + } +} + +impl From for DataType { + fn from(item: IntegerType) -> Self { + match item { + IntegerType::Int8 => DataType::Int8, + IntegerType::Int16 => DataType::Int16, + IntegerType::Int32 => DataType::Int32, + IntegerType::Int64 => DataType::Int64, + IntegerType::UInt8 => DataType::UInt8, + IntegerType::UInt16 => DataType::UInt16, + IntegerType::UInt32 => DataType::UInt32, + IntegerType::UInt64 => DataType::UInt64, + } + } +} + +impl From for DataType { + fn from(item: PrimitiveType) -> Self { + match item { + PrimitiveType::Int8 => DataType::Int8, + PrimitiveType::Int16 => DataType::Int16, + PrimitiveType::Int32 => DataType::Int32, + PrimitiveType::Int64 => DataType::Int64, + PrimitiveType::UInt8 => DataType::UInt8, + PrimitiveType::UInt16 => DataType::UInt16, + PrimitiveType::UInt32 => DataType::UInt32, + PrimitiveType::UInt64 => DataType::UInt64, + PrimitiveType::Int128 => DataType::Decimal(32, 32), + PrimitiveType::Int256 => DataType::Decimal256(32, 32), + PrimitiveType::Float16 => DataType::Float16, + PrimitiveType::Float32 => DataType::Float32, + PrimitiveType::Float64 => DataType::Float64, + PrimitiveType::DaysMs => DataType::Interval(IntervalUnit::DayTime), + PrimitiveType::MonthDayNano => DataType::Interval(IntervalUnit::MonthDayNano), + } + } +} + +/// typedef for [`Arc`]. +pub type SchemaRef = Arc; + +/// support get extension for metadata +pub fn get_extension(metadata: &Metadata) -> Extension { + if let Some(name) = metadata.get("ARROW:extension:name") { + let metadata = metadata.get("ARROW:extension:metadata").cloned(); + Some((name.clone(), metadata)) + } else { + None + } +} diff --git a/crates/nano-arrow/src/datatypes/physical_type.rs b/crates/nano-arrow/src/datatypes/physical_type.rs new file mode 100644 index 000000000000..1e57fcf936bc --- /dev/null +++ b/crates/nano-arrow/src/datatypes/physical_type.rs @@ -0,0 +1,76 @@ +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +pub use crate::types::PrimitiveType; + +/// The set of physical types: unique in-memory representations of an Arrow array. +/// A physical type has a one-to-many relationship with a [`crate::datatypes::DataType`] and +/// a one-to-one mapping to each struct in this crate that implements [`crate::array::Array`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum PhysicalType { + /// A Null with no allocation. + Null, + /// A boolean represented as a single bit. + Boolean, + /// An array where each slot has a known compile-time size. + Primitive(PrimitiveType), + /// Opaque binary data of variable length. + Binary, + /// Opaque binary data of fixed size. + FixedSizeBinary, + /// Opaque binary data of variable length and 64-bit offsets. + LargeBinary, + /// A variable-length string in Unicode with UTF-8 encoding. + Utf8, + /// A variable-length string in Unicode with UFT-8 encoding and 64-bit offsets. + LargeUtf8, + /// A list of some data type with variable length. + List, + /// A list of some data type with fixed length. + FixedSizeList, + /// A list of some data type with variable length and 64-bit offsets. + LargeList, + /// A nested type that contains an arbitrary number of fields. + Struct, + /// A nested type that represents slots of differing types. + Union, + /// A nested type. + Map, + /// A dictionary encoded array by `IntegerType`. + Dictionary(IntegerType), +} + +impl PhysicalType { + /// Whether this physical type equals [`PhysicalType::Primitive`] of type `primitive`. + pub fn eq_primitive(&self, primitive: PrimitiveType) -> bool { + if let Self::Primitive(o) = self { + o == &primitive + } else { + false + } + } +} + +/// the set of valid indices types of a dictionary-encoded Array. +/// Each type corresponds to a variant of [`crate::array::DictionaryArray`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum IntegerType { + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, +} diff --git a/crates/nano-arrow/src/datatypes/schema.rs b/crates/nano-arrow/src/datatypes/schema.rs new file mode 100644 index 000000000000..d01f1937d2ed --- /dev/null +++ b/crates/nano-arrow/src/datatypes/schema.rs @@ -0,0 +1,60 @@ +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +use super::{Field, Metadata}; + +/// An ordered sequence of [`Field`]s with associated [`Metadata`]. +/// +/// [`Schema`] is an abstraction used to read from, and write to, Arrow IPC format, +/// Apache Parquet, and Apache Avro. All these formats have a concept of a schema +/// with fields and metadata. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub struct Schema { + /// The fields composing this schema. + pub fields: Vec, + /// Optional metadata. + pub metadata: Metadata, +} + +impl Schema { + /// Attaches a [`Metadata`] to [`Schema`] + #[inline] + pub fn with_metadata(self, metadata: Metadata) -> Self { + Self { + fields: self.fields, + metadata, + } + } + + /// Returns a new [`Schema`] with a subset of all fields whose `predicate` + /// evaluates to true. + pub fn filter bool>(self, predicate: F) -> Self { + let fields = self + .fields + .into_iter() + .enumerate() + .filter_map(|(index, f)| { + if (predicate)(index, &f) { + Some(f) + } else { + None + } + }) + .collect(); + + Schema { + fields, + metadata: self.metadata, + } + } +} + +impl From> for Schema { + fn from(fields: Vec) -> Self { + Self { + fields, + ..Default::default() + } + } +} diff --git a/crates/nano-arrow/src/doc/lib.md b/crates/nano-arrow/src/doc/lib.md new file mode 100644 index 000000000000..a1b57945c020 --- /dev/null +++ b/crates/nano-arrow/src/doc/lib.md @@ -0,0 +1,87 @@ +Welcome to arrow2's documentation. Thanks for checking it out! + +This is a library for efficient in-memory data operations with +[Arrow in-memory format](https://arrow.apache.org/docs/format/Columnar.html). +It is a re-write from the bottom up of the official `arrow` crate with soundness +and type safety in mind. + +Check out [the guide](https://jorgecarleitao.github.io/arrow2/main/guide/) for an introduction. +Below is an example of some of the things you can do with it: + +```rust +use std::sync::Arc; + +use arrow2::array::*; +use arrow2::datatypes::{Field, DataType, Schema}; +use arrow2::compute::arithmetics; +use arrow2::error::Result; +use arrow2::io::parquet::write::*; +use arrow2::chunk::Chunk; + +fn main() -> Result<()> { + // declare arrays + let a = Int32Array::from(&[Some(1), None, Some(3)]); + let b = Int32Array::from(&[Some(2), None, Some(6)]); + + // compute (probably the fastest implementation of a nullable op you can find out there) + let c = arithmetics::basic::mul_scalar(&a, &2); + assert_eq!(c, b); + + // declare a schema with fields + let schema = Schema::from(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]); + + // declare chunk + let chunk = Chunk::new(vec![a.arced(), b.arced()]); + + // write to parquet (probably the fastest implementation of writing to parquet out there) + + let options = WriteOptions { + write_statistics: true, + compression: CompressionOptions::Snappy, + version: Version::V1, + data_pagesize_limit: None, + }; + + let row_groups = RowGroupIterator::try_new( + vec![Ok(chunk)].into_iter(), + &schema, + options, + vec![vec![Encoding::Plain], vec![Encoding::Plain]], + )?; + + // anything implementing `std::io::Write` works + let mut file = vec![]; + + let mut writer = FileWriter::try_new(file, schema, options)?; + + // Write the file. + for group in row_groups { + writer.write(group?)?; + } + let _ = writer.end(None)?; + Ok(()) +} +``` + +## Cargo features + +This crate has a significant number of cargo features to reduce compilation +time and number of dependencies. The feature `"full"` activates most +functionality, such as: + +- `io_ipc`: to interact with the Arrow IPC format +- `io_ipc_compression`: to read and write compressed Arrow IPC (v2) +- `io_csv` to read and write CSV +- `io_json` to read and write JSON +- `io_flight` to read and write to Arrow's Flight protocol +- `io_parquet` to read and write parquet +- `io_parquet_compression` to read and write compressed parquet +- `io_print` to write batches to formatted ASCII tables +- `compute` to operate on arrays (addition, sum, sort, etc.) + +The feature `simd` (not part of `full`) produces more explicit SIMD instructions +via [`std::simd`](https://doc.rust-lang.org/nightly/std/simd/index.html), but requires the +nightly channel. diff --git a/crates/nano-arrow/src/error.rs b/crates/nano-arrow/src/error.rs new file mode 100644 index 000000000000..e6455d6f055d --- /dev/null +++ b/crates/nano-arrow/src/error.rs @@ -0,0 +1,100 @@ +//! Defines [`Error`], representing all errors returned by this crate. +use std::fmt::{Debug, Display, Formatter}; + +/// Enum with all errors in this crate. +#[derive(Debug)] +#[non_exhaustive] +pub enum Error { + /// Returned when functionality is not yet available. + NotYetImplemented(String), + /// Wrapper for an error triggered by a dependency + External(String, Box), + /// Wrapper for IO errors + Io(std::io::Error), + /// When an invalid argument is passed to a function. + InvalidArgumentError(String), + /// Error during import or export to/from a format + ExternalFormat(String), + /// Whenever pushing to a container fails because it does not support more entries. + /// The solution is usually to use a higher-capacity container-backing type. + Overflow, + /// Whenever incoming data from the C data interface, IPC or Flight does not fulfil the Arrow specification. + OutOfSpec(String), +} + +impl Error { + /// Wraps an external error in an `Error`. + pub fn from_external_error(error: impl std::error::Error + Send + Sync + 'static) -> Self { + Self::External("".to_string(), Box::new(error)) + } + + pub(crate) fn oos>(msg: A) -> Self { + Self::OutOfSpec(msg.into()) + } + + #[allow(dead_code)] + pub(crate) fn nyi>(msg: A) -> Self { + Self::NotYetImplemented(msg.into()) + } +} + +impl From<::std::io::Error> for Error { + fn from(error: std::io::Error) -> Self { + Error::Io(error) + } +} + +impl From for Error { + fn from(error: std::str::Utf8Error) -> Self { + Error::External("".to_string(), Box::new(error)) + } +} + +impl From for Error { + fn from(error: std::string::FromUtf8Error) -> Self { + Error::External("".to_string(), Box::new(error)) + } +} + +impl From for Error { + fn from(error: simdutf8::basic::Utf8Error) -> Self { + Error::External("".to_string(), Box::new(error)) + } +} + +impl From for Error { + fn from(_: std::collections::TryReserveError) -> Error { + Error::Overflow + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Error::NotYetImplemented(source) => { + write!(f, "Not yet implemented: {}", &source) + }, + Error::External(message, source) => { + write!(f, "External error{}: {}", message, &source) + }, + Error::Io(desc) => write!(f, "Io error: {desc}"), + Error::InvalidArgumentError(desc) => { + write!(f, "Invalid argument error: {desc}") + }, + Error::ExternalFormat(desc) => { + write!(f, "External format error: {desc}") + }, + Error::Overflow => { + write!(f, "Operation overflew the backing container.") + }, + Error::OutOfSpec(message) => { + write!(f, "{message}") + }, + } + } +} + +impl std::error::Error for Error {} + +/// Typedef for a [`std::result::Result`] of an [`Error`]. +pub type Result = std::result::Result; diff --git a/crates/nano-arrow/src/ffi/array.rs b/crates/nano-arrow/src/ffi/array.rs new file mode 100644 index 000000000000..f87f7e66a10c --- /dev/null +++ b/crates/nano-arrow/src/ffi/array.rs @@ -0,0 +1,568 @@ +//! Contains functionality to load an ArrayData from the C Data Interface +use std::sync::Arc; + +use super::ArrowArray; +use crate::array::*; +use crate::bitmap::utils::{bytes_for, count_zeros}; +use crate::bitmap::Bitmap; +use crate::buffer::{Buffer, Bytes, BytesAllocator}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::{Error, Result}; +use crate::ffi::schema::get_child; +use crate::types::NativeType; + +/// Reads a valid `ffi` interface into a `Box` +/// # Errors +/// If and only if: +/// * the interface is not valid (e.g. a null pointer) +pub unsafe fn try_from(array: A) -> Result> { + use PhysicalType::*; + Ok(match array.data_type().to_physical_type() { + Null => Box::new(NullArray::try_from_ffi(array)?), + Boolean => Box::new(BooleanArray::try_from_ffi(array)?), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(PrimitiveArray::<$T>::try_from_ffi(array)?) + }), + Utf8 => Box::new(Utf8Array::::try_from_ffi(array)?), + LargeUtf8 => Box::new(Utf8Array::::try_from_ffi(array)?), + Binary => Box::new(BinaryArray::::try_from_ffi(array)?), + LargeBinary => Box::new(BinaryArray::::try_from_ffi(array)?), + FixedSizeBinary => Box::new(FixedSizeBinaryArray::try_from_ffi(array)?), + List => Box::new(ListArray::::try_from_ffi(array)?), + LargeList => Box::new(ListArray::::try_from_ffi(array)?), + FixedSizeList => Box::new(FixedSizeListArray::try_from_ffi(array)?), + Struct => Box::new(StructArray::try_from_ffi(array)?), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + Box::new(DictionaryArray::<$T>::try_from_ffi(array)?) + }) + }, + Union => Box::new(UnionArray::try_from_ffi(array)?), + Map => Box::new(MapArray::try_from_ffi(array)?), + }) +} + +// Sound because the arrow specification does not allow multiple implementations +// to change this struct +// This is intrinsically impossible to prove because the implementations agree +// on this as part of the Arrow specification +unsafe impl Send for ArrowArray {} +unsafe impl Sync for ArrowArray {} + +impl Drop for ArrowArray { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +// callback used to drop [ArrowArray] when it is exported +unsafe extern "C" fn c_release_array(array: *mut ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it + let private = Box::from_raw(array.private_data as *mut PrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary_ptr { + let _ = Box::from_raw(ptr); + } + + array.release = None; +} + +#[allow(dead_code)] +struct PrivateData { + array: Box, + buffers_ptr: Box<[*const std::os::raw::c_void]>, + children_ptr: Box<[*mut ArrowArray]>, + dictionary_ptr: Option<*mut ArrowArray>, +} + +impl ArrowArray { + /// creates a new `ArrowArray` from existing data. + /// # Safety + /// This method releases `buffers`. Consumers of this struct *must* call `release` before + /// releasing this struct, or contents in `buffers` leak. + pub(crate) fn new(array: Box) -> Self { + let (offset, buffers, children, dictionary) = + offset_buffers_children_dictionary(array.as_ref()); + + let buffers_ptr = buffers + .iter() + .map(|maybe_buffer| match maybe_buffer { + Some(b) => *b as *const std::os::raw::c_void, + None => std::ptr::null(), + }) + .collect::>(); + let n_buffers = buffers.len() as i64; + + let children_ptr = children + .into_iter() + .map(|child| Box::into_raw(Box::new(ArrowArray::new(child)))) + .collect::>(); + let n_children = children_ptr.len() as i64; + + let dictionary_ptr = + dictionary.map(|array| Box::into_raw(Box::new(ArrowArray::new(array)))); + + let length = array.len() as i64; + let null_count = array.null_count() as i64; + + let mut private_data = Box::new(PrivateData { + array, + buffers_ptr, + children_ptr, + dictionary_ptr, + }); + + Self { + length, + null_count, + offset: offset as i64, + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children_ptr.as_mut_ptr(), + dictionary: private_data.dictionary_ptr.unwrap_or(std::ptr::null_mut()), + release: Some(c_release_array), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } + } + + /// creates an empty [`ArrowArray`], which can be used to import data into + pub fn empty() -> Self { + Self { + length: 0, + null_count: 0, + offset: 0, + n_buffers: 0, + n_children: 0, + buffers: std::ptr::null_mut(), + children: std::ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// the length of the array + pub(crate) fn len(&self) -> usize { + self.length as usize + } + + /// the offset of the array + pub(crate) fn offset(&self) -> usize { + self.offset as usize + } + + /// the null count of the array + pub(crate) fn null_count(&self) -> usize { + self.null_count as usize + } +} + +/// # Safety +/// The caller must ensure that the buffer at index `i` is not mutably shared. +unsafe fn get_buffer_ptr( + array: &ArrowArray, + data_type: &DataType, + index: usize, +) -> Result<*mut T> { + if array.buffers.is_null() { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} must have non-null buffers" + ))); + } + + if array + .buffers + .align_offset(std::mem::align_of::<*mut *const u8>()) + != 0 + { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} + must have buffer {index} aligned to type {}", + std::any::type_name::<*mut *const u8>() + ))); + } + let buffers = array.buffers as *mut *const u8; + + if index >= array.n_buffers as usize { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} + must have buffer {index}." + ))); + } + + let ptr = *buffers.add(index); + if ptr.is_null() { + return Err(Error::oos(format!( + "An array of type {data_type:?} + must have a non-null buffer {index}" + ))); + } + + // note: we can't prove that this pointer is not mutably shared - part of the safety invariant + Ok(ptr as *mut T) +} + +/// returns the buffer `i` of `array` interpreted as a [`Buffer`]. +/// # Safety +/// This function is safe iff: +/// * the buffers up to position `index` are valid for the declared length +/// * the buffers' pointers are not mutably shared for the lifetime of `owner` +unsafe fn create_buffer( + array: &ArrowArray, + data_type: &DataType, + owner: InternalArrowArray, + index: usize, +) -> Result> { + let len = buffer_len(array, data_type, index)?; + + if len == 0 { + return Ok(Buffer::new()); + } + + let offset = buffer_offset(array, data_type, index); + let ptr: *mut T = get_buffer_ptr(array, data_type, index)?; + + // We have to check alignment. + // This is the zero-copy path. + if ptr.align_offset(std::mem::align_of::()) == 0 { + let bytes = Bytes::from_foreign(ptr, len, BytesAllocator::InternalArrowArray(owner)); + Ok(Buffer::from_bytes(bytes).sliced(offset, len - offset)) + } + // This is the path where alignment isn't correct. + // We copy the data to a new vec + else { + let buf = std::slice::from_raw_parts(ptr, len - offset).to_vec(); + Ok(Buffer::from(buf)) + } +} + +/// returns the buffer `i` of `array` interpreted as a [`Bitmap`]. +/// # Safety +/// This function is safe iff: +/// * the buffer at position `index` is valid for the declared length +/// * the buffers' pointer is not mutable for the lifetime of `owner` +unsafe fn create_bitmap( + array: &ArrowArray, + data_type: &DataType, + owner: InternalArrowArray, + index: usize, + // if this is the validity bitmap + // we can use the null count directly + is_validity: bool, +) -> Result { + let len: usize = array.length.try_into().expect("length to fit in `usize`"); + if len == 0 { + return Ok(Bitmap::new()); + } + let ptr = get_buffer_ptr(array, data_type, index)?; + + // Pointer of u8 has alignment 1, so we don't have to check alignment. + + let offset: usize = array.offset.try_into().expect("offset to fit in `usize`"); + let bytes_len = bytes_for(offset + len); + let bytes = Bytes::from_foreign(ptr, bytes_len, BytesAllocator::InternalArrowArray(owner)); + + let null_count: usize = if is_validity { + array.null_count() + } else { + count_zeros(bytes.as_ref(), offset, len) + }; + Bitmap::from_inner(Arc::new(bytes), offset, len, null_count) +} + +fn buffer_offset(array: &ArrowArray, data_type: &DataType, i: usize) -> usize { + use PhysicalType::*; + match (data_type.to_physical_type(), i) { + (LargeUtf8, 2) | (LargeBinary, 2) | (Utf8, 2) | (Binary, 2) => 0, + (FixedSizeBinary, 1) => { + if let DataType::FixedSizeBinary(size) = data_type.to_logical_type() { + let offset: usize = array.offset.try_into().expect("Offset to fit in `usize`"); + offset * *size + } else { + unreachable!() + } + }, + _ => array.offset.try_into().expect("Offset to fit in `usize`"), + } +} + +/// Returns the length, in slots, of the buffer `i` (indexed according to the C data interface) +unsafe fn buffer_len(array: &ArrowArray, data_type: &DataType, i: usize) -> Result { + Ok(match (data_type.to_physical_type(), i) { + (PhysicalType::FixedSizeBinary, 1) => { + if let DataType::FixedSizeBinary(size) = data_type.to_logical_type() { + *size * (array.offset as usize + array.length as usize) + } else { + unreachable!() + } + }, + (PhysicalType::FixedSizeList, 1) => { + if let DataType::FixedSizeList(_, size) = data_type.to_logical_type() { + *size * (array.offset as usize + array.length as usize) + } else { + unreachable!() + } + }, + (PhysicalType::Utf8, 1) + | (PhysicalType::LargeUtf8, 1) + | (PhysicalType::Binary, 1) + | (PhysicalType::LargeBinary, 1) + | (PhysicalType::List, 1) + | (PhysicalType::LargeList, 1) + | (PhysicalType::Map, 1) => { + // the len of the offset buffer (buffer 1) equals length + 1 + array.offset as usize + array.length as usize + 1 + }, + (PhysicalType::Utf8, 2) | (PhysicalType::Binary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = buffer_len(array, data_type, 1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; + // interpret as i32 + let offset_buffer = offset_buffer as *const i32; + // get last offset + + (unsafe { *offset_buffer.add(len - 1) }) as usize + }, + (PhysicalType::LargeUtf8, 2) | (PhysicalType::LargeBinary, 2) => { + // the len of the data buffer (buffer 2) equals the last value of the offset buffer (buffer 1) + let len = buffer_len(array, data_type, 1)?; + // first buffer is the null buffer => add(1) + let offset_buffer = unsafe { *(array.buffers as *mut *const u8).add(1) }; + // interpret as i64 + let offset_buffer = offset_buffer as *const i64; + // get last offset + (unsafe { *offset_buffer.add(len - 1) }) as usize + }, + // buffer len of primitive types + _ => array.offset as usize + array.length as usize, + }) +} + +/// Safety +/// This function is safe iff: +/// * `array.children` at `index` is valid +/// * `array.children` is not mutably shared for the lifetime of `parent` +/// * the pointer of `array.children` at `index` is valid +/// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` +unsafe fn create_child( + array: &ArrowArray, + data_type: &DataType, + parent: InternalArrowArray, + index: usize, +) -> Result> { + let data_type = get_child(data_type, index)?; + + // catch what we can + if array.children.is_null() { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} must have non-null children" + ))); + } + + if index >= array.n_children as usize { + return Err(Error::oos(format!( + "An ArrowArray of type {data_type:?} + must have child {index}." + ))); + } + + // Safety - part of the invariant + let arr_ptr = unsafe { *array.children.add(index) }; + + // catch what we can + if arr_ptr.is_null() { + return Err(Error::oos(format!( + "An array of type {data_type:?} + must have a non-null child {index}" + ))); + } + + // Safety - invariant of this function + let arr_ptr = unsafe { &*arr_ptr }; + Ok(ArrowArrayChild::new(arr_ptr, data_type, parent)) +} + +/// Safety +/// This function is safe iff: +/// * `array.dictionary` is valid +/// * `array.dictionary` is not mutably shared for the lifetime of `parent` +unsafe fn create_dictionary( + array: &ArrowArray, + data_type: &DataType, + parent: InternalArrowArray, +) -> Result>> { + if let DataType::Dictionary(_, values, _) = data_type { + let data_type = values.as_ref().clone(); + // catch what we can + if array.dictionary.is_null() { + return Err(Error::oos(format!( + "An array of type {data_type:?} + must have a non-null dictionary" + ))); + } + + // safety: part of the invariant + let array = unsafe { &*array.dictionary }; + Ok(Some(ArrowArrayChild::new(array, data_type, parent))) + } else { + Ok(None) + } +} + +pub trait ArrowArrayRef: std::fmt::Debug { + fn owner(&self) -> InternalArrowArray { + (*self.parent()).clone() + } + + /// returns the null bit buffer. + /// Rust implementation uses a buffer that is not part of the array of buffers. + /// The C Data interface's null buffer is part of the array of buffers. + /// # Safety + /// The caller must guarantee that the buffer `index` corresponds to a bitmap. + /// This function assumes that the bitmap created from FFI is valid; this is impossible to prove. + unsafe fn validity(&self) -> Result> { + if self.array().null_count() == 0 { + Ok(None) + } else { + create_bitmap(self.array(), self.data_type(), self.owner(), 0, true).map(Some) + } + } + + /// # Safety + /// The caller must guarantee that the buffer `index` corresponds to a buffer. + /// This function assumes that the buffer created from FFI is valid; this is impossible to prove. + unsafe fn buffer(&self, index: usize) -> Result> { + create_buffer::(self.array(), self.data_type(), self.owner(), index) + } + + /// # Safety + /// This function is safe iff: + /// * the buffer at position `index` is valid for the declared length + /// * the buffers' pointer is not mutable for the lifetime of `owner` + unsafe fn bitmap(&self, index: usize) -> Result { + create_bitmap(self.array(), self.data_type(), self.owner(), index, false) + } + + /// # Safety + /// * `array.children` at `index` is valid + /// * `array.children` is not mutably shared for the lifetime of `parent` + /// * the pointer of `array.children` at `index` is valid + /// * the pointer of `array.children` at `index` is not mutably shared for the lifetime of `parent` + unsafe fn child(&self, index: usize) -> Result { + create_child(self.array(), self.data_type(), self.parent().clone(), index) + } + + unsafe fn dictionary(&self) -> Result> { + create_dictionary(self.array(), self.data_type(), self.parent().clone()) + } + + fn n_buffers(&self) -> usize; + + fn parent(&self) -> &InternalArrowArray; + fn array(&self) -> &ArrowArray; + fn data_type(&self) -> &DataType; +} + +/// Struct used to move an Array from and to the C Data Interface. +/// Its main responsibility is to expose functionality that requires +/// both [ArrowArray] and [ArrowSchema]. +/// +/// This struct has two main paths: +/// +/// ## Import from the C Data Interface +/// * [InternalArrowArray::empty] to allocate memory to be filled by an external call +/// * [InternalArrowArray::try_from_raw] to consume two non-null allocated pointers +/// ## Export to the C Data Interface +/// * [InternalArrowArray::try_new] to create a new [InternalArrowArray] from Rust-specific information +/// * [InternalArrowArray::into_raw] to expose two pointers for [ArrowArray] and [ArrowSchema]. +/// +/// # Safety +/// Whoever creates this struct is responsible for releasing their resources. Specifically, +/// consumers *must* call [InternalArrowArray::into_raw] and take ownership of the individual pointers, +/// calling [ArrowArray::release] and [ArrowSchema::release] accordingly. +/// +/// Furthermore, this struct assumes that the incoming data agrees with the C data interface. +#[derive(Debug, Clone)] +pub struct InternalArrowArray { + // Arc is used for sharability since this is immutable + array: Arc, + // Arced to reduce cost of cloning + data_type: Arc, +} + +impl InternalArrowArray { + pub fn new(array: ArrowArray, data_type: DataType) -> Self { + Self { + array: Arc::new(array), + data_type: Arc::new(data_type), + } + } +} + +impl ArrowArrayRef for InternalArrowArray { + /// the data_type as declared in the schema + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn parent(&self) -> &InternalArrowArray { + self + } + + fn array(&self) -> &ArrowArray { + self.array.as_ref() + } + + fn n_buffers(&self) -> usize { + self.array.n_buffers as usize + } +} + +#[derive(Debug)] +pub struct ArrowArrayChild<'a> { + array: &'a ArrowArray, + data_type: DataType, + parent: InternalArrowArray, +} + +impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { + /// the data_type as declared in the schema + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn parent(&self) -> &InternalArrowArray { + &self.parent + } + + fn array(&self) -> &ArrowArray { + self.array + } + + fn n_buffers(&self) -> usize { + self.array.n_buffers as usize + } +} + +impl<'a> ArrowArrayChild<'a> { + fn new(array: &'a ArrowArray, data_type: DataType, parent: InternalArrowArray) -> Self { + Self { + array, + data_type, + parent, + } + } +} diff --git a/crates/nano-arrow/src/ffi/bridge.rs b/crates/nano-arrow/src/ffi/bridge.rs new file mode 100644 index 000000000000..7a7b9a86ca3a --- /dev/null +++ b/crates/nano-arrow/src/ffi/bridge.rs @@ -0,0 +1,39 @@ +use crate::array::*; + +macro_rules! ffi_dyn { + ($array:expr, $ty:ty) => {{ + let a = $array.as_any().downcast_ref::<$ty>().unwrap(); + if a.offset().is_some() { + $array + } else { + Box::new(a.to_ffi_aligned()) + } + }}; +} + +pub fn align_to_c_data_interface(array: Box) -> Box { + use crate::datatypes::PhysicalType::*; + match array.data_type().to_physical_type() { + Null => ffi_dyn!(array, NullArray), + Boolean => ffi_dyn!(array, BooleanArray), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + ffi_dyn!(array, PrimitiveArray<$T>) + }), + Binary => ffi_dyn!(array, BinaryArray), + LargeBinary => ffi_dyn!(array, BinaryArray), + FixedSizeBinary => ffi_dyn!(array, FixedSizeBinaryArray), + Utf8 => ffi_dyn!(array, Utf8Array::), + LargeUtf8 => ffi_dyn!(array, Utf8Array::), + List => ffi_dyn!(array, ListArray::), + LargeList => ffi_dyn!(array, ListArray::), + FixedSizeList => ffi_dyn!(array, FixedSizeListArray), + Struct => ffi_dyn!(array, StructArray), + Union => ffi_dyn!(array, UnionArray), + Map => ffi_dyn!(array, MapArray), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + ffi_dyn!(array, DictionaryArray<$T>) + }) + }, + } +} diff --git a/crates/nano-arrow/src/ffi/generated.rs b/crates/nano-arrow/src/ffi/generated.rs new file mode 100644 index 000000000000..cd4953b7198a --- /dev/null +++ b/crates/nano-arrow/src/ffi/generated.rs @@ -0,0 +1,55 @@ +/* automatically generated by rust-bindgen 0.59.2 */ + +/// ABI-compatible struct for [`ArrowSchema`](https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions) +#[repr(C)] +#[derive(Debug)] +pub struct ArrowSchema { + pub(super) format: *const ::std::os::raw::c_char, + pub(super) name: *const ::std::os::raw::c_char, + pub(super) metadata: *const ::std::os::raw::c_char, + pub(super) flags: i64, + pub(super) n_children: i64, + pub(super) children: *mut *mut ArrowSchema, + pub(super) dictionary: *mut ArrowSchema, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} + +/// ABI-compatible struct for [`ArrowArray`](https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions) +#[repr(C)] +#[derive(Debug)] +pub struct ArrowArray { + pub(super) length: i64, + pub(super) null_count: i64, + pub(super) offset: i64, + pub(super) n_buffers: i64, + pub(super) n_children: i64, + pub(super) buffers: *mut *const ::std::os::raw::c_void, + pub(super) children: *mut *mut ArrowArray, + pub(super) dictionary: *mut ArrowArray, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} + +/// ABI-compatible struct for [`ArrowArrayStream`](https://arrow.apache.org/docs/format/CStreamInterface.html). +#[repr(C)] +#[derive(Debug)] +pub struct ArrowArrayStream { + pub(super) get_schema: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut ArrowArrayStream, + out: *mut ArrowSchema, + ) -> ::std::os::raw::c_int, + >, + pub(super) get_next: ::std::option::Option< + unsafe extern "C" fn( + arg1: *mut ArrowArrayStream, + out: *mut ArrowArray, + ) -> ::std::os::raw::c_int, + >, + pub(super) get_last_error: ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ArrowArrayStream) -> *const ::std::os::raw::c_char, + >, + pub(super) release: ::std::option::Option, + pub(super) private_data: *mut ::std::os::raw::c_void, +} diff --git a/crates/nano-arrow/src/ffi/mmap.rs b/crates/nano-arrow/src/ffi/mmap.rs new file mode 100644 index 000000000000..03c1ac9aa30a --- /dev/null +++ b/crates/nano-arrow/src/ffi/mmap.rs @@ -0,0 +1,164 @@ +//! Functionality to mmap in-memory data regions. +use std::sync::Arc; + +use super::{ArrowArray, InternalArrowArray}; +use crate::array::{BooleanArray, FromFfi, PrimitiveArray}; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::types::NativeType; + +#[allow(dead_code)] +struct PrivateData { + // the owner of the pointers' regions + data: T, + buffers_ptr: Box<[*const std::os::raw::c_void]>, + children_ptr: Box<[*mut ArrowArray]>, + dictionary_ptr: Option<*mut ArrowArray>, +} + +pub(crate) unsafe fn create_array< + T: AsRef<[u8]>, + I: Iterator>, + II: Iterator, +>( + data: Arc, + num_rows: usize, + null_count: usize, + buffers: I, + children: II, + dictionary: Option, + offset: Option, +) -> ArrowArray { + let buffers_ptr = buffers + .map(|maybe_buffer| match maybe_buffer { + Some(b) => b as *const std::os::raw::c_void, + None => std::ptr::null(), + }) + .collect::>(); + let n_buffers = buffers_ptr.len() as i64; + + let children_ptr = children + .map(|child| Box::into_raw(Box::new(child))) + .collect::>(); + let n_children = children_ptr.len() as i64; + + let dictionary_ptr = dictionary.map(|array| Box::into_raw(Box::new(array))); + + let mut private_data = Box::new(PrivateData::> { + data, + buffers_ptr, + children_ptr, + dictionary_ptr, + }); + + ArrowArray { + length: num_rows as i64, + null_count: null_count as i64, + offset: offset.unwrap_or(0) as i64, // Unwrap: IPC files are by definition not offset + n_buffers, + n_children, + buffers: private_data.buffers_ptr.as_mut_ptr(), + children: private_data.children_ptr.as_mut_ptr(), + dictionary: private_data.dictionary_ptr.unwrap_or(std::ptr::null_mut()), + release: Some(release::>), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } +} + +/// callback used to drop [`ArrowArray`] when it is exported specified for [`PrivateData`]. +unsafe extern "C" fn release(array: *mut ArrowArray) { + if array.is_null() { + return; + } + let array = &mut *array; + + // take ownership of `private_data`, therefore dropping it + let private = Box::from_raw(array.private_data as *mut PrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary_ptr { + let _ = Box::from_raw(ptr); + } + + array.release = None; +} + +/// Creates a (non-null) [`PrimitiveArray`] from a slice of values. +/// This does not have memcopy and is the fastest way to create a [`PrimitiveArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without incurring +/// a memcopy cost. +/// +/// # Safety +/// +/// Using this function is not unsafe, but the returned PrimitiveArray's lifetime is bound to the lifetime +/// of the slice. The returned [`PrimitiveArray`] _must not_ outlive the passed slice. +pub unsafe fn slice(slice: &[T]) -> PrimitiveArray { + let num_rows = slice.len(); + let null_count = 0; + let validity = None; + + let data: &[u8] = bytemuck::cast_slice(slice); + let ptr = data.as_ptr(); + let data = Arc::new(data); + + // safety: the underlying assumption of this function: the array will not be used + // beyond the + let array = create_array( + data, + num_rows, + null_count, + [validity, Some(ptr)].into_iter(), + [].into_iter(), + None, + None, + ); + let array = InternalArrowArray::new(array, T::PRIMITIVE.into()); + + // safety: we just created a valid array + unsafe { PrimitiveArray::::try_from_ffi(array) }.unwrap() +} + +/// Creates a (non-null) [`BooleanArray`] from a slice of bits. +/// This does not have memcopy and is the fastest way to create a [`BooleanArray`]. +/// +/// This can be useful if you want to apply arrow kernels on slices without incurring +/// a memcopy cost. +/// +/// The `offset` indicates where the first bit starts in the first byte. +/// +/// # Safety +/// +/// Using this function is not unsafe, but the returned BooleanArrays's lifetime is bound to the lifetime +/// of the slice. The returned [`BooleanArray`] _must not_ outlive the passed slice. +pub unsafe fn bitmap(data: &[u8], offset: usize, length: usize) -> Result { + if offset >= 8 { + return Err(Error::InvalidArgumentError("offset should be < 8".into())); + }; + if length > data.len() * 8 - offset { + return Err(Error::InvalidArgumentError("given length is oob".into())); + } + let null_count = 0; + let validity = None; + + let ptr = data.as_ptr(); + let data = Arc::new(data); + + // safety: the underlying assumption of this function: the array will not be used + // beyond the + let array = create_array( + data, + length, + null_count, + [validity, Some(ptr)].into_iter(), + [].into_iter(), + None, + Some(offset), + ); + let array = InternalArrowArray::new(array, DataType::Boolean); + + // safety: we just created a valid array + Ok(unsafe { BooleanArray::try_from_ffi(array) }.unwrap()) +} diff --git a/crates/nano-arrow/src/ffi/mod.rs b/crates/nano-arrow/src/ffi/mod.rs new file mode 100644 index 000000000000..b1a1ac3c1210 --- /dev/null +++ b/crates/nano-arrow/src/ffi/mod.rs @@ -0,0 +1,46 @@ +//! contains FFI bindings to import and export [`Array`](crate::array::Array) via +//! Arrow's [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html) +mod array; +mod bridge; +mod generated; +pub mod mmap; +mod schema; +mod stream; + +pub(crate) use array::{try_from, ArrowArrayRef, InternalArrowArray}; +pub use generated::{ArrowArray, ArrowArrayStream, ArrowSchema}; +pub use stream::{export_iterator, ArrowArrayStreamReader}; + +use self::schema::to_field; +use crate::array::Array; +use crate::datatypes::{DataType, Field}; +use crate::error::Result; + +/// Exports an [`Box`] to the C data interface. +pub fn export_array_to_c(array: Box) -> ArrowArray { + ArrowArray::new(bridge::align_to_c_data_interface(array)) +} + +/// Exports a [`Field`] to the C data interface. +pub fn export_field_to_c(field: &Field) -> ArrowSchema { + ArrowSchema::new(field) +} + +/// Imports a [`Field`] from the C data interface. +/// # Safety +/// This function is intrinsically `unsafe` and relies on a [`ArrowSchema`] +/// being valid according to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub unsafe fn import_field_from_c(field: &ArrowSchema) -> Result { + to_field(field) +} + +/// Imports an [`Array`] from the C data interface. +/// # Safety +/// This function is intrinsically `unsafe` and relies on a [`ArrowArray`] +/// being valid according to the [C data interface](https://arrow.apache.org/docs/format/CDataInterface.html) (FFI). +pub unsafe fn import_array_from_c( + array: ArrowArray, + data_type: DataType, +) -> Result> { + try_from(InternalArrowArray::new(array, data_type)) +} diff --git a/crates/nano-arrow/src/ffi/schema.rs b/crates/nano-arrow/src/ffi/schema.rs new file mode 100644 index 000000000000..332410b0b6c5 --- /dev/null +++ b/crates/nano-arrow/src/ffi/schema.rs @@ -0,0 +1,633 @@ +use std::collections::BTreeMap; +use std::convert::TryInto; +use std::ffi::{CStr, CString}; +use std::ptr; + +use super::ArrowSchema; +use crate::datatypes::{ + DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, TimeUnit, UnionMode, +}; +use crate::error::{Error, Result}; + +#[allow(dead_code)] +struct SchemaPrivateData { + name: CString, + format: CString, + metadata: Option>, + children_ptr: Box<[*mut ArrowSchema]>, + dictionary: Option<*mut ArrowSchema>, +} + +// callback used to drop [ArrowSchema] when it is exported. +unsafe extern "C" fn c_release_schema(schema: *mut ArrowSchema) { + if schema.is_null() { + return; + } + let schema = &mut *schema; + + let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData); + for child in private.children_ptr.iter() { + let _ = Box::from_raw(*child); + } + + if let Some(ptr) = private.dictionary { + let _ = Box::from_raw(ptr); + } + + schema.release = None; +} + +/// allocate (and hold) the children +fn schema_children(data_type: &DataType, flags: &mut i64) -> Box<[*mut ArrowSchema]> { + match data_type { + DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => { + Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))]) + }, + DataType::Map(field, is_sorted) => { + *flags += (*is_sorted as i64) * 4; + Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))]) + }, + DataType::Struct(fields) | DataType::Union(fields, _, _) => fields + .iter() + .map(|field| Box::into_raw(Box::new(ArrowSchema::new(field)))) + .collect::>(), + DataType::Extension(_, inner, _) => schema_children(inner, flags), + _ => Box::new([]), + } +} + +impl ArrowSchema { + /// creates a new [ArrowSchema] + pub(crate) fn new(field: &Field) -> Self { + let format = to_format(field.data_type()); + let name = field.name.clone(); + + let mut flags = field.is_nullable as i64 * 2; + + // note: this cannot be done along with the above because the above is fallible and this op leaks. + let children_ptr = schema_children(field.data_type(), &mut flags); + let n_children = children_ptr.len() as i64; + + let dictionary = if let DataType::Dictionary(_, values, is_ordered) = field.data_type() { + flags += *is_ordered as i64; + // we do not store field info in the dict values, so can't recover it all :( + let field = Field::new("", values.as_ref().clone(), true); + Some(Box::new(ArrowSchema::new(&field))) + } else { + None + }; + + let metadata = &field.metadata; + + let metadata = if let DataType::Extension(name, _, extension_metadata) = field.data_type() { + // append extension information. + let mut metadata = metadata.clone(); + + // metadata + if let Some(extension_metadata) = extension_metadata { + metadata.insert( + "ARROW:extension:metadata".to_string(), + extension_metadata.clone(), + ); + } + + metadata.insert("ARROW:extension:name".to_string(), name.clone()); + + Some(metadata_to_bytes(&metadata)) + } else if !metadata.is_empty() { + Some(metadata_to_bytes(metadata)) + } else { + None + }; + + let name = CString::new(name).unwrap(); + let format = CString::new(format).unwrap(); + + let mut private = Box::new(SchemaPrivateData { + name, + format, + metadata, + children_ptr, + dictionary: dictionary.map(Box::into_raw), + }); + + // + Self { + format: private.format.as_ptr(), + name: private.name.as_ptr(), + metadata: private + .metadata + .as_ref() + .map(|x| x.as_ptr()) + .unwrap_or(std::ptr::null()) as *const ::std::os::raw::c_char, + flags, + n_children, + children: private.children_ptr.as_mut_ptr(), + dictionary: private.dictionary.unwrap_or(std::ptr::null_mut()), + release: Some(c_release_schema), + private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void, + } + } + + /// create an empty [ArrowSchema] + pub fn empty() -> Self { + Self { + format: std::ptr::null_mut(), + name: std::ptr::null_mut(), + metadata: std::ptr::null_mut(), + flags: 0, + n_children: 0, + children: ptr::null_mut(), + dictionary: std::ptr::null_mut(), + release: None, + private_data: std::ptr::null_mut(), + } + } + + /// returns the format of this schema. + pub(crate) fn format(&self) -> &str { + assert!(!self.format.is_null()); + // safe because the lifetime of `self.format` equals `self` + unsafe { CStr::from_ptr(self.format) } + .to_str() + .expect("The external API has a non-utf8 as format") + } + + /// returns the name of this schema. + /// + /// Since this field is optional, `""` is returned if it is not set (as per the spec). + pub(crate) fn name(&self) -> &str { + if self.name.is_null() { + return ""; + } + // safe because the lifetime of `self.name` equals `self` + unsafe { CStr::from_ptr(self.name) }.to_str().unwrap() + } + + pub(crate) fn child(&self, index: usize) -> &'static Self { + assert!(index < self.n_children as usize); + unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } + } + + pub(crate) fn dictionary(&self) -> Option<&'static Self> { + if self.dictionary.is_null() { + return None; + }; + Some(unsafe { self.dictionary.as_ref().unwrap() }) + } + + pub(crate) fn nullable(&self) -> bool { + (self.flags / 2) & 1 == 1 + } +} + +impl Drop for ArrowSchema { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +pub(crate) unsafe fn to_field(schema: &ArrowSchema) -> Result { + let dictionary = schema.dictionary(); + let data_type = if let Some(dictionary) = dictionary { + let indices = to_integer_type(schema.format())?; + let values = to_field(dictionary)?; + let is_ordered = schema.flags & 1 == 1; + DataType::Dictionary(indices, Box::new(values.data_type().clone()), is_ordered) + } else { + to_data_type(schema)? + }; + let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) }; + + let data_type = if let Some((name, extension_metadata)) = extension { + DataType::Extension(name, Box::new(data_type), extension_metadata) + } else { + data_type + }; + + Ok(Field::new(schema.name(), data_type, schema.nullable()).with_metadata(metadata)) +} + +fn to_integer_type(format: &str) -> Result { + use IntegerType::*; + Ok(match format { + "c" => Int8, + "C" => UInt8, + "s" => Int16, + "S" => UInt16, + "i" => Int32, + "I" => UInt32, + "l" => Int64, + "L" => UInt64, + _ => { + return Err(Error::OutOfSpec( + "Dictionary indices can only be integers".to_string(), + )) + }, + }) +} + +unsafe fn to_data_type(schema: &ArrowSchema) -> Result { + Ok(match schema.format() { + "n" => DataType::Null, + "b" => DataType::Boolean, + "c" => DataType::Int8, + "C" => DataType::UInt8, + "s" => DataType::Int16, + "S" => DataType::UInt16, + "i" => DataType::Int32, + "I" => DataType::UInt32, + "l" => DataType::Int64, + "L" => DataType::UInt64, + "e" => DataType::Float16, + "f" => DataType::Float32, + "g" => DataType::Float64, + "z" => DataType::Binary, + "Z" => DataType::LargeBinary, + "u" => DataType::Utf8, + "U" => DataType::LargeUtf8, + "tdD" => DataType::Date32, + "tdm" => DataType::Date64, + "tts" => DataType::Time32(TimeUnit::Second), + "ttm" => DataType::Time32(TimeUnit::Millisecond), + "ttu" => DataType::Time64(TimeUnit::Microsecond), + "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "tDs" => DataType::Duration(TimeUnit::Second), + "tDm" => DataType::Duration(TimeUnit::Millisecond), + "tDu" => DataType::Duration(TimeUnit::Microsecond), + "tDn" => DataType::Duration(TimeUnit::Nanosecond), + "tiM" => DataType::Interval(IntervalUnit::YearMonth), + "tiD" => DataType::Interval(IntervalUnit::DayTime), + "+l" => { + let child = schema.child(0); + DataType::List(Box::new(to_field(child)?)) + }, + "+L" => { + let child = schema.child(0); + DataType::LargeList(Box::new(to_field(child)?)) + }, + "+m" => { + let child = schema.child(0); + + let is_sorted = (schema.flags & 4) != 0; + DataType::Map(Box::new(to_field(child)?), is_sorted) + }, + "+s" => { + let children = (0..schema.n_children as usize) + .map(|x| to_field(schema.child(x))) + .collect::>>()?; + DataType::Struct(children) + }, + other => { + match other.splitn(2, ':').collect::>()[..] { + // Timestamps with no timezone + ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), + ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), + ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), + ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), + + // Timestamps with timezone + ["tss", tz] => DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())), + ["tsm", tz] => DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())), + ["tsu", tz] => DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())), + ["tsn", tz] => DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())), + + ["w", size_raw] => { + // Example: "w:42" fixed-width binary [42 bytes] + let size = size_raw + .parse::() + .map_err(|_| Error::OutOfSpec("size is not a valid integer".to_string()))?; + DataType::FixedSizeBinary(size) + }, + ["+w", size_raw] => { + // Example: "+w:123" fixed-sized list [123 items] + let size = size_raw + .parse::() + .map_err(|_| Error::OutOfSpec("size is not a valid integer".to_string()))?; + let child = to_field(schema.child(0))?; + DataType::FixedSizeList(Box::new(child), size) + }, + ["d", raw] => { + // Decimal + let (precision, scale) = match raw.split(',').collect::>()[..] { + [precision_raw, scale_raw] => { + // Example: "d:19,10" decimal128 [precision 19, scale 10] + (precision_raw, scale_raw) + }, + [precision_raw, scale_raw, width_raw] => { + // Example: "d:19,10,NNN" decimal bitwidth = NNN [precision 19, scale 10] + // Only bitwdth of 128 currently supported + let bit_width = width_raw.parse::().map_err(|_| { + Error::OutOfSpec( + "Decimal bit width is not a valid integer".to_string(), + ) + })?; + if bit_width == 256 { + return Ok(DataType::Decimal256( + precision_raw.parse::().map_err(|_| { + Error::OutOfSpec( + "Decimal precision is not a valid integer".to_string(), + ) + })?, + scale_raw.parse::().map_err(|_| { + Error::OutOfSpec( + "Decimal scale is not a valid integer".to_string(), + ) + })?, + )); + } + (precision_raw, scale_raw) + }, + _ => { + return Err(Error::OutOfSpec( + "Decimal must contain 2 or 3 comma-separated values".to_string(), + )); + }, + }; + + DataType::Decimal( + precision.parse::().map_err(|_| { + Error::OutOfSpec("Decimal precision is not a valid integer".to_string()) + })?, + scale.parse::().map_err(|_| { + Error::OutOfSpec("Decimal scale is not a valid integer".to_string()) + })?, + ) + }, + [union_type @ "+us", union_parts] | [union_type @ "+ud", union_parts] => { + // union, sparse + // Example "+us:I,J,..." sparse union with type ids I,J... + // Example: "+ud:I,J,..." dense union with type ids I,J... + let mode = UnionMode::sparse(union_type == "+us"); + let type_ids = union_parts + .split(',') + .map(|x| { + x.parse::().map_err(|_| { + Error::OutOfSpec("Union type id is not a valid integer".to_string()) + }) + }) + .collect::>>()?; + let fields = (0..schema.n_children as usize) + .map(|x| to_field(schema.child(x))) + .collect::>>()?; + DataType::Union(fields, Some(type_ids), mode) + }, + _ => { + return Err(Error::OutOfSpec(format!( + "The datatype \"{other}\" is still not supported in Rust implementation", + ))); + }, + } + }, + }) +} + +/// the inverse of [to_field] +fn to_format(data_type: &DataType) -> String { + match data_type { + DataType::Null => "n".to_string(), + DataType::Boolean => "b".to_string(), + DataType::Int8 => "c".to_string(), + DataType::UInt8 => "C".to_string(), + DataType::Int16 => "s".to_string(), + DataType::UInt16 => "S".to_string(), + DataType::Int32 => "i".to_string(), + DataType::UInt32 => "I".to_string(), + DataType::Int64 => "l".to_string(), + DataType::UInt64 => "L".to_string(), + DataType::Float16 => "e".to_string(), + DataType::Float32 => "f".to_string(), + DataType::Float64 => "g".to_string(), + DataType::Binary => "z".to_string(), + DataType::LargeBinary => "Z".to_string(), + DataType::Utf8 => "u".to_string(), + DataType::LargeUtf8 => "U".to_string(), + DataType::Date32 => "tdD".to_string(), + DataType::Date64 => "tdm".to_string(), + DataType::Time32(TimeUnit::Second) => "tts".to_string(), + DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(), + DataType::Time32(_) => { + unreachable!("Time32 is only supported for seconds and milliseconds") + }, + DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(), + DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(), + DataType::Time64(_) => { + unreachable!("Time64 is only supported for micro and nanoseconds") + }, + DataType::Duration(TimeUnit::Second) => "tDs".to_string(), + DataType::Duration(TimeUnit::Millisecond) => "tDm".to_string(), + DataType::Duration(TimeUnit::Microsecond) => "tDu".to_string(), + DataType::Duration(TimeUnit::Nanosecond) => "tDn".to_string(), + DataType::Interval(IntervalUnit::YearMonth) => "tiM".to_string(), + DataType::Interval(IntervalUnit::DayTime) => "tiD".to_string(), + DataType::Interval(IntervalUnit::MonthDayNano) => { + todo!("Spec for FFI for MonthDayNano still not defined.") + }, + DataType::Timestamp(unit, tz) => { + let unit = match unit { + TimeUnit::Second => "s", + TimeUnit::Millisecond => "m", + TimeUnit::Microsecond => "u", + TimeUnit::Nanosecond => "n", + }; + format!( + "ts{}:{}", + unit, + tz.as_ref().map(|x| x.as_ref()).unwrap_or("") + ) + }, + DataType::Decimal(precision, scale) => format!("d:{precision},{scale}"), + DataType::Decimal256(precision, scale) => format!("d:{precision},{scale},256"), + DataType::List(_) => "+l".to_string(), + DataType::LargeList(_) => "+L".to_string(), + DataType::Struct(_) => "+s".to_string(), + DataType::FixedSizeBinary(size) => format!("w:{size}"), + DataType::FixedSizeList(_, size) => format!("+w:{size}"), + DataType::Union(f, ids, mode) => { + let sparsness = if mode.is_sparse() { 's' } else { 'd' }; + let mut r = format!("+u{sparsness}:"); + let ids = if let Some(ids) = ids { + ids.iter() + .fold(String::new(), |a, b| a + &b.to_string() + ",") + } else { + (0..f.len()).fold(String::new(), |a, b| a + &b.to_string() + ",") + }; + let ids = &ids[..ids.len() - 1]; // take away last "," + r.push_str(ids); + r + }, + DataType::Map(_, _) => "+m".to_string(), + DataType::Dictionary(index, _, _) => to_format(&(*index).into()), + DataType::Extension(_, inner, _) => to_format(inner.as_ref()), + } +} + +pub(super) fn get_child(data_type: &DataType, index: usize) -> Result { + match (index, data_type) { + (0, DataType::List(field)) => Ok(field.data_type().clone()), + (0, DataType::FixedSizeList(field, _)) => Ok(field.data_type().clone()), + (0, DataType::LargeList(field)) => Ok(field.data_type().clone()), + (0, DataType::Map(field, _)) => Ok(field.data_type().clone()), + (index, DataType::Struct(fields)) => Ok(fields[index].data_type().clone()), + (index, DataType::Union(fields, _, _)) => Ok(fields[index].data_type().clone()), + (index, DataType::Extension(_, subtype, _)) => get_child(subtype, index), + (child, data_type) => Err(Error::OutOfSpec(format!( + "Requested child {child} to type {data_type:?} that has no such child", + ))), + } +} + +fn metadata_to_bytes(metadata: &BTreeMap) -> Vec { + let a = (metadata.len() as i32).to_ne_bytes().to_vec(); + metadata.iter().fold(a, |mut acc, (key, value)| { + acc.extend((key.len() as i32).to_ne_bytes()); + acc.extend(key.as_bytes()); + acc.extend((value.len() as i32).to_ne_bytes()); + acc.extend(value.as_bytes()); + acc + }) +} + +unsafe fn read_ne_i32(ptr: *const u8) -> i32 { + let slice = std::slice::from_raw_parts(ptr, 4); + i32::from_ne_bytes(slice.try_into().unwrap()) +} + +unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str { + let slice = std::slice::from_raw_parts(ptr, len); + simdutf8::basic::from_utf8(slice).unwrap() +} + +unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) { + let mut data = data as *const u8; // u8 = i8 + if data.is_null() { + return (Metadata::default(), None); + }; + let len = read_ne_i32(data); + data = data.add(4); + + let mut result = BTreeMap::new(); + let mut extension_name = None; + let mut extension_metadata = None; + for _ in 0..len { + let key_len = read_ne_i32(data) as usize; + data = data.add(4); + let key = read_bytes(data, key_len); + data = data.add(key_len); + let value_len = read_ne_i32(data) as usize; + data = data.add(4); + let value = read_bytes(data, value_len); + data = data.add(value_len); + match key { + "ARROW:extension:name" => { + extension_name = Some(value.to_string()); + }, + "ARROW:extension:metadata" => { + extension_metadata = Some(value.to_string()); + }, + _ => { + result.insert(key.to_string(), value.to_string()); + }, + }; + } + let extension = extension_name.map(|name| (name, extension_metadata)); + (result, extension) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_all() { + let mut dts = vec![ + DataType::Null, + DataType::Boolean, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Millisecond), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Decimal(5, 5), + DataType::Utf8, + DataType::LargeUtf8, + DataType::Binary, + DataType::LargeBinary, + DataType::FixedSizeBinary(2), + DataType::List(Box::new(Field::new("example", DataType::Boolean, false))), + DataType::FixedSizeList(Box::new(Field::new("example", DataType::Boolean, false)), 2), + DataType::LargeList(Box::new(Field::new("example", DataType::Boolean, false))), + DataType::Struct(vec![ + Field::new("a", DataType::Int64, true), + Field::new( + "b", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + ]), + DataType::Map(Box::new(Field::new("a", DataType::Int64, true)), true), + DataType::Union( + vec![ + Field::new("a", DataType::Int64, true), + Field::new( + "b", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + ], + Some(vec![1, 2]), + UnionMode::Dense, + ), + DataType::Union( + vec![ + Field::new("a", DataType::Int64, true), + Field::new( + "b", + DataType::List(Box::new(Field::new("item", DataType::Int32, true))), + true, + ), + ], + Some(vec![0, 1]), + UnionMode::Sparse, + ), + ]; + for time_unit in [ + TimeUnit::Second, + TimeUnit::Millisecond, + TimeUnit::Microsecond, + TimeUnit::Nanosecond, + ] { + dts.push(DataType::Timestamp(time_unit, None)); + dts.push(DataType::Timestamp(time_unit, Some("00:00".to_string()))); + dts.push(DataType::Duration(time_unit)); + } + for interval_type in [ + IntervalUnit::DayTime, + IntervalUnit::YearMonth, + //IntervalUnit::MonthDayNano, // not yet defined on the C data interface + ] { + dts.push(DataType::Interval(interval_type)); + } + + for expected in dts { + let field = Field::new("a", expected.clone(), true); + let schema = ArrowSchema::new(&field); + let result = unsafe { super::to_data_type(&schema).unwrap() }; + assert_eq!(result, expected); + } + } +} diff --git a/crates/nano-arrow/src/ffi/stream.rs b/crates/nano-arrow/src/ffi/stream.rs new file mode 100644 index 000000000000..4776014bca54 --- /dev/null +++ b/crates/nano-arrow/src/ffi/stream.rs @@ -0,0 +1,226 @@ +use std::ffi::{CStr, CString}; +use std::ops::DerefMut; + +use super::{ + export_array_to_c, export_field_to_c, import_array_from_c, import_field_from_c, ArrowArray, + ArrowArrayStream, ArrowSchema, +}; +use crate::array::Array; +use crate::datatypes::Field; +use crate::error::Error; + +impl Drop for ArrowArrayStream { + fn drop(&mut self) { + match self.release { + None => (), + Some(release) => unsafe { release(self) }, + }; + } +} + +impl ArrowArrayStream { + /// Creates an empty [`ArrowArrayStream`] used to import from a producer. + pub fn empty() -> Self { + Self { + get_schema: None, + get_next: None, + get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + } + } +} + +unsafe fn handle_error(iter: &mut ArrowArrayStream) -> Error { + let error = unsafe { (iter.get_last_error.unwrap())(&mut *iter) }; + + if error.is_null() { + return Error::External( + "C stream".to_string(), + Box::new(Error::ExternalFormat("an unspecified error".to_string())), + ); + } + + let error = unsafe { CStr::from_ptr(error) }; + Error::External( + "C stream".to_string(), + Box::new(Error::ExternalFormat(error.to_str().unwrap().to_string())), + ) +} + +/// Implements an iterator of [`Array`] consumed from the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html). +pub struct ArrowArrayStreamReader> { + iter: Iter, + field: Field, +} + +impl> ArrowArrayStreamReader { + /// Returns a new [`ArrowArrayStreamReader`] + /// # Error + /// Errors iff the [`ArrowArrayStream`] is out of specification, + /// or was already released prior to calling this function. + /// # Safety + /// This method is intrinsically `unsafe` since it assumes that the `ArrowArrayStream` + /// contains a valid Arrow C stream interface. + /// In particular: + /// * The `ArrowArrayStream` fulfills the invariants of the C stream interface + /// * The schema `get_schema` produces fulfills the C data interface + pub unsafe fn try_new(mut iter: Iter) -> Result { + if iter.release.is_none() { + return Err(Error::InvalidArgumentError( + "The C stream was already released".to_string(), + )); + }; + + if iter.get_next.is_none() { + return Err(Error::OutOfSpec( + "The C stream MUST contain a non-null get_next".to_string(), + )); + }; + + if iter.get_last_error.is_none() { + return Err(Error::OutOfSpec( + "The C stream MUST contain a non-null get_last_error".to_string(), + )); + }; + + let mut field = ArrowSchema::empty(); + let status = if let Some(f) = iter.get_schema { + unsafe { (f)(&mut *iter, &mut field) } + } else { + return Err(Error::OutOfSpec( + "The C stream MUST contain a non-null get_schema".to_string(), + )); + }; + + if status != 0 { + return Err(unsafe { handle_error(&mut iter) }); + } + + let field = unsafe { import_field_from_c(&field)? }; + + Ok(Self { iter, field }) + } + + /// Returns the field provided by the stream + pub fn field(&self) -> &Field { + &self.field + } + + /// Advances this iterator by one array + /// # Error + /// Errors iff: + /// * The C stream interface returns an error + /// * The C stream interface returns an invalid array (that we can identify, see Safety below) + /// # Safety + /// Calling this iterator's `next` assumes that the [`ArrowArrayStream`] produces arrow arrays + /// that fulfill the C data interface + pub unsafe fn next(&mut self) -> Option, Error>> { + let mut array = ArrowArray::empty(); + let status = unsafe { (self.iter.get_next.unwrap())(&mut *self.iter, &mut array) }; + + if status != 0 { + return Some(Err(unsafe { handle_error(&mut self.iter) })); + } + + // last paragraph of https://arrow.apache.org/docs/format/CStreamInterface.html#c.ArrowArrayStream.get_next + array.release?; + + // Safety: assumed from the C stream interface + unsafe { import_array_from_c(array, self.field.data_type.clone()) } + .map(Some) + .transpose() + } +} + +struct PrivateData { + iter: Box, Error>>>, + field: Field, + error: Option, +} + +unsafe extern "C" fn get_next(iter: *mut ArrowArrayStream, array: *mut ArrowArray) -> i32 { + if iter.is_null() { + return 2001; + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + match private.iter.next() { + Some(Ok(item)) => { + // check that the array has the same data_type as field + let item_dt = item.data_type(); + let expected_dt = private.field.data_type(); + if item_dt != expected_dt { + private.error = Some(CString::new(format!("The iterator produced an item of data type {item_dt:?} but the producer expects data type {expected_dt:?}").as_bytes().to_vec()).unwrap()); + return 2001; // custom application specific error (since this is never a result of this interface) + } + + std::ptr::write(array, export_array_to_c(item)); + + private.error = None; + 0 + }, + Some(Err(err)) => { + private.error = Some(CString::new(err.to_string().as_bytes().to_vec()).unwrap()); + 2001 // custom application specific error (since this is never a result of this interface) + }, + None => { + let a = ArrowArray::empty(); + std::ptr::write_unaligned(array, a); + private.error = None; + 0 + }, + } +} + +unsafe extern "C" fn get_schema(iter: *mut ArrowArrayStream, schema: *mut ArrowSchema) -> i32 { + if iter.is_null() { + return 2001; + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + std::ptr::write(schema, export_field_to_c(&private.field)); + 0 +} + +unsafe extern "C" fn get_last_error(iter: *mut ArrowArrayStream) -> *const ::std::os::raw::c_char { + if iter.is_null() { + return std::ptr::null(); + } + let private = &mut *((*iter).private_data as *mut PrivateData); + + private + .error + .as_ref() + .map(|x| x.as_ptr()) + .unwrap_or(std::ptr::null()) +} + +unsafe extern "C" fn release(iter: *mut ArrowArrayStream) { + if iter.is_null() { + return; + } + let _ = Box::from_raw((*iter).private_data as *mut PrivateData); + (*iter).release = None; + // private drops automatically +} + +/// Exports an iterator to the [C stream interface](https://arrow.apache.org/docs/format/CStreamInterface.html) +pub fn export_iterator( + iter: Box, Error>>>, + field: Field, +) -> ArrowArrayStream { + let private_data = Box::new(PrivateData { + iter, + field, + error: None, + }); + + ArrowArrayStream { + get_schema: Some(get_schema), + get_next: Some(get_next), + get_last_error: Some(get_last_error), + release: Some(release), + private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + } +} diff --git a/crates/nano-arrow/src/io/README.md b/crates/nano-arrow/src/io/README.md new file mode 100644 index 000000000000..a3c7599b8bdf --- /dev/null +++ b/crates/nano-arrow/src/io/README.md @@ -0,0 +1,24 @@ +# IO module + +This document describes the overall design of this module. + +## Rules: + +- Each directory in this module corresponds to a specific format such as `csv` and `json`. +- directories that depend on external dependencies MUST be feature gated, with a feature named with a prefix `io_`. +- modules MUST re-export any API of external dependencies they require as part of their public API. + E.g. + - if a module as an API `write(writer: &mut csv:Writer, ...)`, it MUST contain `pub use csv::Writer;`. + + The rational is that adding this crate to `cargo.toml` must be sufficient to use it. +- Each directory SHOULD contain two directories, `read` and `write`, corresponding + to functionality about reading from the format and writing to the format respectively. +- The base module SHOULD contain `use pub read;` and `use pub write;`. +- Implementations SHOULD separate reading of "data" from reading of "metadata". Examples: + - schema read or inference SHOULD be a separate function + - functions that read "data" SHOULD consume a schema typically pre-read. +- Implementations SHOULD separate IO-bounded operations from CPU-bounded operations. + I.e. implementations SHOULD: + - contain functions that consume a `Read` implementor and output a "raw" struct, i.e. a struct that is e.g. compressed and serialized + - contain functions that consume a "raw" struct and convert it into Arrow. + - offer each of these functions as independent public APIs, so that consumers can decide how to balance CPU-bounds and IO-bounds. diff --git a/crates/nano-arrow/src/io/avro/mod.rs b/crates/nano-arrow/src/io/avro/mod.rs new file mode 100644 index 000000000000..bf7bda85f197 --- /dev/null +++ b/crates/nano-arrow/src/io/avro/mod.rs @@ -0,0 +1,42 @@ +//! Read and write from and to Apache Avro + +pub use avro_schema; + +impl From for crate::error::Error { + fn from(error: avro_schema::error::Error) -> Self { + Self::ExternalFormat(error.to_string()) + } +} + +pub mod read; +pub mod write; + +// macros that can operate in sync and async code. +macro_rules! avro_decode { + ($reader:ident $($_await:tt)*) => { + { + let mut i = 0u64; + let mut buf = [0u8; 1]; + let mut j = 0; + loop { + if j > 9 { + // if j * 7 > 64 + return Err(Error::ExternalFormat( + "zigzag decoding failed - corrupt avro file".to_string(), + )); + } + $reader.read_exact(&mut buf[..])$($_await)*?; + i |= (u64::from(buf[0] & 0x7F)) << (j * 7); + if (buf[0] >> 7) == 0 { + break; + } else { + j += 1; + } + } + + Ok(i) + } + } +} + +pub(crate) use avro_decode; diff --git a/crates/nano-arrow/src/io/avro/read/deserialize.rs b/crates/nano-arrow/src/io/avro/read/deserialize.rs new file mode 100644 index 000000000000..6cafd9d8c4c1 --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/deserialize.rs @@ -0,0 +1,526 @@ +use std::convert::TryInto; + +use avro_schema::file::Block; +use avro_schema::schema::{Enum, Field as AvroField, Record, Schema as AvroSchema}; + +use super::nested::*; +use super::util; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::types::months_days_ns; + +fn make_mutable( + data_type: &DataType, + avro_field: Option<&AvroSchema>, + capacity: usize, +) -> Result> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => { + Box::new(MutableBooleanArray::with_capacity(capacity)) as Box + }, + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }), + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) as Box + }, + PhysicalType::Utf8 => { + Box::new(MutableUtf8Array::::with_capacity(capacity)) as Box + }, + PhysicalType::Dictionary(_) => { + if let Some(AvroSchema::Enum(Enum { symbols, .. })) = avro_field { + let values = Utf8Array::::from_slice(symbols); + Box::new(FixedItemsUtf8Dictionary::with_capacity(values, capacity)) + as Box + } else { + unreachable!() + } + }, + _ => match data_type { + DataType::List(inner) => { + let values = make_mutable(inner.data_type(), None, 0)?; + Box::new(DynMutableListArray::::new_from( + values, + data_type.clone(), + capacity, + )) as Box + }, + DataType::FixedSizeBinary(size) => { + Box::new(MutableFixedSizeBinaryArray::with_capacity(*size, capacity)) + as Box + }, + DataType::Struct(fields) => { + let values = fields + .iter() + .map(|field| make_mutable(field.data_type(), None, capacity)) + .collect::>>()?; + Box::new(DynMutableStructArray::new(values, data_type.clone())) + as Box + }, + other => { + return Err(Error::NotYetImplemented(format!( + "Deserializing type {other:#?} is still not implemented" + ))) + }, + }, + }) +} + +fn is_union_null_first(avro_field: &AvroSchema) -> bool { + if let AvroSchema::Union(schemas) = avro_field { + schemas[0] == AvroSchema::Null + } else { + unreachable!() + } +} + +fn deserialize_item<'a>( + array: &mut dyn MutableArray, + is_nullable: bool, + avro_field: &AvroSchema, + mut block: &'a [u8], +) -> Result<&'a [u8]> { + if is_nullable { + let variant = util::zigzag_i64(&mut block)?; + let is_null_first = is_union_null_first(avro_field); + if is_null_first && variant == 0 || !is_null_first && variant != 0 { + array.push_null(); + return Ok(block); + } + } + deserialize_value(array, avro_field, block) +} + +fn deserialize_value<'a>( + array: &mut dyn MutableArray, + avro_field: &AvroSchema, + mut block: &'a [u8], +) -> Result<&'a [u8]> { + let data_type = array.data_type(); + match data_type { + DataType::List(inner) => { + let is_nullable = inner.is_nullable; + let avro_inner = match avro_field { + AvroSchema::Array(inner) => inner.as_ref(), + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => { + inner.as_ref() + }, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + // Arrays are encoded as a series of blocks. + loop { + // Each block consists of a long count value, followed by that many array items. + let len = util::zigzag_i64(&mut block)?; + let len = if len < 0 { + // Avro spec: If a block's count is negative, its absolute value is used, + // and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields. + let _ = util::zigzag_i64(&mut block)?; + + -len + } else { + len + }; + + // A block with count zero indicates the end of the array. + if len == 0 { + break; + } + + // Each item is encoded per the array’s item schema. + let values = array.mut_values(); + for _ in 0..len { + block = deserialize_item(values, is_nullable, avro_inner, block)?; + } + } + array.try_push_valid()?; + }, + DataType::Struct(inner_fields) => { + let fields = match avro_field { + AvroSchema::Record(Record { fields, .. }) => fields, + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Record(Record { fields, .. }), _] + | &[_, AvroSchema::Record(Record { fields, .. })] => fields, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let is_nullable = inner_fields + .iter() + .map(|x| x.is_nullable) + .collect::>(); + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + + for (index, (field, is_nullable)) in fields.iter().zip(is_nullable.iter()).enumerate() { + let values = array.mut_values(index); + block = deserialize_item(values, *is_nullable, &field.schema, block)?; + } + array.try_push_valid()?; + }, + _ => match data_type.to_physical_type() { + PhysicalType::Boolean => { + let is_valid = block[0] == 1; + block = &block[1..]; + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + array.push(Some(is_valid)) + }, + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int32 => { + let value = util::zigzag_i64(&mut block)? as i32; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Int64 => { + let value = util::zigzag_i64(&mut block)?; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Float32 => { + let value = + f32::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); + block = &block[std::mem::size_of::()..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Float64 => { + let value = + f64::from_le_bytes(block[..std::mem::size_of::()].try_into().unwrap()); + block = &block[std::mem::size_of::()..]; + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::MonthDayNano => { + // https://avro.apache.org/docs/current/spec.html#Duration + // 12 bytes, months, days, millis in LE + let data = &block[..12]; + block = &block[12..]; + + let value = months_days_ns::new( + i32::from_le_bytes([data[0], data[1], data[2], data[3]]), + i32::from_le_bytes([data[4], data[5], data[6], data[7]]), + i32::from_le_bytes([data[8], data[9], data[10], data[11]]) as i64 + * 1_000_000, + ); + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(value)) + }, + PrimitiveType::Int128 => { + let avro_inner = match avro_field { + AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field, + AvroSchema::Union(u) => match &u.as_slice() { + &[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let len = match avro_inner { + AvroSchema::Bytes(_) => { + util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })? + }, + AvroSchema::Fixed(b) => b.size, + _ => unreachable!(), + }; + if len > 16 { + return Err(Error::ExternalFormat( + "Avro decimal bytes return more than 16 bytes".to_string(), + )); + } + let mut bytes = [0u8; 16]; + bytes[..len].copy_from_slice(&block[..len]); + block = &block[len..]; + let data = i128::from_be_bytes(bytes) >> (8 * (16 - len)); + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + }, + _ => unreachable!(), + }, + PhysicalType::Utf8 => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + let data = simdutf8::basic::from_utf8(&block[..len])?; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)) + }, + PhysicalType::Binary => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + let data = &block[..len]; + block = &block[len..]; + + let array = array + .as_mut_any() + .downcast_mut::>() + .unwrap(); + array.push(Some(data)); + }, + PhysicalType::FixedSizeBinary => { + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + let len = array.size(); + let data = &block[..len]; + block = &block[len..]; + array.push(Some(data)); + }, + PhysicalType::Dictionary(_) => { + let index = util::zigzag_i64(&mut block)? as i32; + let array = array + .as_mut_any() + .downcast_mut::() + .unwrap(); + array.push_valid(index); + }, + _ => todo!(), + }, + }; + Ok(block) +} + +fn skip_item<'a>(field: &Field, avro_field: &AvroSchema, mut block: &'a [u8]) -> Result<&'a [u8]> { + if field.is_nullable { + let variant = util::zigzag_i64(&mut block)?; + let is_null_first = is_union_null_first(avro_field); + if is_null_first && variant == 0 || !is_null_first && variant != 0 { + return Ok(block); + } + } + match &field.data_type { + DataType::List(inner) => { + let avro_inner = match avro_field { + AvroSchema::Array(inner) => inner.as_ref(), + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Array(inner), _] | &[_, AvroSchema::Array(inner)] => { + inner.as_ref() + }, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + loop { + let len = util::zigzag_i64(&mut block)?; + let (len, bytes) = if len < 0 { + // Avro spec: If a block's count is negative, its absolute value is used, + // and the count is followed immediately by a long block size indicating the number of bytes in the block. This block size permits fast skipping through data, e.g., when projecting a record to a subset of its fields. + let bytes = util::zigzag_i64(&mut block)?; + + (-len, Some(bytes)) + } else { + (len, None) + }; + + let bytes: Option = bytes + .map(|bytes| { + bytes + .try_into() + .map_err(|_| Error::oos("Avro block size negative or too large")) + }) + .transpose()?; + + if len == 0 { + break; + } + + if let Some(bytes) = bytes { + block = &block[bytes..]; + } else { + for _ in 0..len { + block = skip_item(inner, avro_inner, block)?; + } + } + } + }, + DataType::Struct(inner_fields) => { + let fields = match avro_field { + AvroSchema::Record(Record { fields, .. }) => fields, + AvroSchema::Union(u) => match &u.as_slice() { + &[AvroSchema::Record(Record { fields, .. }), _] + | &[_, AvroSchema::Record(Record { fields, .. })] => fields, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + for (field, avro_field) in inner_fields.iter().zip(fields.iter()) { + block = skip_item(field, &avro_field.schema, block)?; + } + }, + _ => match field.data_type.to_physical_type() { + PhysicalType::Boolean => { + let _ = block[0] == 1; + block = &block[1..]; + }, + PhysicalType::Primitive(primitive) => match primitive { + PrimitiveType::Int32 => { + let _ = util::zigzag_i64(&mut block)?; + }, + PrimitiveType::Int64 => { + let _ = util::zigzag_i64(&mut block)?; + }, + PrimitiveType::Float32 => { + block = &block[std::mem::size_of::()..]; + }, + PrimitiveType::Float64 => { + block = &block[std::mem::size_of::()..]; + }, + PrimitiveType::MonthDayNano => { + block = &block[12..]; + }, + PrimitiveType::Int128 => { + let avro_inner = match avro_field { + AvroSchema::Bytes(_) | AvroSchema::Fixed(_) => avro_field, + AvroSchema::Union(u) => match &u.as_slice() { + &[e, AvroSchema::Null] | &[AvroSchema::Null, e] => e, + _ => unreachable!(), + }, + _ => unreachable!(), + }; + let len = match avro_inner { + AvroSchema::Bytes(_) => { + util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })? + }, + AvroSchema::Fixed(b) => b.size, + _ => unreachable!(), + }; + block = &block[len..]; + }, + _ => unreachable!(), + }, + PhysicalType::Utf8 | PhysicalType::Binary => { + let len: usize = util::zigzag_i64(&mut block)?.try_into().map_err(|_| { + Error::ExternalFormat( + "Avro format contains a non-usize number of bytes".to_string(), + ) + })?; + block = &block[len..]; + }, + PhysicalType::FixedSizeBinary => { + let len = if let DataType::FixedSizeBinary(len) = &field.data_type { + *len + } else { + unreachable!() + }; + + block = &block[len..]; + }, + PhysicalType::Dictionary(_) => { + let _ = util::zigzag_i64(&mut block)? as i32; + }, + _ => todo!(), + }, + } + Ok(block) +} + +/// Deserializes a [`Block`] assumed to be encoded according to [`AvroField`] into [`Chunk`], +/// using `projection` to ignore `avro_fields`. +/// # Panics +/// `fields`, `avro_fields` and `projection` must have the same length. +pub fn deserialize( + block: &Block, + fields: &[Field], + avro_fields: &[AvroField], + projection: &[bool], +) -> Result>> { + assert_eq!(fields.len(), avro_fields.len()); + assert_eq!(fields.len(), projection.len()); + + let rows = block.number_of_rows; + let mut block = block.data.as_ref(); + + // create mutables, one per field + let mut arrays: Vec> = fields + .iter() + .zip(avro_fields.iter()) + .zip(projection.iter()) + .map(|((field, avro_field), projection)| { + if *projection { + make_mutable(&field.data_type, Some(&avro_field.schema), rows) + } else { + // just something; we are not going to use it + make_mutable(&DataType::Int32, None, 0) + } + }) + .collect::>()?; + + // this is _the_ expensive transpose (rows -> columns) + for _ in 0..rows { + let iter = arrays + .iter_mut() + .zip(fields.iter()) + .zip(avro_fields.iter()) + .zip(projection.iter()); + + for (((array, field), avro_field), projection) in iter { + block = if *projection { + deserialize_item(array.as_mut(), field.is_nullable, &avro_field.schema, block) + } else { + skip_item(field, &avro_field.schema, block) + }? + } + } + Chunk::try_new( + arrays + .iter_mut() + .zip(projection.iter()) + .filter_map(|x| x.1.then(|| x.0)) + .map(|array| array.as_box()) + .collect(), + ) +} diff --git a/crates/nano-arrow/src/io/avro/read/mod.rs b/crates/nano-arrow/src/io/avro/read/mod.rs new file mode 100644 index 000000000000..5014499c12a6 --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/mod.rs @@ -0,0 +1,67 @@ +//! APIs to read from Avro format to arrow. +use std::io::Read; + +use avro_schema::file::FileMetadata; +use avro_schema::read::fallible_streaming_iterator::FallibleStreamingIterator; +use avro_schema::read::{block_iterator, BlockStreamingIterator}; +use avro_schema::schema::Field as AvroField; + +mod deserialize; +pub use deserialize::deserialize; +mod nested; +mod schema; +mod util; + +pub use schema::infer_schema; + +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Field; +use crate::error::Result; + +/// Single threaded, blocking reader of Avro; [`Iterator`] of [`Chunk`]. +pub struct Reader { + iter: BlockStreamingIterator, + avro_fields: Vec, + fields: Vec, + projection: Vec, +} + +impl Reader { + /// Creates a new [`Reader`]. + pub fn new( + reader: R, + metadata: FileMetadata, + fields: Vec, + projection: Option>, + ) -> Self { + let projection = projection.unwrap_or_else(|| fields.iter().map(|_| true).collect()); + + Self { + iter: block_iterator(reader, metadata.compression, metadata.marker), + avro_fields: metadata.record.fields, + fields, + projection, + } + } + + /// Deconstructs itself into its internal reader + pub fn into_inner(self) -> R { + self.iter.into_inner() + } +} + +impl Iterator for Reader { + type Item = Result>>; + + fn next(&mut self) -> Option { + let fields = &self.fields[..]; + let avro_fields = &self.avro_fields; + let projection = &self.projection; + + self.iter + .next() + .transpose() + .map(|maybe_block| deserialize(maybe_block?, fields, avro_fields, projection)) + } +} diff --git a/crates/nano-arrow/src/io/avro/read/nested.rs b/crates/nano-arrow/src/io/avro/read/nested.rs new file mode 100644 index 000000000000..056d9a8f836e --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/nested.rs @@ -0,0 +1,309 @@ +use crate::array::*; +use crate::bitmap::*; +use crate::datatypes::*; +use crate::error::*; +use crate::offset::{Offset, Offsets}; + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableListArray { + data_type: DataType, + offsets: Offsets, + values: Box, + validity: Option, +} + +impl DynMutableListArray { + pub fn new_from(values: Box, data_type: DataType, capacity: usize) -> Self { + assert_eq!(values.len(), 0); + ListArray::::get_child_field(&data_type); + Self { + data_type, + offsets: Offsets::::with_capacity(capacity), + values, + validity: None, + } + } + + /// The values + pub fn mut_values(&mut self) -> &mut dyn MutableArray { + self.values.as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> Result<()> { + let total_length = self.values.len(); + let offset = self.offsets.last().to_usize(); + let length = total_length.checked_sub(offset).ok_or(Error::Overflow)?; + + self.offsets.try_push_usize(length)?; + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.offsets.extend_constant(1); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + fn init_validity(&mut self) { + let len = self.offsets.len_proxy(); + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableListArray { + fn len(&self) -> usize { + self.offsets.len_proxy() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .boxed() + } + + fn as_arc(&mut self) -> std::sync::Arc { + ListArray::new( + self.data_type.clone(), + std::mem::take(&mut self.offsets).into(), + self.values.as_box(), + std::mem::take(&mut self.validity).map(|x| x.into()), + ) + .arced() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} + +#[derive(Debug)] +pub struct FixedItemsUtf8Dictionary { + data_type: DataType, + keys: MutablePrimitiveArray, + values: Utf8Array, +} + +impl FixedItemsUtf8Dictionary { + pub fn with_capacity(values: Utf8Array, capacity: usize) -> Self { + Self { + data_type: DataType::Dictionary( + IntegerType::Int32, + Box::new(values.data_type().clone()), + false, + ), + keys: MutablePrimitiveArray::::with_capacity(capacity), + values, + } + } + + pub fn push_valid(&mut self, key: i32) { + self.keys.push(Some(key)) + } + + /// pushes a null value + pub fn push_null(&mut self) { + self.keys.push(None) + } +} + +impl MutableArray for FixedItemsUtf8Dictionary { + fn len(&self) -> usize { + self.keys.len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.keys.validity() + } + + fn as_box(&mut self) -> Box { + Box::new( + DictionaryArray::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) + } + + fn as_arc(&mut self) -> std::sync::Arc { + std::sync::Arc::new( + DictionaryArray::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} + +/// Auxiliary struct +#[derive(Debug)] +pub struct DynMutableStructArray { + data_type: DataType, + values: Vec>, + validity: Option, +} + +impl DynMutableStructArray { + pub fn new(values: Vec>, data_type: DataType) -> Self { + Self { + data_type, + values, + validity: None, + } + } + + /// The values + pub fn mut_values(&mut self, field: usize) -> &mut dyn MutableArray { + self.values[field].as_mut() + } + + #[inline] + pub fn try_push_valid(&mut self) -> Result<()> { + if let Some(validity) = &mut self.validity { + validity.push(true) + } + Ok(()) + } + + #[inline] + fn push_null(&mut self) { + self.values.iter_mut().for_each(|x| x.push_null()); + match &mut self.validity { + Some(validity) => validity.push(false), + None => self.init_validity(), + } + } + + fn init_validity(&mut self) { + let len = self.len(); + + let mut validity = MutableBitmap::new(); + validity.extend_constant(len, true); + validity.set(len - 1, false); + self.validity = Some(validity) + } +} + +impl MutableArray for DynMutableStructArray { + fn len(&self) -> usize { + self.values[0].len() + } + + fn validity(&self) -> Option<&MutableBitmap> { + self.validity.as_ref() + } + + fn as_box(&mut self) -> Box { + let values = self.values.iter_mut().map(|x| x.as_box()).collect(); + + Box::new(StructArray::new( + self.data_type.clone(), + values, + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn as_arc(&mut self) -> std::sync::Arc { + let values = self.values.iter_mut().map(|x| x.as_box()).collect(); + + std::sync::Arc::new(StructArray::new( + self.data_type.clone(), + values, + std::mem::take(&mut self.validity).map(|x| x.into()), + )) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + #[inline] + fn push_null(&mut self) { + self.push_null() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!(); + } +} diff --git a/crates/nano-arrow/src/io/avro/read/schema.rs b/crates/nano-arrow/src/io/avro/read/schema.rs new file mode 100644 index 000000000000..ca50c59ca9fa --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/schema.rs @@ -0,0 +1,145 @@ +use avro_schema::schema::{Enum, Fixed, Record, Schema as AvroSchema}; + +use crate::datatypes::*; +use crate::error::{Error, Result}; + +fn external_props(schema: &AvroSchema) -> Metadata { + let mut props = Metadata::new(); + match &schema { + AvroSchema::Record(Record { + doc: Some(ref doc), .. + }) + | AvroSchema::Enum(Enum { + doc: Some(ref doc), .. + }) => { + props.insert("avro::doc".to_string(), doc.clone()); + }, + _ => {}, + } + props +} + +/// Infers an [`Schema`] from the root [`Record`]. +/// This +pub fn infer_schema(record: &Record) -> Result { + Ok(record + .fields + .iter() + .map(|field| { + schema_to_field( + &field.schema, + Some(&field.name), + external_props(&field.schema), + ) + }) + .collect::>>()? + .into()) +} + +fn schema_to_field(schema: &AvroSchema, name: Option<&str>, props: Metadata) -> Result { + let mut nullable = false; + let data_type = match schema { + AvroSchema::Null => DataType::Null, + AvroSchema::Boolean => DataType::Boolean, + AvroSchema::Int(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::IntLogical::Date => DataType::Date32, + avro_schema::schema::IntLogical::Time => DataType::Time32(TimeUnit::Millisecond), + }, + None => DataType::Int32, + }, + AvroSchema::Long(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::LongLogical::Time => DataType::Time64(TimeUnit::Microsecond), + avro_schema::schema::LongLogical::TimestampMillis => { + DataType::Timestamp(TimeUnit::Millisecond, Some("00:00".to_string())) + }, + avro_schema::schema::LongLogical::TimestampMicros => { + DataType::Timestamp(TimeUnit::Microsecond, Some("00:00".to_string())) + }, + avro_schema::schema::LongLogical::LocalTimestampMillis => { + DataType::Timestamp(TimeUnit::Millisecond, None) + }, + avro_schema::schema::LongLogical::LocalTimestampMicros => { + DataType::Timestamp(TimeUnit::Microsecond, None) + }, + }, + None => DataType::Int64, + }, + AvroSchema::Float => DataType::Float32, + AvroSchema::Double => DataType::Float64, + AvroSchema::Bytes(logical) => match logical { + Some(logical) => match logical { + avro_schema::schema::BytesLogical::Decimal(precision, scale) => { + DataType::Decimal(*precision, *scale) + }, + }, + None => DataType::Binary, + }, + AvroSchema::String(_) => DataType::Utf8, + AvroSchema::Array(item_schema) => DataType::List(Box::new(schema_to_field( + item_schema, + Some("item"), // default name for list items + Metadata::default(), + )?)), + AvroSchema::Map(_) => todo!("Avro maps are mapped to MapArrays"), + AvroSchema::Union(schemas) => { + // If there are only two variants and one of them is null, set the other type as the field data type + let has_nullable = schemas.iter().any(|x| x == &AvroSchema::Null); + if has_nullable && schemas.len() == 2 { + nullable = true; + if let Some(schema) = schemas + .iter() + .find(|&schema| !matches!(schema, AvroSchema::Null)) + { + schema_to_field(schema, None, Metadata::default())?.data_type + } else { + return Err(Error::NotYetImplemented(format!( + "Can't read avro union {schema:?}" + ))); + } + } else { + let fields = schemas + .iter() + .map(|s| schema_to_field(s, None, Metadata::default())) + .collect::>>()?; + DataType::Union(fields, None, UnionMode::Dense) + } + }, + AvroSchema::Record(Record { fields, .. }) => { + let fields = fields + .iter() + .map(|field| { + let mut props = Metadata::new(); + if let Some(doc) = &field.doc { + props.insert("avro::doc".to_string(), doc.clone()); + } + schema_to_field(&field.schema, Some(&field.name), props) + }) + .collect::>()?; + DataType::Struct(fields) + }, + AvroSchema::Enum { .. } => { + return Ok(Field::new( + name.unwrap_or_default(), + DataType::Dictionary(IntegerType::Int32, Box::new(DataType::Utf8), false), + false, + )) + }, + AvroSchema::Fixed(Fixed { size, logical, .. }) => match logical { + Some(logical) => match logical { + avro_schema::schema::FixedLogical::Decimal(precision, scale) => { + DataType::Decimal(*precision, *scale) + }, + avro_schema::schema::FixedLogical::Duration => { + DataType::Interval(IntervalUnit::MonthDayNano) + }, + }, + None => DataType::FixedSizeBinary(*size), + }, + }; + + let name = name.unwrap_or_default(); + + Ok(Field::new(name, data_type, nullable).with_metadata(props)) +} diff --git a/crates/nano-arrow/src/io/avro/read/util.rs b/crates/nano-arrow/src/io/avro/read/util.rs new file mode 100644 index 000000000000..a26ee0e005ee --- /dev/null +++ b/crates/nano-arrow/src/io/avro/read/util.rs @@ -0,0 +1,17 @@ +use std::io::Read; + +use super::super::avro_decode; +use crate::error::{Error, Result}; + +pub fn zigzag_i64(reader: &mut R) -> Result { + let z = decode_variable(reader)?; + Ok(if z & 0x1 == 0 { + (z >> 1) as i64 + } else { + !(z >> 1) as i64 + }) +} + +fn decode_variable(reader: &mut R) -> Result { + avro_decode!(reader) +} diff --git a/crates/nano-arrow/src/io/avro/write/mod.rs b/crates/nano-arrow/src/io/avro/write/mod.rs new file mode 100644 index 000000000000..6448782bb44e --- /dev/null +++ b/crates/nano-arrow/src/io/avro/write/mod.rs @@ -0,0 +1,28 @@ +//! APIs to write to Avro format. +use avro_schema::file::Block; + +mod schema; +pub use schema::to_record; +mod serialize; +pub use serialize::{can_serialize, new_serializer, BoxSerializer}; + +/// consumes a set of [`BoxSerializer`] into an [`Block`]. +/// # Panics +/// Panics iff the number of items in any of the serializers is not equal to the number of rows +/// declared in the `block`. +pub fn serialize(serializers: &mut [BoxSerializer], block: &mut Block) { + let Block { + data, + number_of_rows, + } = block; + + data.clear(); // restart it + + // _the_ transpose (columns -> rows) + for _ in 0..*number_of_rows { + for serializer in &mut *serializers { + let item_data = serializer.next().unwrap(); + data.extend(item_data); + } + } +} diff --git a/crates/nano-arrow/src/io/avro/write/schema.rs b/crates/nano-arrow/src/io/avro/write/schema.rs new file mode 100644 index 000000000000..b81cdc77ce3a --- /dev/null +++ b/crates/nano-arrow/src/io/avro/write/schema.rs @@ -0,0 +1,91 @@ +use avro_schema::schema::{ + BytesLogical, Field as AvroField, Fixed, FixedLogical, IntLogical, LongLogical, Record, + Schema as AvroSchema, +}; + +use crate::datatypes::*; +use crate::error::{Error, Result}; + +/// Converts a [`Schema`] to an Avro [`Record`]. +pub fn to_record(schema: &Schema) -> Result { + let mut name_counter: i32 = 0; + let fields = schema + .fields + .iter() + .map(|f| field_to_field(f, &mut name_counter)) + .collect::>()?; + Ok(Record { + name: "".to_string(), + namespace: None, + doc: None, + aliases: vec![], + fields, + }) +} + +fn field_to_field(field: &Field, name_counter: &mut i32) -> Result { + let schema = type_to_schema(field.data_type(), field.is_nullable, name_counter)?; + Ok(AvroField::new(&field.name, schema)) +} + +fn type_to_schema( + data_type: &DataType, + is_nullable: bool, + name_counter: &mut i32, +) -> Result { + Ok(if is_nullable { + AvroSchema::Union(vec![ + AvroSchema::Null, + _type_to_schema(data_type, name_counter)?, + ]) + } else { + _type_to_schema(data_type, name_counter)? + }) +} + +fn _get_field_name(name_counter: &mut i32) -> String { + *name_counter += 1; + format!("r{name_counter}") +} + +fn _type_to_schema(data_type: &DataType, name_counter: &mut i32) -> Result { + Ok(match data_type.to_logical_type() { + DataType::Null => AvroSchema::Null, + DataType::Boolean => AvroSchema::Boolean, + DataType::Int32 => AvroSchema::Int(None), + DataType::Int64 => AvroSchema::Long(None), + DataType::Float32 => AvroSchema::Float, + DataType::Float64 => AvroSchema::Double, + DataType::Binary => AvroSchema::Bytes(None), + DataType::LargeBinary => AvroSchema::Bytes(None), + DataType::Utf8 => AvroSchema::String(None), + DataType::LargeUtf8 => AvroSchema::String(None), + DataType::LargeList(inner) | DataType::List(inner) => AvroSchema::Array(Box::new( + type_to_schema(&inner.data_type, inner.is_nullable, name_counter)?, + )), + DataType::Struct(fields) => AvroSchema::Record(Record::new( + _get_field_name(name_counter), + fields + .iter() + .map(|f| field_to_field(f, name_counter)) + .collect::>>()?, + )), + DataType::Date32 => AvroSchema::Int(Some(IntLogical::Date)), + DataType::Time32(TimeUnit::Millisecond) => AvroSchema::Int(Some(IntLogical::Time)), + DataType::Time64(TimeUnit::Microsecond) => AvroSchema::Long(Some(LongLogical::Time)), + DataType::Timestamp(TimeUnit::Millisecond, None) => { + AvroSchema::Long(Some(LongLogical::LocalTimestampMillis)) + }, + DataType::Timestamp(TimeUnit::Microsecond, None) => { + AvroSchema::Long(Some(LongLogical::LocalTimestampMicros)) + }, + DataType::Interval(IntervalUnit::MonthDayNano) => { + let mut fixed = Fixed::new("", 12); + fixed.logical = Some(FixedLogical::Duration); + AvroSchema::Fixed(fixed) + }, + DataType::FixedSizeBinary(size) => AvroSchema::Fixed(Fixed::new("", *size)), + DataType::Decimal(p, s) => AvroSchema::Bytes(Some(BytesLogical::Decimal(*p, *s))), + other => return Err(Error::NotYetImplemented(format!("write {other:?} to avro"))), + }) +} diff --git a/crates/nano-arrow/src/io/avro/write/serialize.rs b/crates/nano-arrow/src/io/avro/write/serialize.rs new file mode 100644 index 000000000000..888861db376a --- /dev/null +++ b/crates/nano-arrow/src/io/avro/write/serialize.rs @@ -0,0 +1,535 @@ +use avro_schema::schema::{Record, Schema as AvroSchema}; +use avro_schema::write::encode; + +use super::super::super::iterator::*; +use crate::array::*; +use crate::bitmap::utils::ZipValidity; +use crate::datatypes::{DataType, IntervalUnit, PhysicalType, PrimitiveType}; +use crate::offset::Offset; +use crate::types::months_days_ns; + +// Zigzag representation of false and true respectively. +const IS_NULL: u8 = 0; +const IS_VALID: u8 = 2; + +/// A type alias for a boxed [`StreamingIterator`], used to write arrays into avro rows +/// (i.e. a column -> row transposition of types known at run-time) +pub type BoxSerializer<'a> = Box + 'a + Send + Sync>; + +fn utf8_required(array: &Utf8Array) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x.as_bytes()); + }, + vec![], + )) +} + +fn utf8_optional(array: &Utf8Array) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x.as_bytes()); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn binary_required(array: &BinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x); + }, + vec![], + )) +} + +fn binary_optional(array: &BinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(x.len() as i64, buf).unwrap(); + buf.extend_from_slice(x); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn fixed_size_binary_required(array: &FixedSizeBinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.values_iter(), + |x, buf| { + buf.extend_from_slice(x); + }, + vec![], + )) +} + +fn fixed_size_binary_optional(array: &FixedSizeBinaryArray) -> BoxSerializer { + Box::new(BufStreamingIterator::new( + array.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend_from_slice(x); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn list_required<'a, O: Offset>(array: &'a ListArray, schema: &AvroSchema) -> BoxSerializer<'a> { + let mut inner = new_serializer(array.values().as_ref(), schema); + let lengths = array + .offsets() + .buffer() + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64); + + Box::new(BufStreamingIterator::new( + lengths, + move |length, buf| { + encode::zigzag_encode(length, buf).unwrap(); + let mut rows = 0; + while let Some(item) = inner.next() { + buf.extend_from_slice(item); + rows += 1; + if rows == length { + encode::zigzag_encode(0, buf).unwrap(); + break; + } + } + }, + vec![], + )) +} + +fn list_optional<'a, O: Offset>(array: &'a ListArray, schema: &AvroSchema) -> BoxSerializer<'a> { + let mut inner = new_serializer(array.values().as_ref(), schema); + let lengths = array + .offsets() + .buffer() + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64); + let lengths = ZipValidity::new_with_validity(lengths, array.validity()); + + Box::new(BufStreamingIterator::new( + lengths, + move |length, buf| { + if let Some(length) = length { + buf.push(IS_VALID); + encode::zigzag_encode(length, buf).unwrap(); + let mut rows = 0; + while let Some(item) = inner.next() { + buf.extend_from_slice(item); + rows += 1; + if rows == length { + encode::zigzag_encode(0, buf).unwrap(); + break; + } + } + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) +} + +fn struct_required<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + Box::new(BufStreamingIterator::new( + 0..array.len(), + move |_, buf| { + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + }, + vec![], + )) +} + +fn struct_optional<'a>(array: &'a StructArray, schema: &Record) -> BoxSerializer<'a> { + let schemas = schema.fields.iter().map(|x| &x.schema); + let mut inner = array + .values() + .iter() + .zip(schemas) + .map(|(x, schema)| new_serializer(x.as_ref(), schema)) + .collect::>(); + + let iterator = ZipValidity::new_with_validity(0..array.len(), array.validity()); + + Box::new(BufStreamingIterator::new( + iterator, + move |maybe, buf| { + if maybe.is_some() { + buf.push(IS_VALID); + inner + .iter_mut() + .for_each(|item| buf.extend_from_slice(item.next().unwrap())) + } else { + buf.push(IS_NULL); + // skip the item + inner.iter_mut().for_each(|item| { + let _ = item.next().unwrap(); + }); + } + }, + vec![], + )) +} + +/// Creates a [`StreamingIterator`] trait object that presents items from `array` +/// encoded according to `schema`. +/// # Panic +/// This function panics iff the `data_type` is not supported (use [`can_serialize`] to check) +/// # Implementation +/// This function performs minimal CPU work: it dynamically dispatches based on the schema +/// and arrow type. +pub fn new_serializer<'a>(array: &'a dyn Array, schema: &AvroSchema) -> BoxSerializer<'a> { + let data_type = array.data_type().to_physical_type(); + + match (data_type, schema) { + (PhysicalType::Boolean, AvroSchema::Boolean) => { + let values = array.as_any().downcast_ref::().unwrap(); + Box::new(BufStreamingIterator::new( + values.values_iter(), + |x, buf| { + buf.push(x as u8); + }, + vec![], + )) + }, + (PhysicalType::Boolean, AvroSchema::Union(_)) => { + let values = array.as_any().downcast_ref::().unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.extend_from_slice(&[IS_VALID, x as u8]); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Utf8, AvroSchema::Union(_)) => { + utf8_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeUtf8, AvroSchema::Union(_)) => { + utf8_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Utf8, AvroSchema::String(_)) => { + utf8_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeUtf8, AvroSchema::String(_)) => { + utf8_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Binary, AvroSchema::Union(_)) => { + binary_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeBinary, AvroSchema::Union(_)) => { + binary_optional::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::FixedSizeBinary, AvroSchema::Union(_)) => { + fixed_size_binary_optional(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::Binary, AvroSchema::Bytes(_)) => { + binary_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::LargeBinary, AvroSchema::Bytes(_)) => { + binary_required::(array.as_any().downcast_ref().unwrap()) + }, + (PhysicalType::FixedSizeBinary, AvroSchema::Fixed(_)) => { + fixed_size_binary_required(array.as_any().downcast_ref().unwrap()) + }, + + (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(*x as i64, buf).unwrap(); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int32), AvroSchema::Int(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + encode::zigzag_encode(*x as i64, buf).unwrap(); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + encode::zigzag_encode(*x, buf).unwrap(); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int64), AvroSchema::Long(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + encode::zigzag_encode(*x, buf).unwrap(); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend(x.to_le_bytes()) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float32), AvroSchema::Float) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + buf.extend_from_slice(&x.to_le_bytes()); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + buf.extend(x.to_le_bytes()) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Float64), AvroSchema::Double) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + buf.extend_from_slice(&x.to_le_bytes()); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Bytes(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + |x, buf| { + let len = ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize; + encode::zigzag_encode((16 - len) as i64, buf).unwrap(); + buf.extend_from_slice(&x.to_be_bytes()[len..]); + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::Int128), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + let len = + ((x.leading_zeros() / 8) - ((x.leading_zeros() / 8) % 2)) as usize; + encode::zigzag_encode((16 - len) as i64, buf).unwrap(); + buf.extend_from_slice(&x.to_be_bytes()[len..]); + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Fixed(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.values().iter(), + interval_write, + vec![], + )) + }, + (PhysicalType::Primitive(PrimitiveType::MonthDayNano), AvroSchema::Union(_)) => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap(); + Box::new(BufStreamingIterator::new( + values.iter(), + |x, buf| { + if let Some(x) = x { + buf.push(IS_VALID); + interval_write(x, buf) + } else { + buf.push(IS_NULL); + } + }, + vec![], + )) + }, + + (PhysicalType::List, AvroSchema::Array(schema)) => { + list_required::(array.as_any().downcast_ref().unwrap(), schema.as_ref()) + }, + (PhysicalType::LargeList, AvroSchema::Array(schema)) => { + list_required::(array.as_any().downcast_ref().unwrap(), schema.as_ref()) + }, + (PhysicalType::List, AvroSchema::Union(inner)) => { + let schema = if let AvroSchema::Array(schema) = &inner[1] { + schema.as_ref() + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + list_optional::(array.as_any().downcast_ref().unwrap(), schema) + }, + (PhysicalType::LargeList, AvroSchema::Union(inner)) => { + let schema = if let AvroSchema::Array(schema) = &inner[1] { + schema.as_ref() + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + list_optional::(array.as_any().downcast_ref().unwrap(), schema) + }, + (PhysicalType::Struct, AvroSchema::Record(inner)) => { + struct_required(array.as_any().downcast_ref().unwrap(), inner) + }, + (PhysicalType::Struct, AvroSchema::Union(inner)) => { + let inner = if let AvroSchema::Record(inner) = &inner[1] { + inner + } else { + unreachable!("The schema declaration does not match the deserialization") + }; + struct_optional(array.as_any().downcast_ref().unwrap(), inner) + }, + (a, b) => todo!("{:?} -> {:?} not supported", a, b), + } +} + +/// Whether [`new_serializer`] supports `data_type`. +pub fn can_serialize(data_type: &DataType) -> bool { + use DataType::*; + match data_type.to_logical_type() { + List(inner) => return can_serialize(&inner.data_type), + LargeList(inner) => return can_serialize(&inner.data_type), + Struct(inner) => return inner.iter().all(|inner| can_serialize(&inner.data_type)), + _ => {}, + }; + + matches!( + data_type, + Boolean + | Int32 + | Int64 + | Float32 + | Float64 + | Decimal(_, _) + | Utf8 + | Binary + | FixedSizeBinary(_) + | LargeUtf8 + | LargeBinary + | Interval(IntervalUnit::MonthDayNano) + ) +} + +#[inline] +fn interval_write(x: &months_days_ns, buf: &mut Vec) { + // https://avro.apache.org/docs/current/spec.html#Duration + // 12 bytes, months, days, millis in LE + buf.reserve(12); + buf.extend(x.months().to_le_bytes()); + buf.extend(x.days().to_le_bytes()); + buf.extend(((x.ns() / 1_000_000) as i32).to_le_bytes()); +} diff --git a/crates/nano-arrow/src/io/flight/mod.rs b/crates/nano-arrow/src/io/flight/mod.rs new file mode 100644 index 000000000000..0cce1774568f --- /dev/null +++ b/crates/nano-arrow/src/io/flight/mod.rs @@ -0,0 +1,243 @@ +//! Serialization and deserialization to Arrow's flight protocol + +use arrow_format::flight::data::{FlightData, SchemaResult}; +use arrow_format::ipc; +use arrow_format::ipc::planus::ReadAsRoot; + +use super::ipc::read::Dictionaries; +pub use super::ipc::write::default_ipc_fields; +use super::ipc::{IpcField, IpcSchema}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +pub use crate::io::ipc::write::common::WriteOptions; +use crate::io::ipc::write::common::{encode_chunk, DictionaryTracker, EncodedData}; +use crate::io::ipc::{read, write}; + +/// Serializes [`Chunk`] to a vector of [`FlightData`] representing the serialized dictionaries +/// and a [`FlightData`] representing the batch. +/// # Errors +/// This function errors iff `fields` is not consistent with `columns` +pub fn serialize_batch( + chunk: &Chunk>, + fields: &[IpcField], + options: &WriteOptions, +) -> Result<(Vec, FlightData)> { + if fields.len() != chunk.arrays().len() { + return Err(Error::InvalidArgumentError("The argument `fields` must be consistent with the columns' schema. Use e.g. &arrow2::io::flight::default_ipc_fields(&schema.fields)".to_string())); + } + + let mut dictionary_tracker = DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: false, + }; + + let (encoded_dictionaries, encoded_batch) = + encode_chunk(chunk, fields, &mut dictionary_tracker, options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + Ok((flight_dictionaries, flight_batch)) +} + +impl From for FlightData { + fn from(data: EncodedData) -> Self { + FlightData { + data_header: data.ipc_message, + data_body: data.arrow_data, + ..Default::default() + } + } +} + +/// Serializes a [`Schema`] to [`SchemaResult`]. +pub fn serialize_schema_to_result( + schema: &Schema, + ipc_fields: Option<&[IpcField]>, +) -> SchemaResult { + SchemaResult { + schema: _serialize_schema(schema, ipc_fields), + } +} + +/// Serializes a [`Schema`] to [`FlightData`]. +pub fn serialize_schema(schema: &Schema, ipc_fields: Option<&[IpcField]>) -> FlightData { + FlightData { + data_header: _serialize_schema(schema, ipc_fields), + ..Default::default() + } +} + +/// Convert a [`Schema`] to bytes in the format expected in [`arrow_format::flight::data::FlightInfo`]. +pub fn serialize_schema_to_info( + schema: &Schema, + ipc_fields: Option<&[IpcField]>, +) -> Result> { + let encoded_data = if let Some(ipc_fields) = ipc_fields { + schema_as_encoded_data(schema, ipc_fields) + } else { + let ipc_fields = default_ipc_fields(&schema.fields); + schema_as_encoded_data(schema, &ipc_fields) + }; + + let mut schema = vec![]; + write::common_sync::write_message(&mut schema, &encoded_data)?; + Ok(schema) +} + +fn _serialize_schema(schema: &Schema, ipc_fields: Option<&[IpcField]>) -> Vec { + if let Some(ipc_fields) = ipc_fields { + write::schema_to_bytes(schema, ipc_fields) + } else { + let ipc_fields = default_ipc_fields(&schema.fields); + write::schema_to_bytes(schema, &ipc_fields) + } +} + +fn schema_as_encoded_data(schema: &Schema, ipc_fields: &[IpcField]) -> EncodedData { + EncodedData { + ipc_message: write::schema_to_bytes(schema, ipc_fields), + arrow_data: vec![], + } +} + +/// Deserialize an IPC message into [`Schema`], [`IpcSchema`]. +/// Use to deserialize [`FlightData::data_header`] and [`SchemaResult::schema`]. +pub fn deserialize_schemas(bytes: &[u8]) -> Result<(Schema, IpcSchema)> { + read::deserialize_schema(bytes) +} + +/// Deserializes [`FlightData`] representing a record batch message to [`Chunk`]. +pub fn deserialize_batch( + data: &FlightData, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &read::Dictionaries, +) -> Result>> { + // check that the data_header is a record batch message + let message = arrow_format::ipc::MessageRef::read_as_root(&data.data_header) + .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {err:?}")))?; + + let length = data.data_body.len(); + let mut reader = std::io::Cursor::new(&data.data_body); + + match message.header()?.ok_or_else(|| { + Error::oos("Unable to convert flight data header to a record batch".to_string()) + })? { + ipc::MessageHeaderRef::RecordBatch(batch) => read::read_record_batch( + batch, + fields, + ipc_schema, + None, + None, + dictionaries, + message.version()?, + &mut reader, + 0, + length as u64, + &mut Default::default(), + ), + _ => Err(Error::nyi( + "flight currently only supports reading RecordBatch messages", + )), + } +} + +/// Deserializes [`FlightData`], assuming it to be a dictionary message, into `dictionaries`. +pub fn deserialize_dictionary( + data: &FlightData, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut read::Dictionaries, +) -> Result<()> { + let message = ipc::MessageRef::read_as_root(&data.data_header)?; + + let chunk = if let ipc::MessageHeaderRef::DictionaryBatch(chunk) = message + .header()? + .ok_or_else(|| Error::oos("Header is required"))? + { + chunk + } else { + return Ok(()); + }; + + let length = data.data_body.len(); + let mut reader = std::io::Cursor::new(&data.data_body); + read::read_dictionary( + chunk, + fields, + ipc_schema, + dictionaries, + &mut reader, + 0, + length as u64, + &mut Default::default(), + )?; + + Ok(()) +} + +/// Deserializes [`FlightData`] into either a [`Chunk`] (when the message is a record batch) +/// or by upserting into `dictionaries` (when the message is a dictionary) +pub fn deserialize_message( + data: &FlightData, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, +) -> Result>>> { + let FlightData { + data_header, + data_body, + .. + } = data; + + let message = arrow_format::ipc::MessageRef::read_as_root(data_header)?; + let header = message + .header()? + .ok_or_else(|| Error::oos("IPC Message must contain a header"))?; + + match header { + ipc::MessageHeaderRef::RecordBatch(batch) => { + let length = data_body.len(); + let mut reader = std::io::Cursor::new(data_body); + + let chunk = read::read_record_batch( + batch, + fields, + ipc_schema, + None, + None, + dictionaries, + arrow_format::ipc::MetadataVersion::V5, + &mut reader, + 0, + length as u64, + &mut Default::default(), + )?; + + Ok(chunk.into()) + }, + ipc::MessageHeaderRef::DictionaryBatch(dict_batch) => { + let length = data_body.len(); + let mut reader = std::io::Cursor::new(data_body); + + read::read_dictionary( + dict_batch, + fields, + ipc_schema, + dictionaries, + &mut reader, + 0, + length as u64, + &mut Default::default(), + )?; + Ok(None) + }, + t => Err(Error::nyi(format!( + "Reading types other than record batches not yet supported, unable to read {t:?}" + ))), + } +} diff --git a/crates/nano-arrow/src/io/ipc/append/mod.rs b/crates/nano-arrow/src/io/ipc/append/mod.rs new file mode 100644 index 000000000000..1acb39a931ef --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/append/mod.rs @@ -0,0 +1,72 @@ +//! A struct adapter of Read+Seek+Write to append to IPC files +// read header and convert to writer information +// seek to first byte of header - 1 +// write new batch +// write new footer +use std::io::{Read, Seek, SeekFrom, Write}; + +use super::endianness::is_native_little_endian; +use super::read::{self, FileMetadata}; +use super::write::common::DictionaryTracker; +use super::write::writer::*; +use super::write::*; +use crate::error::{Error, Result}; + +impl FileWriter { + /// Creates a new [`FileWriter`] from an existing file, seeking to the last message + /// and appending new messages afterwards. Users call `finish` to write the footer (with both) + /// the existing and appended messages on it. + /// # Error + /// This function errors iff: + /// * the file's endianness is not the native endianness (not yet supported) + /// * the file is not a valid Arrow IPC file + pub fn try_from_file( + mut writer: R, + metadata: FileMetadata, + options: WriteOptions, + ) -> Result> { + if metadata.ipc_schema.is_little_endian != is_native_little_endian() { + return Err(Error::nyi( + "Appending to a file of a non-native endianness is still not supported", + )); + } + + let dictionaries = + read::read_file_dictionaries(&mut writer, &metadata, &mut Default::default())?; + + let last_block = metadata.blocks.last().ok_or_else(|| { + Error::oos("An Arrow IPC file must have at least 1 message (the schema message)") + })?; + let offset: u64 = last_block + .offset + .try_into() + .map_err(|_| Error::oos("The block's offset must be a positive number"))?; + let meta_data_length: u64 = last_block + .meta_data_length + .try_into() + .map_err(|_| Error::oos("The block's meta length must be a positive number"))?; + let body_length: u64 = last_block + .body_length + .try_into() + .map_err(|_| Error::oos("The block's body length must be a positive number"))?; + let offset: u64 = offset + meta_data_length + body_length; + + writer.seek(SeekFrom::Start(offset))?; + + Ok(FileWriter { + writer, + options, + schema: metadata.schema, + ipc_fields: metadata.ipc_schema.fields, + block_offsets: offset as usize, + dictionary_blocks: metadata.dictionaries.unwrap_or_default(), + record_blocks: metadata.blocks, + state: State::Started, // file already exists, so we are ready + dictionary_tracker: DictionaryTracker { + dictionaries, + cannot_replace: true, + }, + encoded_message: Default::default(), + }) + } +} diff --git a/crates/nano-arrow/src/io/ipc/compression.rs b/crates/nano-arrow/src/io/ipc/compression.rs new file mode 100644 index 000000000000..9a69deb8248a --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/compression.rs @@ -0,0 +1,91 @@ +use crate::error::Result; + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn decompress_lz4(input_buf: &[u8], output_buf: &mut [u8]) -> Result<()> { + use std::io::Read; + let mut decoder = lz4::Decoder::new(input_buf)?; + decoder.read_exact(output_buf).map_err(|e| e.into()) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn decompress_zstd(input_buf: &[u8], output_buf: &mut [u8]) -> Result<()> { + use std::io::Read; + let mut decoder = zstd::Decoder::new(input_buf)?; + decoder.read_exact(output_buf).map_err(|e| e.into()) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn decompress_lz4(_input_buf: &[u8], _output_buf: &mut [u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to read compressed IPC.".to_string())) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn decompress_zstd(_input_buf: &[u8], _output_buf: &mut [u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to read compressed IPC.".to_string())) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn compress_lz4(input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { + use std::io::Write; + + use crate::error::Error; + let mut encoder = lz4::EncoderBuilder::new() + .build(output_buf) + .map_err(Error::from)?; + encoder.write_all(input_buf)?; + encoder.finish().1.map_err(|e| e.into()) +} + +#[cfg(feature = "io_ipc_compression")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_compression")))] +pub fn compress_zstd(input_buf: &[u8], output_buf: &mut Vec) -> Result<()> { + zstd::stream::copy_encode(input_buf, output_buf, 0).map_err(|e| e.into()) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn compress_lz4(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string())) +} + +#[cfg(not(feature = "io_ipc_compression"))] +pub fn compress_zstd(_input_buf: &[u8], _output_buf: &[u8]) -> Result<()> { + use crate::error::Error; + Err(Error::OutOfSpec("The crate was compiled without IPC compression. Use `io_ipc_compression` to write compressed IPC.".to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "io_ipc_compression")] + #[test] + #[cfg_attr(miri, ignore)] // ZSTD uses foreign calls that miri does not support + fn round_trip_zstd() { + let data: Vec = (0..200u8).map(|x| x % 10).collect(); + let mut buffer = vec![]; + compress_zstd(&data, &mut buffer).unwrap(); + + let mut result = vec![0; 200]; + decompress_zstd(&buffer, &mut result).unwrap(); + assert_eq!(data, result); + } + + #[cfg(feature = "io_ipc_compression")] + #[test] + #[cfg_attr(miri, ignore)] // LZ4 uses foreign calls that miri does not support + fn round_trip_lz4() { + let data: Vec = (0..200u8).map(|x| x % 10).collect(); + let mut buffer = vec![]; + compress_lz4(&data, &mut buffer).unwrap(); + + let mut result = vec![0; 200]; + decompress_lz4(&buffer, &mut result).unwrap(); + assert_eq!(data, result); + } +} diff --git a/crates/nano-arrow/src/io/ipc/endianness.rs b/crates/nano-arrow/src/io/ipc/endianness.rs new file mode 100644 index 000000000000..61b3f9b7c51c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/endianness.rs @@ -0,0 +1,11 @@ +#[cfg(target_endian = "little")] +#[inline] +pub fn is_native_little_endian() -> bool { + true +} + +#[cfg(target_endian = "big")] +#[inline] +pub fn is_native_little_endian() -> bool { + false +} diff --git a/crates/nano-arrow/src/io/ipc/mod.rs b/crates/nano-arrow/src/io/ipc/mod.rs new file mode 100644 index 000000000000..7da03e5c0abb --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/mod.rs @@ -0,0 +1,104 @@ +//! APIs to read from and write to Arrow's IPC format. +//! +//! Inter-process communication is a method through which different processes +//! share and pass data between them. Its use-cases include parallel +//! processing of chunks of data across different CPU cores, transferring +//! data between different Apache Arrow implementations in other languages and +//! more. Under the hood Apache Arrow uses [FlatBuffers](https://google.github.io/flatbuffers/) +//! as its binary protocol, so every Arrow-centered streaming or serialiation +//! problem that could be solved using FlatBuffers could probably be solved +//! using the more integrated approach that is exposed in this module. +//! +//! [Arrow's IPC protocol](https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc) +//! allows only batch or dictionary columns to be passed +//! around due to its reliance on a pre-defined data scheme. This constraint +//! provides a large performance gain because serialized data will always have a +//! known structutre, i.e. the same fields and datatypes, with the only variance +//! being the number of rows and the actual data inside the Batch. This dramatically +//! increases the deserialization rate, as the bytes in the file or stream are already +//! structured "correctly". +//! +//! Reading and writing IPC messages is done using one of two variants - either +//! [`FileReader`](read::FileReader) <-> [`FileWriter`](struct@write::FileWriter) or +//! [`StreamReader`](read::StreamReader) <-> [`StreamWriter`](struct@write::StreamWriter). +//! These two variants wrap a type `T` that implements [`Read`](std::io::Read), and in +//! the case of the `File` variant it also implements [`Seek`](std::io::Seek). In +//! practice it means that `File`s can be arbitrarily accessed while `Stream`s are only +//! read in certain order - the one they were written in (first in, first out). +//! +//! # Examples +//! Read and write to a file: +//! ``` +//! use arrow2::io::ipc::{{read::{FileReader, read_file_metadata}}, {write::{FileWriter, WriteOptions}}}; +//! # use std::fs::File; +//! # use arrow2::datatypes::{Field, Schema, DataType}; +//! # use arrow2::array::{Int32Array, Array}; +//! # use arrow2::chunk::Chunk; +//! # use arrow2::error::Error; +//! // Setup the writer +//! let path = "example.arrow".to_string(); +//! let mut file = File::create(&path)?; +//! let x_coord = Field::new("x", DataType::Int32, false); +//! let y_coord = Field::new("y", DataType::Int32, false); +//! let schema = Schema::from(vec![x_coord, y_coord]); +//! let options = WriteOptions {compression: None}; +//! let mut writer = FileWriter::try_new(file, schema, None, options)?; +//! +//! // Setup the data +//! let x_data = Int32Array::from_slice([-1i32, 1]); +//! let y_data = Int32Array::from_slice([1i32, -1]); +//! let chunk = Chunk::try_new(vec![x_data.boxed(), y_data.boxed()])?; +//! +//! // Write the messages and finalize the stream +//! for _ in 0..5 { +//! writer.write(&chunk, None); +//! } +//! writer.finish(); +//! +//! // Fetch some of the data and get the reader back +//! let mut reader = File::open(&path)?; +//! let metadata = read_file_metadata(&mut reader)?; +//! let mut reader = FileReader::new(reader, metadata, None, None); +//! let row1 = reader.next().unwrap(); // [[-1, 1], [1, -1]] +//! let row2 = reader.next().unwrap(); // [[-1, 1], [1, -1]] +//! let mut reader = reader.into_inner(); +//! // Do more stuff with the reader, like seeking ahead. +//! # Ok::<(), Error>(()) +//! ``` +//! +//! For further information and examples please consult the +//! [user guide](https://jorgecarleitao.github.io/arrow2/io/index.html). +//! For even more examples check the `examples` folder in the main repository +//! ([1](https://github.com/jorgecarleitao/arrow2/blob/main/examples/ipc_file_read.rs), +//! [2](https://github.com/jorgecarleitao/arrow2/blob/main/examples/ipc_file_write.rs), +//! [3](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow)). + +mod compression; +mod endianness; + +pub mod append; +pub mod read; +pub mod write; + +const ARROW_MAGIC_V1: [u8; 4] = [b'F', b'E', b'A', b'1']; +const ARROW_MAGIC_V2: [u8; 6] = [b'A', b'R', b'R', b'O', b'W', b'1']; +pub(crate) const CONTINUATION_MARKER: [u8; 4] = [0xff; 4]; + +/// Struct containing `dictionary_id` and nested `IpcField`, allowing users +/// to specify the dictionary ids of the IPC fields when writing to IPC. +#[derive(Debug, Clone, PartialEq, Default)] +pub struct IpcField { + /// optional children + pub fields: Vec, + /// dictionary id + pub dictionary_id: Option, +} + +/// Struct containing fields and whether the file is written in little or big endian. +#[derive(Debug, Clone, PartialEq)] +pub struct IpcSchema { + /// The fields in the schema + pub fields: Vec, + /// Endianness of the file + pub is_little_endian: bool, +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/binary.rs b/crates/nano-arrow/src/io/ipc/read/array/binary.rs new file mode 100644 index 000000000000..52a5c4b7b7b0 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/binary.rs @@ -0,0 +1,91 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::BinaryArray; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_binary( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets: Buffer = read_buffer( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + let values = read_buffer( + buffers, + last_offset, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + BinaryArray::::try_new(data_type, offsets.try_into()?, values, validity) +} + +pub fn skip_binary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for binary. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/boolean.rs b/crates/nano-arrow/src/io/ipc/read/array/boolean.rs new file mode 100644 index 000000000000..6d78c184b168 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/boolean.rs @@ -0,0 +1,72 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::BooleanArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_boolean( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let values = read_bitmap( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + BooleanArray::try_new(data_type, values, validity) +} + +pub fn skip_boolean( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for boolean. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/dictionary.rs b/crates/nano-arrow/src/io/ipc/read/array/dictionary.rs new file mode 100644 index 000000000000..554e6d32dcbf --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/dictionary.rs @@ -0,0 +1,65 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek}; + +use ahash::HashSet; + +use super::super::{Compression, Dictionaries, IpcBuffer, Node}; +use super::{read_primitive, skip_primitive}; +use crate::array::{DictionaryArray, DictionaryKey}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_dictionary( + field_nodes: &mut VecDeque, + data_type: DataType, + id: Option, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + compression: Option, + limit: Option, + is_little_endian: bool, + scratch: &mut Vec, +) -> Result> +where + Vec: TryInto, +{ + let id = if let Some(id) = id { + id + } else { + return Err(Error::OutOfSpec("Dictionary has no id.".to_string())); + }; + let values = dictionaries + .get(&id) + .ok_or_else(|| { + let valid_ids = dictionaries.keys().collect::>(); + Error::OutOfSpec(format!( + "Dictionary id {id} not found. Valid ids: {valid_ids:?}" + )) + })? + .clone(); + + let keys = read_primitive( + field_nodes, + T::PRIMITIVE.into(), + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + DictionaryArray::::try_new(data_type, keys, values) +} + +pub fn skip_dictionary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + skip_primitive(field_nodes, buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/fixed_size_binary.rs b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_binary.rs new file mode 100644 index 000000000000..ed0d0049ffb2 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_binary.rs @@ -0,0 +1,76 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::FixedSizeBinaryArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_fixed_size_binary( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let length = length.saturating_mul(FixedSizeBinaryArray::maybe_get_size(&data_type)?); + let values = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + FixedSizeBinaryArray::try_new(data_type, values, validity) +} + +pub fn skip_fixed_size_binary( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos( + "IPC: unable to fetch the field for fixed-size binary. The file or stream is corrupted.", + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/fixed_size_list.rs b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_list.rs new file mode 100644 index 000000000000..5553c1f478ff --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/fixed_size_list.rs @@ -0,0 +1,83 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::FixedSizeListArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_fixed_size_list( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let (field, size) = FixedSizeListArray::get_child_and_size(&data_type); + + let limit = limit.map(|x| x.saturating_mul(size)); + + let values = read( + field_nodes, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + )?; + FixedSizeListArray::try_new(data_type, values, validity) +} + +pub fn skip_fixed_size_list( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos( + "IPC: unable to fetch the field for fixed-size list. The file or stream is corrupted.", + ) + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + + let (field, _) = FixedSizeListArray::get_child_and_size(data_type); + + skip(field_nodes, field.data_type(), buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/list.rs b/crates/nano-arrow/src/io/ipc/read/array/list.rs new file mode 100644 index 000000000000..83809cf995c1 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/list.rs @@ -0,0 +1,108 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, OutOfSpecKind, Version}; +use crate::array::ListArray; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_list( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result> +where + Vec: TryInto, +{ + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets = read_buffer::( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + + let field = ListArray::::get_child_field(&data_type); + + let values = read( + field_nodes, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + Some(last_offset), + version, + scratch, + )?; + ListArray::try_new(data_type, offsets.try_into()?, values, validity) +} + +pub fn skip_list( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for list. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + + let data_type = ListArray::::get_child_type(data_type); + + skip(field_nodes, data_type, buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/map.rs b/crates/nano-arrow/src/io/ipc/read/array/map.rs new file mode 100644 index 000000000000..cf383407a8c0 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/map.rs @@ -0,0 +1,103 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, OutOfSpecKind, Version}; +use crate::array::MapArray; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_map( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets = read_buffer::( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![0i32])))?; + + let field = MapArray::get_field(&data_type); + + let last_offset: usize = offsets.last().copied().unwrap() as usize; + + let field = read( + field_nodes, + field, + &ipc_field.fields[0], + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + Some(last_offset), + version, + scratch, + )?; + MapArray::try_new(data_type, offsets.try_into()?, field, validity) +} + +pub fn skip_map( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for map. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + + let data_type = MapArray::get_field(data_type).data_type(); + + skip(field_nodes, data_type, buffers) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/mod.rs b/crates/nano-arrow/src/io/ipc/read/array/mod.rs new file mode 100644 index 000000000000..249e5e05e165 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/mod.rs @@ -0,0 +1,24 @@ +mod primitive; +pub use primitive::*; +mod boolean; +pub use boolean::*; +mod utf8; +pub use utf8::*; +mod binary; +pub use binary::*; +mod fixed_size_binary; +pub use fixed_size_binary::*; +mod list; +pub use list::*; +mod fixed_size_list; +pub use fixed_size_list::*; +mod struct_; +pub use struct_::*; +mod null; +pub use null::*; +mod dictionary; +pub use dictionary::*; +mod union; +pub use union::*; +mod map; +pub use map::*; diff --git a/crates/nano-arrow/src/io/ipc/read/array/null.rs b/crates/nano-arrow/src/io/ipc/read/array/null.rs new file mode 100644 index 000000000000..e56f1886112d --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/null.rs @@ -0,0 +1,28 @@ +use std::collections::VecDeque; + +use super::super::{Node, OutOfSpecKind}; +use crate::array::NullArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +pub fn read_null(field_nodes: &mut VecDeque, data_type: DataType) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + NullArray::try_new(data_type, length) +} + +pub fn skip_null(field_nodes: &mut VecDeque) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for null. The file or stream is corrupted.") + })?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/primitive.rs b/crates/nano-arrow/src/io/ipc/read/array/primitive.rs new file mode 100644 index 000000000000..d6ccb581ffe5 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/primitive.rs @@ -0,0 +1,77 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::PrimitiveArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::types::NativeType; + +#[allow(clippy::too_many_arguments)] +pub fn read_primitive( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> +where + Vec: TryInto, +{ + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let values = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + PrimitiveArray::::try_new(data_type, values, validity) +} + +pub fn skip_primitive( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for primitive. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/struct_.rs b/crates/nano-arrow/src/io/ipc/read/array/struct_.rs new file mode 100644 index 000000000000..9a5084a8783f --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/struct_.rs @@ -0,0 +1,88 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, Version}; +use crate::array::StructArray; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_struct( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let fields = StructArray::get_fields(&data_type); + + let values = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { + read( + field_nodes, + field, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + }) + .collect::>>()?; + + StructArray::try_new(data_type, values, validity) +} + +pub fn skip_struct( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for struct. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + + let fields = StructArray::get_fields(data_type); + + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.data_type(), buffers)) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/union.rs b/crates/nano-arrow/src/io/ipc/read/array/union.rs new file mode 100644 index 000000000000..ac1eb9b02527 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/union.rs @@ -0,0 +1,125 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::super::IpcField; +use super::super::deserialize::{read, skip}; +use super::super::read_basic::*; +use super::super::{Compression, Dictionaries, IpcBuffer, Node, OutOfSpecKind, Version}; +use crate::array::UnionArray; +use crate::datatypes::DataType; +use crate::datatypes::UnionMode::Dense; +use crate::error::{Error, Result}; + +#[allow(clippy::too_many_arguments)] +pub fn read_union( + field_nodes: &mut VecDeque, + data_type: DataType, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: Version, + scratch: &mut Vec, +) -> Result { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + if version != Version::V5 { + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + }; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let types = read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + let offsets = if let DataType::Union(_, _, mode) = data_type { + if !mode.is_sparse() { + Some(read_buffer( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?) + } else { + None + } + } else { + unreachable!() + }; + + let fields = UnionArray::get_fields(&data_type); + + let fields = fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc_field)| { + read( + field_nodes, + field, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + None, + version, + scratch, + ) + }) + .collect::>>()?; + + UnionArray::try_new(data_type, types, fields, offsets) +} + +pub fn skip_union( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for struct. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + if let DataType::Union(_, _, Dense) = data_type { + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + } else { + unreachable!() + }; + + let fields = UnionArray::get_fields(data_type); + + fields + .iter() + .try_for_each(|field| skip(field_nodes, field.data_type(), buffers)) +} diff --git a/crates/nano-arrow/src/io/ipc/read/array/utf8.rs b/crates/nano-arrow/src/io/ipc/read/array/utf8.rs new file mode 100644 index 000000000000..21e54480e48e --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/array/utf8.rs @@ -0,0 +1,92 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use super::super::read_basic::*; +use super::super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::array::Utf8Array; +use crate::buffer::Buffer; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[allow(clippy::too_many_arguments)] +pub fn read_utf8( + field_nodes: &mut VecDeque, + data_type: DataType, + buffers: &mut VecDeque, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> { + let field_node = field_nodes.pop_front().ok_or_else(|| { + Error::oos(format!( + "IPC: unable to fetch the field for {data_type:?}. The file or stream is corrupted." + )) + })?; + + let validity = read_validity( + buffers, + field_node, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + )?; + + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + let offsets: Buffer = read_buffer( + buffers, + 1 + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + ) + // Older versions of the IPC format sometimes do not report an offset + .or_else(|_| Result::Ok(Buffer::::from(vec![O::default()])))?; + + let last_offset = offsets.last().unwrap().to_usize(); + let values = read_buffer( + buffers, + last_offset, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?; + + Utf8Array::::try_new(data_type, offsets.try_into()?, values, validity) +} + +pub fn skip_utf8( + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result<()> { + let _ = field_nodes.pop_front().ok_or_else(|| { + Error::oos("IPC: unable to fetch the field for utf8. The file or stream is corrupted.") + })?; + + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing validity buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing offsets buffer."))?; + let _ = buffers + .pop_front() + .ok_or_else(|| Error::oos("IPC: missing values buffer."))?; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/common.rs b/crates/nano-arrow/src/io/ipc/read/common.rs new file mode 100644 index 000000000000..f890562ed41c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/common.rs @@ -0,0 +1,363 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use ahash::AHashMap; +use arrow_format; + +use super::deserialize::{read, skip}; +use super::Dictionaries; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; +use crate::io::ipc::read::OutOfSpecKind; +use crate::io::ipc::{IpcField, IpcSchema}; + +#[derive(Debug, Eq, PartialEq, Hash)] +enum ProjectionResult { + Selected(A), + NotSelected(A), +} + +/// An iterator adapter that will return `Some(x)` or `None` +/// # Panics +/// The iterator panics iff the `projection` is not strictly increasing. +struct ProjectionIter<'a, A, I: Iterator> { + projection: &'a [usize], + iter: I, + current_count: usize, + current_projection: usize, +} + +impl<'a, A, I: Iterator> ProjectionIter<'a, A, I> { + /// # Panics + /// iff `projection` is empty + pub fn new(projection: &'a [usize], iter: I) -> Self { + Self { + projection: &projection[1..], + iter, + current_count: 0, + current_projection: projection[0], + } + } +} + +impl<'a, A, I: Iterator> Iterator for ProjectionIter<'a, A, I> { + type Item = ProjectionResult; + + fn next(&mut self) -> Option { + if let Some(item) = self.iter.next() { + let result = if self.current_count == self.current_projection { + if !self.projection.is_empty() { + assert!(self.projection[0] > self.current_projection); + self.current_projection = self.projection[0]; + self.projection = &self.projection[1..]; + } else { + self.current_projection = 0 // a value that most likely already passed + }; + Some(ProjectionResult::Selected(item)) + } else { + Some(ProjectionResult::NotSelected(item)) + }; + self.current_count += 1; + result + } else { + None + } + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } +} + +/// Returns a [`Chunk`] from a reader. +/// # Panic +/// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid) +#[allow(clippy::too_many_arguments)] +pub fn read_record_batch( + batch: arrow_format::ipc::RecordBatchRef, + fields: &[Field], + ipc_schema: &IpcSchema, + projection: Option<&[usize]>, + limit: Option, + dictionaries: &Dictionaries, + version: arrow_format::ipc::MetadataVersion, + reader: &mut R, + block_offset: u64, + file_size: u64, + scratch: &mut Vec, +) -> Result>> { + assert_eq!(fields.len(), ipc_schema.fields.len()); + let buffers = batch + .buffers() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBuffers(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageBuffers))?; + let mut buffers: VecDeque = buffers.iter().collect(); + + // check that the sum of the sizes of all buffers is <= than the size of the file + let buffers_size = buffers + .iter() + .map(|buffer| { + let buffer_size: u64 = buffer + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + Ok(buffer_size) + }) + .sum::>()?; + if buffers_size > file_size { + return Err(Error::from(OutOfSpecKind::InvalidBuffersLength { + buffers_size, + file_size, + })); + } + + let field_nodes = batch + .nodes() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferNodes(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageNodes))?; + let mut field_nodes = field_nodes.iter().collect::>(); + + let columns = if let Some(projection) = projection { + let projection = + ProjectionIter::new(projection, fields.iter().zip(ipc_schema.fields.iter())); + + projection + .map(|maybe_field| match maybe_field { + ProjectionResult::Selected((field, ipc_field)) => Ok(Some(read( + &mut field_nodes, + field, + ipc_field, + &mut buffers, + reader, + dictionaries, + block_offset, + ipc_schema.is_little_endian, + batch.compression().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, + limit, + version, + scratch, + )?)), + ProjectionResult::NotSelected((field, _)) => { + skip(&mut field_nodes, &field.data_type, &mut buffers)?; + Ok(None) + }, + }) + .filter_map(|x| x.transpose()) + .collect::>>()? + } else { + fields + .iter() + .zip(ipc_schema.fields.iter()) + .map(|(field, ipc_field)| { + read( + &mut field_nodes, + field, + ipc_field, + &mut buffers, + reader, + dictionaries, + block_offset, + ipc_schema.is_little_endian, + batch.compression().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)) + })?, + limit, + version, + scratch, + ) + }) + .collect::>>()? + }; + Chunk::try_new(columns) +} + +fn find_first_dict_field_d<'a>( + id: i64, + data_type: &'a DataType, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { + use DataType::*; + match data_type { + Dictionary(_, inner, _) => find_first_dict_field_d(id, inner.as_ref(), ipc_field), + List(field) | LargeList(field) | FixedSizeList(field, ..) | Map(field, ..) => { + find_first_dict_field(id, field.as_ref(), &ipc_field.fields[0]) + }, + Union(fields, ..) | Struct(fields) => { + for (field, ipc_field) in fields.iter().zip(ipc_field.fields.iter()) { + if let Some(f) = find_first_dict_field(id, field, ipc_field) { + return Some(f); + } + } + None + }, + _ => None, + } +} + +fn find_first_dict_field<'a>( + id: i64, + field: &'a Field, + ipc_field: &'a IpcField, +) -> Option<(&'a Field, &'a IpcField)> { + if let Some(field_id) = ipc_field.dictionary_id { + if id == field_id { + return Some((field, ipc_field)); + } + } + find_first_dict_field_d(id, &field.data_type, ipc_field) +} + +pub(crate) fn first_dict_field<'a>( + id: i64, + fields: &'a [Field], + ipc_fields: &'a [IpcField], +) -> Result<(&'a Field, &'a IpcField)> { + assert_eq!(fields.len(), ipc_fields.len()); + for (field, ipc_field) in fields.iter().zip(ipc_fields.iter()) { + if let Some(field) = find_first_dict_field(id, field, ipc_field) { + return Ok(field); + } + } + Err(Error::from(OutOfSpecKind::InvalidId { requested_id: id })) +} + +/// Reads a dictionary from the reader, +/// updating `dictionaries` with the resulting dictionary +#[allow(clippy::too_many_arguments)] +pub fn read_dictionary( + batch: arrow_format::ipc::DictionaryBatchRef, + fields: &[Field], + ipc_schema: &IpcSchema, + dictionaries: &mut Dictionaries, + reader: &mut R, + block_offset: u64, + file_size: u64, + scratch: &mut Vec, +) -> Result<()> { + if batch + .is_delta() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferIsDelta(err)))? + { + return Err(Error::NotYetImplemented( + "delta dictionary batches not supported".to_string(), + )); + } + + let id = batch + .id() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferId(err)))?; + let (first_field, first_ipc_field) = first_dict_field(id, fields, &ipc_schema.fields)?; + + let batch = batch + .data() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferData(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingData))?; + + let value_type = + if let DataType::Dictionary(_, value_type, _) = first_field.data_type.to_logical_type() { + value_type.as_ref() + } else { + return Err(Error::from(OutOfSpecKind::InvalidIdDataType { + requested_id: id, + })); + }; + + // Make a fake schema for the dictionary batch. + let fields = vec![Field::new("", value_type.clone(), false)]; + let ipc_schema = IpcSchema { + fields: vec![first_ipc_field.clone()], + is_little_endian: ipc_schema.is_little_endian, + }; + let chunk = read_record_batch( + batch, + &fields, + &ipc_schema, + None, + None, // we must read the whole dictionary + dictionaries, + arrow_format::ipc::MetadataVersion::V5, + reader, + block_offset, + file_size, + scratch, + )?; + + dictionaries.insert(id, chunk.into_arrays().pop().unwrap()); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn project_iter() { + let iter = 1..6; + let iter = ProjectionIter::new(&[0, 2, 4], iter); + let result: Vec<_> = iter.collect(); + use ProjectionResult::*; + assert_eq!( + result, + vec![ + Selected(1), + NotSelected(2), + Selected(3), + NotSelected(4), + Selected(5) + ] + ) + } +} + +pub fn prepare_projection( + fields: &[Field], + mut projection: Vec, +) -> (Vec, AHashMap, Vec) { + let fields = projection.iter().map(|x| fields[*x].clone()).collect(); + + // todo: find way to do this more efficiently + let mut indices = (0..projection.len()).collect::>(); + indices.sort_unstable_by_key(|&i| &projection[i]); + let map = indices.iter().copied().enumerate().fold( + AHashMap::default(), + |mut acc, (index, new_index)| { + acc.insert(index, new_index); + acc + }, + ); + projection.sort_unstable(); + + // check unique + if !projection.is_empty() { + let mut previous = projection[0]; + + for &i in &projection[1..] { + assert!( + previous < i, + "The projection on IPC must not contain duplicates" + ); + previous = i; + } + } + + (projection, map, fields) +} + +pub fn apply_projection( + chunk: Chunk>, + map: &AHashMap, +) -> Chunk> { + // re-order according to projection + let arrays = chunk.into_arrays(); + let mut new_arrays = arrays.clone(); + + map.iter() + .for_each(|(old, new)| new_arrays[*new] = arrays[*old].clone()); + + Chunk::new(new_arrays) +} diff --git a/crates/nano-arrow/src/io/ipc/read/deserialize.rs b/crates/nano-arrow/src/io/ipc/read/deserialize.rs new file mode 100644 index 000000000000..28f8b9e68191 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/deserialize.rs @@ -0,0 +1,251 @@ +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +use arrow_format::ipc::{BodyCompressionRef, MetadataVersion}; + +use super::array::*; +use super::{Dictionaries, IpcBuffer, Node}; +use crate::array::*; +use crate::datatypes::{DataType, Field, PhysicalType}; +use crate::error::Result; +use crate::io::ipc::IpcField; + +#[allow(clippy::too_many_arguments)] +pub fn read( + field_nodes: &mut VecDeque, + field: &Field, + ipc_field: &IpcField, + buffers: &mut VecDeque, + reader: &mut R, + dictionaries: &Dictionaries, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + version: MetadataVersion, + scratch: &mut Vec, +) -> Result> { + use PhysicalType::*; + let data_type = field.data_type.clone(); + + match data_type.to_physical_type() { + Null => read_null(field_nodes, data_type).map(|x| x.boxed()), + Boolean => read_boolean( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + read_primitive::<$T, _>( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()) + }), + Binary => read_binary::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + LargeBinary => read_binary::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + FixedSizeBinary => read_fixed_size_binary( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + Utf8 => read_utf8::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + LargeUtf8 => read_utf8::( + field_nodes, + data_type, + buffers, + reader, + block_offset, + is_little_endian, + compression, + limit, + scratch, + ) + .map(|x| x.boxed()), + List => read_list::( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + LargeList => read_list::( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + FixedSizeList => read_fixed_size_list( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Struct => read_struct( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Dictionary(key_type) => { + match_integer_type!(key_type, |$T| { + read_dictionary::<$T, _>( + field_nodes, + data_type, + ipc_field.dictionary_id, + buffers, + reader, + dictionaries, + block_offset, + compression, + limit, + is_little_endian, + scratch, + ) + .map(|x| x.boxed()) + }) + }, + Union => read_union( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + Map => read_map( + field_nodes, + data_type, + ipc_field, + buffers, + reader, + dictionaries, + block_offset, + is_little_endian, + compression, + limit, + version, + scratch, + ) + .map(|x| x.boxed()), + } +} + +pub fn skip( + field_nodes: &mut VecDeque, + data_type: &DataType, + buffers: &mut VecDeque, +) -> Result<()> { + use PhysicalType::*; + match data_type.to_physical_type() { + Null => skip_null(field_nodes), + Boolean => skip_boolean(field_nodes, buffers), + Primitive(_) => skip_primitive(field_nodes, buffers), + LargeBinary | Binary => skip_binary(field_nodes, buffers), + LargeUtf8 | Utf8 => skip_utf8(field_nodes, buffers), + FixedSizeBinary => skip_fixed_size_binary(field_nodes, buffers), + List => skip_list::(field_nodes, data_type, buffers), + LargeList => skip_list::(field_nodes, data_type, buffers), + FixedSizeList => skip_fixed_size_list(field_nodes, data_type, buffers), + Struct => skip_struct(field_nodes, data_type, buffers), + Dictionary(_) => skip_dictionary(field_nodes, buffers), + Union => skip_union(field_nodes, data_type, buffers), + Map => skip_map(field_nodes, data_type, buffers), + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/error.rs b/crates/nano-arrow/src/io/ipc/read/error.rs new file mode 100644 index 000000000000..cbac69aef2e3 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/error.rs @@ -0,0 +1,112 @@ +use crate::error::Error; + +/// The different types of errors that reading from IPC can cause +#[derive(Debug)] +#[non_exhaustive] +pub enum OutOfSpecKind { + /// The IPC file does not start with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidHeader, + /// The IPC file does not end with [b'A', b'R', b'R', b'O', b'W', b'1'] + InvalidFooter, + /// The first 4 bytes of the last 10 bytes is < 0 + NegativeFooterLength, + /// The footer is an invalid flatbuffer + InvalidFlatbufferFooter(arrow_format::ipc::planus::Error), + /// The file's footer does not contain record batches + MissingRecordBatches, + /// The footer's record batches is an invalid flatbuffer + InvalidFlatbufferRecordBatches(arrow_format::ipc::planus::Error), + /// The file's footer does not contain a schema + MissingSchema, + /// The footer's schema is an invalid flatbuffer + InvalidFlatbufferSchema(arrow_format::ipc::planus::Error), + /// The file's schema does not contain fields + MissingFields, + /// The footer's dictionaries is an invalid flatbuffer + InvalidFlatbufferDictionaries(arrow_format::ipc::planus::Error), + /// The block is an invalid flatbuffer + InvalidFlatbufferBlock(arrow_format::ipc::planus::Error), + /// The dictionary message is an invalid flatbuffer + InvalidFlatbufferMessage(arrow_format::ipc::planus::Error), + /// The message does not contain a header + MissingMessageHeader, + /// The message's header is an invalid flatbuffer + InvalidFlatbufferHeader(arrow_format::ipc::planus::Error), + /// Relative positions in the file is < 0 + UnexpectedNegativeInteger, + /// dictionaries can only contain dictionary messages; record batches can only contain records + UnexpectedMessageType, + /// RecordBatch messages do not contain buffers + MissingMessageBuffers, + /// The message's buffers is an invalid flatbuffer + InvalidFlatbufferBuffers(arrow_format::ipc::planus::Error), + /// RecordBatch messages does not contain nodes + MissingMessageNodes, + /// The message's nodes is an invalid flatbuffer + InvalidFlatbufferNodes(arrow_format::ipc::planus::Error), + /// The message's body length is an invalid flatbuffer + InvalidFlatbufferBodyLength(arrow_format::ipc::planus::Error), + /// The message does not contain data + MissingData, + /// The message's data is an invalid flatbuffer + InvalidFlatbufferData(arrow_format::ipc::planus::Error), + /// The version is an invalid flatbuffer + InvalidFlatbufferVersion(arrow_format::ipc::planus::Error), + /// The compression is an invalid flatbuffer + InvalidFlatbufferCompression(arrow_format::ipc::planus::Error), + /// The record contains a number of buffers that does not match the required number by the data type + ExpectedBuffer, + /// A buffer's size is smaller than the required for the number of elements + InvalidBuffer { + /// Declared number of elements in the buffer + length: usize, + /// The name of the `NativeType` + type_name: &'static str, + /// Bytes required for the `length` and `type` + required_number_of_bytes: usize, + /// The size of the IPC buffer + buffer_length: usize, + }, + /// A buffer's size is larger than the file size + InvalidBuffersLength { + /// number of bytes of all buffers in the record + buffers_size: u64, + /// the size of the file + file_size: u64, + }, + /// A bitmap's size is smaller than the required for the number of elements + InvalidBitmap { + /// Declared length of the bitmap + length: usize, + /// Number of bits on the IPC buffer + number_of_bits: usize, + }, + /// The dictionary is_delta is an invalid flatbuffer + InvalidFlatbufferIsDelta(arrow_format::ipc::planus::Error), + /// The dictionary id is an invalid flatbuffer + InvalidFlatbufferId(arrow_format::ipc::planus::Error), + /// Invalid dictionary id + InvalidId { + /// The requested dictionary id + requested_id: i64, + }, + /// Field id is not a dictionary + InvalidIdDataType { + /// The requested dictionary id + requested_id: i64, + }, + /// FixedSizeBinaryArray has invalid datatype. + InvalidDataType, +} + +impl From for Error { + fn from(kind: OutOfSpecKind) -> Self { + Error::OutOfSpec(format!("{kind:?}")) + } +} + +impl From for Error { + fn from(error: arrow_format::ipc::planus::Error) -> Self { + Error::OutOfSpec(error.to_string()) + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/file.rs b/crates/nano-arrow/src/io/ipc/read/file.rs new file mode 100644 index 000000000000..ec0084a08614 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/file.rs @@ -0,0 +1,321 @@ +use std::convert::TryInto; +use std::io::{Read, Seek, SeekFrom}; + +use ahash::AHashMap; +use arrow_format::ipc::planus::ReadAsRoot; + +use super::super::{ARROW_MAGIC_V1, ARROW_MAGIC_V2, CONTINUATION_MARKER}; +use super::common::*; +use super::schema::fb_to_schema; +use super::{Dictionaries, OutOfSpecKind}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; +use crate::io::ipc::IpcSchema; + +/// Metadata of an Arrow IPC file, written in the footer of the file. +#[derive(Debug, Clone)] +pub struct FileMetadata { + /// The schema that is read from the file footer + pub schema: Schema, + + /// The files' [`IpcSchema`] + pub ipc_schema: IpcSchema, + + /// The blocks in the file + /// + /// A block indicates the regions in the file to read to get data + pub blocks: Vec, + + /// Dictionaries associated to each dict_id + pub(crate) dictionaries: Option>, + + /// The total size of the file in bytes + pub size: u64, +} + +fn read_dictionary_message( + reader: &mut R, + offset: u64, + data: &mut Vec, +) -> Result<()> { + let mut message_size: [u8; 4] = [0; 4]; + reader.seek(SeekFrom::Start(offset))?; + reader.read_exact(&mut message_size)?; + if message_size == CONTINUATION_MARKER { + reader.read_exact(&mut message_size)?; + }; + let message_length = i32::from_le_bytes(message_size); + + let message_length: usize = message_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + data.clear(); + data.try_reserve(message_length)?; + reader + .by_ref() + .take(message_length as u64) + .read_to_end(data)?; + + Ok(()) +} + +pub(crate) fn get_dictionary_batch<'a>( + message: &'a arrow_format::ipc::MessageRef, +) -> Result> { + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + match header { + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => Ok(batch), + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +fn read_dictionary_block( + reader: &mut R, + metadata: &FileMetadata, + block: &arrow_format::ipc::Block, + dictionaries: &mut Dictionaries, + message_scratch: &mut Vec, + dictionary_scratch: &mut Vec, +) -> Result<()> { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + read_dictionary_message(reader, offset, message_scratch)?; + + let message = arrow_format::ipc::MessageRef::read_as_root(message_scratch.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let batch = get_dictionary_batch(&message)?; + + read_dictionary( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + dictionaries, + reader, + offset + length, + metadata.size, + dictionary_scratch, + ) +} + +/// Reads all file's dictionaries, if any +/// This function is IO-bounded +pub fn read_file_dictionaries( + reader: &mut R, + metadata: &FileMetadata, + scratch: &mut Vec, +) -> Result { + let mut dictionaries = Default::default(); + + let blocks = if let Some(blocks) = &metadata.dictionaries { + blocks + } else { + return Ok(AHashMap::new()); + }; + // use a temporary smaller scratch for the messages + let mut message_scratch = Default::default(); + + for block in blocks { + read_dictionary_block( + reader, + metadata, + block, + &mut dictionaries, + &mut message_scratch, + scratch, + )?; + } + Ok(dictionaries) +} + +/// Reads the footer's length and magic number in footer +fn read_footer_len(reader: &mut R) -> Result<(u64, usize)> { + // read footer length and magic number in footer + let end = reader.seek(SeekFrom::End(-10))? + 10; + + let mut footer: [u8; 10] = [0; 10]; + + reader.read_exact(&mut footer)?; + let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); + + if footer[4..] != ARROW_MAGIC_V2 { + return Err(Error::from(OutOfSpecKind::InvalidFooter)); + } + let footer_len = footer_len + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok((end, footer_len)) +} + +pub(super) fn deserialize_footer(footer_data: &[u8], size: u64) -> Result { + let footer = arrow_format::ipc::FooterRef::read_as_root(footer_data) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferFooter(err)))?; + + let blocks = footer + .record_batches() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingRecordBatches))?; + + let blocks = blocks + .iter() + .map(|block| { + block + .try_into() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err))) + }) + .collect::>>()?; + + let ipc_schema = footer + .schema() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferSchema(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingSchema))?; + let (schema, ipc_schema) = fb_to_schema(ipc_schema)?; + + let dictionaries = footer + .dictionaries() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferDictionaries(err)))? + .map(|dictionaries| { + dictionaries + .into_iter() + .map(|block| { + block.try_into().map_err(|err| { + Error::from(OutOfSpecKind::InvalidFlatbufferRecordBatches(err)) + }) + }) + .collect::>>() + }) + .transpose()?; + + Ok(FileMetadata { + schema, + ipc_schema, + blocks, + dictionaries, + size, + }) +} + +/// Read the Arrow IPC file's metadata +pub fn read_file_metadata(reader: &mut R) -> Result { + // check if header contain the correct magic bytes + let mut magic_buffer: [u8; 6] = [0; 6]; + let start = reader.stream_position()?; + reader.read_exact(&mut magic_buffer)?; + if magic_buffer != ARROW_MAGIC_V2 { + if magic_buffer[..4] == ARROW_MAGIC_V1 { + return Err(Error::NotYetImplemented("feather v1 not supported".into())); + } + return Err(Error::from(OutOfSpecKind::InvalidHeader)); + } + + let (end, footer_len) = read_footer_len(reader)?; + + // read footer + reader.seek(SeekFrom::End(-10 - footer_len as i64))?; + + let mut serialized_footer = vec![]; + serialized_footer.try_reserve(footer_len)?; + reader + .by_ref() + .take(footer_len as u64) + .read_to_end(&mut serialized_footer)?; + + deserialize_footer(&serialized_footer, end - start) +} + +pub(crate) fn get_record_batch( + message: arrow_format::ipc::MessageRef, +) -> Result { + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => Ok(batch), + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +/// Reads the record batch at position `index` from the reader. +/// +/// This function is useful for random access to the file. For example, if +/// you have indexed the file somewhere else, this allows pruning +/// certain parts of the file. +/// # Panics +/// This function panics iff `index >= metadata.blocks.len()` +#[allow(clippy::too_many_arguments)] +pub fn read_batch( + reader: &mut R, + dictionaries: &Dictionaries, + metadata: &FileMetadata, + projection: Option<&[usize]>, + limit: Option, + index: usize, + message_scratch: &mut Vec, + data_scratch: &mut Vec, +) -> Result>> { + let block = metadata.blocks[index]; + + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: u64 = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + // read length + reader.seek(SeekFrom::Start(offset))?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf)?; + if meta_buf == CONTINUATION_MARKER { + // continuation marker encountered, read message next + reader.read_exact(&mut meta_buf)?; + } + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + message_scratch.clear(); + message_scratch.try_reserve(meta_len)?; + reader + .by_ref() + .take(meta_len as u64) + .read_to_end(message_scratch)?; + + let message = arrow_format::ipc::MessageRef::read_as_root(message_scratch.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let batch = get_record_batch(message)?; + + read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection, + limit, + dictionaries, + message + .version() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferVersion(err)))?, + reader, + offset + length, + metadata.size, + data_scratch, + ) +} diff --git a/crates/nano-arrow/src/io/ipc/read/file_async.rs b/crates/nano-arrow/src/io/ipc/read/file_async.rs new file mode 100644 index 000000000000..df1895021282 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/file_async.rs @@ -0,0 +1,349 @@ +//! Async reader for Arrow IPC files +use std::io::SeekFrom; + +use ahash::AHashMap; +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, MessageHeaderRef}; +use futures::stream::BoxStream; +use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, Stream, StreamExt}; + +use super::common::{apply_projection, prepare_projection, read_dictionary, read_record_batch}; +use super::file::{deserialize_footer, get_record_batch}; +use super::{Dictionaries, FileMetadata, OutOfSpecKind}; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::{Field, Schema}; +use crate::error::{Error, Result}; +use crate::io::ipc::{IpcSchema, ARROW_MAGIC_V2, CONTINUATION_MARKER}; + +/// Async reader for Arrow IPC files +pub struct FileStream<'a> { + stream: BoxStream<'a, Result>>>, + schema: Option, + metadata: FileMetadata, +} + +impl<'a> FileStream<'a> { + /// Create a new IPC file reader. + /// + /// # Examples + /// See [`FileSink`](crate::io::ipc::write::file_async::FileSink). + pub fn new( + reader: R, + metadata: FileMetadata, + projection: Option>, + limit: Option, + ) -> Self + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + let (projection, schema) = if let Some(projection) = projection { + let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); + let schema = Schema { + fields, + metadata: metadata.schema.metadata.clone(), + }; + (Some((p, h)), Some(schema)) + } else { + (None, None) + }; + + let stream = Self::stream(reader, None, metadata.clone(), projection, limit); + Self { + stream, + metadata, + schema, + } + } + + /// Get the metadata from the IPC file. + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + + /// Get the projected schema from the IPC file. + pub fn schema(&self) -> &Schema { + self.schema.as_ref().unwrap_or(&self.metadata.schema) + } + + fn stream( + mut reader: R, + mut dictionaries: Option, + metadata: FileMetadata, + projection: Option<(Vec, AHashMap)>, + limit: Option, + ) -> BoxStream<'a, Result>>> + where + R: AsyncRead + AsyncSeek + Unpin + Send + 'a, + { + async_stream::try_stream! { + // read dictionaries + cached_read_dictionaries(&mut reader, &metadata, &mut dictionaries).await?; + + let mut meta_buffer = Default::default(); + let mut block_buffer = Default::default(); + let mut scratch = Default::default(); + let mut remaining = limit.unwrap_or(usize::MAX); + for block in 0..metadata.blocks.len() { + let chunk = read_batch( + &mut reader, + dictionaries.as_mut().unwrap(), + &metadata, + projection.as_ref().map(|x| x.0.as_ref()), + Some(remaining), + block, + &mut meta_buffer, + &mut block_buffer, + &mut scratch + ).await?; + remaining -= chunk.len(); + + let chunk = if let Some((_, map)) = &projection { + // re-order according to projection + apply_projection(chunk, map) + } else { + chunk + }; + + yield chunk; + } + } + .boxed() + } +} + +impl<'a> Stream for FileStream<'a> { + type Item = Result>>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().stream.poll_next_unpin(cx) + } +} + +/// Reads the footer's length and magic number in footer +async fn read_footer_len(reader: &mut R) -> Result { + // read footer length and magic number in footer + reader.seek(SeekFrom::End(-10)).await?; + let mut footer: [u8; 10] = [0; 10]; + + reader.read_exact(&mut footer).await?; + let footer_len = i32::from_le_bytes(footer[..4].try_into().unwrap()); + + if footer[4..] != ARROW_MAGIC_V2 { + return Err(Error::from(OutOfSpecKind::InvalidFooter)); + } + footer_len + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength)) +} + +/// Read the metadata from an IPC file. +pub async fn read_file_metadata_async(reader: &mut R) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let footer_size = read_footer_len(reader).await?; + // Read footer + reader.seek(SeekFrom::End(-10 - footer_size as i64)).await?; + + let mut footer = vec![]; + footer.try_reserve(footer_size)?; + reader + .take(footer_size as u64) + .read_to_end(&mut footer) + .await?; + + deserialize_footer(&footer, u64::MAX) +} + +#[allow(clippy::too_many_arguments)] +async fn read_batch( + mut reader: R, + dictionaries: &mut Dictionaries, + metadata: &FileMetadata, + projection: Option<&[usize]>, + limit: Option, + block: usize, + meta_buffer: &mut Vec, + block_buffer: &mut Vec, + scratch: &mut Vec, +) -> Result>> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let block = metadata.blocks[block]; + + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(offset)).await?; + let mut meta_buf = [0; 4]; + reader.read_exact(&mut meta_buf).await?; + if meta_buf == CONTINUATION_MARKER { + reader.read_exact(&mut meta_buf).await?; + } + + let meta_len = i32::from_le_bytes(meta_buf) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + meta_buffer.clear(); + meta_buffer.try_reserve(meta_len)?; + (&mut reader) + .take(meta_len as u64) + .read_to_end(meta_buffer) + .await?; + + let message = arrow_format::ipc::MessageRef::read_as_root(meta_buffer) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let batch = get_record_batch(message)?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + block_buffer.clear(); + block_buffer.try_reserve(block_length)?; + reader + .take(block_length as u64) + .read_to_end(block_buffer) + .await?; + + let mut cursor = std::io::Cursor::new(&block_buffer); + + read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection, + limit, + dictionaries, + message + .version() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferVersion(err)))?, + &mut cursor, + 0, + metadata.size, + scratch, + ) +} + +async fn read_dictionaries( + mut reader: R, + fields: &[Field], + ipc_schema: &IpcSchema, + blocks: &[Block], + scratch: &mut Vec, +) -> Result +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut dictionaries = Default::default(); + let mut data: Vec = vec![]; + let mut buffer: Vec = vec![]; + + for block in blocks { + let offset: u64 = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: usize = block + .body_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + read_dictionary_message(&mut reader, offset, &mut data).await?; + + let message = arrow_format::ipc::MessageRef::read_as_root(data.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + match header { + MessageHeaderRef::DictionaryBatch(batch) => { + buffer.clear(); + buffer.try_reserve(length)?; + (&mut reader) + .take(length as u64) + .read_to_end(&mut buffer) + .await?; + let mut cursor = std::io::Cursor::new(&buffer); + read_dictionary( + batch, + fields, + ipc_schema, + &mut dictionaries, + &mut cursor, + 0, + u64::MAX, + scratch, + )?; + }, + _ => return Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } + } + Ok(dictionaries) +} + +async fn read_dictionary_message(mut reader: R, offset: u64, data: &mut Vec) -> Result<()> +where + R: AsyncRead + AsyncSeek + Unpin, +{ + let mut message_size = [0; 4]; + reader.seek(SeekFrom::Start(offset)).await?; + reader.read_exact(&mut message_size).await?; + if message_size == CONTINUATION_MARKER { + reader.read_exact(&mut message_size).await?; + } + let footer_size = i32::from_le_bytes(message_size); + + let footer_size: usize = footer_size + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + data.clear(); + data.try_reserve(footer_size)?; + (&mut reader) + .take(footer_size as u64) + .read_to_end(data) + .await?; + + Ok(()) +} + +async fn cached_read_dictionaries( + reader: &mut R, + metadata: &FileMetadata, + dictionaries: &mut Option, +) -> Result<()> { + match (&dictionaries, metadata.dictionaries.as_deref()) { + (None, Some(blocks)) => { + let new_dictionaries = read_dictionaries( + reader, + &metadata.schema.fields, + &metadata.ipc_schema, + blocks, + &mut Default::default(), + ) + .await?; + *dictionaries = Some(new_dictionaries); + }, + (None, None) => { + *dictionaries = Some(Default::default()); + }, + _ => {}, + }; + Ok(()) +} diff --git a/crates/nano-arrow/src/io/ipc/read/mod.rs b/crates/nano-arrow/src/io/ipc/read/mod.rs new file mode 100644 index 000000000000..887cf7b36258 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/mod.rs @@ -0,0 +1,45 @@ +//! APIs to read Arrow's IPC format. +//! +//! The two important structs here are the [`FileReader`](reader::FileReader), +//! which provides arbitrary access to any of its messages, and the +//! [`StreamReader`](stream::StreamReader), which only supports reading +//! data in the order it was written in. +use ahash::AHashMap; + +use crate::array::Array; + +mod array; +mod common; +mod deserialize; +mod error; +pub(crate) mod file; +mod read_basic; +mod reader; +mod schema; +mod stream; + +pub use error::OutOfSpecKind; + +#[cfg(feature = "io_ipc_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] +pub mod stream_async; + +#[cfg(feature = "io_ipc_read_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_read_async")))] +pub mod file_async; + +pub(crate) use common::first_dict_field; +#[cfg(feature = "io_flight")] +pub(crate) use common::{read_dictionary, read_record_batch}; +pub use file::{read_batch, read_file_dictionaries, read_file_metadata, FileMetadata}; +pub use reader::FileReader; +pub use schema::deserialize_schema; +pub use stream::{read_stream_metadata, StreamMetadata, StreamReader, StreamState}; + +/// how dictionaries are tracked in this crate +pub type Dictionaries = AHashMap>; + +pub(crate) type Node<'a> = arrow_format::ipc::FieldNodeRef<'a>; +pub(crate) type IpcBuffer<'a> = arrow_format::ipc::BufferRef<'a>; +pub(crate) type Compression<'a> = arrow_format::ipc::BodyCompressionRef<'a>; +pub(crate) type Version = arrow_format::ipc::MetadataVersion; diff --git a/crates/nano-arrow/src/io/ipc/read/read_basic.rs b/crates/nano-arrow/src/io/ipc/read/read_basic.rs new file mode 100644 index 000000000000..a56ebc81b3c4 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/read_basic.rs @@ -0,0 +1,291 @@ +use std::collections::VecDeque; +use std::convert::TryInto; +use std::io::{Read, Seek, SeekFrom}; + +use super::super::compression; +use super::super::endianness::is_native_little_endian; +use super::{Compression, IpcBuffer, Node, OutOfSpecKind}; +use crate::bitmap::Bitmap; +use crate::buffer::Buffer; +use crate::error::{Error, Result}; +use crate::types::NativeType; + +fn read_swapped( + reader: &mut R, + length: usize, + buffer: &mut Vec, + is_little_endian: bool, +) -> Result<()> { + // slow case where we must reverse bits + let mut slice = vec![0u8; length * std::mem::size_of::()]; + reader.read_exact(&mut slice)?; + + let chunks = slice.chunks_exact(std::mem::size_of::()); + if !is_little_endian { + // machine is little endian, file is big endian + buffer + .as_mut_slice() + .iter_mut() + .zip(chunks) + .try_for_each(|(slot, chunk)| { + let a: T::Bytes = match chunk.try_into() { + Ok(a) => a, + Err(_) => unreachable!(), + }; + *slot = T::from_be_bytes(a); + Result::Ok(()) + })?; + } else { + // machine is big endian, file is little endian + return Err(Error::NotYetImplemented( + "Reading little endian files from big endian machines".to_string(), + )); + } + Ok(()) +} + +fn read_uncompressed_buffer( + reader: &mut R, + buffer_length: usize, + length: usize, + is_little_endian: bool, +) -> Result> { + let required_number_of_bytes = length.saturating_mul(std::mem::size_of::()); + if required_number_of_bytes > buffer_length { + return Err(Error::from(OutOfSpecKind::InvalidBuffer { + length, + type_name: std::any::type_name::(), + required_number_of_bytes, + buffer_length, + })); + // todo: move this to the error's Display + /* + return Err(Error::OutOfSpec( + format!("The slots of the array times the physical size must \ + be smaller or equal to the length of the IPC buffer. \ + However, this array reports {} slots, which, for physical type \"{}\", corresponds to {} bytes, \ + which is larger than the buffer length {}", + length, + std::any::type_name::(), + bytes, + buffer_length, + ), + )); + */ + } + + // it is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read + // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 + let mut buffer = vec![T::default(); length]; + + if is_native_little_endian() == is_little_endian { + // fast case where we can just copy the contents + let slice = bytemuck::cast_slice_mut(&mut buffer); + reader.read_exact(slice)?; + } else { + read_swapped(reader, length, &mut buffer, is_little_endian)?; + } + Ok(buffer) +} + +fn read_compressed_buffer( + reader: &mut R, + buffer_length: usize, + length: usize, + is_little_endian: bool, + compression: Compression, + scratch: &mut Vec, +) -> Result> { + if is_little_endian != is_native_little_endian() { + return Err(Error::NotYetImplemented( + "Reading compressed and big endian IPC".to_string(), + )); + } + + // it is undefined behavior to call read_exact on un-initialized, https://doc.rust-lang.org/std/io/trait.Read.html#tymethod.read + // see also https://github.com/MaikKlein/ash/issues/354#issue-781730580 + let mut buffer = vec![T::default(); length]; + + // decompress first + scratch.clear(); + scratch.try_reserve(buffer_length)?; + reader + .by_ref() + .take(buffer_length as u64) + .read_to_end(scratch)?; + + let out_slice = bytemuck::cast_slice_mut(&mut buffer); + + let compression = compression + .codec() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { + arrow_format::ipc::CompressionType::Lz4Frame => { + compression::decompress_lz4(&scratch[8..], out_slice)?; + }, + arrow_format::ipc::CompressionType::Zstd => { + compression::decompress_zstd(&scratch[8..], out_slice)?; + }, + } + Ok(buffer) +} + +pub fn read_buffer( + buf: &mut VecDeque, + length: usize, // in slots + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + scratch: &mut Vec, +) -> Result> { + let buf = buf + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let buffer_length: usize = buf + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; + + if let Some(compression) = compression { + Ok(read_compressed_buffer( + reader, + buffer_length, + length, + is_little_endian, + compression, + scratch, + )? + .into()) + } else { + Ok(read_uncompressed_buffer(reader, buffer_length, length, is_little_endian)?.into()) + } +} + +fn read_uncompressed_bitmap( + length: usize, + bytes: usize, + reader: &mut R, +) -> Result> { + if length > bytes * 8 { + return Err(Error::from(OutOfSpecKind::InvalidBitmap { + length, + number_of_bits: bytes * 8, + })); + } + + let mut buffer = vec![]; + buffer.try_reserve(bytes)?; + reader + .by_ref() + .take(bytes as u64) + .read_to_end(&mut buffer)?; + + Ok(buffer) +} + +fn read_compressed_bitmap( + length: usize, + bytes: usize, + compression: Compression, + reader: &mut R, + scratch: &mut Vec, +) -> Result> { + let mut buffer = vec![0; (length + 7) / 8]; + + scratch.clear(); + scratch.try_reserve(bytes)?; + reader.by_ref().take(bytes as u64).read_to_end(scratch)?; + + let compression = compression + .codec() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferCompression(err)))?; + + match compression { + arrow_format::ipc::CompressionType::Lz4Frame => { + compression::decompress_lz4(&scratch[8..], &mut buffer)?; + }, + arrow_format::ipc::CompressionType::Zstd => { + compression::decompress_zstd(&scratch[8..], &mut buffer)?; + }, + } + Ok(buffer) +} + +pub fn read_bitmap( + buf: &mut VecDeque, + length: usize, + reader: &mut R, + block_offset: u64, + _: bool, + compression: Option, + scratch: &mut Vec, +) -> Result { + let buf = buf + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + let offset: u64 = buf + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let bytes: usize = buf + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + reader.seek(SeekFrom::Start(block_offset + offset))?; + + let buffer = if let Some(compression) = compression { + read_compressed_bitmap(length, bytes, compression, reader, scratch) + } else { + read_uncompressed_bitmap(length, bytes, reader) + }?; + + Bitmap::try_new(buffer, length) +} + +#[allow(clippy::too_many_arguments)] +pub fn read_validity( + buffers: &mut VecDeque, + field_node: Node, + reader: &mut R, + block_offset: u64, + is_little_endian: bool, + compression: Option, + limit: Option, + scratch: &mut Vec, +) -> Result> { + let length: usize = field_node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let length = limit.map(|limit| limit.min(length)).unwrap_or(length); + + Ok(if field_node.null_count() > 0 { + Some(read_bitmap( + buffers, + length, + reader, + block_offset, + is_little_endian, + compression, + scratch, + )?) + } else { + let _ = buffers + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + None + }) +} diff --git a/crates/nano-arrow/src/io/ipc/read/reader.rs b/crates/nano-arrow/src/io/ipc/read/reader.rs new file mode 100644 index 000000000000..80c900fd9a76 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/reader.rs @@ -0,0 +1,137 @@ +use std::io::{Read, Seek}; + +use ahash::AHashMap; + +use super::common::*; +use super::{read_batch, read_file_dictionaries, Dictionaries, FileMetadata}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::Result; + +/// An iterator of [`Chunk`]s from an Arrow IPC file. +pub struct FileReader { + reader: R, + metadata: FileMetadata, + // the dictionaries are going to be read + dictionaries: Option, + current_block: usize, + projection: Option<(Vec, AHashMap, Schema)>, + remaining: usize, + data_scratch: Vec, + message_scratch: Vec, +} + +impl FileReader { + /// Creates a new [`FileReader`]. Use `projection` to only take certain columns. + /// # Panic + /// Panics iff the projection is not in increasing order (e.g. `[1, 0]` nor `[0, 1, 1]` are valid) + pub fn new( + reader: R, + metadata: FileMetadata, + projection: Option>, + limit: Option, + ) -> Self { + let projection = projection.map(|projection| { + let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); + let schema = Schema { + fields, + metadata: metadata.schema.metadata.clone(), + }; + (p, h, schema) + }); + Self { + reader, + metadata, + dictionaries: Default::default(), + projection, + remaining: limit.unwrap_or(usize::MAX), + current_block: 0, + data_scratch: Default::default(), + message_scratch: Default::default(), + } + } + + /// Return the schema of the file + pub fn schema(&self) -> &Schema { + self.projection + .as_ref() + .map(|x| &x.2) + .unwrap_or(&self.metadata.schema) + } + + /// Returns the [`FileMetadata`] + pub fn metadata(&self) -> &FileMetadata { + &self.metadata + } + + /// Consumes this FileReader, returning the underlying reader + pub fn into_inner(self) -> R { + self.reader + } + + /// Get the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn get_scratches(&mut self) -> (Vec, Vec) { + ( + std::mem::take(&mut self.data_scratch), + std::mem::take(&mut self.message_scratch), + ) + } + + /// Set the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn set_scratches(&mut self, scratches: (Vec, Vec)) { + (self.data_scratch, self.message_scratch) = scratches; + } + + fn read_dictionaries(&mut self) -> Result<()> { + if self.dictionaries.is_none() { + self.dictionaries = Some(read_file_dictionaries( + &mut self.reader, + &self.metadata, + &mut self.data_scratch, + )?); + }; + Ok(()) + } +} + +impl Iterator for FileReader { + type Item = Result>>; + + fn next(&mut self) -> Option { + // get current block + if self.current_block == self.metadata.blocks.len() { + return None; + } + + match self.read_dictionaries() { + Ok(_) => {}, + Err(e) => return Some(Err(e)), + }; + + let block = self.current_block; + self.current_block += 1; + + let chunk = read_batch( + &mut self.reader, + self.dictionaries.as_ref().unwrap(), + &self.metadata, + self.projection.as_ref().map(|x| x.0.as_ref()), + Some(self.remaining), + block, + &mut self.message_scratch, + &mut self.data_scratch, + ); + self.remaining -= chunk.as_ref().map(|x| x.len()).unwrap_or_default(); + + let chunk = if let Some((_, map, _)) = &self.projection { + // re-order according to projection + chunk.map(|chunk| apply_projection(chunk, map)) + } else { + chunk + }; + Some(chunk) + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/schema.rs b/crates/nano-arrow/src/io/ipc/read/schema.rs new file mode 100644 index 000000000000..1b6687f30c95 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/schema.rs @@ -0,0 +1,429 @@ +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef}; + +use super::super::{IpcField, IpcSchema}; +use super::{OutOfSpecKind, StreamMetadata}; +use crate::datatypes::{ + get_extension, DataType, Extension, Field, IntegerType, IntervalUnit, Metadata, Schema, + TimeUnit, UnionMode, +}; +use crate::error::{Error, Result}; + +fn try_unzip_vec>>(iter: I) -> Result<(Vec, Vec)> { + let mut a = vec![]; + let mut b = vec![]; + for maybe_item in iter { + let (a_i, b_i) = maybe_item?; + a.push(a_i); + b.push(b_i); + } + + Ok((a, b)) +} + +fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> Result<(Field, IpcField)> { + let metadata = read_metadata(&ipc_field)?; + + let extension = get_extension(&metadata); + + let (data_type, ipc_field_) = get_data_type(ipc_field, extension, true)?; + + let field = Field { + name: ipc_field + .name()? + .ok_or_else(|| Error::oos("Every field in IPC must have a name"))? + .to_string(), + data_type, + is_nullable: ipc_field.nullable()?, + metadata, + }; + + Ok((field, ipc_field_)) +} + +fn read_metadata(field: &arrow_format::ipc::FieldRef) -> Result { + Ok(if let Some(list) = field.custom_metadata()? { + let mut metadata_map = Metadata::new(); + for kv in list { + let kv = kv?; + if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) { + metadata_map.insert(k.to_string(), v.to_string()); + } + } + metadata_map + } else { + Metadata::default() + }) +} + +fn deserialize_integer(int: arrow_format::ipc::IntRef) -> Result { + Ok(match (int.bit_width()?, int.is_signed()?) { + (8, true) => IntegerType::Int8, + (8, false) => IntegerType::UInt8, + (16, true) => IntegerType::Int16, + (16, false) => IntegerType::UInt16, + (32, true) => IntegerType::Int32, + (32, false) => IntegerType::UInt32, + (64, true) => IntegerType::Int64, + (64, false) => IntegerType::UInt64, + _ => return Err(Error::oos("IPC: indexType can only be 8, 16, 32 or 64.")), + }) +} + +fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> Result { + use arrow_format::ipc::TimeUnit::*; + Ok(match time_unit { + Second => TimeUnit::Second, + Millisecond => TimeUnit::Millisecond, + Microsecond => TimeUnit::Microsecond, + Nanosecond => TimeUnit::Nanosecond, + }) +} + +fn deserialize_time(time: TimeRef) -> Result<(DataType, IpcField)> { + let unit = deserialize_timeunit(time.unit()?)?; + + let data_type = match (time.bit_width()?, unit) { + (32, TimeUnit::Second) => DataType::Time32(TimeUnit::Second), + (32, TimeUnit::Millisecond) => DataType::Time32(TimeUnit::Millisecond), + (64, TimeUnit::Microsecond) => DataType::Time64(TimeUnit::Microsecond), + (64, TimeUnit::Nanosecond) => DataType::Time64(TimeUnit::Nanosecond), + (bits, precision) => { + return Err(Error::nyi(format!( + "Time type with bit width of {bits} and unit of {precision:?}" + ))) + }, + }; + Ok((data_type, IpcField::default())) +} + +fn deserialize_timestamp(timestamp: TimestampRef) -> Result<(DataType, IpcField)> { + let timezone = timestamp.timezone()?.map(|tz| tz.to_string()); + let time_unit = deserialize_timeunit(timestamp.unit()?)?; + Ok(( + DataType::Timestamp(time_unit, timezone), + IpcField::default(), + )) +} + +fn deserialize_union(union_: UnionRef, field: FieldRef) -> Result<(DataType, IpcField)> { + let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse); + let ids = union_.type_ids()?.map(|x| x.iter().collect()); + + let fields = field + .children()? + .ok_or_else(|| Error::oos("IPC: Union must contain children"))?; + if fields.is_empty() { + return Err(Error::oos("IPC: Union must contain at least one child")); + } + + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + Ok((DataType::Union(fields, ids, mode), ipc_field)) +} + +fn deserialize_map(map: MapRef, field: FieldRef) -> Result<(DataType, IpcField)> { + let is_sorted = map.keys_sorted()?; + + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: Map must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: Map must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + let data_type = DataType::Map(Box::new(field), is_sorted); + Ok(( + data_type, + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_struct(field: FieldRef) -> Result<(DataType, IpcField)> { + let fields = field + .children()? + .ok_or_else(|| Error::oos("IPC: Struct must contain children"))?; + if fields.is_empty() { + return Err(Error::oos("IPC: Struct must contain at least one child")); + } + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + let ipc_field = IpcField { + fields: ipc_fields, + dictionary_id: None, + }; + Ok((DataType::Struct(fields), ipc_field)) +} + +fn deserialize_list(field: FieldRef) -> Result<(DataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: List must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: List must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + Ok(( + DataType::List(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_large_list(field: FieldRef) -> Result<(DataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: List must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: List must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + Ok(( + DataType::LargeList(Box::new(field)), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +fn deserialize_fixed_size_list( + list: FixedSizeListRef, + field: FieldRef, +) -> Result<(DataType, IpcField)> { + let children = field + .children()? + .ok_or_else(|| Error::oos("IPC: FixedSizeList must contain children"))?; + let inner = children + .get(0) + .ok_or_else(|| Error::oos("IPC: FixedSizeList must contain one child"))??; + let (field, ipc_field) = deserialize_field(inner)?; + + let size = list + .list_size()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok(( + DataType::FixedSizeList(Box::new(field), size), + IpcField { + fields: vec![ipc_field], + dictionary_id: None, + }, + )) +} + +/// Get the Arrow data type from the flatbuffer Field table +fn get_data_type( + field: arrow_format::ipc::FieldRef, + extension: Extension, + may_be_dictionary: bool, +) -> Result<(DataType, IpcField)> { + if let Some(dictionary) = field.dictionary()? { + if may_be_dictionary { + let int = dictionary + .index_type()? + .ok_or_else(|| Error::oos("indexType is mandatory in Dictionary."))?; + let index_type = deserialize_integer(int)?; + let (inner, mut ipc_field) = get_data_type(field, extension, false)?; + ipc_field.dictionary_id = Some(dictionary.id()?); + return Ok(( + DataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?), + ipc_field, + )); + } + } + + if let Some(extension) = extension { + let (name, metadata) = extension; + let (data_type, fields) = get_data_type(field, None, false)?; + return Ok(( + DataType::Extension(name, Box::new(data_type), metadata), + fields, + )); + } + + let type_ = field + .type_()? + .ok_or_else(|| Error::oos("IPC: field type is mandatory"))?; + + use arrow_format::ipc::TypeRef::*; + Ok(match type_ { + Null(_) => (DataType::Null, IpcField::default()), + Bool(_) => (DataType::Boolean, IpcField::default()), + Int(int) => { + let data_type = deserialize_integer(int)?.into(); + (data_type, IpcField::default()) + }, + Binary(_) => (DataType::Binary, IpcField::default()), + LargeBinary(_) => (DataType::LargeBinary, IpcField::default()), + Utf8(_) => (DataType::Utf8, IpcField::default()), + LargeUtf8(_) => (DataType::LargeUtf8, IpcField::default()), + FixedSizeBinary(fixed) => ( + DataType::FixedSizeBinary( + fixed + .byte_width()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?, + ), + IpcField::default(), + ), + FloatingPoint(float) => { + let data_type = match float.precision()? { + arrow_format::ipc::Precision::Half => DataType::Float16, + arrow_format::ipc::Precision::Single => DataType::Float32, + arrow_format::ipc::Precision::Double => DataType::Float64, + }; + (data_type, IpcField::default()) + }, + Date(date) => { + let data_type = match date.unit()? { + arrow_format::ipc::DateUnit::Day => DataType::Date32, + arrow_format::ipc::DateUnit::Millisecond => DataType::Date64, + }; + (data_type, IpcField::default()) + }, + Time(time) => deserialize_time(time)?, + Timestamp(timestamp) => deserialize_timestamp(timestamp)?, + Interval(interval) => { + let data_type = match interval.unit()? { + arrow_format::ipc::IntervalUnit::YearMonth => { + DataType::Interval(IntervalUnit::YearMonth) + }, + arrow_format::ipc::IntervalUnit::DayTime => { + DataType::Interval(IntervalUnit::DayTime) + }, + arrow_format::ipc::IntervalUnit::MonthDayNano => { + DataType::Interval(IntervalUnit::MonthDayNano) + }, + }; + (data_type, IpcField::default()) + }, + Duration(duration) => { + let time_unit = deserialize_timeunit(duration.unit()?)?; + (DataType::Duration(time_unit), IpcField::default()) + }, + Decimal(decimal) => { + let bit_width: usize = decimal + .bit_width()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let precision: usize = decimal + .precision()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + let scale: usize = decimal + .scale()? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_type = match bit_width { + 128 => DataType::Decimal(precision, scale), + 256 => DataType::Decimal256(precision, scale), + _ => return Err(Error::from(OutOfSpecKind::NegativeFooterLength)), + }; + + (data_type, IpcField::default()) + }, + List(_) => deserialize_list(field)?, + LargeList(_) => deserialize_large_list(field)?, + FixedSizeList(list) => deserialize_fixed_size_list(list, field)?, + Struct(_) => deserialize_struct(field)?, + Union(union_) => deserialize_union(union_, field)?, + Map(map) => deserialize_map(map, field)?, + }) +} + +/// Deserialize an flatbuffers-encoded Schema message into [`Schema`] and [`IpcSchema`]. +pub fn deserialize_schema(message: &[u8]) -> Result<(Schema, IpcSchema)> { + let message = arrow_format::ipc::MessageRef::read_as_root(message) + .map_err(|err| Error::oos(format!("Unable deserialize message: {err:?}")))?; + + let schema = match message + .header()? + .ok_or_else(|| Error::oos("Unable to convert header to a schema".to_string()))? + { + arrow_format::ipc::MessageHeaderRef::Schema(schema) => Ok(schema), + _ => Err(Error::nyi("The message is expected to be a Schema message")), + }?; + + fb_to_schema(schema) +} + +/// Deserialize the raw Schema table from IPC format to Schema data type +pub(super) fn fb_to_schema(schema: arrow_format::ipc::SchemaRef) -> Result<(Schema, IpcSchema)> { + let fields = schema + .fields()? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingFields))?; + let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| { + let (field, fields) = deserialize_field(field?)?; + Ok((field, fields)) + }))?; + + let is_little_endian = match schema.endianness()? { + arrow_format::ipc::Endianness::Little => true, + arrow_format::ipc::Endianness::Big => false, + }; + + let mut metadata = Metadata::default(); + if let Some(md_fields) = schema.custom_metadata()? { + for kv in md_fields { + let kv = kv?; + let k_str = kv.key()?; + let v_str = kv.value()?; + if let Some(k) = k_str { + if let Some(v) = v_str { + metadata.insert(k.to_string(), v.to_string()); + } + } + } + } + + Ok(( + Schema { fields, metadata }, + IpcSchema { + fields: ipc_fields, + is_little_endian, + }, + )) +} + +pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> Result { + let message = arrow_format::ipc::MessageRef::read_as_root(meta) + .map_err(|err| Error::OutOfSpec(format!("Unable to get root as message: {err:?}")))?; + let version = message.version()?; + // message header is a Schema, so read it + let header = message + .header()? + .ok_or_else(|| Error::oos("Unable to read the first IPC message"))?; + let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header { + schema + } else { + return Err(Error::oos( + "The first IPC message of the stream must be a schema", + )); + }; + let (schema, ipc_schema) = fb_to_schema(schema)?; + + Ok(StreamMetadata { + schema, + version, + ipc_schema, + }) +} diff --git a/crates/nano-arrow/src/io/ipc/read/stream.rs b/crates/nano-arrow/src/io/ipc/read/stream.rs new file mode 100644 index 000000000000..848bf5acb938 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/stream.rs @@ -0,0 +1,318 @@ +use std::io::Read; + +use ahash::AHashMap; +use arrow_format; +use arrow_format::ipc::planus::ReadAsRoot; + +use super::super::CONTINUATION_MARKER; +use super::common::*; +use super::schema::deserialize_stream_metadata; +use super::{Dictionaries, OutOfSpecKind}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; +use crate::io::ipc::IpcSchema; + +/// Metadata of an Arrow IPC stream, written at the start of the stream +#[derive(Debug, Clone)] +pub struct StreamMetadata { + /// The schema that is read from the stream's first message + pub schema: Schema, + + /// The IPC version of the stream + pub version: arrow_format::ipc::MetadataVersion, + + /// The IPC fields tracking dictionaries + pub ipc_schema: IpcSchema, +} + +/// Reads the metadata of the stream +pub fn read_stream_metadata(reader: &mut R) -> Result { + // determine metadata length + let mut meta_size: [u8; 4] = [0; 4]; + reader.read_exact(&mut meta_size)?; + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_size == CONTINUATION_MARKER { + reader.read_exact(&mut meta_size)?; + } + i32::from_le_bytes(meta_size) + }; + + let length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let mut buffer = vec![]; + buffer.try_reserve(length)?; + reader + .by_ref() + .take(length as u64) + .read_to_end(&mut buffer)?; + + deserialize_stream_metadata(&buffer) +} + +/// Encodes the stream's status after each read. +/// +/// A stream is an iterator, and an iterator returns `Option`. The `Item` +/// type in the [`StreamReader`] case is `StreamState`, which means that an Arrow +/// stream may yield one of three values: (1) `None`, which signals that the stream +/// is done; (2) [`StreamState::Some`], which signals that there was +/// data waiting in the stream and we read it; and finally (3) +/// [`Some(StreamState::Waiting)`], which means that the stream is still "live", it +/// just doesn't hold any data right now. +pub enum StreamState { + /// A live stream without data + Waiting, + /// Next item in the stream + Some(Chunk>), +} + +impl StreamState { + /// Return the data inside this wrapper. + /// + /// # Panics + /// + /// If the `StreamState` was `Waiting`. + pub fn unwrap(self) -> Chunk> { + if let StreamState::Some(batch) = self { + batch + } else { + panic!("The batch is not available") + } + } +} + +/// Reads the next item, yielding `None` if the stream is done, +/// and a [`StreamState`] otherwise. +fn read_next( + reader: &mut R, + metadata: &StreamMetadata, + dictionaries: &mut Dictionaries, + message_buffer: &mut Vec, + data_buffer: &mut Vec, + projection: &Option<(Vec, AHashMap, Schema)>, + scratch: &mut Vec, +) -> Result> { + // determine metadata length + let mut meta_length: [u8; 4] = [0; 4]; + + match reader.read_exact(&mut meta_length) { + Ok(()) => (), + Err(e) => { + return if e.kind() == std::io::ErrorKind::UnexpectedEof { + // Handle EOF without the "0xFFFFFFFF 0x00000000" + // valid according to: + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + Ok(Some(StreamState::Waiting)) + } else { + Err(Error::from(e)) + }; + }, + } + + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_length == CONTINUATION_MARKER { + reader.read_exact(&mut meta_length)?; + } + i32::from_le_bytes(meta_length) + }; + + let meta_length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + if meta_length == 0 { + // the stream has ended, mark the reader as finished + return Ok(None); + } + + message_buffer.clear(); + message_buffer.try_reserve(meta_length)?; + reader + .by_ref() + .take(meta_length as u64) + .read_to_end(message_buffer)?; + + let message = arrow_format::ipc::MessageRef::read_as_root(message_buffer.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { + data_buffer.clear(); + data_buffer.try_reserve(block_length)?; + reader + .by_ref() + .take(block_length as u64) + .read_to_end(data_buffer)?; + + let file_size = data_buffer.len() as u64; + + let mut reader = std::io::Cursor::new(data_buffer); + + let chunk = read_record_batch( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + projection.as_ref().map(|x| x.0.as_ref()), + None, + dictionaries, + metadata.version, + &mut reader, + 0, + file_size, + scratch, + ); + + if let Some((_, map, _)) = projection { + // re-order according to projection + chunk + .map(|chunk| apply_projection(chunk, map)) + .map(|x| Some(StreamState::Some(x))) + } else { + chunk.map(|x| Some(StreamState::Some(x))) + } + }, + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { + data_buffer.clear(); + data_buffer.try_reserve(block_length)?; + reader + .by_ref() + .take(block_length as u64) + .read_to_end(data_buffer)?; + + let file_size = data_buffer.len() as u64; + let mut dict_reader = std::io::Cursor::new(&data_buffer); + + read_dictionary( + batch, + &metadata.schema.fields, + &metadata.ipc_schema, + dictionaries, + &mut dict_reader, + 0, + file_size, + scratch, + )?; + + // read the next message until we encounter a RecordBatch message + read_next( + reader, + metadata, + dictionaries, + message_buffer, + data_buffer, + projection, + scratch, + ) + }, + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +/// Arrow Stream reader. +/// +/// An [`Iterator`] over an Arrow stream that yields a result of [`StreamState`]s. +/// This is the recommended way to read an arrow stream (by iterating over its data). +/// +/// For a more thorough walkthrough consult [this example](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow). +pub struct StreamReader { + reader: R, + metadata: StreamMetadata, + dictionaries: Dictionaries, + finished: bool, + data_buffer: Vec, + message_buffer: Vec, + projection: Option<(Vec, AHashMap, Schema)>, + scratch: Vec, +} + +impl StreamReader { + /// Try to create a new stream reader + /// + /// The first message in the stream is the schema, the reader will fail if it does not + /// encounter a schema. + /// To check if the reader is done, use `is_finished(self)` + pub fn new(reader: R, metadata: StreamMetadata, projection: Option>) -> Self { + let projection = projection.map(|projection| { + let (p, h, fields) = prepare_projection(&metadata.schema.fields, projection); + let schema = Schema { + fields, + metadata: metadata.schema.metadata.clone(), + }; + (p, h, schema) + }); + + Self { + reader, + metadata, + dictionaries: Default::default(), + finished: false, + data_buffer: Default::default(), + message_buffer: Default::default(), + projection, + scratch: Default::default(), + } + } + + /// Return the schema of the stream + pub fn metadata(&self) -> &StreamMetadata { + &self.metadata + } + + /// Return the schema of the file + pub fn schema(&self) -> &Schema { + self.projection + .as_ref() + .map(|x| &x.2) + .unwrap_or(&self.metadata.schema) + } + + /// Check if the stream is finished + pub fn is_finished(&self) -> bool { + self.finished + } + + fn maybe_next(&mut self) -> Result> { + if self.finished { + return Ok(None); + } + let batch = read_next( + &mut self.reader, + &self.metadata, + &mut self.dictionaries, + &mut self.message_buffer, + &mut self.data_buffer, + &self.projection, + &mut self.scratch, + )?; + if batch.is_none() { + self.finished = true; + } + Ok(batch) + } +} + +impl Iterator for StreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + self.maybe_next().transpose() + } +} diff --git a/crates/nano-arrow/src/io/ipc/read/stream_async.rs b/crates/nano-arrow/src/io/ipc/read/stream_async.rs new file mode 100644 index 000000000000..f87f84a8d317 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/read/stream_async.rs @@ -0,0 +1,237 @@ +//! APIs to read Arrow streams asynchronously + +use arrow_format::ipc::planus::ReadAsRoot; +use futures::future::BoxFuture; +use futures::{AsyncRead, AsyncReadExt, FutureExt, Stream}; + +use super::super::CONTINUATION_MARKER; +use super::common::{read_dictionary, read_record_batch}; +use super::schema::deserialize_stream_metadata; +use super::{Dictionaries, OutOfSpecKind, StreamMetadata}; +use crate::array::*; +use crate::chunk::Chunk; +use crate::error::{Error, Result}; + +/// A (private) state of stream messages +struct ReadState { + pub reader: R, + pub metadata: StreamMetadata, + pub dictionaries: Dictionaries, + /// The internal buffer to read data inside the messages (records and dictionaries) to + pub data_buffer: Vec, + /// The internal buffer to read messages to + pub message_buffer: Vec, +} + +/// The state of an Arrow stream +enum StreamState { + /// The stream does not contain new chunks (and it has not been closed) + Waiting(ReadState), + /// The stream contain a new chunk + Some((ReadState, Chunk>)), +} + +/// Reads the [`StreamMetadata`] of the Arrow stream asynchronously +pub async fn read_stream_metadata_async( + reader: &mut R, +) -> Result { + // determine metadata length + let mut meta_size: [u8; 4] = [0; 4]; + reader.read_exact(&mut meta_size).await?; + let meta_len = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_size == CONTINUATION_MARKER { + reader.read_exact(&mut meta_size).await?; + } + i32::from_le_bytes(meta_size) + }; + + let meta_len: usize = meta_len + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let mut meta_buffer = vec![]; + meta_buffer.try_reserve(meta_len)?; + reader + .take(meta_len as u64) + .read_to_end(&mut meta_buffer) + .await?; + + deserialize_stream_metadata(&meta_buffer) +} + +/// Reads the next item, yielding `None` if the stream has been closed, +/// or a [`StreamState`] otherwise. +async fn maybe_next( + mut state: ReadState, +) -> Result>> { + let mut scratch = Default::default(); + // determine metadata length + let mut meta_length: [u8; 4] = [0; 4]; + + match state.reader.read_exact(&mut meta_length).await { + Ok(()) => (), + Err(e) => { + return if e.kind() == std::io::ErrorKind::UnexpectedEof { + // Handle EOF without the "0xFFFFFFFF 0x00000000" + // valid according to: + // https://arrow.apache.org/docs/format/Columnar.html#ipc-streaming-format + Ok(Some(StreamState::Waiting(state))) + } else { + Err(Error::from(e)) + }; + }, + } + + let meta_length = { + // If a continuation marker is encountered, skip over it and read + // the size from the next four bytes. + if meta_length == CONTINUATION_MARKER { + state.reader.read_exact(&mut meta_length).await?; + } + i32::from_le_bytes(meta_length) + }; + + let meta_length: usize = meta_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + if meta_length == 0 { + // the stream has ended, mark the reader as finished + return Ok(None); + } + + state.message_buffer.clear(); + state.message_buffer.try_reserve(meta_length)?; + (&mut state.reader) + .take(meta_length as u64) + .read_to_end(&mut state.message_buffer) + .await?; + + let message = arrow_format::ipc::MessageRef::read_as_root(state.message_buffer.as_ref()) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + let header = message + .header() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferHeader(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageHeader))?; + + let block_length: usize = message + .body_length() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBodyLength(err)))? + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::UnexpectedNegativeInteger))?; + + match header { + arrow_format::ipc::MessageHeaderRef::RecordBatch(batch) => { + state.data_buffer.clear(); + state.data_buffer.try_reserve(block_length)?; + (&mut state.reader) + .take(block_length as u64) + .read_to_end(&mut state.data_buffer) + .await?; + + read_record_batch( + batch, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, + None, + None, + &state.dictionaries, + state.metadata.version, + &mut std::io::Cursor::new(&state.data_buffer), + 0, + state.data_buffer.len() as u64, + &mut scratch, + ) + .map(|chunk| Some(StreamState::Some((state, chunk)))) + }, + arrow_format::ipc::MessageHeaderRef::DictionaryBatch(batch) => { + state.data_buffer.clear(); + state.data_buffer.try_reserve(block_length)?; + (&mut state.reader) + .take(block_length as u64) + .read_to_end(&mut state.data_buffer) + .await?; + + let file_size = state.data_buffer.len() as u64; + + let mut dict_reader = std::io::Cursor::new(&state.data_buffer); + + read_dictionary( + batch, + &state.metadata.schema.fields, + &state.metadata.ipc_schema, + &mut state.dictionaries, + &mut dict_reader, + 0, + file_size, + &mut scratch, + )?; + + // read the next message until we encounter a Chunk> message + Ok(Some(StreamState::Waiting(state))) + }, + _ => Err(Error::from(OutOfSpecKind::UnexpectedMessageType)), + } +} + +/// A [`Stream`] over an Arrow IPC stream that asynchronously yields [`Chunk`]s. +pub struct AsyncStreamReader<'a, R: AsyncRead + Unpin + Send + 'a> { + metadata: StreamMetadata, + future: Option>>>>, +} + +impl<'a, R: AsyncRead + Unpin + Send + 'a> AsyncStreamReader<'a, R> { + /// Creates a new [`AsyncStreamReader`] + pub fn new(reader: R, metadata: StreamMetadata) -> Self { + let state = ReadState { + reader, + metadata: metadata.clone(), + dictionaries: Default::default(), + data_buffer: Default::default(), + message_buffer: Default::default(), + }; + let future = Some(maybe_next(state).boxed()); + Self { metadata, future } + } + + /// Return the schema of the stream + pub fn metadata(&self) -> &StreamMetadata { + &self.metadata + } +} + +impl<'a, R: AsyncRead + Unpin + Send> Stream for AsyncStreamReader<'a, R> { + type Item = Result>>; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::pin::Pin; + use std::task::Poll; + let me = Pin::into_inner(self); + + match &mut me.future { + Some(fut) => match fut.as_mut().poll(cx) { + Poll::Ready(Ok(None)) => { + me.future = None; + Poll::Ready(None) + }, + Poll::Ready(Ok(Some(StreamState::Some((state, batch))))) => { + me.future = Some(Box::pin(maybe_next(state))); + Poll::Ready(Some(Ok(batch))) + }, + Poll::Ready(Ok(Some(StreamState::Waiting(_)))) => Poll::Pending, + Poll::Ready(Err(err)) => { + me.future = None; + Poll::Ready(Some(Err(err))) + }, + Poll::Pending => Poll::Pending, + }, + None => Poll::Ready(None), + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/common.rs b/crates/nano-arrow/src/io/ipc/write/common.rs new file mode 100644 index 000000000000..4684bd7f658d --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/common.rs @@ -0,0 +1,448 @@ +use std::borrow::{Borrow, Cow}; + +use arrow_format::ipc::planus::Builder; + +use super::super::IpcField; +use super::{write, write_dictionary}; +use crate::array::*; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::io::ipc::endianness::is_native_little_endian; +use crate::io::ipc::read::Dictionaries; + +/// Compression codec +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Compression { + /// LZ4 (framed) + LZ4, + /// ZSTD + ZSTD, +} + +/// Options declaring the behaviour of writing to IPC +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +pub struct WriteOptions { + /// Whether the buffers should be compressed and which codec to use. + /// Note: to use compression the crate must be compiled with feature `io_ipc_compression`. + pub compression: Option, +} + +fn encode_dictionary( + field: &IpcField, + array: &dyn Array, + options: &WriteOptions, + dictionary_tracker: &mut DictionaryTracker, + encoded_dictionaries: &mut Vec, +) -> Result<()> { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Utf8 | LargeUtf8 | Binary | LargeBinary | Primitive(_) | Boolean | Null + | FixedSizeBinary => Ok(()), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let dict_id = field.dictionary_id + .ok_or_else(|| Error::InvalidArgumentError("Dictionaries must have an associated id".to_string()))?; + + let emit = dictionary_tracker.insert(dict_id, array)?; + + let array = array.as_any().downcast_ref::>().unwrap(); + let values = array.values(); + encode_dictionary(field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries + )?; + + if emit { + encoded_dictionaries.push(dictionary_batch_to_bytes::<$T>( + dict_id, + array, + options, + is_native_little_endian(), + )); + }; + Ok(()) + }), + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + let fields = field.fields.as_slice(); + if array.fields().len() != fields.len() { + return Err(Error::InvalidArgumentError( + "The number of fields in a struct must equal the number of children in IpcField".to_string(), + )); + } + fields + .iter() + .zip(array.values().iter()) + .try_for_each(|(field, values)| { + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }) + }, + List => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + LargeList => { + let values = array + .as_any() + .downcast_ref::>() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + FixedSizeList => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .values(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + Union => { + let values = array + .as_any() + .downcast_ref::() + .unwrap() + .fields(); + let fields = &field.fields[..]; // todo: error instead + if values.len() != fields.len() { + return Err(Error::InvalidArgumentError( + "The number of fields in a union must equal the number of children in IpcField" + .to_string(), + )); + } + fields + .iter() + .zip(values.iter()) + .try_for_each(|(field, values)| { + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }) + }, + Map => { + let values = array.as_any().downcast_ref::().unwrap().field(); + let field = &field.fields[0]; // todo: error instead + encode_dictionary( + field, + values.as_ref(), + options, + dictionary_tracker, + encoded_dictionaries, + ) + }, + } +} + +pub fn encode_chunk( + chunk: &Chunk>, + fields: &[IpcField], + dictionary_tracker: &mut DictionaryTracker, + options: &WriteOptions, +) -> Result<(Vec, EncodedData)> { + let mut encoded_message = EncodedData::default(); + let encoded_dictionaries = encode_chunk_amortized( + chunk, + fields, + dictionary_tracker, + options, + &mut encoded_message, + )?; + Ok((encoded_dictionaries, encoded_message)) +} + +// Amortizes `EncodedData` allocation. +pub fn encode_chunk_amortized( + chunk: &Chunk>, + fields: &[IpcField], + dictionary_tracker: &mut DictionaryTracker, + options: &WriteOptions, + encoded_message: &mut EncodedData, +) -> Result> { + let mut encoded_dictionaries = vec![]; + + for (field, array) in fields.iter().zip(chunk.as_ref()) { + encode_dictionary( + field, + array.as_ref(), + options, + dictionary_tracker, + &mut encoded_dictionaries, + )?; + } + + chunk_to_bytes_amortized(chunk, options, encoded_message); + + Ok(encoded_dictionaries) +} + +fn serialize_compression( + compression: Option, +) -> Option> { + if let Some(compression) = compression { + let codec = match compression { + Compression::LZ4 => arrow_format::ipc::CompressionType::Lz4Frame, + Compression::ZSTD => arrow_format::ipc::CompressionType::Zstd, + }; + Some(Box::new(arrow_format::ipc::BodyCompression { + codec, + method: arrow_format::ipc::BodyCompressionMethod::Buffer, + })) + } else { + None + } +} + +/// Write [`Chunk`] into two sets of bytes, one for the header (ipc::Schema::Message) and the +/// other for the batch's data +fn chunk_to_bytes_amortized( + chunk: &Chunk>, + options: &WriteOptions, + encoded_message: &mut EncodedData, +) { + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data = std::mem::take(&mut encoded_message.arrow_data); + arrow_data.clear(); + + let mut offset = 0; + for array in chunk.arrays() { + write( + array.as_ref(), + &mut buffers, + &mut arrow_data, + &mut nodes, + &mut offset, + is_native_little_endian(), + options.compression, + ) + } + + let compression = serialize_compression(options.compression); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::RecordBatch(Box::new( + arrow_format::ipc::RecordBatch { + length: chunk.len() as i64, + nodes: Some(nodes), + buffers: Some(buffers), + compression, + }, + ))), + body_length: arrow_data.len() as i64, + custom_metadata: None, + }; + + let mut builder = Builder::new(); + let ipc_message = builder.finish(&message, None); + encoded_message.ipc_message = ipc_message.to_vec(); + encoded_message.arrow_data = arrow_data +} + +/// Write dictionary values into two sets of bytes, one for the header (ipc::Schema::Message) and the +/// other for the data +fn dictionary_batch_to_bytes( + dict_id: i64, + array: &DictionaryArray, + options: &WriteOptions, + is_little_endian: bool, +) -> EncodedData { + let mut nodes: Vec = vec![]; + let mut buffers: Vec = vec![]; + let mut arrow_data: Vec = vec![]; + + let length = write_dictionary( + array, + &mut buffers, + &mut arrow_data, + &mut nodes, + &mut 0, + is_little_endian, + options.compression, + false, + ); + + let compression = serialize_compression(options.compression); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::DictionaryBatch(Box::new( + arrow_format::ipc::DictionaryBatch { + id: dict_id, + data: Some(Box::new(arrow_format::ipc::RecordBatch { + length: length as i64, + nodes: Some(nodes), + buffers: Some(buffers), + compression, + })), + is_delta: false, + }, + ))), + body_length: arrow_data.len() as i64, + custom_metadata: None, + }; + + let mut builder = Builder::new(); + let ipc_message = builder.finish(&message, None); + + EncodedData { + ipc_message: ipc_message.to_vec(), + arrow_data, + } +} + +/// Keeps track of dictionaries that have been written, to avoid emitting the same dictionary +/// multiple times. Can optionally error if an update to an existing dictionary is attempted, which +/// isn't allowed in the `FileWriter`. +pub struct DictionaryTracker { + pub dictionaries: Dictionaries, + pub cannot_replace: bool, +} + +impl DictionaryTracker { + /// Keep track of the dictionary with the given ID and values. Behavior: + /// + /// * If this ID has been written already and has the same data, return `Ok(false)` to indicate + /// that the dictionary was not actually inserted (because it's already been seen). + /// * If this ID has been written already but with different data, and this tracker is + /// configured to return an error, return an error. + /// * If the tracker has not been configured to error on replacement or this dictionary + /// has never been seen before, return `Ok(true)` to indicate that the dictionary was just + /// inserted. + pub fn insert(&mut self, dict_id: i64, array: &dyn Array) -> Result { + let values = match array.data_type() { + DataType::Dictionary(key_type, _, _) => { + match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + array.values() + }) + }, + _ => unreachable!(), + }; + + // If a dictionary with this id was already emitted, check if it was the same. + if let Some(last) = self.dictionaries.get(&dict_id) { + if last.as_ref() == values.as_ref() { + // Same dictionary values => no need to emit it again + return Ok(false); + } else if self.cannot_replace { + return Err(Error::InvalidArgumentError( + "Dictionary replacement detected when writing IPC file format. \ + Arrow IPC files only support a single dictionary for a given field \ + across all batches." + .to_string(), + )); + } + }; + + self.dictionaries.insert(dict_id, values.clone()); + Ok(true) + } +} + +/// Stores the encoded data, which is an ipc::Schema::Message, and optional Arrow data +#[derive(Debug, Default)] +pub struct EncodedData { + /// An encoded ipc::Schema::Message + pub ipc_message: Vec, + /// Arrow buffers to be written, should be an empty vec for schema messages + pub arrow_data: Vec, +} + +/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes +#[inline] +pub(crate) fn pad_to_64(len: usize) -> usize { + ((len + 63) & !63) - len +} + +/// An array [`Chunk`] with optional accompanying IPC fields. +#[derive(Debug, Clone, PartialEq)] +pub struct Record<'a> { + columns: Cow<'a, Chunk>>, + fields: Option>, +} + +impl<'a> Record<'a> { + /// Get the IPC fields for this record. + pub fn fields(&self) -> Option<&[IpcField]> { + self.fields.as_deref() + } + + /// Get the Arrow columns in this record. + pub fn columns(&self) -> &Chunk> { + self.columns.borrow() + } +} + +impl From>> for Record<'static> { + fn from(columns: Chunk>) -> Self { + Self { + columns: Cow::Owned(columns), + fields: None, + } + } +} + +impl<'a, F> From<(Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (Chunk>, Option)) -> Self { + Self { + columns: Cow::Owned(columns), + fields: fields.map(|f| f.into()), + } + } +} + +impl<'a, F> From<(&'a Chunk>, Option)> for Record<'a> +where + F: Into>, +{ + fn from((columns, fields): (&'a Chunk>, Option)) -> Self { + Self { + columns: Cow::Borrowed(columns), + fields: fields.map(|f| f.into()), + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/common_async.rs b/crates/nano-arrow/src/io/ipc/write/common_async.rs new file mode 100644 index 000000000000..397391cd24ee --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/common_async.rs @@ -0,0 +1,66 @@ +use futures::{AsyncWrite, AsyncWriteExt}; + +use super::super::CONTINUATION_MARKER; +use super::common::{pad_to_64, EncodedData}; +use crate::error::Result; + +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written +pub async fn write_message( + mut writer: W, + encoded: EncodedData, +) -> Result<(usize, usize)> { + let arrow_data_len = encoded.arrow_data.len(); + + let a = 64 - 1; + let buffer = encoded.ipc_message; + let flatbuf_size = buffer.len(); + let prefix_size = 8; // the message length + let aligned_size = (flatbuf_size + prefix_size + a) & !a; + let padding_bytes = aligned_size - flatbuf_size - prefix_size; + + write_continuation(&mut writer, (aligned_size - prefix_size) as i32).await?; + + // write the flatbuf + if flatbuf_size > 0 { + writer.write_all(&buffer).await?; + } + // write padding + writer.write_all(&vec![0; padding_bytes]).await?; + + // write arrow data + let body_len = if arrow_data_len > 0 { + write_body_buffers(writer, &encoded.arrow_data).await? + } else { + 0 + }; + + Ok((aligned_size, body_len)) +} + +/// Write a record batch to the writer, writing the message size before the message +/// if the record batch is being written to a stream +pub async fn write_continuation( + mut writer: W, + total_len: i32, +) -> Result { + writer.write_all(&CONTINUATION_MARKER).await?; + writer.write_all(&total_len.to_le_bytes()[..]).await?; + Ok(8) +} + +async fn write_body_buffers( + mut writer: W, + data: &[u8], +) -> Result { + let len = data.len(); + let pad_len = pad_to_64(data.len()); + let total_len = len + pad_len; + + // write body buffer + writer.write_all(data).await?; + if pad_len > 0 { + writer.write_all(&vec![0u8; pad_len][..]).await?; + } + + Ok(total_len) +} diff --git a/crates/nano-arrow/src/io/ipc/write/common_sync.rs b/crates/nano-arrow/src/io/ipc/write/common_sync.rs new file mode 100644 index 000000000000..b20196419b2c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/common_sync.rs @@ -0,0 +1,59 @@ +use std::io::Write; + +use super::super::CONTINUATION_MARKER; +use super::common::{pad_to_64, EncodedData}; +use crate::error::Result; + +/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written +pub fn write_message(writer: &mut W, encoded: &EncodedData) -> Result<(usize, usize)> { + let arrow_data_len = encoded.arrow_data.len(); + + let a = 8 - 1; + let buffer = &encoded.ipc_message; + let flatbuf_size = buffer.len(); + let prefix_size = 8; + let aligned_size = (flatbuf_size + prefix_size + a) & !a; + let padding_bytes = aligned_size - flatbuf_size - prefix_size; + + write_continuation(writer, (aligned_size - prefix_size) as i32)?; + + // write the flatbuf + if flatbuf_size > 0 { + writer.write_all(buffer)?; + } + // write padding + // aligned to a 8 byte boundary, so maximum is [u8;8] + const PADDING_MAX: [u8; 8] = [0u8; 8]; + writer.write_all(&PADDING_MAX[..padding_bytes])?; + + // write arrow data + let body_len = if arrow_data_len > 0 { + write_body_buffers(writer, &encoded.arrow_data)? + } else { + 0 + }; + + Ok((aligned_size, body_len)) +} + +fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { + let len = data.len(); + let pad_len = pad_to_64(data.len()); + let total_len = len + pad_len; + + // write body buffer + writer.write_all(data)?; + if pad_len > 0 { + writer.write_all(&vec![0u8; pad_len][..])?; + } + + Ok(total_len) +} + +/// Write a record batch to the writer, writing the message size before the message +/// if the record batch is being written to a stream +pub fn write_continuation(writer: &mut W, total_len: i32) -> Result { + writer.write_all(&CONTINUATION_MARKER)?; + writer.write_all(&total_len.to_le_bytes()[..])?; + Ok(8) +} diff --git a/crates/nano-arrow/src/io/ipc/write/file_async.rs b/crates/nano-arrow/src/io/ipc/write/file_async.rs new file mode 100644 index 000000000000..93a1715282e2 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/file_async.rs @@ -0,0 +1,252 @@ +//! Async writer for IPC files. + +use std::task::Poll; + +use arrow_format::ipc::planus::Builder; +use arrow_format::ipc::{Block, Footer, MetadataVersion}; +use futures::future::BoxFuture; +use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink}; + +use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; +use super::common_async::{write_continuation, write_message}; +use super::schema::serialize_schema; +use super::{default_ipc_fields, schema_to_bytes, Record}; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::io::ipc::{IpcField, ARROW_MAGIC_V2}; + +type WriteOutput = (usize, Option, Vec, Option); + +/// Sink that writes array [`chunks`](crate::chunk::Chunk) as an IPC file. +/// +/// The file header is automatically written before writing the first chunk, and the file footer is +/// automatically written when the sink is closed. +/// +/// # Examples +/// +/// ``` +/// use futures::{SinkExt, TryStreamExt, io::Cursor}; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::ipc::write::file_async::FileSink; +/// use arrow2::io::ipc::read::file_async::{read_file_metadata_async, FileStream}; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = Cursor::new(vec![]); +/// let mut sink = FileSink::new( +/// &mut buffer, +/// schema, +/// None, +/// Default::default(), +/// ); +/// +/// // Write chunks to file +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// drop(sink); +/// +/// // Read chunks from file +/// buffer.set_position(0); +/// let metadata = read_file_metadata_async(&mut buffer).await?; +/// let mut stream = FileStream::new(buffer, metadata, None, None); +/// let chunks = stream.try_collect::>().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + writer: Option, + task: Option>>>, + options: WriteOptions, + dictionary_tracker: DictionaryTracker, + offset: usize, + fields: Vec, + record_blocks: Vec, + dictionary_blocks: Vec, + schema: Schema, +} + +impl<'a, W> FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new file writer. + pub fn new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let encoded = EncodedData { + ipc_message: schema_to_bytes(&schema, &fields), + arrow_data: vec![], + }; + let task = Some(Self::start(writer, encoded).boxed()); + Self { + writer: None, + task, + options, + fields, + offset: 0, + schema, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: true, + }, + record_blocks: vec![], + dictionary_blocks: vec![], + } + } + + async fn start(mut writer: W, encoded: EncodedData) -> Result> { + writer.write_all(&ARROW_MAGIC_V2[..]).await?; + writer.write_all(&[0, 0]).await?; + let (meta, data) = write_message(&mut writer, encoded).await?; + + Ok((meta + data + 8, None, vec![], Some(writer))) + } + + async fn write( + mut writer: W, + mut offset: usize, + record: EncodedData, + dictionaries: Vec, + ) -> Result> { + let mut dict_blocks = vec![]; + for dict in dictionaries { + let (meta, data) = write_message(&mut writer, dict).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + dict_blocks.push(block); + offset += meta + data; + } + let (meta, data) = write_message(&mut writer, record).await?; + let block = Block { + offset: offset as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + offset += meta + data; + Ok((offset, Some(block), dict_blocks, Some(writer))) + } + + async fn finish(mut writer: W, footer: Footer) -> Result> { + write_continuation(&mut writer, 0).await?; + let footer = { + let mut builder = Builder::new(); + builder.finish(&footer, None).to_owned() + }; + writer.write_all(&footer[..]).await?; + writer + .write_all(&(footer.len() as i32).to_le_bytes()) + .await?; + writer.write_all(&ARROW_MAGIC_V2).await?; + writer.close().await?; + + Ok((0, None, vec![], None)) + } + + fn poll_write(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok((offset, record, mut dictionaries, writer)) => { + self.task = None; + self.writer = writer; + self.offset = offset; + if let Some(block) = record { + self.record_blocks.push(block); + } + self.dictionary_blocks.append(&mut dictionaries); + Poll::Ready(Ok(())) + }, + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + }, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink> for FileSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + type Error = Error; + + fn poll_ready( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn start_send(self: std::pin::Pin<&mut Self>, item: Record<'_>) -> Result<()> { + let this = self.get_mut(); + + if let Some(writer) = this.writer.take() { + let fields = item.fields().unwrap_or_else(|| &this.fields[..]); + + let (dictionaries, record) = encode_chunk( + item.columns(), + fields, + &mut this.dictionary_tracker, + &this.options, + )?; + + this.task = Some(Self::write(writer, this.offset, record, dictionaries).boxed()); + Ok(()) + } else { + Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer is closed", + ))) + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_write(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match futures::ready!(this.poll_write(cx)) { + Ok(()) => { + if let Some(writer) = this.writer.take() { + let schema = serialize_schema(&this.schema, &this.fields); + let footer = Footer { + version: MetadataVersion::V5, + schema: Some(Box::new(schema)), + dictionaries: Some(std::mem::take(&mut this.dictionary_blocks)), + record_batches: Some(std::mem::take(&mut this.record_blocks)), + custom_metadata: None, + }; + this.task = Some(Self::finish(writer, footer).boxed()); + this.poll_write(cx) + } else { + Poll::Ready(Ok(())) + } + }, + Err(error) => Poll::Ready(Err(error)), + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/mod.rs b/crates/nano-arrow/src/io/ipc/write/mod.rs new file mode 100644 index 000000000000..55672a85da3c --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/mod.rs @@ -0,0 +1,70 @@ +//! APIs to write to Arrow's IPC format. +pub(crate) mod common; +mod schema; +mod serialize; +mod stream; +pub(crate) mod writer; + +pub use common::{Compression, Record, WriteOptions}; +pub use schema::schema_to_bytes; +pub use serialize::write; +use serialize::write_dictionary; +pub use stream::StreamWriter; +pub use writer::FileWriter; + +pub(crate) mod common_sync; + +#[cfg(feature = "io_ipc_write_async")] +mod common_async; +#[cfg(feature = "io_ipc_write_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] +pub mod stream_async; + +#[cfg(feature = "io_ipc_write_async")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc_write_async")))] +pub mod file_async; + +use super::IpcField; +use crate::datatypes::{DataType, Field}; + +fn default_ipc_field(data_type: &DataType, current_id: &mut i64) -> IpcField { + use crate::datatypes::DataType::*; + match data_type.to_logical_type() { + // single child => recurse + Map(inner, ..) | FixedSizeList(inner, _) | LargeList(inner) | List(inner) => IpcField { + fields: vec![default_ipc_field(inner.data_type(), current_id)], + dictionary_id: None, + }, + // multiple children => recurse + Union(fields, ..) | Struct(fields) => IpcField { + fields: fields + .iter() + .map(|f| default_ipc_field(f.data_type(), current_id)) + .collect(), + dictionary_id: None, + }, + // dictionary => current_id + Dictionary(_, data_type, _) => { + let dictionary_id = Some(*current_id); + *current_id += 1; + IpcField { + fields: vec![default_ipc_field(data_type, current_id)], + dictionary_id, + } + }, + // no children => do nothing + _ => IpcField { + fields: vec![], + dictionary_id: None, + }, + } +} + +/// Assigns every dictionary field a unique ID +pub fn default_ipc_fields(fields: &[Field]) -> Vec { + let mut dictionary_id = 0i64; + fields + .iter() + .map(|field| default_ipc_field(field.data_type().to_logical_type(), &mut dictionary_id)) + .collect() +} diff --git a/crates/nano-arrow/src/io/ipc/write/schema.rs b/crates/nano-arrow/src/io/ipc/write/schema.rs new file mode 100644 index 000000000000..dd6f44bbd33a --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/schema.rs @@ -0,0 +1,333 @@ +use arrow_format::ipc::planus::Builder; + +use super::super::IpcField; +use crate::datatypes::{ + DataType, Field, IntegerType, IntervalUnit, Metadata, Schema, TimeUnit, UnionMode, +}; +use crate::io::ipc::endianness::is_native_little_endian; + +/// Converts a [Schema] and [IpcField]s to a flatbuffers-encoded [arrow_format::ipc::Message]. +pub fn schema_to_bytes(schema: &Schema, ipc_fields: &[IpcField]) -> Vec { + let schema = serialize_schema(schema, ipc_fields); + + let message = arrow_format::ipc::Message { + version: arrow_format::ipc::MetadataVersion::V5, + header: Some(arrow_format::ipc::MessageHeader::Schema(Box::new(schema))), + body_length: 0, + custom_metadata: None, // todo: allow writing custom metadata + }; + let mut builder = Builder::new(); + let footer_data = builder.finish(&message, None); + footer_data.to_vec() +} + +pub fn serialize_schema(schema: &Schema, ipc_fields: &[IpcField]) -> arrow_format::ipc::Schema { + let endianness = if is_native_little_endian() { + arrow_format::ipc::Endianness::Little + } else { + arrow_format::ipc::Endianness::Big + }; + + let fields = schema + .fields + .iter() + .zip(ipc_fields.iter()) + .map(|(field, ipc_field)| serialize_field(field, ipc_field)) + .collect::>(); + + let mut custom_metadata = vec![]; + for (key, value) in &schema.metadata { + custom_metadata.push(arrow_format::ipc::KeyValue { + key: Some(key.clone()), + value: Some(value.clone()), + }); + } + let custom_metadata = if custom_metadata.is_empty() { + None + } else { + Some(custom_metadata) + }; + + arrow_format::ipc::Schema { + endianness, + fields: Some(fields), + custom_metadata, + features: None, // todo add this one + } +} + +fn write_metadata(metadata: &Metadata, kv_vec: &mut Vec) { + for (k, v) in metadata { + if k != "ARROW:extension:name" && k != "ARROW:extension:metadata" { + let entry = arrow_format::ipc::KeyValue { + key: Some(k.clone()), + value: Some(v.clone()), + }; + kv_vec.push(entry); + } + } +} + +fn write_extension( + name: &str, + metadata: &Option, + kv_vec: &mut Vec, +) { + // metadata + if let Some(metadata) = metadata { + let entry = arrow_format::ipc::KeyValue { + key: Some("ARROW:extension:metadata".to_string()), + value: Some(metadata.clone()), + }; + kv_vec.push(entry); + } + + // name + let entry = arrow_format::ipc::KeyValue { + key: Some("ARROW:extension:name".to_string()), + value: Some(name.to_string()), + }; + kv_vec.push(entry); +} + +/// Create an IPC Field from an Arrow Field +pub(crate) fn serialize_field(field: &Field, ipc_field: &IpcField) -> arrow_format::ipc::Field { + // custom metadata. + let mut kv_vec = vec![]; + if let DataType::Extension(name, _, metadata) = field.data_type() { + write_extension(name, metadata, &mut kv_vec); + } + + let type_ = serialize_type(field.data_type()); + let children = serialize_children(field.data_type(), ipc_field); + + let dictionary = if let DataType::Dictionary(index_type, inner, is_ordered) = field.data_type() + { + if let DataType::Extension(name, _, metadata) = inner.as_ref() { + write_extension(name, metadata, &mut kv_vec); + } + Some(serialize_dictionary( + index_type, + ipc_field + .dictionary_id + .expect("All Dictionary types have `dict_id`"), + *is_ordered, + )) + } else { + None + }; + + write_metadata(&field.metadata, &mut kv_vec); + + let custom_metadata = if !kv_vec.is_empty() { + Some(kv_vec) + } else { + None + }; + + arrow_format::ipc::Field { + name: Some(field.name.clone()), + nullable: field.is_nullable, + type_: Some(type_), + dictionary: dictionary.map(Box::new), + children: Some(children), + custom_metadata, + } +} + +fn serialize_time_unit(unit: &TimeUnit) -> arrow_format::ipc::TimeUnit { + match unit { + TimeUnit::Second => arrow_format::ipc::TimeUnit::Second, + TimeUnit::Millisecond => arrow_format::ipc::TimeUnit::Millisecond, + TimeUnit::Microsecond => arrow_format::ipc::TimeUnit::Microsecond, + TimeUnit::Nanosecond => arrow_format::ipc::TimeUnit::Nanosecond, + } +} + +fn serialize_type(data_type: &DataType) -> arrow_format::ipc::Type { + use arrow_format::ipc; + use DataType::*; + match data_type { + Null => ipc::Type::Null(Box::new(ipc::Null {})), + Boolean => ipc::Type::Bool(Box::new(ipc::Bool {})), + UInt8 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 8, + is_signed: false, + })), + UInt16 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 16, + is_signed: false, + })), + UInt32 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 32, + is_signed: false, + })), + UInt64 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 64, + is_signed: false, + })), + Int8 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 8, + is_signed: true, + })), + Int16 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 16, + is_signed: true, + })), + Int32 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 32, + is_signed: true, + })), + Int64 => ipc::Type::Int(Box::new(ipc::Int { + bit_width: 64, + is_signed: true, + })), + Float16 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Half, + })), + Float32 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Single, + })), + Float64 => ipc::Type::FloatingPoint(Box::new(ipc::FloatingPoint { + precision: ipc::Precision::Double, + })), + Decimal(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal { + precision: *precision as i32, + scale: *scale as i32, + bit_width: 128, + })), + Decimal256(precision, scale) => ipc::Type::Decimal(Box::new(ipc::Decimal { + precision: *precision as i32, + scale: *scale as i32, + bit_width: 256, + })), + Binary => ipc::Type::Binary(Box::new(ipc::Binary {})), + LargeBinary => ipc::Type::LargeBinary(Box::new(ipc::LargeBinary {})), + Utf8 => ipc::Type::Utf8(Box::new(ipc::Utf8 {})), + LargeUtf8 => ipc::Type::LargeUtf8(Box::new(ipc::LargeUtf8 {})), + FixedSizeBinary(size) => ipc::Type::FixedSizeBinary(Box::new(ipc::FixedSizeBinary { + byte_width: *size as i32, + })), + Date32 => ipc::Type::Date(Box::new(ipc::Date { + unit: ipc::DateUnit::Day, + })), + Date64 => ipc::Type::Date(Box::new(ipc::Date { + unit: ipc::DateUnit::Millisecond, + })), + Duration(unit) => ipc::Type::Duration(Box::new(ipc::Duration { + unit: serialize_time_unit(unit), + })), + Time32(unit) => ipc::Type::Time(Box::new(ipc::Time { + unit: serialize_time_unit(unit), + bit_width: 32, + })), + Time64(unit) => ipc::Type::Time(Box::new(ipc::Time { + unit: serialize_time_unit(unit), + bit_width: 64, + })), + Timestamp(unit, tz) => ipc::Type::Timestamp(Box::new(ipc::Timestamp { + unit: serialize_time_unit(unit), + timezone: tz.as_ref().cloned(), + })), + Interval(unit) => ipc::Type::Interval(Box::new(ipc::Interval { + unit: match unit { + IntervalUnit::YearMonth => ipc::IntervalUnit::YearMonth, + IntervalUnit::DayTime => ipc::IntervalUnit::DayTime, + IntervalUnit::MonthDayNano => ipc::IntervalUnit::MonthDayNano, + }, + })), + List(_) => ipc::Type::List(Box::new(ipc::List {})), + LargeList(_) => ipc::Type::LargeList(Box::new(ipc::LargeList {})), + FixedSizeList(_, size) => ipc::Type::FixedSizeList(Box::new(ipc::FixedSizeList { + list_size: *size as i32, + })), + Union(_, type_ids, mode) => ipc::Type::Union(Box::new(ipc::Union { + mode: match mode { + UnionMode::Dense => ipc::UnionMode::Dense, + UnionMode::Sparse => ipc::UnionMode::Sparse, + }, + type_ids: type_ids.clone(), + })), + Map(_, keys_sorted) => ipc::Type::Map(Box::new(ipc::Map { + keys_sorted: *keys_sorted, + })), + Struct(_) => ipc::Type::Struct(Box::new(ipc::Struct {})), + Dictionary(_, v, _) => serialize_type(v), + Extension(_, v, _) => serialize_type(v), + } +} + +fn serialize_children(data_type: &DataType, ipc_field: &IpcField) -> Vec { + use DataType::*; + match data_type { + Null + | Boolean + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Timestamp(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Duration(_) + | Interval(_) + | Binary + | FixedSizeBinary(_) + | LargeBinary + | Utf8 + | LargeUtf8 + | Decimal(_, _) + | Decimal256(_, _) => vec![], + FixedSizeList(inner, _) | LargeList(inner) | List(inner) | Map(inner, _) => { + vec![serialize_field(inner, &ipc_field.fields[0])] + }, + Union(fields, _, _) | Struct(fields) => fields + .iter() + .zip(ipc_field.fields.iter()) + .map(|(field, ipc)| serialize_field(field, ipc)) + .collect(), + Dictionary(_, inner, _) => serialize_children(inner, ipc_field), + Extension(_, inner, _) => serialize_children(inner, ipc_field), + } +} + +/// Create an IPC dictionary encoding +pub(crate) fn serialize_dictionary( + index_type: &IntegerType, + dict_id: i64, + dict_is_ordered: bool, +) -> arrow_format::ipc::DictionaryEncoding { + use IntegerType::*; + let is_signed = match index_type { + Int8 | Int16 | Int32 | Int64 => true, + UInt8 | UInt16 | UInt32 | UInt64 => false, + }; + + let bit_width = match index_type { + Int8 | UInt8 => 8, + Int16 | UInt16 => 16, + Int32 | UInt32 => 32, + Int64 | UInt64 => 64, + }; + + let index_type = arrow_format::ipc::Int { + bit_width, + is_signed, + }; + + arrow_format::ipc::DictionaryEncoding { + id: dict_id, + index_type: Some(Box::new(index_type)), + is_ordered: dict_is_ordered, + dictionary_kind: arrow_format::ipc::DictionaryKind::DenseArray, + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/serialize.rs b/crates/nano-arrow/src/io/ipc/write/serialize.rs new file mode 100644 index 000000000000..f5bad22d6fe4 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/serialize.rs @@ -0,0 +1,763 @@ +#![allow(clippy::ptr_arg)] // false positive in clippy, see https://github.com/rust-lang/rust-clippy/issues/8463 +use arrow_format::ipc; + +use super::super::compression; +use super::super::endianness::is_native_little_endian; +use super::common::{pad_to_64, Compression}; +use crate::array::*; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; +use crate::offset::{Offset, OffsetsBuffer}; +use crate::trusted_len::TrustedLen; +use crate::types::NativeType; + +fn write_primitive( + array: &PrimitiveArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + + write_buffer( + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ) +} + +fn write_boolean( + array: &BooleanArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + _: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write_bitmap( + Some(&array.values().clone()), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); +} + +#[allow(clippy::too_many_arguments)] +fn write_generic_binary( + validity: Option<&Bitmap>, + offsets: &OffsetsBuffer, + values: &[u8], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = offsets.buffer(); + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == O::default() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write_bytes( + &values[first.to_usize()..last.to_usize()], + buffers, + arrow_data, + offset, + compression, + ); +} + +fn write_binary( + array: &BinaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_generic_binary( + array.validity(), + array.offsets(), + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); +} + +fn write_utf8( + array: &Utf8Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_generic_binary( + array.validity(), + array.offsets(), + array.values(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); +} + +fn write_fixed_size_binary( + array: &FixedSizeBinaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + _is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write_bytes(array.values(), buffers, arrow_data, offset, compression); +} + +fn write_list( + array: &ListArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = array.offsets().buffer(); + let validity = array.validity(); + + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == O::zero() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write( + array + .values() + .sliced(first.to_usize(), last.to_usize() - first.to_usize()) + .as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} + +pub fn write_struct( + array: &StructArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + array.values().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }); +} + +pub fn write_union( + array: &UnionArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_buffer( + array.types(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + + if let Some(offsets) = array.offsets() { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + array.fields().iter().for_each(|array| { + write( + array.as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ) + }); +} + +fn write_map( + array: &MapArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let offsets = array.offsets().buffer(); + let validity = array.validity(); + + write_bitmap( + validity, + offsets.len() - 1, + buffers, + arrow_data, + offset, + compression, + ); + + let first = *offsets.first().unwrap(); + let last = *offsets.last().unwrap(); + if first == 0 { + write_buffer( + offsets, + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } else { + write_buffer_from_iter( + offsets.iter().map(|x| *x - first), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + } + + write( + array + .field() + .sliced(first as usize, last as usize - first as usize) + .as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} + +fn write_fixed_size_list( + array: &FixedSizeListArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + write_bitmap( + array.validity(), + array.len(), + buffers, + arrow_data, + offset, + compression, + ); + write( + array.values().as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); +} + +// use `write_keys` to either write keys or values +#[allow(clippy::too_many_arguments)] +pub(super) fn write_dictionary( + array: &DictionaryArray, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, + write_keys: bool, +) -> usize { + if write_keys { + write_primitive( + array.keys(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ); + array.keys().len() + } else { + write( + array.values().as_ref(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + array.values().len() + } +} + +/// Writes an [`Array`] to `arrow_data` +pub fn write( + array: &dyn Array, + buffers: &mut Vec, + arrow_data: &mut Vec, + nodes: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + nodes.push(ipc::FieldNode { + length: array.len() as i64, + null_count: array.null_count() as i64, + }); + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => (), + Boolean => write_boolean( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array.as_any().downcast_ref().unwrap(); + write_primitive::<$T>(array, buffers, arrow_data, offset, is_little_endian, compression) + }), + Binary => write_binary::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + LargeBinary => write_binary::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + FixedSizeBinary => write_fixed_size_binary( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + Utf8 => write_utf8::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + LargeUtf8 => write_utf8::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + offset, + is_little_endian, + compression, + ), + List => write_list::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + LargeList => write_list::( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + FixedSizeList => write_fixed_size_list( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + Struct => write_struct( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + write_dictionary::<$T>( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + true, + ); + }), + Union => { + write_union( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }, + Map => { + write_map( + array.as_any().downcast_ref().unwrap(), + buffers, + arrow_data, + nodes, + offset, + is_little_endian, + compression, + ); + }, + } +} + +#[inline] +fn pad_buffer_to_64(buffer: &mut Vec, length: usize) { + let pad_len = pad_to_64(length); + buffer.extend_from_slice(&vec![0u8; pad_len]); +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +fn write_bytes( + bytes: &[u8], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + compression: Option, +) { + let start = arrow_data.len(); + if let Some(compression) = compression { + arrow_data.extend_from_slice(&(bytes.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(bytes, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(bytes, arrow_data).unwrap(); + }, + } + } else { + arrow_data.extend_from_slice(bytes); + }; + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +fn write_bitmap( + bitmap: Option<&Bitmap>, + length: usize, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + compression: Option, +) { + match bitmap { + Some(bitmap) => { + assert_eq!(bitmap.len(), length); + let (slice, slice_offset, _) = bitmap.as_slice(); + if slice_offset != 0 { + // case where we can't slice the bitmap as the offsets are not multiple of 8 + let bytes = Bitmap::from_trusted_len_iter(bitmap.iter()); + let (slice, _, _) = bytes.as_slice(); + write_bytes(slice, buffers, arrow_data, offset, compression) + } else { + write_bytes(slice, buffers, arrow_data, offset, compression) + } + }, + None => { + buffers.push(ipc::Buffer { + offset: *offset, + length: 0, + }); + }, + } +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +fn write_buffer( + buffer: &[T], + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let start = arrow_data.len(); + if let Some(compression) = compression { + _write_compressed_buffer(buffer, arrow_data, is_little_endian, compression); + } else { + _write_buffer(buffer, arrow_data, is_little_endian); + }; + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +#[inline] +fn _write_buffer_from_iter>( + buffer: I, + arrow_data: &mut Vec, + is_little_endian: bool, +) { + let len = buffer.size_hint().0; + arrow_data.reserve(len * std::mem::size_of::()); + if is_little_endian { + buffer + .map(|x| T::to_le_bytes(&x)) + .for_each(|x| arrow_data.extend_from_slice(x.as_ref())) + } else { + buffer + .map(|x| T::to_be_bytes(&x)) + .for_each(|x| arrow_data.extend_from_slice(x.as_ref())) + } +} + +#[inline] +fn _write_compressed_buffer_from_iter>( + buffer: I, + arrow_data: &mut Vec, + is_little_endian: bool, + compression: Compression, +) { + let len = buffer.size_hint().0; + let mut swapped = Vec::with_capacity(len * std::mem::size_of::()); + if is_little_endian { + buffer + .map(|x| T::to_le_bytes(&x)) + .for_each(|x| swapped.extend_from_slice(x.as_ref())); + } else { + buffer + .map(|x| T::to_be_bytes(&x)) + .for_each(|x| swapped.extend_from_slice(x.as_ref())) + }; + arrow_data.extend_from_slice(&(swapped.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(&swapped, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(&swapped, arrow_data).unwrap(); + }, + } +} + +fn _write_buffer(buffer: &[T], arrow_data: &mut Vec, is_little_endian: bool) { + if is_little_endian == is_native_little_endian() { + // in native endianness we can use the bytes directly. + let buffer = bytemuck::cast_slice(buffer); + arrow_data.extend_from_slice(buffer); + } else { + _write_buffer_from_iter(buffer.iter().copied(), arrow_data, is_little_endian) + } +} + +fn _write_compressed_buffer( + buffer: &[T], + arrow_data: &mut Vec, + is_little_endian: bool, + compression: Compression, +) { + if is_little_endian == is_native_little_endian() { + let bytes = bytemuck::cast_slice(buffer); + arrow_data.extend_from_slice(&(bytes.len() as i64).to_le_bytes()); + match compression { + Compression::LZ4 => { + compression::compress_lz4(bytes, arrow_data).unwrap(); + }, + Compression::ZSTD => { + compression::compress_zstd(bytes, arrow_data).unwrap(); + }, + } + } else { + todo!() + } +} + +/// writes `bytes` to `arrow_data` updating `buffers` and `offset` and guaranteeing a 8 byte boundary. +#[inline] +fn write_buffer_from_iter>( + buffer: I, + buffers: &mut Vec, + arrow_data: &mut Vec, + offset: &mut i64, + is_little_endian: bool, + compression: Option, +) { + let start = arrow_data.len(); + + if let Some(compression) = compression { + _write_compressed_buffer_from_iter(buffer, arrow_data, is_little_endian, compression); + } else { + _write_buffer_from_iter(buffer, arrow_data, is_little_endian); + } + + buffers.push(finish_buffer(arrow_data, start, offset)); +} + +fn finish_buffer(arrow_data: &mut Vec, start: usize, offset: &mut i64) -> ipc::Buffer { + let buffer_len = (arrow_data.len() - start) as i64; + + pad_buffer_to_64(arrow_data, arrow_data.len() - start); + let total_len = (arrow_data.len() - start) as i64; + + let buffer = ipc::Buffer { + offset: *offset, + length: buffer_len, + }; + *offset += total_len; + buffer +} diff --git a/crates/nano-arrow/src/io/ipc/write/stream.rs b/crates/nano-arrow/src/io/ipc/write/stream.rs new file mode 100644 index 000000000000..3fe7e143e02d --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/stream.rs @@ -0,0 +1,113 @@ +//! Arrow IPC File and Stream Writers +//! +//! The `FileWriter` and `StreamWriter` have similar interfaces, +//! however the `FileWriter` expects a reader that supports `Seek`ing + +use std::io::Write; + +use super::super::IpcField; +use super::common::{encode_chunk, DictionaryTracker, EncodedData, WriteOptions}; +use super::common_sync::{write_continuation, write_message}; +use super::{default_ipc_fields, schema_to_bytes}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; + +/// Arrow stream writer +/// +/// The data written by this writer must be read in order. To signal that no more +/// data is arriving through the stream call [`self.finish()`](StreamWriter::finish); +/// +/// For a usage walkthrough consult [this example](https://github.com/jorgecarleitao/arrow2/tree/main/examples/ipc_pyarrow). +pub struct StreamWriter { + /// The object to write to + writer: W, + /// IPC write options + write_options: WriteOptions, + /// Whether the stream has been finished + finished: bool, + /// Keeps track of dictionaries that have been written + dictionary_tracker: DictionaryTracker, + + ipc_fields: Option>, +} + +impl StreamWriter { + /// Creates a new [`StreamWriter`] + pub fn new(writer: W, write_options: WriteOptions) -> Self { + Self { + writer, + write_options, + finished: false, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: false, + }, + ipc_fields: None, + } + } + + /// Starts the stream by writing a Schema message to it. + /// Use `ipc_fields` to declare dictionary ids in the schema, for dictionary-reuse + pub fn start(&mut self, schema: &Schema, ipc_fields: Option>) -> Result<()> { + self.ipc_fields = Some(if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + default_ipc_fields(&schema.fields) + }); + + let encoded_message = EncodedData { + ipc_message: schema_to_bytes(schema, self.ipc_fields.as_ref().unwrap()), + arrow_data: vec![], + }; + write_message(&mut self.writer, &encoded_message)?; + Ok(()) + } + + /// Writes [`Chunk`] to the stream + pub fn write( + &mut self, + columns: &Chunk>, + ipc_fields: Option<&[IpcField]>, + ) -> Result<()> { + if self.finished { + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Cannot write to a finished stream".to_string(), + ))); + } + + // we can't make it a closure because it borrows (and it can't borrow mut and non-mut below) + #[allow(clippy::or_fun_call)] + let fields = ipc_fields.unwrap_or(self.ipc_fields.as_ref().unwrap()); + + let (encoded_dictionaries, encoded_message) = encode_chunk( + columns, + fields, + &mut self.dictionary_tracker, + &self.write_options, + )?; + + for encoded_dictionary in encoded_dictionaries { + write_message(&mut self.writer, &encoded_dictionary)?; + } + + write_message(&mut self.writer, &encoded_message)?; + Ok(()) + } + + /// Write continuation bytes, and mark the stream as done + pub fn finish(&mut self) -> Result<()> { + write_continuation(&mut self.writer, 0)?; + + self.finished = true; + + Ok(()) + } + + /// Consumes itself, returning the inner writer. + pub fn into_inner(self) -> W { + self.writer + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/stream_async.rs b/crates/nano-arrow/src/io/ipc/write/stream_async.rs new file mode 100644 index 000000000000..7af62682935a --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/stream_async.rs @@ -0,0 +1,188 @@ +//! `async` writing of arrow streams + +use std::pin::Pin; +use std::task::Poll; + +use futures::future::BoxFuture; +use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink}; + +use super::super::IpcField; +pub use super::common::WriteOptions; +use super::common::{encode_chunk, DictionaryTracker, EncodedData}; +use super::common_async::{write_continuation, write_message}; +use super::{default_ipc_fields, schema_to_bytes, Record}; +use crate::datatypes::*; +use crate::error::{Error, Result}; + +/// A sink that writes array [`chunks`](crate::chunk::Chunk) as an IPC stream. +/// +/// The stream header is automatically written before writing the first chunk. +/// +/// # Examples +/// +/// ``` +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// # use arrow2::io::ipc::write::stream_async::StreamSink; +/// # futures::executor::block_on(async move { +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// +/// let mut buffer = vec![]; +/// let mut sink = StreamSink::new( +/// &mut buffer, +/// &schema, +/// None, +/// Default::default(), +/// ); +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk.into()).await?; +/// } +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct StreamSink<'a, W: AsyncWrite + Unpin + Send + 'a> { + writer: Option, + task: Option>>>, + options: WriteOptions, + dictionary_tracker: DictionaryTracker, + fields: Vec, +} + +impl<'a, W> StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send + 'a, +{ + /// Create a new [`StreamSink`]. + pub fn new( + writer: W, + schema: &Schema, + ipc_fields: Option>, + write_options: WriteOptions, + ) -> Self { + let fields = ipc_fields.unwrap_or_else(|| default_ipc_fields(&schema.fields)); + let task = Some(Self::start(writer, schema, &fields[..])); + Self { + writer: None, + task, + fields, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: false, + }, + options: write_options, + } + } + + fn start( + mut writer: W, + schema: &Schema, + ipc_fields: &[IpcField], + ) -> BoxFuture<'a, Result>> { + let message = EncodedData { + ipc_message: schema_to_bytes(schema, ipc_fields), + arrow_data: vec![], + }; + async move { + write_message(&mut writer, message).await?; + Ok(Some(writer)) + } + .boxed() + } + + fn write(&mut self, record: Record<'_>) -> Result<()> { + let fields = record.fields().unwrap_or(&self.fields[..]); + let (dictionaries, message) = encode_chunk( + record.columns(), + fields, + &mut self.dictionary_tracker, + &self.options, + )?; + + if let Some(mut writer) = self.writer.take() { + self.task = Some( + async move { + for d in dictionaries { + write_message(&mut writer, d).await?; + } + write_message(&mut writer, message).await?; + Ok(Some(writer)) + } + .boxed(), + ); + Ok(()) + } else { + Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer closed".to_string(), + ))) + } + } + + fn poll_complete(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok(writer) => { + self.writer = writer; + self.task = None; + Poll::Ready(Ok(())) + }, + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + }, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink> for StreamSink<'a, W> +where + W: AsyncWrite + Unpin + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) + } + + fn start_send(self: Pin<&mut Self>, item: Record<'_>) -> Result<()> { + self.get_mut().write(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let this = self.get_mut(); + match this.poll_complete(cx) { + Poll::Ready(Ok(())) => { + if let Some(mut writer) = this.writer.take() { + this.task = Some( + async move { + write_continuation(&mut writer, 0).await?; + writer.flush().await?; + writer.close().await?; + Ok(None) + } + .boxed(), + ); + this.poll_complete(cx) + } else { + Poll::Ready(Ok(())) + } + }, + res => res, + } + } +} diff --git a/crates/nano-arrow/src/io/ipc/write/writer.rs b/crates/nano-arrow/src/io/ipc/write/writer.rs new file mode 100644 index 000000000000..8fcdd2a8bd66 --- /dev/null +++ b/crates/nano-arrow/src/io/ipc/write/writer.rs @@ -0,0 +1,210 @@ +use std::io::Write; + +use arrow_format::ipc::planus::Builder; + +use super::super::{IpcField, ARROW_MAGIC_V2}; +use super::common::{DictionaryTracker, EncodedData, WriteOptions}; +use super::common_sync::{write_continuation, write_message}; +use super::{default_ipc_fields, schema, schema_to_bytes}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::io::ipc::write::common::encode_chunk_amortized; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) enum State { + None, + Started, + Finished, +} + +/// Arrow file writer +pub struct FileWriter { + /// The object to write to + pub(crate) writer: W, + /// IPC write options + pub(crate) options: WriteOptions, + /// A reference to the schema, used in validating record batches + pub(crate) schema: Schema, + pub(crate) ipc_fields: Vec, + /// The number of bytes between each block of bytes, as an offset for random access + pub(crate) block_offsets: usize, + /// Dictionary blocks that will be written as part of the IPC footer + pub(crate) dictionary_blocks: Vec, + /// Record blocks that will be written as part of the IPC footer + pub(crate) record_blocks: Vec, + /// Whether the writer footer has been written, and the writer is finished + pub(crate) state: State, + /// Keeps track of dictionaries that have been written + pub(crate) dictionary_tracker: DictionaryTracker, + /// Buffer/scratch that is reused between writes + pub(crate) encoded_message: EncodedData, +} + +impl FileWriter { + /// Creates a new [`FileWriter`] and writes the header to `writer` + pub fn try_new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Result { + let mut slf = Self::new(writer, schema, ipc_fields, options); + slf.start()?; + + Ok(slf) + } + + /// Creates a new [`FileWriter`]. + pub fn new( + writer: W, + schema: Schema, + ipc_fields: Option>, + options: WriteOptions, + ) -> Self { + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + default_ipc_fields(&schema.fields) + }; + + Self { + writer, + options, + schema, + ipc_fields, + block_offsets: 0, + dictionary_blocks: vec![], + record_blocks: vec![], + state: State::None, + dictionary_tracker: DictionaryTracker { + dictionaries: Default::default(), + cannot_replace: true, + }, + encoded_message: Default::default(), + } + } + + /// Consumes itself into the inner writer + pub fn into_inner(self) -> W { + self.writer + } + + /// Get the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn get_scratches(&mut self) -> EncodedData { + std::mem::take(&mut self.encoded_message) + } + /// Set the inner memory scratches so they can be reused in a new writer. + /// This can be utilized to save memory allocations for performance reasons. + pub fn set_scratches(&mut self, scratches: EncodedData) { + self.encoded_message = scratches; + } + + /// Writes the header and first (schema) message to the file. + /// # Errors + /// Errors if the file has been started or has finished. + pub fn start(&mut self) -> Result<()> { + if self.state != State::None { + return Err(Error::oos("The IPC file can only be started once")); + } + // write magic to header + self.writer.write_all(&ARROW_MAGIC_V2[..])?; + // create an 8-byte boundary after the header + self.writer.write_all(&[0, 0])?; + // write the schema, set the written bytes to the schema + + let encoded_message = EncodedData { + ipc_message: schema_to_bytes(&self.schema, &self.ipc_fields), + arrow_data: vec![], + }; + + let (meta, data) = write_message(&mut self.writer, &encoded_message)?; + self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment + self.state = State::Started; + Ok(()) + } + + /// Writes [`Chunk`] to the file + pub fn write( + &mut self, + chunk: &Chunk>, + ipc_fields: Option<&[IpcField]>, + ) -> Result<()> { + if self.state != State::Started { + return Err(Error::oos( + "The IPC file must be started before it can be written to. Call `start` before `write`", + )); + } + + let ipc_fields = if let Some(ipc_fields) = ipc_fields { + ipc_fields + } else { + self.ipc_fields.as_ref() + }; + let encoded_dictionaries = encode_chunk_amortized( + chunk, + ipc_fields, + &mut self.dictionary_tracker, + &self.options, + &mut self.encoded_message, + )?; + + // add all dictionaries + for encoded_dictionary in encoded_dictionaries { + let (meta, data) = write_message(&mut self.writer, &encoded_dictionary)?; + + let block = arrow_format::ipc::Block { + offset: self.block_offsets as i64, + meta_data_length: meta as i32, + body_length: data as i64, + }; + self.dictionary_blocks.push(block); + self.block_offsets += meta + data; + } + + let (meta, data) = write_message(&mut self.writer, &self.encoded_message)?; + // add a record block for the footer + let block = arrow_format::ipc::Block { + offset: self.block_offsets as i64, + meta_data_length: meta as i32, // TODO: is this still applicable? + body_length: data as i64, + }; + self.record_blocks.push(block); + self.block_offsets += meta + data; + Ok(()) + } + + /// Write footer and closing tag, then mark the writer as done + pub fn finish(&mut self) -> Result<()> { + if self.state != State::Started { + return Err(Error::oos( + "The IPC file must be started before it can be finished. Call `start` before `finish`", + )); + } + + // write EOS + write_continuation(&mut self.writer, 0)?; + + let schema = schema::serialize_schema(&self.schema, &self.ipc_fields); + + let root = arrow_format::ipc::Footer { + version: arrow_format::ipc::MetadataVersion::V5, + schema: Some(Box::new(schema)), + dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)), + record_batches: Some(std::mem::take(&mut self.record_blocks)), + custom_metadata: None, + }; + let mut builder = Builder::new(); + let footer_data = builder.finish(&root, None); + self.writer.write_all(footer_data)?; + self.writer + .write_all(&(footer_data.len() as i32).to_le_bytes())?; + self.writer.write_all(&ARROW_MAGIC_V2)?; + self.writer.flush()?; + self.state = State::Finished; + + Ok(()) + } +} diff --git a/crates/nano-arrow/src/io/iterator.rs b/crates/nano-arrow/src/io/iterator.rs new file mode 100644 index 000000000000..91ec86fc2e04 --- /dev/null +++ b/crates/nano-arrow/src/io/iterator.rs @@ -0,0 +1,65 @@ +pub use streaming_iterator::StreamingIterator; + +/// A [`StreamingIterator`] with an internal buffer of [`Vec`] used to efficiently +/// present items of type `T` as `&[u8]`. +/// It is generic over the type `T` and the transformation `F: T -> &[u8]`. +pub struct BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + iterator: I, + f: F, + buffer: Vec, + is_valid: bool, +} + +impl BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + #[inline] + pub fn new(iterator: I, f: F, buffer: Vec) -> Self { + Self { + iterator, + f, + buffer, + is_valid: false, + } + } +} + +impl StreamingIterator for BufStreamingIterator +where + I: Iterator, + F: FnMut(T, &mut Vec), +{ + type Item = [u8]; + + #[inline] + fn advance(&mut self) { + let a = self.iterator.next(); + if let Some(a) = a { + self.is_valid = true; + self.buffer.clear(); + (self.f)(a, &mut self.buffer); + } else { + self.is_valid = false; + } + } + + #[inline] + fn get(&self) -> Option<&Self::Item> { + if self.is_valid { + Some(&self.buffer) + } else { + None + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iterator.size_hint() + } +} diff --git a/crates/nano-arrow/src/io/mod.rs b/crates/nano-arrow/src/io/mod.rs new file mode 100644 index 000000000000..72bf37ba9ea5 --- /dev/null +++ b/crates/nano-arrow/src/io/mod.rs @@ -0,0 +1,21 @@ +#![forbid(unsafe_code)] +//! Contains modules to interface with other formats such as [`csv`], +//! [`parquet`], [`json`], [`ipc`], [`mod@print`] and [`avro`]. + +#[cfg(feature = "io_ipc")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] +pub mod ipc; + +#[cfg(feature = "io_flight")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_flight")))] +pub mod flight; + +#[cfg(feature = "io_parquet")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_parquet")))] +pub mod parquet; + +#[cfg(feature = "io_avro")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_avro")))] +pub mod avro; + +pub mod iterator; diff --git a/crates/nano-arrow/src/io/parquet/mod.rs b/crates/nano-arrow/src/io/parquet/mod.rs new file mode 100644 index 000000000000..04e5693fcfe6 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/mod.rs @@ -0,0 +1,31 @@ +//! APIs to read from and write to Parquet format. +use crate::error::Error; + +pub mod read; +pub mod write; + +#[cfg(feature = "io_parquet_bloom_filter")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_parquet_bloom_filter")))] +pub use parquet2::bloom_filter; + +const ARROW_SCHEMA_META_KEY: &str = "ARROW:schema"; + +impl From for Error { + fn from(error: parquet2::error::Error) -> Self { + match error { + parquet2::error::Error::FeatureNotActive(_, _) => { + let message = "Failed to read a compressed parquet file. \ + Use the cargo feature \"io_parquet_compression\" to read compressed parquet files." + .to_string(); + Error::ExternalFormat(message) + }, + _ => Error::ExternalFormat(error.to_string()), + } + } +} + +impl From for parquet2::error::Error { + fn from(error: Error) -> Self { + parquet2::error::Error::OutOfSpec(error.to_string()) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/README.md b/crates/nano-arrow/src/io/parquet/read/README.md new file mode 100644 index 000000000000..c36aaafaf79a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/README.md @@ -0,0 +1,36 @@ +## Observations + +### LSB equivalence between definition levels and bitmaps + +When the maximum repetition level is 0 and the maximum definition level is 1, +the RLE-encoded definition levels correspond exactly to Arrow's bitmap and can be +memcopied without further transformations. + +## Nested parquet groups are deserialized recursively + +Reading a parquet nested field is done by reading each primitive +column sequentially, and build the nested struct recursively. + +Rows of nested parquet groups are encoded in the repetition and definition levels. +In arrow, they correspond to: + +- list's offsets and validity +- struct's validity + +The implementation in this module leverages this observation: + +Nested parquet fields are initially recursed over to gather +whether the type is a Struct or List, and whether it is required or optional, which we store +in `nested_info: Vec>`. `Nested` is a trait object that receives definition +and repetition levels depending on the type and nullability of the nested item. +We process the definition and repetition levels into `nested_info`. + +When we finish a field, we recursively pop from `nested_info` as we build +the `StructArray` or `ListArray`. + +With this approach, the only difference vs flat is: + +1. we do not leverage the bitmap optimization, and instead need to deserialize the repetition + and definition levels to `i32`. +2. we deserialize definition levels twice, once to extend the values/nullability and + one to extend `nested_info`. diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/README.md b/crates/nano-arrow/src/io/parquet/read/deserialize/README.md new file mode 100644 index 000000000000..5b985bac8e9b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/README.md @@ -0,0 +1,71 @@ +# Design + +## Non-nested types + +Let's start with the design used for non-nested arrays. The (private) entry point of this +module for non-nested arrays is `simple::page_iter_to_arrays`. + +This function expects + +- a (fallible) streaming iterator of decompressed and encoded pages, `Pages` +- the source (parquet) column type, including its logical information +- the target (arrow) `DataType` +- the chunk size + +and returns an iterator of `Array`, `ArrayIter`. + +This design is shared among _all_ `(parquet, arrow)` implemented tuples. Their main +difference is how they are deserialized, which depends on the source and target types. + +When the array iterator is pulled the first time, the following happens: + +- a page from `Pages` is pulled +- a `PageState<'a>` is built from the page +- the `PageState` is consumed into a mutable array: + - if `chunk_size` is larger than the number of rows in the page, the mutable array state is preserved and a new page is pulled and the process repeated until we fill a chunk. + - if `chunk_size` is smaller than the number of rows in the page, the mutable array state + is returned and the remaining of the page is consumed into multiple mutable arrays of length `chunk_size` into a FIFO queue. + +Subsequent pulls of arrays will first try to pull from the FIFO queue. Once the queue is empty, the +a new page is pulled. + +### `PageState` + +As mentioned above, the iterator leverages the idea that we attach a state to a page. Recall +that a page is essentially `[header][data]`. The `data` part contains encoded +`[rep levels][def levels][non-null values]`. Some pages have an associated dictionary page, +in which case the `non-null values` represent the indices. + +Irrespectively of the physical type, the main idea is to split the page in two iterators: + +- An iterator over `def levels` +- An iterator over `non-null values` + +and progress the iterators as needed. In particular, for non-nested types, `def levels` is +a bitmap with the same representation as Arrow, in which case the validity is extended directly. + +The `non-null values` are "expanded" by filling null values with the default value of each physical +type. + +## Nested types + +For nested type with N+1 levels (1 is the primitive), we need to build the nest information of each +N levels + the non-nested Arrow array. + +This is done by first transversing the parquet types and using it to initialize, per chunk, the N levels. + +The per-chunk execution is then similar but `chunk_size` only drives the number of retrieved +rows from the outermost parquet group (the field). Each of these pulls knows how many items need +to be pulled from the inner groups, all the way to the primitive type. This works because +in parquet a row cannot be split between two pages and thus each page is guaranteed +to contain a full row. + +The `PageState` of nested types is composed by 4 iterators: + +- A (zipped) iterator over `rep levels` and `def levels` +- An iterator over `def levels` +- An iterator over `non-null values` + +The idea is that an iterator of `rep, def` contain all the information to decode the +nesting structure of an arrow array. The other two iterators are equivalent to the non-nested +types with the exception that `def levels` are no equivalent to arrow bitmaps. diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/basic.rs new file mode 100644 index 000000000000..6008dd9de005 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/basic.rs @@ -0,0 +1,516 @@ +use std::collections::VecDeque; +use std::default::Default; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::{delta_length_byte_array, hybrid_rle, Encoding}; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{ + extend_from_decoder, get_selected_rows, next, DecodedState, FilteredOptionalPageValidity, + MaybeNext, OptionalPageValidity, +}; +use super::super::{utils, Pages}; +use super::utils::*; +use crate::array::{Array, BinaryArray, Utf8Array}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::{Error, Result}; +use crate::offset::Offset; + +#[derive(Debug)] +pub(super) struct Required<'a> { + pub values: SizedBinaryIter<'a>, +} + +impl<'a> Required<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + let values = SizedBinaryIter::new(values, page.num_values()); + + Ok(Self { values }) + } + + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct Delta<'a> { + pub lengths: std::vec::IntoIter, + pub values: &'a [u8], +} + +impl<'a> Delta<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + + let mut lengths_iter = delta_length_byte_array::Decoder::try_new(values)?; + + #[allow(clippy::needless_collect)] // we need to consume it to get the values + let lengths = lengths_iter + .by_ref() + .map(|x| x.map(|x| x as usize).map_err(Error::from)) + .collect::>>()?; + + let values = lengths_iter.into_values(); + Ok(Self { + lengths: lengths.into_iter(), + values, + }) + } + + pub fn len(&self) -> usize { + self.lengths.size_hint().0 + } +} + +impl<'a> Iterator for Delta<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + let length = self.lengths.next()?; + let (item, remaining) = self.values.split_at(length); + self.values = remaining; + Some(item) + } + + fn size_hint(&self) -> (usize, Option) { + self.lengths.size_hint() + } +} + +#[derive(Debug)] +pub(super) struct FilteredRequired<'a> { + pub values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + pub fn new(page: &'a DataPage) -> Self { + let values = SizedBinaryIter::new(page.buffer(), page.num_values()); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Self { values } + } + + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct FilteredDelta<'a> { + pub values: SliceFilteredIter>, +} + +impl<'a> FilteredDelta<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let values = Delta::try_new(page)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } + + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +pub(super) type Dict = Vec>; + +#[derive(Debug)] +pub(super) struct RequiredDictionary<'a> { + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Dict, +} + +impl<'a> RequiredDictionary<'a> { + pub fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = utils::dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct FilteredRequiredDictionary<'a> { + pub values: SliceFilteredIter>, + pub dict: &'a Dict, +} + +impl<'a> FilteredRequiredDictionary<'a> { + pub fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = utils::dict_indices_decoder(page)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values, dict }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct ValuesDictionary<'a> { + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Dict, +} + +impl<'a> ValuesDictionary<'a> { + pub fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = utils::dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +enum State<'a> { + Optional(OptionalPageValidity<'a>, BinaryIter<'a>), + Required(Required<'a>), + RequiredDictionary(RequiredDictionary<'a>), + OptionalDictionary(OptionalPageValidity<'a>, ValuesDictionary<'a>), + Delta(Delta<'a>), + OptionalDelta(OptionalPageValidity<'a>, Delta<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredDelta(FilteredDelta<'a>), + FilteredOptionalDelta(FilteredOptionalPageValidity<'a>, Delta<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, BinaryIter<'a>), + FilteredRequiredDictionary(FilteredRequiredDictionary<'a>), + FilteredOptionalDictionary(FilteredOptionalPageValidity<'a>, ValuesDictionary<'a>), +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(validity, _) => validity.len(), + State::Required(state) => state.len(), + State::Delta(state) => state.len(), + State::OptionalDelta(state, _) => state.len(), + State::RequiredDictionary(values) => values.len(), + State::OptionalDictionary(optional, _) => optional.len(), + State::FilteredRequired(state) => state.len(), + State::FilteredOptional(validity, _) => validity.len(), + State::FilteredDelta(state) => state.len(), + State::FilteredOptionalDelta(state, _) => state.len(), + State::FilteredRequiredDictionary(values) => values.len(), + State::FilteredOptionalDictionary(optional, _) => optional.len(), + } + } +} + +impl DecodedState for (Binary, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +#[derive(Debug, Default)] +struct BinaryDecoder { + phantom_o: std::marker::PhantomData, +} + +impl<'a, O: Offset> utils::Decoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dict = Dict; + type DecodedState = (Binary, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => Ok( + State::RequiredDictionary(RequiredDictionary::try_new(page, dict)?), + ), + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + Ok(State::OptionalDictionary( + OptionalPageValidity::try_new(page)?, + ValuesDictionary::try_new(page, dict)?, + )) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, true) => { + FilteredRequiredDictionary::try_new(page, dict) + .map(State::FilteredRequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, true) => { + Ok(State::FilteredOptionalDictionary( + FilteredOptionalPageValidity::try_new(page)?, + ValuesDictionary::try_new(page, dict)?, + )) + }, + (Encoding::Plain, _, true, false) => { + let (_, _, values) = split_buffer(page)?; + + let values = BinaryIter::new(values); + + Ok(State::Optional( + OptionalPageValidity::try_new(page)?, + values, + )) + }, + (Encoding::Plain, _, false, false) => Ok(State::Required(Required::try_new(page)?)), + (Encoding::Plain, _, false, true) => { + Ok(State::FilteredRequired(FilteredRequired::new(page))) + }, + (Encoding::Plain, _, true, true) => { + let (_, _, values) = split_buffer(page)?; + + Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + BinaryIter::new(values), + )) + }, + (Encoding::DeltaLengthByteArray, _, false, false) => { + Delta::try_new(page).map(State::Delta) + }, + (Encoding::DeltaLengthByteArray, _, true, false) => Ok(State::OptionalDelta( + OptionalPageValidity::try_new(page)?, + Delta::try_new(page)?, + )), + (Encoding::DeltaLengthByteArray, _, false, true) => { + FilteredDelta::try_new(page).map(State::FilteredDelta) + }, + (Encoding::DeltaLengthByteArray, _, true, true) => Ok(State::FilteredOptionalDelta( + FilteredOptionalPageValidity::try_new(page)?, + Delta::try_new(page)?, + )), + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Binary::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + additional: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page_validity, page_values) => extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + page_values, + ), + State::Required(page) => { + for x in page.values.by_ref().take(additional) { + values.push(x) + } + }, + State::Delta(page) => { + values.extend_lengths(page.lengths.by_ref().take(additional), &mut page.values); + }, + State::OptionalDelta(page_validity, page_values) => { + let Binary { + offsets, + values: values_, + } = values; + + let last_offset = *offsets.last(); + extend_from_decoder( + validity, + page_validity, + Some(additional), + offsets, + page_values.lengths.by_ref(), + ); + + let length = *offsets.last() - last_offset; + + let (consumed, remaining) = page_values.values.split_at(length.to_usize()); + page_values.values = remaining; + values_.extend_from_slice(consumed); + }, + State::FilteredRequired(page) => { + for x in page.values.by_ref().take(additional) { + values.push(x) + } + }, + State::FilteredDelta(page) => { + for x in page.values.by_ref().take(additional) { + values.push(x) + } + }, + State::OptionalDictionary(page_validity, page_values) => { + let page_dict = &page_values.dict; + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + &mut page_values + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()), + ) + }, + State::RequiredDictionary(page) => { + let page_dict = &page.dict; + + for x in page + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()) + .take(additional) + { + values.push(x) + } + }, + State::FilteredOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + page_values.by_ref(), + ); + }, + State::FilteredOptionalDelta(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + page_values.by_ref(), + ); + }, + State::FilteredRequiredDictionary(page) => { + let page_dict = &page.dict; + for x in page + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()) + .take(additional) + { + values.push(x) + } + }, + State::FilteredOptionalDictionary(page_validity, page_values) => { + let page_dict = &page_values.dict; + utils::extend_from_decoder( + validity, + page_validity, + Some(additional), + values, + &mut page_values + .values + .by_ref() + .map(|index| page_dict[index.unwrap() as usize].as_ref()), + ) + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + deserialize_plain(&page.buffer, page.num_values) + } +} + +pub(super) fn finish( + data_type: &DataType, + mut values: Binary, + mut validity: MutableBitmap, +) -> Result> { + values.offsets.shrink_to_fit(); + values.values.shrink_to_fit(); + validity.shrink_to_fit(); + + match data_type.to_physical_type() { + PhysicalType::Binary | PhysicalType::LargeBinary => BinaryArray::::try_new( + data_type.clone(), + values.offsets.into(), + values.values.into(), + validity.into(), + ) + .map(|x| x.boxed()), + PhysicalType::Utf8 | PhysicalType::LargeUtf8 => Utf8Array::::try_new( + data_type.clone(), + values.offsets.into(), + values.values.into(), + validity.into(), + ) + .map(|x| x.boxed()), + _ => unreachable!(), + } +} + +pub struct Iter { + iter: I, + data_type: DataType, + items: VecDeque<(Binary, MutableBitmap)>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl Iter { + pub fn new(iter: I, data_type: DataType, chunk_size: Option, num_rows: usize) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for Iter { + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &BinaryDecoder::::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((values, validity))) => { + Some(finish(&self.data_type, values, validity)) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +pub(super) fn deserialize_plain(values: &[u8], num_values: usize) -> Dict { + SizedBinaryIter::new(values, num_values) + .map(|x| x.to_vec()) + .collect() +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/dictionary.rs new file mode 100644 index 000000000000..0fb3615de050 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/dictionary.rs @@ -0,0 +1,174 @@ +use std::collections::VecDeque; + +use parquet2::page::DictPage; + +use super::super::dictionary::*; +use super::super::utils::MaybeNext; +use super::super::Pages; +use super::utils::{Binary, SizedBinaryIter}; +use crate::array::{Array, BinaryArray, DictionaryArray, DictionaryKey, Utf8Array}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Result; +use crate::io::parquet::read::deserialize::nested_utils::{InitNested, NestedState}; +use crate::offset::Offset; + +/// An iterator adapter over [`Pages`] assumed to be encoded as parquet's dictionary-encoded binary representation +#[derive(Debug)] +pub struct DictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + iter: I, + data_type: DataType, + values: Option>, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + phantom: std::marker::PhantomData, +} + +impl DictIter +where + K: DictionaryKey, + O: Offset, + I: Pages, +{ + pub fn new(iter: I, data_type: DataType, num_rows: usize, chunk_size: Option) -> Self { + Self { + iter, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + phantom: std::marker::PhantomData, + } + } +} + +fn read_dict(data_type: DataType, dict: &DictPage) -> Box { + let data_type = match data_type { + DataType::Dictionary(_, values, _) => *values, + _ => data_type, + }; + + let values = SizedBinaryIter::new(&dict.buffer, dict.num_values); + + let mut data = Binary::::with_capacity(dict.num_values); + data.values = Vec::with_capacity(dict.buffer.len() - 4 * dict.num_values); + for item in values { + data.push(item) + } + + match data_type.to_physical_type() { + PhysicalType::Utf8 | PhysicalType::LargeUtf8 => { + Utf8Array::::new(data_type, data.offsets.into(), data.values.into(), None).boxed() + }, + PhysicalType::Binary | PhysicalType::LargeBinary => { + BinaryArray::::new(data_type, data.offsets.into(), data.values.into(), None).boxed() + }, + _ => unreachable!(), + } +} + +impl Iterator for DictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next_dict( + &mut self.iter, + &mut self.items, + &mut self.values, + self.data_type.clone(), + &mut self.remaining, + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +/// An iterator adapter that converts [`DataPages`] into an [`Iterator`] of [`DictionaryArray`] +#[derive(Debug)] +pub struct NestedDictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + iter: I, + init: Vec, + data_type: DataType, + values: Option>, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: usize, + chunk_size: Option, + phantom: std::marker::PhantomData, +} + +impl NestedDictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + init, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + phantom: Default::default(), + } + } +} + +impl Iterator for NestedDictIter +where + I: Pages, + O: Offset, + K: DictionaryKey, +{ + type Item = Result<(NestedState, DictionaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = nested_next_dict( + &mut self.iter, + &mut self.items, + &mut self.remaining, + &self.init, + &mut self.values, + self.data_type.clone(), + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/mod.rs new file mode 100644 index 000000000000..c48bfe276bcc --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/mod.rs @@ -0,0 +1,8 @@ +mod basic; +mod dictionary; +mod nested; +mod utils; + +pub use basic::Iter; +pub use dictionary::{DictIter, NestedDictIter}; +pub use nested::NestedIter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/nested.rs new file mode 100644 index 000000000000..64f076932e49 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/nested.rs @@ -0,0 +1,187 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::nested_utils::*; +use super::super::utils; +use super::super::utils::MaybeNext; +use super::basic::{deserialize_plain, finish, Dict, ValuesDictionary}; +use super::utils::*; +use crate::array::Array; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::Pages; +use crate::offset::Offset; + +#[derive(Debug)] +enum State<'a> { + Optional(BinaryIter<'a>), + Required(BinaryIter<'a>), + RequiredDictionary(ValuesDictionary<'a>), + OptionalDictionary(ValuesDictionary<'a>), +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(validity) => validity.size_hint().0, + State::Required(state) => state.size_hint().0, + State::RequiredDictionary(required) => required.len(), + State::OptionalDictionary(optional) => optional.len(), + } + } +} + +#[derive(Debug, Default)] +struct BinaryDecoder { + phantom_o: std::marker::PhantomData, +} + +impl<'a, O: Offset> NestedDecoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dictionary = Dict; + type DecodedState = (Binary, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + ValuesDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + ValuesDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + (Encoding::Plain, _, true, false) => { + let (_, _, values) = split_buffer(page)?; + + let values = BinaryIter::new(values); + + Ok(State::Optional(values)) + }, + (Encoding::Plain, _, false, false) => { + let (_, _, values) = split_buffer(page)?; + + let values = BinaryIter::new(values); + + Ok(State::Required(values)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Binary::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page) => { + let value = page.next().unwrap_or_default(); + values.push(value); + validity.push(true); + }, + State::Required(page) => { + let value = page.next().unwrap_or_default(); + values.push(value); + }, + State::RequiredDictionary(page) => { + let dict_values = &page.dict; + let item = page + .values + .next() + .map(|index| dict_values[index.unwrap() as usize].as_ref()) + .unwrap_or_default(); + values.push(item); + }, + State::OptionalDictionary(page) => { + let dict_values = &page.dict; + let item = page + .values + .next() + .map(|index| dict_values[index.unwrap() as usize].as_ref()) + .unwrap_or_default(); + values.push(item); + validity.push(true); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(&[]); + validity.push(false); + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + deserialize_plain(&page.buffer, page.num_values) + } +} + +pub struct NestedIter { + iter: I, + data_type: DataType, + init: Vec, + items: VecDeque<(NestedState, (Binary, MutableBitmap))>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl NestedIter { + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + data_type, + init, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for NestedIter { + type Item = Result<(NestedState, Box)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + &self.init, + self.chunk_size, + &BinaryDecoder::::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((nested, decoded))) => { + Some(finish(&self.data_type, decoded.0, decoded.1).map(|array| (nested, array))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/binary/utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/utils.rs new file mode 100644 index 000000000000..0a2a0f3466f8 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/binary/utils.rs @@ -0,0 +1,169 @@ +use super::super::utils::Pushable; +use crate::offset::{Offset, Offsets}; + +/// [`Pushable`] for variable length binary data. +#[derive(Debug)] +pub struct Binary { + pub offsets: Offsets, + pub values: Vec, +} + +impl Pushable for Offsets { + fn reserve(&mut self, additional: usize) { + self.reserve(additional) + } + #[inline] + fn len(&self) -> usize { + self.len_proxy() + } + + #[inline] + fn push(&mut self, value: usize) { + self.try_push_usize(value).unwrap() + } + + #[inline] + fn push_null(&mut self) { + self.extend_constant(1); + } + + #[inline] + fn extend_constant(&mut self, additional: usize, _: usize) { + self.extend_constant(additional) + } +} + +impl Binary { + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + offsets: Offsets::with_capacity(capacity), + values: Vec::with_capacity(capacity.min(100) * 24), + } + } + + #[inline] + pub fn push(&mut self, v: &[u8]) { + if self.offsets.len_proxy() == 100 && self.offsets.capacity() > 100 { + let bytes_per_row = self.values.len() / 100 + 1; + let bytes_estimate = bytes_per_row * self.offsets.capacity(); + if bytes_estimate > self.values.capacity() { + self.values.reserve(bytes_estimate - self.values.capacity()); + } + } + + self.values.extend(v); + self.offsets.try_push_usize(v.len()).unwrap() + } + + #[inline] + pub fn extend_constant(&mut self, additional: usize) { + self.offsets.extend_constant(additional); + } + + #[inline] + pub fn len(&self) -> usize { + self.offsets.len_proxy() + } + + #[inline] + pub fn extend_lengths>(&mut self, lengths: I, values: &mut &[u8]) { + let current_offset = *self.offsets.last(); + self.offsets.try_extend_from_lengths(lengths).unwrap(); + let new_offset = *self.offsets.last(); + let length = new_offset.to_usize() - current_offset.to_usize(); + let (consumed, remaining) = values.split_at(length); + *values = remaining; + self.values.extend_from_slice(consumed); + } +} + +impl<'a, O: Offset> Pushable<&'a [u8]> for Binary { + #[inline] + fn reserve(&mut self, additional: usize) { + let avg_len = self.values.len() / std::cmp::max(self.offsets.last().to_usize(), 1); + self.values.reserve(additional * avg_len); + self.offsets.reserve(additional); + } + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push_null(&mut self) { + self.push(&[]) + } + + #[inline] + fn push(&mut self, value: &[u8]) { + self.push(value) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: &[u8]) { + assert_eq!(value.len(), 0); + self.extend_constant(additional) + } +} + +#[derive(Debug)] +pub struct BinaryIter<'a> { + values: &'a [u8], +} + +impl<'a> BinaryIter<'a> { + pub fn new(values: &'a [u8]) -> Self { + Self { values } + } +} + +impl<'a> Iterator for BinaryIter<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + if self.values.is_empty() { + return None; + } + let (length, remaining) = self.values.split_at(4); + let length = u32::from_le_bytes(length.try_into().unwrap()) as usize; + let (result, remaining) = remaining.split_at(length); + self.values = remaining; + Some(result) + } +} + +#[derive(Debug)] +pub struct SizedBinaryIter<'a> { + iter: BinaryIter<'a>, + remaining: usize, +} + +impl<'a> SizedBinaryIter<'a> { + pub fn new(values: &'a [u8], size: usize) -> Self { + let iter = BinaryIter::new(values); + Self { + iter, + remaining: size, + } + } +} + +impl<'a> Iterator for SizedBinaryIter<'a> { + type Item = &'a [u8]; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } else { + self.remaining -= 1 + }; + self.iter.next() + } + + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/basic.rs new file mode 100644 index 000000000000..dd3ac9eb52c5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/basic.rs @@ -0,0 +1,229 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{ + extend_from_decoder, get_selected_rows, next, DecodedState, Decoder, + FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, +}; +use super::super::{utils, Pages}; +use crate::array::BooleanArray; +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +#[derive(Debug)] +struct Values<'a>(BitmapIter<'a>); + +impl<'a> Values<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + + Ok(Self(BitmapIter::new(values, 0, values.len() * 8))) + } +} + +// The state of a required DataPage with a boolean physical type +#[derive(Debug)] +struct Required<'a> { + values: &'a [u8], + // invariant: offset <= length; + offset: usize, + length: usize, +} + +impl<'a> Required<'a> { + pub fn new(page: &'a DataPage) -> Self { + Self { + values: page.buffer(), + offset: 0, + length: page.num_values(), + } + } +} + +#[derive(Debug)] +struct FilteredRequired<'a> { + values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + // todo: replace this by an iterator over slices, for faster deserialization + let values = BitmapIter::new(values, 0, page.num_values()); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +// The state of a `DataPage` of `Boolean` parquet boolean type +#[derive(Debug)] +enum State<'a> { + Optional(OptionalPageValidity<'a>, Values<'a>), + Required(Required<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>), +} + +impl<'a> State<'a> { + pub fn len(&self) -> usize { + match self { + State::Optional(validity, _) => validity.len(), + State::Required(page) => page.length - page.offset, + State::FilteredRequired(page) => page.len(), + State::FilteredOptional(optional, _) => optional.len(), + } + } +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + self.len() + } +} + +impl DecodedState for (MutableBitmap, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +#[derive(Default)] +struct BooleanDecoder {} + +impl<'a> Decoder<'a> for BooleanDecoder { + type State = State<'a>; + type Dict = (); + type DecodedState = (MutableBitmap, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, _: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::Plain, true, false) => Ok(State::Optional( + OptionalPageValidity::try_new(page)?, + Values::try_new(page)?, + )), + (Encoding::Plain, false, false) => Ok(State::Required(Required::new(page))), + (Encoding::Plain, true, true) => Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + Values::try_new(page)?, + )), + (Encoding::Plain, false, true) => { + Ok(State::FilteredRequired(FilteredRequired::try_new(page)?)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + MutableBitmap::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page_validity, page_values) => extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.0, + ), + State::Required(page) => { + let remaining = remaining.min(page.length - page.offset); + values.extend_from_slice(page.values, page.offset, remaining); + page.offset += remaining; + }, + State::FilteredRequired(page) => { + values.reserve(remaining); + for item in page.values.by_ref().take(remaining) { + values.push(item) + } + }, + State::FilteredOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.0.by_ref(), + ); + }, + } + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dict {} +} + +fn finish(data_type: &DataType, values: MutableBitmap, validity: MutableBitmap) -> BooleanArray { + BooleanArray::new(data_type.clone(), values.into(), validity.into()) +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct Iter { + iter: I, + data_type: DataType, + items: VecDeque<(MutableBitmap, MutableBitmap)>, + chunk_size: Option, + remaining: usize, +} + +impl Iter { + pub fn new(iter: I, data_type: DataType, chunk_size: Option, num_rows: usize) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for Iter { + type Item = Result; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut None, + &mut self.remaining, + self.chunk_size, + &BooleanDecoder::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/mod.rs new file mode 100644 index 000000000000..dc00cc2a4249 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/mod.rs @@ -0,0 +1,6 @@ +mod basic; +mod nested; + +pub use nested::NestedIter; + +pub use self::basic::Iter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/nested.rs new file mode 100644 index 000000000000..f3e684ab9fe3 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/boolean/nested.rs @@ -0,0 +1,153 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::nested_utils::*; +use super::super::utils::MaybeNext; +use super::super::{utils, Pages}; +use crate::array::BooleanArray; +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +// The state of a `DataPage` of `Boolean` parquet boolean type +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +enum State<'a> { + Optional(BitmapIter<'a>), + Required(BitmapIter<'a>), +} + +impl<'a> State<'a> { + pub fn len(&self) -> usize { + match self { + State::Optional(iter) => iter.size_hint().0, + State::Required(iter) => iter.size_hint().0, + } + } +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + self.len() + } +} + +#[derive(Default)] +struct BooleanDecoder {} + +impl<'a> NestedDecoder<'a> for BooleanDecoder { + type State = State<'a>; + type Dictionary = (); + type DecodedState = (MutableBitmap, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + _: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::Plain, true, false) => { + let (_, _, values) = split_buffer(page)?; + let values = BitmapIter::new(values, 0, values.len() * 8); + + Ok(State::Optional(values)) + }, + (Encoding::Plain, false, false) => { + let (_, _, values) = split_buffer(page)?; + let values = BitmapIter::new(values, 0, values.len() * 8); + + Ok(State::Required(values)) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + MutableBitmap::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page_values) => { + let value = page_values.next().unwrap_or_default(); + values.push(value); + validity.push(true); + }, + State::Required(page_values) => { + let value = page_values.next().unwrap_or_default(); + values.push(value); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(false); + validity.push(false); + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dictionary {} +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct NestedIter { + iter: I, + init: Vec, + items: VecDeque<(NestedState, (MutableBitmap, MutableBitmap))>, + remaining: usize, + chunk_size: Option, +} + +impl NestedIter { + pub fn new(iter: I, init: Vec, num_rows: usize, chunk_size: Option) -> Self { + Self { + iter, + init, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + } + } +} + +fn finish(data_type: &DataType, values: MutableBitmap, validity: MutableBitmap) -> BooleanArray { + BooleanArray::new(data_type.clone(), values.into(), validity.into()) +} + +impl Iterator for NestedIter { + type Item = Result<(NestedState, BooleanArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut None, + &mut self.remaining, + &self.init, + self.chunk_size, + &BooleanDecoder::default(), + ); + match maybe_state { + MaybeNext::Some(Ok((nested, (values, validity)))) => { + Some(Ok((nested, finish(&DataType::Boolean, values, validity)))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/mod.rs new file mode 100644 index 000000000000..7826f5856c0e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/mod.rs @@ -0,0 +1,314 @@ +mod nested; + +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::hybrid_rle::HybridRleDecoder; +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage, Page}; +use parquet2::schema::Repetition; + +use super::utils::{ + self, dict_indices_decoder, extend_from_decoder, get_selected_rows, DecodedState, Decoder, + FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, +}; +use super::Pages; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +// The state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +pub enum State<'a> { + Optional(Optional<'a>), + Required(Required<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, HybridRleDecoder<'a>), +} + +#[derive(Debug)] +pub struct Required<'a> { + values: HybridRleDecoder<'a>, +} + +impl<'a> Required<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + Ok(Self { values }) + } +} + +#[derive(Debug)] +pub struct FilteredRequired<'a> { + values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } +} + +#[derive(Debug)] +pub struct Optional<'a> { + values: HybridRleDecoder<'a>, + validity: OptionalPageValidity<'a>, +} + +impl<'a> Optional<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + + Ok(Self { + values, + validity: OptionalPageValidity::try_new(page)?, + }) + } +} + +impl<'a> utils::PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(optional) => optional.validity.len(), + State::Required(required) => required.values.size_hint().0, + State::FilteredRequired(required) => required.values.size_hint().0, + State::FilteredOptional(validity, _) => validity.len(), + } + } +} + +#[derive(Debug)] +pub struct PrimitiveDecoder +where + K: DictionaryKey, +{ + phantom_k: std::marker::PhantomData, +} + +impl Default for PrimitiveDecoder +where + K: DictionaryKey, +{ + #[inline] + fn default() -> Self { + Self { + phantom_k: std::marker::PhantomData, + } + } +} + +impl<'a, K> utils::Decoder<'a> for PrimitiveDecoder +where + K: DictionaryKey, +{ + type State = State<'a>; + type Dict = (); + type DecodedState = (Vec, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, _: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, false, false) => { + Required::try_new(page).map(State::Required) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, true, false) => { + Optional::try_new(page).map(State::Optional) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, false, true) => { + FilteredRequired::try_new(page).map(State::FilteredRequired) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, true, true) => { + Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + dict_indices_decoder(page)?, + )) + }, + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page) => extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + &mut page.values.by_ref().map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => panic!("The maximum key is too small"), + } + }), + ), + State::Required(page) => { + values.extend( + page.values + .by_ref() + .map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }) + .take(remaining), + ); + }, + State::FilteredOptional(page_validity, page_values) => extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.by_ref().map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }), + ), + State::FilteredRequired(page) => { + values.extend( + page.values + .by_ref() + .map(|x| { + // todo: rm unwrap + let x: usize = x.unwrap().try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + }, + }; + x + }) + .take(remaining), + ); + }, + } + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dict {} +} + +fn finish_key(values: Vec, validity: MutableBitmap) -> PrimitiveArray { + PrimitiveArray::new(K::PRIMITIVE.into(), values.into(), validity.into()) +} + +#[inline] +pub(super) fn next_dict Box>( + iter: &mut I, + items: &mut VecDeque<(Vec, MutableBitmap)>, + dict: &mut Option>, + data_type: DataType, + remaining: &mut usize, + chunk_size: Option, + read_dict: F, +) -> MaybeNext>> { + if items.len() > 1 { + let (values, validity) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + return MaybeNext::Some(DictionaryArray::try_new( + data_type, + keys, + dict.clone().unwrap(), + )); + } + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(Some(page)) => { + let (page, dict) = match (&dict, page) { + (None, Page::Data(_)) => { + return MaybeNext::Some(Err(Error::nyi( + "dictionary arrays from non-dict-encoded pages", + ))); + }, + (_, Page::Dict(dict_page)) => { + *dict = Some(read_dict(dict_page)); + return next_dict( + iter, items, dict, data_type, remaining, chunk_size, read_dict, + ); + }, + (Some(dict), Page::Data(page)) => (page, dict), + }; + + // there is a new page => consume the page from the start + let maybe_page = PrimitiveDecoder::::default().build_state(page, None); + let page = match maybe_page { + Ok(page) => page, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + utils::extend_from_new_page( + page, + chunk_size, + items, + remaining, + &PrimitiveDecoder::::default(), + ); + + if items.front().unwrap().len() < chunk_size.unwrap_or(usize::MAX) { + MaybeNext::More + } else { + let (values, validity) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + MaybeNext::Some(DictionaryArray::try_new(data_type, keys, dict.clone())) + } + }, + Ok(None) => { + if let Some((values, validity)) = items.pop_front() { + // we have a populated item and no more pages + // the only case where an item's length may be smaller than chunk_size + debug_assert!(values.len() <= chunk_size.unwrap_or(usize::MAX)); + + let keys = finish_key(values, validity); + MaybeNext::Some(DictionaryArray::try_new( + data_type, + keys, + dict.clone().unwrap(), + )) + } else { + MaybeNext::None + } + }, + } +} + +pub use nested::next_dict as nested_next_dict; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/nested.rs new file mode 100644 index 000000000000..1fb1919d1504 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/dictionary/nested.rs @@ -0,0 +1,213 @@ +use std::collections::VecDeque; + +use parquet2::encoding::hybrid_rle::HybridRleDecoder; +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage, Page}; +use parquet2::schema::Repetition; + +use super::super::super::Pages; +use super::super::nested_utils::*; +use super::super::utils::{dict_indices_decoder, not_implemented, MaybeNext, PageState}; +use super::finish_key; +use crate::array::{Array, DictionaryArray, DictionaryKey}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; + +// The state of a required DataPage with a boolean physical type +#[derive(Debug)] +pub struct Required<'a> { + values: HybridRleDecoder<'a>, + length: usize, +} + +impl<'a> Required<'a> { + fn try_new(page: &'a DataPage) -> Result { + let values = dict_indices_decoder(page)?; + let length = page.num_values(); + Ok(Self { values, length }) + } +} + +// The state of a `DataPage` of a `Dictionary` type +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub enum State<'a> { + Optional(HybridRleDecoder<'a>), + Required(Required<'a>), +} + +impl<'a> State<'a> { + pub fn len(&self) -> usize { + match self { + State::Optional(page) => page.len(), + State::Required(page) => page.length, + } + } +} + +impl<'a> PageState<'a> for State<'a> { + fn len(&self) -> usize { + self.len() + } +} + +#[derive(Debug)] +pub struct DictionaryDecoder +where + K: DictionaryKey, +{ + phantom_k: std::marker::PhantomData, +} + +impl Default for DictionaryDecoder +where + K: DictionaryKey, +{ + #[inline] + fn default() -> Self { + Self { + phantom_k: std::marker::PhantomData, + } + } +} + +impl<'a, K: DictionaryKey> NestedDecoder<'a> for DictionaryDecoder { + type State = State<'a>; + type Dictionary = (); + type DecodedState = (Vec, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + _: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), is_optional, is_filtered) { + (Encoding::RleDictionary | Encoding::PlainDictionary, true, false) => { + dict_indices_decoder(page).map(State::Optional) + }, + (Encoding::RleDictionary | Encoding::PlainDictionary, false, false) => { + Required::try_new(page).map(State::Required) + }, + _ => Err(not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page_values) => { + let key = page_values.next().transpose()?; + // todo: convert unwrap to error + let key = match K::try_from(key.unwrap_or_default() as usize) { + Ok(key) => key, + Err(_) => todo!(), + }; + values.push(key); + validity.push(true); + }, + State::Required(page_values) => { + let key = page_values.values.next().transpose()?; + let key = match K::try_from(key.unwrap_or_default() as usize) { + Ok(key) => key, + Err(_) => todo!(), + }; + values.push(key); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(K::default()); + validity.push(false) + } + + fn deserialize_dict(&self, _: &DictPage) -> Self::Dictionary {} +} + +#[allow(clippy::too_many_arguments)] +pub fn next_dict Box>( + iter: &mut I, + items: &mut VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: &mut usize, + init: &[InitNested], + dict: &mut Option>, + data_type: DataType, + chunk_size: Option, + read_dict: F, +) -> MaybeNext)>> { + if items.len() > 1 { + let (nested, (values, validity)) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + let dict = DictionaryArray::try_new(data_type, keys, dict.clone().unwrap()); + return MaybeNext::Some(dict.map(|dict| (nested, dict))); + } + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(Some(page)) => { + let (page, dict) = match (&dict, page) { + (None, Page::Data(_)) => { + return MaybeNext::Some(Err(Error::nyi( + "dictionary arrays from non-dict-encoded pages", + ))); + }, + (_, Page::Dict(dict_page)) => { + *dict = Some(read_dict(dict_page)); + return next_dict( + iter, items, remaining, init, dict, data_type, chunk_size, read_dict, + ); + }, + (Some(dict), Page::Data(page)) => (page, dict), + }; + + let error = extend( + page, + init, + items, + None, + remaining, + &DictionaryDecoder::::default(), + chunk_size, + ); + match error { + Ok(_) => {}, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + if items.front().unwrap().0.len() < chunk_size.unwrap_or(usize::MAX) { + MaybeNext::More + } else { + let (nested, (values, validity)) = items.pop_front().unwrap(); + let keys = finish_key(values, validity); + let dict = DictionaryArray::try_new(data_type, keys, dict.clone()); + MaybeNext::Some(dict.map(|dict| (nested, dict))) + } + }, + Ok(None) => { + if let Some((nested, (values, validity))) = items.pop_front() { + // we have a populated item and no more pages + // the only case where an item's length may be smaller than chunk_size + debug_assert!(values.len() <= chunk_size.unwrap_or(usize::MAX)); + + let keys = finish_key(values, validity); + let dict = DictionaryArray::try_new(data_type, keys, dict.clone().unwrap()); + MaybeNext::Some(dict.map(|dict| (nested, dict))) + } else { + MaybeNext::None + } + }, + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/basic.rs new file mode 100644 index 000000000000..aee3116ed64e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/basic.rs @@ -0,0 +1,322 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::{hybrid_rle, Encoding}; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{ + dict_indices_decoder, extend_from_decoder, get_selected_rows, next, not_implemented, + DecodedState, Decoder, FilteredOptionalPageValidity, MaybeNext, OptionalPageValidity, + PageState, Pushable, +}; +use super::super::Pages; +use super::utils::FixedSizeBinary; +use crate::array::FixedSizeBinaryArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; + +pub(super) type Dict = Vec; + +#[derive(Debug)] +pub(super) struct Optional<'a> { + pub(super) values: std::slice::ChunksExact<'a, u8>, + pub(super) validity: OptionalPageValidity<'a>, +} + +impl<'a> Optional<'a> { + pub(super) fn try_new(page: &'a DataPage, size: usize) -> Result { + let (_, _, values) = split_buffer(page)?; + + let values = values.chunks_exact(size); + + Ok(Self { + values, + validity: OptionalPageValidity::try_new(page)?, + }) + } +} + +#[derive(Debug)] +pub(super) struct Required<'a> { + pub values: std::slice::ChunksExact<'a, u8>, +} + +impl<'a> Required<'a> { + pub(super) fn new(page: &'a DataPage, size: usize) -> Self { + let values = page.buffer(); + assert_eq!(values.len() % size, 0); + let values = values.chunks_exact(size); + Self { values } + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct FilteredRequired<'a> { + pub values: SliceFilteredIter>, +} + +impl<'a> FilteredRequired<'a> { + fn new(page: &'a DataPage, size: usize) -> Self { + let values = page.buffer(); + assert_eq!(values.len() % size, 0); + let values = values.chunks_exact(size); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Self { values } + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct RequiredDictionary<'a> { + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Dict, +} + +impl<'a> RequiredDictionary<'a> { + pub(super) fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct OptionalDictionary<'a> { + pub(super) values: hybrid_rle::HybridRleDecoder<'a>, + pub(super) validity: OptionalPageValidity<'a>, + pub(super) dict: &'a Dict, +} + +impl<'a> OptionalDictionary<'a> { + pub(super) fn try_new(page: &'a DataPage, dict: &'a Dict) -> Result { + let values = dict_indices_decoder(page)?; + + Ok(Self { + values, + validity: OptionalPageValidity::try_new(page)?, + dict, + }) + } +} + +#[derive(Debug)] +enum State<'a> { + Optional(Optional<'a>), + Required(Required<'a>), + RequiredDictionary(RequiredDictionary<'a>), + OptionalDictionary(OptionalDictionary<'a>), + FilteredRequired(FilteredRequired<'a>), + FilteredOptional( + FilteredOptionalPageValidity<'a>, + std::slice::ChunksExact<'a, u8>, + ), +} + +impl<'a> PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(state) => state.validity.len(), + State::Required(state) => state.len(), + State::RequiredDictionary(state) => state.len(), + State::OptionalDictionary(state) => state.validity.len(), + State::FilteredRequired(state) => state.len(), + State::FilteredOptional(state, _) => state.len(), + } + } +} + +struct BinaryDecoder { + size: usize, +} + +impl DecodedState for (FixedSizeBinary, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +impl<'a> Decoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dict = Dict; + type DecodedState = (FixedSizeBinary, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::Plain, _, true, false) => { + Ok(State::Optional(Optional::try_new(page, self.size)?)) + }, + (Encoding::Plain, _, false, false) => { + Ok(State::Required(Required::new(page, self.size))) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + RequiredDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + OptionalDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + (Encoding::Plain, None, false, true) => Ok(State::FilteredRequired( + FilteredRequired::new(page, self.size), + )), + (Encoding::Plain, _, true, true) => { + let (_, _, values) = split_buffer(page)?; + + Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + values.chunks_exact(self.size), + )) + }, + _ => Err(not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + FixedSizeBinary::with_capacity(capacity, self.size), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page) => extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + &mut page.values, + ), + State::Required(page) => { + for x in page.values.by_ref().take(remaining) { + values.push(x) + } + }, + State::FilteredRequired(page) => { + for x in page.values.by_ref().take(remaining) { + values.push(x) + } + }, + State::OptionalDictionary(page) => extend_from_decoder( + validity, + &mut page.validity, + Some(remaining), + values, + page.values.by_ref().map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }), + ), + State::RequiredDictionary(page) => { + for x in page + .values + .by_ref() + .map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }) + .take(remaining) + { + values.push(x) + } + }, + State::FilteredOptional(page_validity, page_values) => { + extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.by_ref(), + ); + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + page.buffer.clone() + } +} + +pub fn finish( + data_type: &DataType, + values: FixedSizeBinary, + validity: MutableBitmap, +) -> FixedSizeBinaryArray { + FixedSizeBinaryArray::new(data_type.clone(), values.values.into(), validity.into()) +} + +pub struct Iter { + iter: I, + data_type: DataType, + size: usize, + items: VecDeque<(FixedSizeBinary, MutableBitmap)>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl Iter { + pub fn new(iter: I, data_type: DataType, num_rows: usize, chunk_size: Option) -> Self { + let size = FixedSizeBinaryArray::get_size(&data_type); + Self { + iter, + data_type, + size, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for Iter { + type Item = Result; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &BinaryDecoder { size: self.size }, + ); + match maybe_state { + MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs new file mode 100644 index 000000000000..3f5455b0bdb8 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs @@ -0,0 +1,150 @@ +use std::collections::VecDeque; + +use parquet2::page::DictPage; + +use super::super::dictionary::*; +use super::super::utils::MaybeNext; +use super::super::Pages; +use crate::array::{Array, DictionaryArray, DictionaryKey, FixedSizeBinaryArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::deserialize::nested_utils::{InitNested, NestedState}; + +/// An iterator adapter over [`Pages`] assumed to be encoded as parquet's dictionary-encoded binary representation +#[derive(Debug)] +pub struct DictIter +where + I: Pages, + K: DictionaryKey, +{ + iter: I, + data_type: DataType, + values: Option>, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, +} + +impl DictIter +where + K: DictionaryKey, + I: Pages, +{ + pub fn new(iter: I, data_type: DataType, num_rows: usize, chunk_size: Option) -> Self { + Self { + iter, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + } + } +} + +fn read_dict(data_type: DataType, dict: &DictPage) -> Box { + let data_type = match data_type { + DataType::Dictionary(_, values, _) => *values, + _ => data_type, + }; + + let values = dict.buffer.clone(); + + FixedSizeBinaryArray::try_new(data_type, values.into(), None) + .unwrap() + .boxed() +} + +impl Iterator for DictIter +where + I: Pages, + K: DictionaryKey, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next_dict( + &mut self.iter, + &mut self.items, + &mut self.values, + self.data_type.clone(), + &mut self.remaining, + self.chunk_size, + |dict| read_dict(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +/// An iterator adapter that converts [`DataPages`] into an [`Iterator`] of [`DictionaryArray`]. +#[derive(Debug)] +pub struct NestedDictIter +where + I: Pages, + K: DictionaryKey, +{ + iter: I, + init: Vec, + data_type: DataType, + values: Option>, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: usize, + chunk_size: Option, +} + +impl NestedDictIter +where + I: Pages, + K: DictionaryKey, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + init, + data_type, + values: None, + remaining: num_rows, + items: VecDeque::new(), + chunk_size, + } + } +} + +impl Iterator for NestedDictIter +where + I: Pages, + K: DictionaryKey, +{ + type Item = Result<(NestedState, DictionaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = nested_next_dict( + &mut self.iter, + &mut self.items, + &mut self.remaining, + &self.init, + &mut self.values, + self.data_type.clone(), + self.chunk_size, + |dict| read_dict(self.data_type.clone(), dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/mod.rs new file mode 100644 index 000000000000..c48bfe276bcc --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/mod.rs @@ -0,0 +1,8 @@ +mod basic; +mod dictionary; +mod nested; +mod utils; + +pub use basic::Iter; +pub use dictionary::{DictIter, NestedDictIter}; +pub use nested::NestedIter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs new file mode 100644 index 000000000000..f2b65380baad --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs @@ -0,0 +1,189 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage}; +use parquet2::schema::Repetition; + +use super::super::utils::{not_implemented, MaybeNext, PageState}; +use super::utils::FixedSizeBinary; +use crate::array::FixedSizeBinaryArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::deserialize::fixed_size_binary::basic::{ + finish, Dict, Optional, OptionalDictionary, Required, RequiredDictionary, +}; +use crate::io::parquet::read::deserialize::nested_utils::{next, NestedDecoder}; +use crate::io::parquet::read::deserialize::utils::Pushable; +use crate::io::parquet::read::{InitNested, NestedState, Pages}; + +#[derive(Debug)] +enum State<'a> { + Optional(Optional<'a>), + Required(Required<'a>), + RequiredDictionary(RequiredDictionary<'a>), + OptionalDictionary(OptionalDictionary<'a>), +} + +impl<'a> PageState<'a> for State<'a> { + fn len(&self) -> usize { + match self { + State::Optional(state) => state.validity.len(), + State::Required(state) => state.len(), + State::RequiredDictionary(state) => state.len(), + State::OptionalDictionary(state) => state.validity.len(), + } + } +} + +#[derive(Debug, Default)] +struct BinaryDecoder { + size: usize, +} + +impl<'a> NestedDecoder<'a> for BinaryDecoder { + type State = State<'a>; + type Dictionary = Dict; + type DecodedState = (FixedSizeBinary, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::Plain, _, true, false) => { + Ok(State::Optional(Optional::try_new(page, self.size)?)) + }, + (Encoding::Plain, _, false, false) => { + Ok(State::Required(Required::new(page, self.size))) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + RequiredDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + OptionalDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + _ => Err(not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + FixedSizeBinary::with_capacity(capacity, self.size), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page) => { + let value = page.values.by_ref().next().unwrap_or_default(); + values.push(value); + validity.push(true); + }, + State::Required(page) => { + let value = page.values.by_ref().next().unwrap_or_default(); + values.push(value); + }, + State::RequiredDictionary(page) => { + let item = page + .values + .by_ref() + .next() + .map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }) + .unwrap_or_default(); + values.push(item); + }, + State::OptionalDictionary(page) => { + let item = page + .values + .by_ref() + .next() + .map(|index| { + let index = index.unwrap() as usize; + &page.dict[index * self.size..(index + 1) * self.size] + }) + .unwrap_or_default(); + values.push(item); + validity.push(true); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push_null(); + validity.push(false); + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + page.buffer.clone() + } +} + +pub struct NestedIter { + iter: I, + data_type: DataType, + size: usize, + init: Vec, + items: VecDeque<(NestedState, (FixedSizeBinary, MutableBitmap))>, + dict: Option, + chunk_size: Option, + remaining: usize, +} + +impl NestedIter { + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + let size = FixedSizeBinaryArray::get_size(&data_type); + Self { + iter, + data_type, + size, + init, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + } + } +} + +impl Iterator for NestedIter { + type Item = Result<(NestedState, FixedSizeBinaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + &self.init, + self.chunk_size, + &BinaryDecoder { size: self.size }, + ); + match maybe_state { + MaybeNext::Some(Ok((nested, decoded))) => { + Some(Ok((nested, finish(&self.data_type, decoded.0, decoded.1)))) + }, + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/utils.rs new file mode 100644 index 000000000000..f718ce1bdc2b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/fixed_size_binary/utils.rs @@ -0,0 +1,58 @@ +use super::super::utils::Pushable; + +/// A [`Pushable`] for fixed sized binary data +#[derive(Debug)] +pub struct FixedSizeBinary { + pub values: Vec, + pub size: usize, +} + +impl FixedSizeBinary { + #[inline] + pub fn with_capacity(capacity: usize, size: usize) -> Self { + Self { + values: Vec::with_capacity(capacity * size), + size, + } + } + + #[inline] + pub fn push(&mut self, value: &[u8]) { + debug_assert_eq!(value.len(), self.size); + self.values.extend(value); + } + + #[inline] + pub fn extend_constant(&mut self, additional: usize) { + self.values + .resize(self.values.len() + additional * self.size, 0); + } +} + +impl<'a> Pushable<&'a [u8]> for FixedSizeBinary { + #[inline] + fn reserve(&mut self, additional: usize) { + self.values.reserve(additional * self.size); + } + #[inline] + fn push(&mut self, value: &[u8]) { + debug_assert_eq!(value.len(), self.size); + self.push(value); + } + + #[inline] + fn push_null(&mut self) { + self.values.extend(std::iter::repeat(0).take(self.size)) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: &[u8]) { + assert_eq!(value.len(), 0); + self.extend_constant(additional) + } + + #[inline] + fn len(&self) -> usize { + self.values.len() / self.size + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/mod.rs new file mode 100644 index 000000000000..098430b3d154 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/mod.rs @@ -0,0 +1,212 @@ +//! APIs to read from Parquet format. +mod binary; +mod boolean; +mod dictionary; +mod fixed_size_binary; +mod nested; +mod nested_utils; +mod null; +mod primitive; +mod simple; +mod struct_; +mod utils; + +use parquet2::read::get_page_iterator as _get_page_iterator; +use parquet2::schema::types::PrimitiveType; +use simple::page_iter_to_arrays; + +pub use self::nested_utils::{init_nested, InitNested, NestedArrayIter, NestedState}; +pub use self::struct_::StructIterator; +use super::*; +use crate::array::{Array, DictionaryKey, FixedSizeListArray, ListArray, MapArray}; +use crate::datatypes::{DataType, Field, IntervalUnit}; +use crate::error::Result; +use crate::offset::Offsets; + +/// Creates a new iterator of compressed pages. +pub fn get_page_iterator( + column_metadata: &ColumnChunkMetaData, + reader: R, + pages_filter: Option, + buffer: Vec, + max_header_size: usize, +) -> Result> { + Ok(_get_page_iterator( + column_metadata, + reader, + pages_filter, + buffer, + max_header_size, + )?) +} + +/// Creates a new [`ListArray`] or [`FixedSizeListArray`]. +pub fn create_list( + data_type: DataType, + nested: &mut NestedState, + values: Box, +) -> Box { + let (mut offsets, validity) = nested.nested.pop().unwrap().inner(); + match data_type.to_logical_type() { + DataType::List(_) => { + offsets.push(values.len() as i64); + + let offsets = offsets.iter().map(|x| *x as i32).collect::>(); + + let offsets: Offsets = offsets + .try_into() + .expect("i64 offsets do not fit in i32 offsets"); + + Box::new(ListArray::::new( + data_type, + offsets.into(), + values, + validity.and_then(|x| x.into()), + )) + }, + DataType::LargeList(_) => { + offsets.push(values.len() as i64); + + Box::new(ListArray::::new( + data_type, + offsets.try_into().expect("List too large"), + values, + validity.and_then(|x| x.into()), + )) + }, + DataType::FixedSizeList(_, _) => Box::new(FixedSizeListArray::new( + data_type, + values, + validity.and_then(|x| x.into()), + )), + _ => unreachable!(), + } +} + +/// Creates a new [`MapArray`]. +pub fn create_map( + data_type: DataType, + nested: &mut NestedState, + values: Box, +) -> Box { + let (mut offsets, validity) = nested.nested.pop().unwrap().inner(); + match data_type.to_logical_type() { + DataType::Map(_, _) => { + offsets.push(values.len() as i64); + let offsets = offsets.iter().map(|x| *x as i32).collect::>(); + + let offsets: Offsets = offsets + .try_into() + .expect("i64 offsets do not fit in i32 offsets"); + + Box::new(MapArray::new( + data_type, + offsets.into(), + values, + validity.and_then(|x| x.into()), + )) + }, + _ => unreachable!(), + } +} + +fn is_primitive(data_type: &DataType) -> bool { + matches!( + data_type.to_physical_type(), + crate::datatypes::PhysicalType::Primitive(_) + | crate::datatypes::PhysicalType::Null + | crate::datatypes::PhysicalType::Boolean + | crate::datatypes::PhysicalType::Utf8 + | crate::datatypes::PhysicalType::LargeUtf8 + | crate::datatypes::PhysicalType::Binary + | crate::datatypes::PhysicalType::LargeBinary + | crate::datatypes::PhysicalType::FixedSizeBinary + | crate::datatypes::PhysicalType::Dictionary(_) + ) +} + +fn columns_to_iter_recursive<'a, I: 'a>( + mut columns: Vec, + mut types: Vec<&PrimitiveType>, + field: Field, + init: Vec, + num_rows: usize, + chunk_size: Option, +) -> Result> +where + I: Pages, +{ + if init.is_empty() && is_primitive(&field.data_type) { + return Ok(Box::new( + page_iter_to_arrays( + columns.pop().unwrap(), + types.pop().unwrap(), + field.data_type, + chunk_size, + num_rows, + )? + .map(|x| Ok((NestedState::new(vec![]), x?))), + )); + } + + nested::columns_to_iter_recursive(columns, types, field, init, num_rows, chunk_size) +} + +/// Returns the number of (parquet) columns that a [`DataType`] contains. +pub fn n_columns(data_type: &DataType) -> usize { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | Dictionary(_) | LargeUtf8 => 1, + List | FixedSizeList | LargeList => { + let a = data_type.to_logical_type(); + if let DataType::List(inner) = a { + n_columns(&inner.data_type) + } else if let DataType::LargeList(inner) = a { + n_columns(&inner.data_type) + } else if let DataType::FixedSizeList(inner, _) = a { + n_columns(&inner.data_type) + } else { + unreachable!() + } + }, + Map => { + let a = data_type.to_logical_type(); + if let DataType::Map(inner, _) = a { + n_columns(&inner.data_type) + } else { + unreachable!() + } + }, + Struct => { + if let DataType::Struct(fields) = data_type.to_logical_type() { + fields.iter().map(|inner| n_columns(&inner.data_type)).sum() + } else { + unreachable!() + } + }, + _ => todo!(), + } +} + +/// An iterator adapter that maps multiple iterators of [`Pages`] into an iterator of [`Array`]s. +/// +/// For a non-nested datatypes such as [`DataType::Int32`], this function requires a single element in `columns` and `types`. +/// For nested types, `columns` must be composed by all parquet columns with associated types `types`. +/// +/// The arrays are guaranteed to be at most of size `chunk_size` and data type `field.data_type`. +pub fn column_iter_to_arrays<'a, I: 'a>( + columns: Vec, + types: Vec<&PrimitiveType>, + field: Field, + chunk_size: Option, + num_rows: usize, +) -> Result> +where + I: Pages, +{ + Ok(Box::new( + columns_to_iter_recursive(columns, types, field, vec![], num_rows, chunk_size)? + .map(|x| x.map(|x| x.1)), + )) +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/nested.rs new file mode 100644 index 000000000000..14f75fa8d672 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/nested.rs @@ -0,0 +1,590 @@ +use ethnum::I256; +use parquet2::schema::types::PrimitiveType; + +use super::nested_utils::{InitNested, NestedArrayIter}; +use super::*; +use crate::array::PrimitiveArray; +use crate::datatypes::{DataType, Field}; +use crate::error::{Error, Result}; + +/// Converts an iterator of arrays to a trait object returning trait objects +#[inline] +fn remove_nested<'a, I>(iter: I) -> NestedArrayIter<'a> +where + I: Iterator)>> + Send + Sync + 'a, +{ + Box::new(iter.map(|x| { + x.map(|(mut nested, array)| { + let _ = nested.nested.pop().unwrap(); // the primitive + (nested, array) + }) + })) +} + +/// Converts an iterator of arrays to a trait object returning trait objects +#[inline] +fn primitive<'a, A, I>(iter: I) -> NestedArrayIter<'a> +where + A: Array, + I: Iterator> + Send + Sync + 'a, +{ + Box::new(iter.map(|x| { + x.map(|(mut nested, array)| { + let _ = nested.nested.pop().unwrap(); // the primitive + (nested, Box::new(array) as _) + }) + })) +} + +pub fn columns_to_iter_recursive<'a, I: 'a>( + mut columns: Vec, + mut types: Vec<&PrimitiveType>, + field: Field, + mut init: Vec, + num_rows: usize, + chunk_size: Option, +) -> Result> +where + I: Pages, +{ + use crate::datatypes::PhysicalType::*; + use crate::datatypes::PrimitiveType::*; + + Ok(match field.data_type().to_physical_type() { + Null => { + // physical type is i32 + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(null::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + )) + }, + Boolean => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(boolean::NestedIter::new( + columns.pop().unwrap(), + init, + num_rows, + chunk_size, + )) + }, + Primitive(Int8) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as i8, + )) + }, + Primitive(Int16) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as i16, + )) + }, + Primitive(Int32) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x, + )) + }, + Primitive(Int64) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i64| x, + )) + }, + Primitive(UInt8) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as u8, + )) + }, + Primitive(UInt16) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as u16, + )) + }, + Primitive(UInt32) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + match type_.physical_type { + PhysicalType::Int32 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i32| x as u32, + )), + // some implementations of parquet write arrow's u32 into i64. + PhysicalType::Int64 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i64| x as u32, + )), + other => { + return Err(Error::nyi(format!( + "Deserializing UInt32 from {other:?}'s parquet" + ))) + }, + } + }, + Primitive(UInt64) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: i64| x as u64, + )) + }, + Primitive(Float32) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: f32| x, + )) + }, + Primitive(Float64) => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + |x: f64| x, + )) + }, + Binary | Utf8 => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + remove_nested(binary::NestedIter::::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + )) + }, + LargeBinary | LargeUtf8 => { + init.push(InitNested::Primitive(field.is_nullable)); + types.pop(); + remove_nested(binary::NestedIter::::new( + columns.pop().unwrap(), + init, + field.data_type().clone(), + num_rows, + chunk_size, + )) + }, + _ => match field.data_type().to_logical_type() { + DataType::Dictionary(key_type, _, _) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + let iter = columns.pop().unwrap(); + let data_type = field.data_type().clone(); + match_integer_type!(key_type, |$K| { + dict_read::<$K, _>(iter, init, type_, data_type, num_rows, chunk_size) + })? + }, + DataType::List(inner) + | DataType::LargeList(inner) + | DataType::FixedSizeList(inner, _) => { + init.push(InitNested::List(field.is_nullable)); + let iter = columns_to_iter_recursive( + columns, + types, + inner.as_ref().clone(), + init, + num_rows, + chunk_size, + )?; + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let array = create_list(field.data_type().clone(), &mut nested, array); + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + DataType::Decimal(_, _) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + match type_.physical_type { + PhysicalType::Int32 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i32| x as i128, + )), + PhysicalType::Int64 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i64| x as i128, + )), + PhysicalType::FixedLenByteArray(n) if n > 16 => { + return Err(Error::InvalidArgumentError(format!( + "Can't decode Decimal128 type from `FixedLenByteArray` of len {n}" + ))) + }, + PhysicalType::FixedLenByteArray(n) => { + let iter = fixed_size_binary::NestedIter::new( + columns.pop().unwrap(), + init, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + // Convert the fixed length byte array to Decimal. + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| super::super::convert_i128(value, n)) + .collect::>(); + let validity = array.validity().cloned(); + + let array: Box = Box::new(PrimitiveArray::::try_new( + field.data_type.clone(), + values.into(), + validity, + )?); + + let _ = nested.nested.pop().unwrap(); // the primitive + + Ok((nested, array)) + }); + Box::new(iter) + }, + _ => { + return Err(Error::nyi(format!( + "Deserializing type for Decimal {:?} from parquet", + type_.physical_type + ))) + }, + } + }, + DataType::Decimal256(_, _) => { + init.push(InitNested::Primitive(field.is_nullable)); + let type_ = types.pop().unwrap(); + match type_.physical_type { + PhysicalType::Int32 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i32| i256(I256::new(x as i128)), + )), + PhysicalType::Int64 => primitive(primitive::NestedIter::new( + columns.pop().unwrap(), + init, + field.data_type.clone(), + num_rows, + chunk_size, + |x: i64| i256(I256::new(x as i128)), + )), + PhysicalType::FixedLenByteArray(n) if n <= 16 => { + let iter = fixed_size_binary::NestedIter::new( + columns.pop().unwrap(), + init, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + // Convert the fixed length byte array to Decimal. + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let values = array + .values() + .chunks_exact(n) + .map(|value| i256(I256::new(super::super::convert_i128(value, n)))) + .collect::>(); + let validity = array.validity().cloned(); + + let array: Box = Box::new(PrimitiveArray::::try_new( + field.data_type.clone(), + values.into(), + validity, + )?); + + let _ = nested.nested.pop().unwrap(); // the primitive + + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + + PhysicalType::FixedLenByteArray(n) if n <= 32 => { + let iter = fixed_size_binary::NestedIter::new( + columns.pop().unwrap(), + init, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + // Convert the fixed length byte array to Decimal. + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_i256) + .collect::>(); + let validity = array.validity().cloned(); + + let array: Box = Box::new(PrimitiveArray::::try_new( + field.data_type.clone(), + values.into(), + validity, + )?); + + let _ = nested.nested.pop().unwrap(); // the primitive + + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + PhysicalType::FixedLenByteArray(n) => { + return Err(Error::InvalidArgumentError(format!( + "Can't decode Decimal256 type from from `FixedLenByteArray` of len {n}" + ))) + }, + _ => { + return Err(Error::nyi(format!( + "Deserializing type for Decimal {:?} from parquet", + type_.physical_type + ))) + }, + } + }, + DataType::Struct(fields) => { + let columns = fields + .iter() + .rev() + .map(|f| { + let mut init = init.clone(); + init.push(InitNested::Struct(field.is_nullable)); + let n = n_columns(&f.data_type); + let columns = columns.drain(columns.len() - n..).collect(); + let types = types.drain(types.len() - n..).collect(); + columns_to_iter_recursive( + columns, + types, + f.clone(), + init, + num_rows, + chunk_size, + ) + }) + .collect::>>()?; + let columns = columns.into_iter().rev().collect(); + Box::new(struct_::StructIterator::new(columns, fields.clone())) + }, + DataType::Map(inner, _) => { + init.push(InitNested::List(field.is_nullable)); + let iter = columns_to_iter_recursive( + columns, + types, + inner.as_ref().clone(), + init, + num_rows, + chunk_size, + )?; + let iter = iter.map(move |x| { + let (mut nested, array) = x?; + let array = create_map(field.data_type().clone(), &mut nested, array); + Ok((nested, array)) + }); + Box::new(iter) as _ + }, + other => { + return Err(Error::nyi(format!( + "Deserializing type {other:?} from parquet" + ))) + }, + }, + }) +} + +fn dict_read<'a, K: DictionaryKey, I: 'a + Pages>( + iter: I, + init: Vec, + _type_: &PrimitiveType, + data_type: DataType, + num_rows: usize, + chunk_size: Option, +) -> Result> { + use DataType::*; + let values_data_type = if let Dictionary(_, v, _) = &data_type { + v.as_ref() + } else { + panic!() + }; + + Ok(match values_data_type.to_logical_type() { + UInt8 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as u8, + )), + UInt16 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as u16, + )), + UInt32 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as u32, + )), + Int8 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as i8, + )), + Int16 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x as i16, + )), + Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth) => { + primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i32| x, + )) + }, + Int64 | Date64 | Time64(_) | Duration(_) => { + primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: i64| x as i32, + )) + }, + Float32 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: f32| x, + )), + Float64 => primitive(primitive::NestedDictIter::::new( + iter, + init, + data_type, + num_rows, + chunk_size, + |x: f64| x, + )), + Utf8 | Binary => primitive(binary::NestedDictIter::::new( + iter, init, data_type, num_rows, chunk_size, + )), + LargeUtf8 | LargeBinary => primitive(binary::NestedDictIter::::new( + iter, init, data_type, num_rows, chunk_size, + )), + FixedSizeBinary(_) => primitive(fixed_size_binary::NestedDictIter::::new( + iter, init, data_type, num_rows, chunk_size, + )), + /* + + Timestamp(time_unit, _) => { + let time_unit = *time_unit; + return timestamp_dict::( + iter, + physical_type, + logical_type, + data_type, + chunk_size, + time_unit, + ); + } + */ + other => { + return Err(Error::nyi(format!( + "Reading nested dictionaries of type {other:?}" + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/nested_utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/nested_utils.rs new file mode 100644 index 000000000000..fc68080e0799 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/nested_utils.rs @@ -0,0 +1,556 @@ +use std::collections::VecDeque; + +use parquet2::encoding::hybrid_rle::HybridRleDecoder; +use parquet2::page::{split_buffer, DataPage, DictPage, Page}; +use parquet2::read::levels::get_bit_width; + +use super::super::Pages; +pub use super::utils::Zip; +use super::utils::{DecodedState, MaybeNext, PageState}; +use crate::array::Array; +use crate::bitmap::MutableBitmap; +use crate::error::Result; + +/// trait describing deserialized repetition and definition levels +pub trait Nested: std::fmt::Debug + Send + Sync { + fn inner(&mut self) -> (Vec, Option); + + fn push(&mut self, length: i64, is_valid: bool); + + fn is_nullable(&self) -> bool; + + fn is_repeated(&self) -> bool { + false + } + + // Whether the Arrow container requires all items to be filled. + fn is_required(&self) -> bool; + + /// number of rows + fn len(&self) -> usize; + + /// number of values associated to the primitive type this nested tracks + fn num_values(&self) -> usize; +} + +#[derive(Debug, Default)] +pub struct NestedPrimitive { + is_nullable: bool, + length: usize, +} + +impl NestedPrimitive { + pub fn new(is_nullable: bool) -> Self { + Self { + is_nullable, + length: 0, + } + } +} + +impl Nested for NestedPrimitive { + fn inner(&mut self) -> (Vec, Option) { + (Default::default(), Default::default()) + } + + fn is_nullable(&self) -> bool { + self.is_nullable + } + + fn is_required(&self) -> bool { + false + } + + fn push(&mut self, _value: i64, _is_valid: bool) { + self.length += 1 + } + + fn len(&self) -> usize { + self.length + } + + fn num_values(&self) -> usize { + self.length + } +} + +#[derive(Debug, Default)] +pub struct NestedOptional { + pub validity: MutableBitmap, + pub offsets: Vec, +} + +impl Nested for NestedOptional { + fn inner(&mut self) -> (Vec, Option) { + let offsets = std::mem::take(&mut self.offsets); + let validity = std::mem::take(&mut self.validity); + (offsets, Some(validity)) + } + + fn is_nullable(&self) -> bool { + true + } + + fn is_repeated(&self) -> bool { + true + } + + fn is_required(&self) -> bool { + // it may be for FixedSizeList + false + } + + fn push(&mut self, value: i64, is_valid: bool) { + self.offsets.push(value); + self.validity.push(is_valid); + } + + fn len(&self) -> usize { + self.offsets.len() + } + + fn num_values(&self) -> usize { + self.offsets.last().copied().unwrap_or(0) as usize + } +} + +impl NestedOptional { + pub fn with_capacity(capacity: usize) -> Self { + let offsets = Vec::::with_capacity(capacity + 1); + let validity = MutableBitmap::with_capacity(capacity); + Self { validity, offsets } + } +} + +#[derive(Debug, Default)] +pub struct NestedValid { + pub offsets: Vec, +} + +impl Nested for NestedValid { + fn inner(&mut self) -> (Vec, Option) { + let offsets = std::mem::take(&mut self.offsets); + (offsets, None) + } + + fn is_nullable(&self) -> bool { + false + } + + fn is_repeated(&self) -> bool { + true + } + + fn is_required(&self) -> bool { + // it may be for FixedSizeList + false + } + + fn push(&mut self, value: i64, _is_valid: bool) { + self.offsets.push(value); + } + + fn len(&self) -> usize { + self.offsets.len() + } + + fn num_values(&self) -> usize { + self.offsets.last().copied().unwrap_or(0) as usize + } +} + +impl NestedValid { + pub fn with_capacity(capacity: usize) -> Self { + let offsets = Vec::::with_capacity(capacity + 1); + Self { offsets } + } +} + +#[derive(Debug, Default)] +pub struct NestedStructValid { + length: usize, +} + +impl NestedStructValid { + pub fn new() -> Self { + Self { length: 0 } + } +} + +impl Nested for NestedStructValid { + fn inner(&mut self) -> (Vec, Option) { + (Default::default(), None) + } + + fn is_nullable(&self) -> bool { + false + } + + fn is_required(&self) -> bool { + true + } + + fn push(&mut self, _value: i64, _is_valid: bool) { + self.length += 1; + } + + fn len(&self) -> usize { + self.length + } + + fn num_values(&self) -> usize { + self.length + } +} + +#[derive(Debug, Default)] +pub struct NestedStruct { + validity: MutableBitmap, +} + +impl NestedStruct { + pub fn with_capacity(capacity: usize) -> Self { + Self { + validity: MutableBitmap::with_capacity(capacity), + } + } +} + +impl Nested for NestedStruct { + fn inner(&mut self) -> (Vec, Option) { + (Default::default(), Some(std::mem::take(&mut self.validity))) + } + + fn is_nullable(&self) -> bool { + true + } + + fn is_required(&self) -> bool { + true + } + + fn push(&mut self, _value: i64, is_valid: bool) { + self.validity.push(is_valid) + } + + fn len(&self) -> usize { + self.validity.len() + } + + fn num_values(&self) -> usize { + self.validity.len() + } +} + +/// A decoder that knows how to map `State` -> Array +pub(super) trait NestedDecoder<'a> { + type State: PageState<'a>; + type Dictionary; + type DecodedState: DecodedState; + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result; + + /// Initializes a new state + fn with_capacity(&self, capacity: usize) -> Self::DecodedState; + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()>; + fn push_null(&self, decoded: &mut Self::DecodedState); + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary; +} + +/// The initial info of nested data types. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InitNested { + /// Primitive data types + Primitive(bool), + /// List data types + List(bool), + /// Struct data types + Struct(bool), +} + +/// Initialize [`NestedState`] from `&[InitNested]`. +pub fn init_nested(init: &[InitNested], capacity: usize) -> NestedState { + let container = init + .iter() + .map(|init| match init { + InitNested::Primitive(is_nullable) => { + Box::new(NestedPrimitive::new(*is_nullable)) as Box + }, + InitNested::List(is_nullable) => { + if *is_nullable { + Box::new(NestedOptional::with_capacity(capacity)) as Box + } else { + Box::new(NestedValid::with_capacity(capacity)) as Box + } + }, + InitNested::Struct(is_nullable) => { + if *is_nullable { + Box::new(NestedStruct::with_capacity(capacity)) as Box + } else { + Box::new(NestedStructValid::new()) as Box + } + }, + }) + .collect(); + NestedState::new(container) +} + +pub struct NestedPage<'a> { + iter: std::iter::Peekable, HybridRleDecoder<'a>>>, +} + +impl<'a> NestedPage<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (rep_levels, def_levels, _) = split_buffer(page)?; + + let max_rep_level = page.descriptor.max_rep_level; + let max_def_level = page.descriptor.max_def_level; + + let reps = + HybridRleDecoder::try_new(rep_levels, get_bit_width(max_rep_level), page.num_values())?; + let defs = + HybridRleDecoder::try_new(def_levels, get_bit_width(max_def_level), page.num_values())?; + + let iter = reps.zip(defs).peekable(); + + Ok(Self { iter }) + } + + // number of values (!= number of rows) + pub fn len(&self) -> usize { + self.iter.size_hint().0 + } +} + +/// The state of nested data types. +#[derive(Debug)] +pub struct NestedState { + /// The nesteds composing `NestedState`. + pub nested: Vec>, +} + +impl NestedState { + /// Creates a new [`NestedState`]. + pub fn new(nested: Vec>) -> Self { + Self { nested } + } + + /// The number of rows in this state + pub fn len(&self) -> usize { + // outermost is the number of rows + self.nested[0].len() + } +} + +/// Extends `items` by consuming `page`, first trying to complete the last `item` +/// and extending it if more are needed +pub(super) fn extend<'a, D: NestedDecoder<'a>>( + page: &'a DataPage, + init: &[InitNested], + items: &mut VecDeque<(NestedState, D::DecodedState)>, + dict: Option<&'a D::Dictionary>, + remaining: &mut usize, + decoder: &D, + chunk_size: Option, +) -> Result<()> { + let mut values_page = decoder.build_state(page, dict)?; + let mut page = NestedPage::try_new(page)?; + + let capacity = chunk_size.unwrap_or(0); + // chunk_size = None, remaining = 44 => chunk_size = 44 + let chunk_size = chunk_size.unwrap_or(usize::MAX); + + let (mut nested, mut decoded) = if let Some((nested, decoded)) = items.pop_back() { + (nested, decoded) + } else { + // there is no state => initialize it + (init_nested(init, capacity), decoder.with_capacity(0)) + }; + let existing = nested.len(); + + let additional = (chunk_size - existing).min(*remaining); + + // extend the current state + extend_offsets2( + &mut page, + &mut values_page, + &mut nested.nested, + &mut decoded, + decoder, + additional, + )?; + *remaining -= nested.len() - existing; + items.push_back((nested, decoded)); + + while page.len() > 0 && *remaining > 0 { + let additional = chunk_size.min(*remaining); + + let mut nested = init_nested(init, additional); + let mut decoded = decoder.with_capacity(0); + extend_offsets2( + &mut page, + &mut values_page, + &mut nested.nested, + &mut decoded, + decoder, + additional, + )?; + *remaining -= nested.len(); + items.push_back((nested, decoded)); + } + Ok(()) +} + +fn extend_offsets2<'a, D: NestedDecoder<'a>>( + page: &mut NestedPage<'a>, + values_state: &mut D::State, + nested: &mut [Box], + decoded: &mut D::DecodedState, + decoder: &D, + additional: usize, +) -> Result<()> { + let max_depth = nested.len(); + + let mut cum_sum = vec![0u32; max_depth + 1]; + for (i, nest) in nested.iter().enumerate() { + let delta = nest.is_nullable() as u32 + nest.is_repeated() as u32; + cum_sum[i + 1] = cum_sum[i] + delta; + } + + let mut cum_rep = vec![0u32; max_depth + 1]; + for (i, nest) in nested.iter().enumerate() { + let delta = nest.is_repeated() as u32; + cum_rep[i + 1] = cum_rep[i] + delta; + } + + let mut rows = 0; + while let Some((rep, def)) = page.iter.next() { + let rep = rep?; + let def = def?; + if rep == 0 { + rows += 1; + } + + let mut is_required = false; + for depth in 0..max_depth { + let right_level = rep <= cum_rep[depth] && def >= cum_sum[depth]; + if is_required || right_level { + let length = nested + .get(depth + 1) + .map(|x| x.len() as i64) + // the last depth is the leaf, which is always increased by 1 + .unwrap_or(1); + + let nest = &mut nested[depth]; + + let is_valid = nest.is_nullable() && def > cum_sum[depth]; + nest.push(length, is_valid); + is_required = nest.is_required() && !is_valid; + + if depth == max_depth - 1 { + // the leaf / primitive + let is_valid = (def != cum_sum[depth]) || !nest.is_nullable(); + if right_level && is_valid { + decoder.push_valid(values_state, decoded)?; + } else { + decoder.push_null(decoded); + } + } + } + } + + let next_rep = *page + .iter + .peek() + .map(|x| x.0.as_ref()) + .transpose() + .unwrap() // todo: fix this + .unwrap_or(&0); + + if next_rep == 0 && rows == additional { + break; + } + } + Ok(()) +} + +#[inline] +pub(super) fn next<'a, I, D>( + iter: &'a mut I, + items: &mut VecDeque<(NestedState, D::DecodedState)>, + dict: &'a mut Option, + remaining: &mut usize, + init: &[InitNested], + chunk_size: Option, + decoder: &D, +) -> MaybeNext> +where + I: Pages, + D: NestedDecoder<'a>, +{ + // front[a1, a2, a3, ...]back + if items.len() > 1 { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if (items.len() == 1) && items.front().unwrap().0.len() == chunk_size.unwrap_or(usize::MAX) { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if *remaining == 0 { + return match items.pop_front() { + Some(decoded) => MaybeNext::Some(Ok(decoded)), + None => MaybeNext::None, + }; + } + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(None) => { + if let Some(decoded) = items.pop_front() { + MaybeNext::Some(Ok(decoded)) + } else { + MaybeNext::None + } + }, + Ok(Some(page)) => { + let page = match page { + Page::Data(page) => page, + Page::Dict(dict_page) => { + *dict = Some(decoder.deserialize_dict(dict_page)); + return MaybeNext::More; + }, + }; + + // there is a new page => consume the page from the start + let error = extend( + page, + init, + items, + dict.as_ref(), + remaining, + decoder, + chunk_size, + ); + match error { + Ok(_) => {}, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + if (items.len() == 1) + && items.front().unwrap().0.len() < chunk_size.unwrap_or(usize::MAX) + { + MaybeNext::More + } else { + MaybeNext::Some(Ok(items.pop_front().unwrap())) + } + }, + } +} + +/// Type def for a sharable, boxed dyn [`Iterator`] of NestedStates and arrays +pub type NestedArrayIter<'a> = + Box)>> + Send + Sync + 'a>; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/null/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/null/mod.rs new file mode 100644 index 000000000000..576db09d364b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/null/mod.rs @@ -0,0 +1,104 @@ +mod nested; + +pub(super) use nested::NestedIter; +use parquet2::page::Page; + +use super::super::{ArrayIter, Pages}; +use crate::array::NullArray; +use crate::datatypes::DataType; + +/// Converts [`Pages`] to an [`ArrayIter`] +pub fn iter_to_arrays<'a, I>( + mut iter: I, + data_type: DataType, + chunk_size: Option, + num_rows: usize, +) -> ArrayIter<'a> +where + I: 'a + Pages, +{ + let mut len = 0usize; + + while let Ok(Some(page)) = iter.next() { + match page { + Page::Dict(_) => continue, + Page::Data(page) => { + let rows = page.num_values(); + len = (len + rows).min(num_rows); + if len == num_rows { + break; + } + }, + } + } + + if len == 0 { + return Box::new(std::iter::empty()); + } + + let chunk_size = chunk_size.unwrap_or(len); + + let complete_chunks = len / chunk_size; + + let remainder = len - (complete_chunks * chunk_size); + let i_data_type = data_type.clone(); + let complete = (0..complete_chunks) + .map(move |_| Ok(NullArray::new(i_data_type.clone(), chunk_size).boxed())); + if len % chunk_size == 0 { + Box::new(complete) + } else { + let array = NullArray::new(data_type, remainder); + Box::new(complete.chain(std::iter::once(Ok(array.boxed())))) + } +} + +#[cfg(test)] +mod tests { + use parquet2::encoding::Encoding; + use parquet2::error::Error as ParquetError; + use parquet2::metadata::Descriptor; + use parquet2::page::{DataPage, DataPageHeader, DataPageHeaderV1, Page}; + use parquet2::schema::types::{PhysicalType, PrimitiveType}; + + use super::iter_to_arrays; + use crate::array::NullArray; + use crate::datatypes::DataType; + use crate::error::Error; + + #[test] + fn limit() { + let new_page = |values: i32| { + Page::Data(DataPage::new( + DataPageHeader::V1(DataPageHeaderV1 { + num_values: values, + encoding: Encoding::Plain.into(), + definition_level_encoding: Encoding::Plain.into(), + repetition_level_encoding: Encoding::Plain.into(), + statistics: None, + }), + vec![], + Descriptor { + primitive_type: PrimitiveType::from_physical( + "a".to_string(), + PhysicalType::Int32, + ), + max_def_level: 0, + max_rep_level: 0, + }, + None, + )) + }; + + let p1 = new_page(100); + let p2 = new_page(100); + let pages = vec![Result::<_, ParquetError>::Ok(&p1), Ok(&p2)]; + let pages = fallible_streaming_iterator::convert(pages.into_iter()); + let arrays = iter_to_arrays(pages, DataType::Null, Some(10), 101); + + let arrays = arrays.collect::, Error>>().unwrap(); + let expected = std::iter::repeat(NullArray::new(DataType::Null, 10).boxed()) + .take(10) + .chain(std::iter::once(NullArray::new(DataType::Null, 1).boxed())); + assert_eq!(arrays, expected.collect::>()) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/null/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/null/nested.rs new file mode 100644 index 000000000000..9528720e73be --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/null/nested.rs @@ -0,0 +1,126 @@ +use std::collections::VecDeque; + +use parquet2::page::{DataPage, DictPage}; + +use super::super::nested_utils::*; +use super::super::{utils, Pages}; +use crate::array::NullArray; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::io::parquet::read::deserialize::utils::DecodedState; + +impl<'a> utils::PageState<'a> for usize { + fn len(&self) -> usize { + *self + } +} + +#[derive(Debug)] +struct NullDecoder {} + +impl DecodedState for usize { + fn len(&self) -> usize { + *self + } +} + +impl<'a> NestedDecoder<'a> for NullDecoder { + type State = usize; + type Dictionary = usize; + type DecodedState = usize; + + fn build_state( + &self, + _page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + if let Some(n) = dict { + return Ok(*n); + } + Ok(1) + } + + /// Initializes a new state + fn with_capacity(&self, _capacity: usize) -> Self::DecodedState { + 0 + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + *decoded += *state; + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let length = decoded; + *length += 1; + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + page.num_values + } +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as null arrays +#[derive(Debug)] +pub struct NestedIter +where + I: Pages, +{ + iter: I, + init: Vec, + data_type: DataType, + items: VecDeque<(NestedState, usize)>, + remaining: usize, + chunk_size: Option, + decoder: NullDecoder, +} + +impl NestedIter +where + I: Pages, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + ) -> Self { + Self { + iter, + init, + data_type, + items: VecDeque::new(), + chunk_size, + remaining: num_rows, + decoder: NullDecoder {}, + } + } +} + +impl Iterator for NestedIter +where + I: Pages, +{ + type Item = Result<(NestedState, NullArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut None, + &mut self.remaining, + &self.init, + self.chunk_size, + &self.decoder, + ); + match maybe_state { + utils::MaybeNext::Some(Ok((nested, state))) => { + Some(Ok((nested, NullArray::new(self.data_type.clone(), state)))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/basic.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/basic.rs new file mode 100644 index 000000000000..200c9a517dd0 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/basic.rs @@ -0,0 +1,370 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::{hybrid_rle, Encoding}; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; +use parquet2::types::{decode, NativeType as ParquetNativeType}; + +use super::super::utils::{get_selected_rows, FilteredOptionalPageValidity, OptionalPageValidity}; +use super::super::{utils, Pages}; +use crate::array::MutablePrimitiveArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +#[derive(Debug)] +pub(super) struct FilteredRequiredValues<'a> { + values: SliceFilteredIter>, +} + +impl<'a> FilteredRequiredValues<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + assert_eq!(values.len() % std::mem::size_of::

(), 0); + + let values = values.chunks_exact(std::mem::size_of::

()); + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(Self { values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct Values<'a> { + pub values: std::slice::ChunksExact<'a, u8>, +} + +impl<'a> Values<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, values) = split_buffer(page)?; + assert_eq!(values.len() % std::mem::size_of::

(), 0); + Ok(Self { + values: values.chunks_exact(std::mem::size_of::

()), + }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +#[derive(Debug)] +pub(super) struct ValuesDictionary<'a, T> +where + T: NativeType, +{ + pub values: hybrid_rle::HybridRleDecoder<'a>, + pub dict: &'a Vec, +} + +impl<'a, T> ValuesDictionary<'a, T> +where + T: NativeType, +{ + pub fn try_new(page: &'a DataPage, dict: &'a Vec) -> Result { + let values = utils::dict_indices_decoder(page)?; + + Ok(Self { dict, values }) + } + + #[inline] + pub fn len(&self) -> usize { + self.values.size_hint().0 + } +} + +// The state of a `DataPage` of `Primitive` parquet primitive type +#[derive(Debug)] +pub(super) enum State<'a, T> +where + T: NativeType, +{ + Optional(OptionalPageValidity<'a>, Values<'a>), + Required(Values<'a>), + RequiredDictionary(ValuesDictionary<'a, T>), + OptionalDictionary(OptionalPageValidity<'a>, ValuesDictionary<'a, T>), + FilteredRequired(FilteredRequiredValues<'a>), + FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>), +} + +impl<'a, T> utils::PageState<'a> for State<'a, T> +where + T: NativeType, +{ + fn len(&self) -> usize { + match self { + State::Optional(optional, _) => optional.len(), + State::Required(values) => values.len(), + State::RequiredDictionary(values) => values.len(), + State::OptionalDictionary(optional, _) => optional.len(), + State::FilteredRequired(values) => values.len(), + State::FilteredOptional(optional, _) => optional.len(), + } + } +} + +#[derive(Debug)] +pub(super) struct PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData

, + pub op: F, +} + +impl PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + #[inline] + pub(super) fn new(op: F) -> Self { + Self { + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData, + op, + } + } +} + +impl utils::DecodedState for (Vec, MutableBitmap) { + fn len(&self) -> usize { + self.0.len() + } +} + +impl<'a, T, P, F> utils::Decoder<'a> for PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type State = State<'a, T>; + type Dict = Vec; + type DecodedState = (Vec, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + ValuesDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + Ok(State::OptionalDictionary( + OptionalPageValidity::try_new(page)?, + ValuesDictionary::try_new(page, dict)?, + )) + }, + (Encoding::Plain, _, true, false) => { + let validity = OptionalPageValidity::try_new(page)?; + let values = Values::try_new::

(page)?; + + Ok(State::Optional(validity, values)) + }, + (Encoding::Plain, _, false, false) => Ok(State::Required(Values::try_new::

(page)?)), + (Encoding::Plain, _, false, true) => { + FilteredRequiredValues::try_new::

(page).map(State::FilteredRequired) + }, + (Encoding::Plain, _, true, true) => Ok(State::FilteredOptional( + FilteredOptionalPageValidity::try_new(page)?, + Values::try_new::

(page)?, + )), + _ => Err(utils::not_implemented(page)), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Optional(page_validity, page_values) => utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.values.by_ref().map(decode).map(self.op), + ), + State::Required(page) => { + values.extend( + page.values + .by_ref() + .map(decode) + .map(self.op) + .take(remaining), + ); + }, + State::OptionalDictionary(page_validity, page_values) => { + let op1 = |index: u32| page_values.dict[index as usize]; + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.values.by_ref().map(|x| x.unwrap()).map(op1), + ) + }, + State::RequiredDictionary(page) => { + let op1 = |index: u32| page.dict[index as usize]; + values.extend( + page.values + .by_ref() + .map(|x| x.unwrap()) + .map(op1) + .take(remaining), + ); + }, + State::FilteredRequired(page) => { + values.extend( + page.values + .by_ref() + .map(decode) + .map(self.op) + .take(remaining), + ); + }, + State::FilteredOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values.values.by_ref().map(decode).map(self.op), + ); + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + deserialize_plain(&page.buffer, self.op) + } +} + +pub(super) fn finish( + data_type: &DataType, + values: Vec, + validity: MutableBitmap, +) -> MutablePrimitiveArray { + let validity = if validity.is_empty() { + None + } else { + Some(validity) + }; + MutablePrimitiveArray::try_new(data_type.clone(), values, validity).unwrap() +} + +/// An [`Iterator`] adapter over [`Pages`] assumed to be encoded as primitive arrays +#[derive(Debug)] +pub struct Iter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + data_type: DataType, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + dict: Option>, + op: F, + phantom: std::marker::PhantomData

, +} + +impl Iter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + dict: None, + remaining: num_rows, + chunk_size, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for Iter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = utils::next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &PrimitiveDecoder::new(self.op), + ); + match maybe_state { + utils::MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} + +pub(super) fn deserialize_plain(values: &[u8], op: F) -> Vec +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + values + .chunks_exact(std::mem::size_of::

()) + .map(decode) + .map(op) + .collect::>() +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/dictionary.rs new file mode 100644 index 000000000000..35293d582d10 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/dictionary.rs @@ -0,0 +1,190 @@ +use std::collections::VecDeque; + +use parquet2::page::DictPage; +use parquet2::types::NativeType as ParquetNativeType; + +use super::super::dictionary::{nested_next_dict, *}; +use super::super::nested_utils::{InitNested, NestedState}; +use super::super::utils::MaybeNext; +use super::super::Pages; +use super::basic::deserialize_plain; +use crate::array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +fn read_dict(data_type: DataType, op: F, dict: &DictPage) -> Box +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + let data_type = match data_type { + DataType::Dictionary(_, values, _) => *values, + _ => data_type, + }; + let values = deserialize_plain(&dict.buffer, op); + + Box::new(PrimitiveArray::new(data_type, values.into(), None)) +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct DictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + data_type: DataType, + values: Option>, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + op: F, + phantom: std::marker::PhantomData

, +} + +impl DictIter +where + K: DictionaryKey, + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + data_type, + values: None, + items: VecDeque::new(), + chunk_size, + remaining: num_rows, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for DictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = next_dict( + &mut self.iter, + &mut self.items, + &mut self.values, + self.data_type.clone(), + &mut self.remaining, + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), self.op, dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} + +/// An iterator adapter that converts [`DataPages`] into an [`Iterator`] of [`DictionaryArray`] +#[derive(Debug)] +pub struct NestedDictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + init: Vec, + data_type: DataType, + values: Option>, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + remaining: usize, + chunk_size: Option, + op: F, + phantom: std::marker::PhantomData

, +} + +impl NestedDictIter +where + K: DictionaryKey, + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + init, + data_type, + values: None, + items: VecDeque::new(), + remaining: num_rows, + chunk_size, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for NestedDictIter +where + I: Pages, + T: NativeType, + K: DictionaryKey, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result<(NestedState, DictionaryArray)>; + + fn next(&mut self) -> Option { + let maybe_state = nested_next_dict( + &mut self.iter, + &mut self.items, + &mut self.remaining, + &self.init, + &mut self.values, + self.data_type.clone(), + self.chunk_size, + |dict| read_dict::(self.data_type.clone(), self.op, dict), + ); + match maybe_state { + MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), + MaybeNext::Some(Err(e)) => Some(Err(e)), + MaybeNext::None => None, + MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/integer.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/integer.rs new file mode 100644 index 000000000000..ac6c0bac0c1f --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/integer.rs @@ -0,0 +1,262 @@ +use std::collections::VecDeque; + +use num_traits::AsPrimitive; +use parquet2::deserialize::SliceFilteredIter; +use parquet2::encoding::delta_bitpacked::Decoder; +use parquet2::encoding::Encoding; +use parquet2::page::{split_buffer, DataPage, DictPage}; +use parquet2::schema::Repetition; +use parquet2::types::NativeType as ParquetNativeType; + +use super::super::{utils, Pages}; +use super::basic::{finish, PrimitiveDecoder, State as PrimitiveState}; +use crate::array::MutablePrimitiveArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::io::parquet::read::deserialize::utils::{ + get_selected_rows, FilteredOptionalPageValidity, OptionalPageValidity, +}; +use crate::types::NativeType; + +/// The state of a [`DataPage`] of an integer parquet type (i32 or i64) +#[derive(Debug)] +enum State<'a, T> +where + T: NativeType, +{ + Common(PrimitiveState<'a, T>), + DeltaBinaryPackedRequired(Decoder<'a>), + DeltaBinaryPackedOptional(OptionalPageValidity<'a>, Decoder<'a>), + FilteredDeltaBinaryPackedRequired(SliceFilteredIter>), + FilteredDeltaBinaryPackedOptional(FilteredOptionalPageValidity<'a>, Decoder<'a>), +} + +impl<'a, T> utils::PageState<'a> for State<'a, T> +where + T: NativeType, +{ + fn len(&self) -> usize { + match self { + State::Common(state) => state.len(), + State::DeltaBinaryPackedRequired(state) => state.size_hint().0, + State::DeltaBinaryPackedOptional(state, _) => state.len(), + State::FilteredDeltaBinaryPackedRequired(state) => state.size_hint().0, + State::FilteredDeltaBinaryPackedOptional(state, _) => state.len(), + } + } +} + +/// Decoder of integer parquet type +#[derive(Debug)] +struct IntDecoder(PrimitiveDecoder) +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Fn(P) -> T; + +impl IntDecoder +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Fn(P) -> T, +{ + #[inline] + fn new(op: F) -> Self { + Self(PrimitiveDecoder::new(op)) + } +} + +impl<'a, T, P, F> utils::Decoder<'a> for IntDecoder +where + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Copy + Fn(P) -> T, +{ + type State = State<'a, T>; + type Dict = Vec; + type DecodedState = (Vec, MutableBitmap); + + fn build_state(&self, page: &'a DataPage, dict: Option<&'a Self::Dict>) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::DeltaBinaryPacked, _, false, false) => { + let (_, _, values) = split_buffer(page)?; + Decoder::try_new(values) + .map(State::DeltaBinaryPackedRequired) + .map_err(Error::from) + }, + (Encoding::DeltaBinaryPacked, _, true, false) => { + let (_, _, values) = split_buffer(page)?; + Ok(State::DeltaBinaryPackedOptional( + OptionalPageValidity::try_new(page)?, + Decoder::try_new(values)?, + )) + }, + (Encoding::DeltaBinaryPacked, _, false, true) => { + let (_, _, values) = split_buffer(page)?; + let values = Decoder::try_new(values)?; + + let rows = get_selected_rows(page); + let values = SliceFilteredIter::new(values, rows); + + Ok(State::FilteredDeltaBinaryPackedRequired(values)) + }, + (Encoding::DeltaBinaryPacked, _, true, true) => { + let (_, _, values) = split_buffer(page)?; + let values = Decoder::try_new(values)?; + + Ok(State::FilteredDeltaBinaryPackedOptional( + FilteredOptionalPageValidity::try_new(page)?, + values, + )) + }, + _ => self.0.build_state(page, dict).map(State::Common), + } + } + + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + self.0.with_capacity(capacity) + } + + fn extend_from_state( + &self, + state: &mut Self::State, + decoded: &mut Self::DecodedState, + remaining: usize, + ) { + let (values, validity) = decoded; + match state { + State::Common(state) => self.0.extend_from_state(state, decoded, remaining), + State::DeltaBinaryPackedRequired(state) => { + values.extend( + state + .by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op) + .take(remaining), + ); + }, + State::DeltaBinaryPackedOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values + .by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op), + ) + }, + State::FilteredDeltaBinaryPackedRequired(page) => { + values.extend( + page.by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op) + .take(remaining), + ); + }, + State::FilteredDeltaBinaryPackedOptional(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + page_values + .by_ref() + .map(|x| x.unwrap().as_()) + .map(self.0.op), + ); + }, + } + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict { + self.0.deserialize_dict(page) + } +} + +/// An [`Iterator`] adapter over [`Pages`] assumed to be encoded as primitive arrays +/// encoded as parquet integer types +#[derive(Debug)] +pub struct IntegerIter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + iter: I, + data_type: DataType, + items: VecDeque<(Vec, MutableBitmap)>, + remaining: usize, + chunk_size: Option, + dict: Option>, + op: F, + phantom: std::marker::PhantomData

, +} + +impl IntegerIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + data_type, + items: VecDeque::new(), + dict: None, + remaining: num_rows, + chunk_size, + op, + phantom: Default::default(), + } + } +} + +impl Iterator for IntegerIter +where + I: Pages, + T: NativeType, + P: ParquetNativeType, + i64: num_traits::AsPrimitive

, + F: Copy + Fn(P) -> T, +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let maybe_state = utils::next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + self.chunk_size, + &IntDecoder::new(self.op), + ); + match maybe_state { + utils::MaybeNext::Some(Ok((values, validity))) => { + Some(Ok(finish(&self.data_type, values, validity))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/mod.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/mod.rs new file mode 100644 index 000000000000..27d9c27c3186 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/mod.rs @@ -0,0 +1,9 @@ +mod basic; +mod dictionary; +mod integer; +mod nested; + +pub use basic::Iter; +pub use dictionary::{DictIter, NestedDictIter}; +pub use integer::IntegerIter; +pub use nested::NestedIter; diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/nested.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/nested.rs new file mode 100644 index 000000000000..405e2d9a7c09 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/primitive/nested.rs @@ -0,0 +1,244 @@ +use std::collections::VecDeque; + +use parquet2::encoding::Encoding; +use parquet2::page::{DataPage, DictPage}; +use parquet2::schema::Repetition; +use parquet2::types::{decode, NativeType as ParquetNativeType}; + +use super::super::nested_utils::*; +use super::super::{utils, Pages}; +use super::basic::{deserialize_plain, Values, ValuesDictionary}; +use crate::array::PrimitiveArray; +use crate::bitmap::MutableBitmap; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::types::NativeType; + +// The state of a `DataPage` of `Primitive` parquet primitive type +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +enum State<'a, T> +where + T: NativeType, +{ + Optional(Values<'a>), + Required(Values<'a>), + RequiredDictionary(ValuesDictionary<'a, T>), + OptionalDictionary(ValuesDictionary<'a, T>), +} + +impl<'a, T> utils::PageState<'a> for State<'a, T> +where + T: NativeType, +{ + fn len(&self) -> usize { + match self { + State::Optional(values) => values.len(), + State::Required(values) => values.len(), + State::RequiredDictionary(values) => values.len(), + State::OptionalDictionary(values) => values.len(), + } + } +} + +#[derive(Debug)] +struct PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData

, + op: F, +} + +impl PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Fn(P) -> T, +{ + #[inline] + fn new(op: F) -> Self { + Self { + phantom: std::marker::PhantomData, + phantom_p: std::marker::PhantomData, + op, + } + } +} + +impl<'a, T, P, F> NestedDecoder<'a> for PrimitiveDecoder +where + T: NativeType, + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type State = State<'a, T>; + type Dictionary = Vec; + type DecodedState = (Vec, MutableBitmap); + + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dictionary>, + ) -> Result { + let is_optional = + page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + + match (page.encoding(), dict, is_optional, is_filtered) { + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), false, false) => { + ValuesDictionary::try_new(page, dict).map(State::RequiredDictionary) + }, + (Encoding::PlainDictionary | Encoding::RleDictionary, Some(dict), true, false) => { + ValuesDictionary::try_new(page, dict).map(State::OptionalDictionary) + }, + (Encoding::Plain, _, true, false) => Values::try_new::

(page).map(State::Optional), + (Encoding::Plain, _, false, false) => Values::try_new::

(page).map(State::Required), + _ => Err(utils::not_implemented(page)), + } + } + + /// Initializes a new state + fn with_capacity(&self, capacity: usize) -> Self::DecodedState { + ( + Vec::::with_capacity(capacity), + MutableBitmap::with_capacity(capacity), + ) + } + + fn push_valid(&self, state: &mut Self::State, decoded: &mut Self::DecodedState) -> Result<()> { + let (values, validity) = decoded; + match state { + State::Optional(page_values) => { + let value = page_values.values.by_ref().next().map(decode).map(self.op); + // convert unwrap to error + values.push(value.unwrap_or_default()); + validity.push(true); + }, + State::Required(page_values) => { + let value = page_values.values.by_ref().next().map(decode).map(self.op); + // convert unwrap to error + values.push(value.unwrap_or_default()); + }, + State::RequiredDictionary(page) => { + let value = page + .values + .next() + .map(|index| page.dict[index.unwrap() as usize]); + + values.push(value.unwrap_or_default()); + }, + State::OptionalDictionary(page) => { + let value = page + .values + .next() + .map(|index| page.dict[index.unwrap() as usize]); + + values.push(value.unwrap_or_default()); + validity.push(true); + }, + } + Ok(()) + } + + fn push_null(&self, decoded: &mut Self::DecodedState) { + let (values, validity) = decoded; + values.push(T::default()); + validity.push(false) + } + + fn deserialize_dict(&self, page: &DictPage) -> Self::Dictionary { + deserialize_plain(&page.buffer, self.op) + } +} + +fn finish( + data_type: &DataType, + values: Vec, + validity: MutableBitmap, +) -> PrimitiveArray { + PrimitiveArray::new(data_type.clone(), values.into(), validity.into()) +} + +/// An iterator adapter over [`Pages`] assumed to be encoded as boolean arrays +#[derive(Debug)] +pub struct NestedIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + iter: I, + init: Vec, + data_type: DataType, + items: VecDeque<(NestedState, (Vec, MutableBitmap))>, + dict: Option>, + remaining: usize, + chunk_size: Option, + decoder: PrimitiveDecoder, +} + +impl NestedIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + pub fn new( + iter: I, + init: Vec, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + op: F, + ) -> Self { + Self { + iter, + init, + data_type, + items: VecDeque::new(), + dict: None, + chunk_size, + remaining: num_rows, + decoder: PrimitiveDecoder::new(op), + } + } +} + +impl Iterator for NestedIter +where + I: Pages, + T: NativeType, + + P: ParquetNativeType, + F: Copy + Fn(P) -> T, +{ + type Item = Result<(NestedState, PrimitiveArray)>; + + fn next(&mut self) -> Option { + let maybe_state = next( + &mut self.iter, + &mut self.items, + &mut self.dict, + &mut self.remaining, + &self.init, + self.chunk_size, + &self.decoder, + ); + match maybe_state { + utils::MaybeNext::Some(Ok((nested, state))) => { + Some(Ok((nested, finish(&self.data_type, state.0, state.1)))) + }, + utils::MaybeNext::Some(Err(e)) => Some(Err(e)), + utils::MaybeNext::None => None, + utils::MaybeNext::More => self.next(), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/simple.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/simple.rs new file mode 100644 index 000000000000..83d9d8fbae8a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/simple.rs @@ -0,0 +1,651 @@ +use ethnum::I256; +use parquet2::schema::types::{ + PhysicalType, PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit, +}; +use parquet2::types::int96_to_i64_ns; + +use super::super::{ArrayIter, Pages}; +use super::{binary, boolean, fixed_size_binary, null, primitive}; +use crate::array::{Array, DictionaryKey, MutablePrimitiveArray, PrimitiveArray}; +use crate::datatypes::{DataType, IntervalUnit, TimeUnit}; +use crate::error::{Error, Result}; +use crate::types::{days_ms, i256, NativeType}; + +/// Converts an iterator of arrays to a trait object returning trait objects +#[inline] +fn dyn_iter<'a, A, I>(iter: I) -> ArrayIter<'a> +where + A: Array, + I: Iterator> + Send + Sync + 'a, +{ + Box::new(iter.map(|x| x.map(|x| Box::new(x) as Box))) +} + +/// Converts an iterator of [MutablePrimitiveArray] into an iterator of [PrimitiveArray] +#[inline] +fn iden(iter: I) -> impl Iterator>> +where + T: NativeType, + I: Iterator>>, +{ + iter.map(|x| x.map(|x| x.into())) +} + +#[inline] +fn op(iter: I, op: F) -> impl Iterator>> +where + T: NativeType, + I: Iterator>>, + F: Fn(T) -> T + Copy, +{ + iter.map(move |x| { + x.map(move |mut x| { + x.values_mut_slice().iter_mut().for_each(|x| *x = op(*x)); + x.into() + }) + }) +} + +/// An iterator adapter that maps an iterator of Pages into an iterator of Arrays +/// of [`DataType`] `data_type` and length `chunk_size`. +pub fn page_iter_to_arrays<'a, I: Pages + 'a>( + pages: I, + type_: &PrimitiveType, + data_type: DataType, + chunk_size: Option, + num_rows: usize, +) -> Result> { + use DataType::*; + + let physical_type = &type_.physical_type; + let logical_type = &type_.logical_type; + + Ok(match (physical_type, data_type.to_logical_type()) { + (_, Null) => null::iter_to_arrays(pages, data_type, chunk_size, num_rows), + (PhysicalType::Boolean, Boolean) => { + dyn_iter(boolean::Iter::new(pages, data_type, chunk_size, num_rows)) + }, + (PhysicalType::Int32, UInt8) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as u8, + ))), + (PhysicalType::Int32, UInt16) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as u16, + ))), + (PhysicalType::Int32, UInt32) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as u32, + ))), + (PhysicalType::Int64, UInt32) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x as u32, + ))), + (PhysicalType::Int32, Int8) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i8, + ))), + (PhysicalType::Int32, Int16) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i16, + ))), + (PhysicalType::Int32, Int32 | Date32 | Time32(_)) => dyn_iter(iden( + primitive::IntegerIter::new(pages, data_type, num_rows, chunk_size, |x: i32| x), + )), + (PhysicalType::Int64 | PhysicalType::Int96, Timestamp(time_unit, _)) => { + let time_unit = *time_unit; + return timestamp( + pages, + physical_type, + logical_type, + data_type, + num_rows, + chunk_size, + time_unit, + ); + }, + (PhysicalType::FixedLenByteArray(_), FixedSizeBinary(_)) => dyn_iter( + fixed_size_binary::Iter::new(pages, data_type, num_rows, chunk_size), + ), + (PhysicalType::FixedLenByteArray(12), Interval(IntervalUnit::YearMonth)) => { + let n = 12; + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| i32::from_le_bytes(value[..4].try_into().unwrap())) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::FixedLenByteArray(12), Interval(IntervalUnit::DayTime)) => { + let n = 12; + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_days_ms) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::Int32, Decimal(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i128, + ))), + (PhysicalType::Int64, Decimal(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x as i128, + ))), + (PhysicalType::FixedLenByteArray(n), Decimal(_, _)) if *n > 16 => { + return Err(Error::NotYetImplemented(format!( + "Can't decode Decimal128 type from Fixed Size Byte Array of len {n:?}" + ))) + }, + (PhysicalType::FixedLenByteArray(n), Decimal(_, _)) => { + let n = *n; + + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| super::super::convert_i128(value, n)) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::Int32, Decimal256(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| i256(I256::new(x as i128)), + ))), + (PhysicalType::Int64, Decimal256(_, _)) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| i256(I256::new(x as i128)), + ))), + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n <= 16 => { + let n = *n; + + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(|value: &[u8]| i256(I256::new(super::super::convert_i128(value, n)))) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n <= 32 => { + let n = *n; + + let pages = fixed_size_binary::Iter::new( + pages, + DataType::FixedSizeBinary(n), + num_rows, + chunk_size, + ); + + let pages = pages.map(move |maybe_array| { + let array = maybe_array?; + let values = array + .values() + .chunks_exact(n) + .map(super::super::convert_i256) + .collect::>(); + let validity = array.validity().cloned(); + + PrimitiveArray::::try_new(data_type.clone(), values.into(), validity) + }); + + let arrays = pages.map(|x| x.map(|x| x.boxed())); + + Box::new(arrays) as _ + }, + (PhysicalType::FixedLenByteArray(n), Decimal256(_, _)) if *n > 32 => { + return Err(Error::NotYetImplemented(format!( + "Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}" + ))) + }, + (PhysicalType::Int32, Date64) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i32| x as i64 * 86400000, + ))), + (PhysicalType::Int64, Date64) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x, + ))), + (PhysicalType::Int64, Int64 | Time64(_) | Duration(_)) => dyn_iter(iden( + primitive::IntegerIter::new(pages, data_type, num_rows, chunk_size, |x: i64| x), + )), + (PhysicalType::Int64, UInt64) => dyn_iter(iden(primitive::IntegerIter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: i64| x as u64, + ))), + (PhysicalType::Float, Float32) => dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: f32| x, + ))), + (PhysicalType::Double, Float64) => dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + |x: f64| x, + ))), + + (PhysicalType::ByteArray, Utf8 | Binary) => Box::new(binary::Iter::::new( + pages, data_type, chunk_size, num_rows, + )), + (PhysicalType::ByteArray, LargeBinary | LargeUtf8) => Box::new( + binary::Iter::::new(pages, data_type, chunk_size, num_rows), + ), + + (_, Dictionary(key_type, _, _)) => { + return match_integer_type!(key_type, |$K| { + dict_read::<$K, _>(pages, physical_type, logical_type, data_type, num_rows, chunk_size) + }) + }, + (from, to) => { + return Err(Error::NotYetImplemented(format!( + "Reading parquet type {from:?} to {to:?} still not implemented" + ))) + }, + }) +} + +/// Unify the timestamp unit from parquet TimeUnit into arrow's TimeUnit +/// Returns (a int64 factor, is_multiplier) +fn unify_timestamp_unit( + logical_type: &Option, + time_unit: TimeUnit, +) -> (i64, bool) { + if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + match (*unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) + | (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => (1, true), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) + | (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => (1000, false), + + (ParquetTimeUnit::Microseconds, TimeUnit::Second) + | (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => (1_000_000, false), + + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => (1_000_000_000, false), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) + | (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => (1_000, true), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => (1_000_000, true), + } + } else { + (1, true) + } +} + +#[inline] +pub fn int96_to_i64_us(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + const MICROS_PER_SECOND: i64 = 1_000_000; + + let day = value[2] as i64; + let microseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000; + let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + seconds * MICROS_PER_SECOND + microseconds +} + +#[inline] +pub fn int96_to_i64_ms(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + const MILLIS_PER_SECOND: i64 = 1_000; + + let day = value[2] as i64; + let milliseconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000_000; + let seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + seconds * MILLIS_PER_SECOND + milliseconds +} + +#[inline] +pub fn int96_to_i64_s(value: [u32; 3]) -> i64 { + const JULIAN_DAY_OF_EPOCH: i64 = 2_440_588; + const SECONDS_PER_DAY: i64 = 86_400; + + let day = value[2] as i64; + let seconds = (((value[1] as i64) << 32) + value[0] as i64) / 1_000_000_000; + let day_seconds = (day - JULIAN_DAY_OF_EPOCH) * SECONDS_PER_DAY; + + day_seconds + seconds +} + +fn timestamp<'a, I: Pages + 'a>( + pages: I, + physical_type: &PhysicalType, + logical_type: &Option, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + time_unit: TimeUnit, +) -> Result> { + if physical_type == &PhysicalType::Int96 { + return match time_unit { + TimeUnit::Nanosecond => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_ns, + )))), + TimeUnit::Microsecond => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_us, + )))), + TimeUnit::Millisecond => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_ms, + )))), + TimeUnit::Second => Ok(dyn_iter(iden(primitive::Iter::new( + pages, + data_type, + num_rows, + chunk_size, + int96_to_i64_s, + )))), + }; + }; + + if physical_type != &PhysicalType::Int64 { + return Err(Error::nyi( + "Can't decode a timestamp from a non-int64 parquet type", + )); + } + + let iter = primitive::IntegerIter::new(pages, data_type, num_rows, chunk_size, |x: i64| x); + let (factor, is_multiplier) = unify_timestamp_unit(logical_type, time_unit); + match (factor, is_multiplier) { + (1, _) => Ok(dyn_iter(iden(iter))), + (a, true) => Ok(dyn_iter(op(iter, move |x| x * a))), + (a, false) => Ok(dyn_iter(op(iter, move |x| x / a))), + } +} + +fn timestamp_dict<'a, K: DictionaryKey, I: Pages + 'a>( + pages: I, + physical_type: &PhysicalType, + logical_type: &Option, + data_type: DataType, + num_rows: usize, + chunk_size: Option, + time_unit: TimeUnit, +) -> Result> { + if physical_type == &PhysicalType::Int96 { + let logical_type = PrimitiveLogicalType::Timestamp { + unit: ParquetTimeUnit::Nanoseconds, + is_adjusted_to_utc: false, + }; + let (factor, is_multiplier) = unify_timestamp_unit(&Some(logical_type), time_unit); + return match (factor, is_multiplier) { + (a, true) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + DataType::Timestamp(TimeUnit::Nanosecond, None), + num_rows, + chunk_size, + move |x| int96_to_i64_ns(x) * a, + ))), + (a, false) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + DataType::Timestamp(TimeUnit::Nanosecond, None), + num_rows, + chunk_size, + move |x| int96_to_i64_ns(x) / a, + ))), + }; + }; + + let (factor, is_multiplier) = unify_timestamp_unit(logical_type, time_unit); + match (factor, is_multiplier) { + (a, true) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + data_type, + num_rows, + chunk_size, + move |x: i64| x * a, + ))), + (a, false) => Ok(dyn_iter(primitive::DictIter::::new( + pages, + data_type, + num_rows, + chunk_size, + move |x: i64| x / a, + ))), + } +} + +fn dict_read<'a, K: DictionaryKey, I: Pages + 'a>( + iter: I, + physical_type: &PhysicalType, + logical_type: &Option, + data_type: DataType, + num_rows: usize, + chunk_size: Option, +) -> Result> { + use DataType::*; + let values_data_type = if let Dictionary(_, v, _) = &data_type { + v.as_ref() + } else { + panic!() + }; + + Ok(match (physical_type, values_data_type.to_logical_type()) { + (PhysicalType::Int32, UInt8) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as u8, + )), + (PhysicalType::Int32, UInt16) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as u16, + )), + (PhysicalType::Int32, UInt32) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as u32, + )), + (PhysicalType::Int64, UInt64) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i64| x as u64, + )), + (PhysicalType::Int32, Int8) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as i8, + )), + (PhysicalType::Int32, Int16) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x as i16, + )), + (PhysicalType::Int32, Int32 | Date32 | Time32(_) | Interval(IntervalUnit::YearMonth)) => { + dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i32| x, + )) + }, + + (PhysicalType::Int64, Timestamp(time_unit, _)) => { + let time_unit = *time_unit; + return timestamp_dict::( + iter, + physical_type, + logical_type, + data_type, + num_rows, + chunk_size, + time_unit, + ); + }, + + (PhysicalType::Int64, Int64 | Date64 | Time64(_) | Duration(_)) => { + dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: i64| x, + )) + }, + (PhysicalType::Float, Float32) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: f32| x, + )), + (PhysicalType::Double, Float64) => dyn_iter(primitive::DictIter::::new( + iter, + data_type, + num_rows, + chunk_size, + |x: f64| x, + )), + + (PhysicalType::ByteArray, Utf8 | Binary) => dyn_iter(binary::DictIter::::new( + iter, data_type, num_rows, chunk_size, + )), + (PhysicalType::ByteArray, LargeUtf8 | LargeBinary) => dyn_iter( + binary::DictIter::::new(iter, data_type, num_rows, chunk_size), + ), + (PhysicalType::FixedLenByteArray(_), FixedSizeBinary(_)) => dyn_iter( + fixed_size_binary::DictIter::::new(iter, data_type, num_rows, chunk_size), + ), + other => { + return Err(Error::nyi(format!( + "Reading dictionaries of type {other:?}" + ))) + }, + }) +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/struct_.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/struct_.rs new file mode 100644 index 000000000000..947e7f1141e5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/struct_.rs @@ -0,0 +1,58 @@ +use super::nested_utils::{NestedArrayIter, NestedState}; +use crate::array::{Array, StructArray}; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; + +/// An iterator adapter over [`NestedArrayIter`] assumed to be encoded as Struct arrays +pub struct StructIterator<'a> { + iters: Vec>, + fields: Vec, +} + +impl<'a> StructIterator<'a> { + /// Creates a new [`StructIterator`] with `iters` and `fields`. + pub fn new(iters: Vec>, fields: Vec) -> Self { + assert_eq!(iters.len(), fields.len()); + Self { iters, fields } + } +} + +impl<'a> Iterator for StructIterator<'a> { + type Item = Result<(NestedState, Box), Error>; + + fn next(&mut self) -> Option { + let values = self + .iters + .iter_mut() + .map(|iter| iter.next()) + .collect::>(); + + if values.iter().any(|x| x.is_none()) { + return None; + } + + // todo: unzip of Result not yet supported in stable Rust + let mut nested = vec![]; + let mut new_values = vec![]; + for x in values { + match x.unwrap() { + Ok((nest, values)) => { + new_values.push(values); + nested.push(nest); + }, + Err(e) => return Some(Err(e)), + } + } + let mut nested = nested.pop().unwrap(); + let (_, validity) = nested.nested.pop().unwrap().inner(); + + Some(Ok(( + nested, + Box::new(StructArray::new( + DataType::Struct(self.fields.clone()), + new_values, + validity.and_then(|x| x.into()), + )), + ))) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/deserialize/utils.rs b/crates/nano-arrow/src/io/parquet/read/deserialize/utils.rs new file mode 100644 index 000000000000..a39a7506d8e1 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/deserialize/utils.rs @@ -0,0 +1,524 @@ +use std::collections::VecDeque; + +use parquet2::deserialize::{ + FilteredHybridEncoded, FilteredHybridRleDecoderIter, HybridDecoderBitmapIter, HybridEncoded, +}; +use parquet2::encoding::hybrid_rle; +use parquet2::indexes::Interval; +use parquet2::page::{split_buffer, DataPage, DictPage, Page}; +use parquet2::schema::Repetition; + +use super::super::Pages; +use crate::bitmap::utils::BitmapIter; +use crate::bitmap::MutableBitmap; +use crate::error::Error; + +pub fn not_implemented(page: &DataPage) -> Error { + let is_optional = page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; + let is_filtered = page.selected_rows().is_some(); + let required = if is_optional { "optional" } else { "required" }; + let is_filtered = if is_filtered { ", index-filtered" } else { "" }; + Error::NotYetImplemented(format!( + "Decoding {:?} \"{:?}\"-encoded {} {} parquet pages", + page.descriptor.primitive_type.physical_type, + page.encoding(), + required, + is_filtered, + )) +} + +/// A private trait representing structs that can receive elements. +pub(super) trait Pushable: Sized { + fn reserve(&mut self, additional: usize); + fn push(&mut self, value: T); + fn len(&self) -> usize; + fn push_null(&mut self); + fn extend_constant(&mut self, additional: usize, value: T); +} + +impl Pushable for MutableBitmap { + #[inline] + fn reserve(&mut self, additional: usize) { + MutableBitmap::reserve(self, additional) + } + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push(&mut self, value: bool) { + self.push(value) + } + + #[inline] + fn push_null(&mut self) { + self.push(false) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: bool) { + self.extend_constant(additional, value) + } +} + +impl Pushable for Vec { + #[inline] + fn reserve(&mut self, additional: usize) { + Vec::reserve(self, additional) + } + #[inline] + fn len(&self) -> usize { + self.len() + } + + #[inline] + fn push_null(&mut self) { + self.push(A::default()) + } + + #[inline] + fn push(&mut self, value: A) { + self.push(value) + } + + #[inline] + fn extend_constant(&mut self, additional: usize, value: A) { + self.resize(self.len() + additional, value); + } +} + +/// The state of a partially deserialized page +pub(super) trait PageValidity<'a> { + fn next_limited(&mut self, limit: usize) -> Option>; +} + +#[derive(Debug, Clone)] +pub struct FilteredOptionalPageValidity<'a> { + iter: FilteredHybridRleDecoderIter<'a>, + current: Option<(FilteredHybridEncoded<'a>, usize)>, +} + +impl<'a> FilteredOptionalPageValidity<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, validity, _) = split_buffer(page)?; + + let iter = hybrid_rle::Decoder::new(validity, 1); + let iter = HybridDecoderBitmapIter::new(iter, page.num_values()); + let selected_rows = get_selected_rows(page); + let iter = FilteredHybridRleDecoderIter::new(iter, selected_rows); + + Ok(Self { + iter, + current: None, + }) + } + + pub fn len(&self) -> usize { + self.iter.len() + } +} + +pub fn get_selected_rows(page: &DataPage) -> VecDeque { + page.selected_rows() + .unwrap_or(&[Interval::new(0, page.num_values())]) + .iter() + .copied() + .collect() +} + +impl<'a> PageValidity<'a> for FilteredOptionalPageValidity<'a> { + fn next_limited(&mut self, limit: usize) -> Option> { + let (run, own_offset) = if let Some((run, offset)) = self.current { + (run, offset) + } else { + // a new run + let run = self.iter.next()?.unwrap(); // no run -> None + self.current = Some((run, 0)); + return self.next_limited(limit); + }; + + match run { + FilteredHybridEncoded::Bitmap { + values, + offset, + length, + } => { + let run_length = length - own_offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, own_offset + length)); + } + + Some(FilteredHybridEncoded::Bitmap { + values, + offset, + length, + }) + }, + FilteredHybridEncoded::Repeated { is_set, length } => { + let run_length = length - own_offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, own_offset + length)); + } + + Some(FilteredHybridEncoded::Repeated { is_set, length }) + }, + FilteredHybridEncoded::Skipped(set) => { + self.current = None; + Some(FilteredHybridEncoded::Skipped(set)) + }, + } + } +} + +pub struct Zip { + validity: V, + values: I, +} + +impl Zip { + pub fn new(validity: V, values: I) -> Self { + Self { validity, values } + } +} + +impl, I: Iterator> Iterator for Zip { + type Item = Option; + + #[inline] + fn next(&mut self) -> Option { + self.validity + .next() + .map(|x| if x { self.values.next() } else { None }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.validity.size_hint() + } +} + +#[derive(Debug, Clone)] +pub struct OptionalPageValidity<'a> { + iter: HybridDecoderBitmapIter<'a>, + current: Option<(HybridEncoded<'a>, usize)>, +} + +impl<'a> OptionalPageValidity<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, validity, _) = split_buffer(page)?; + + let iter = hybrid_rle::Decoder::new(validity, 1); + let iter = HybridDecoderBitmapIter::new(iter, page.num_values()); + Ok(Self { + iter, + current: None, + }) + } + + /// Number of items remaining + pub fn len(&self) -> usize { + self.iter.len() + + self + .current + .as_ref() + .map(|(run, offset)| run.len() - offset) + .unwrap_or_default() + } + + fn next_limited(&mut self, limit: usize) -> Option> { + let (run, offset) = if let Some((run, offset)) = self.current { + (run, offset) + } else { + // a new run + let run = self.iter.next()?.unwrap(); // no run -> None + self.current = Some((run, 0)); + return self.next_limited(limit); + }; + + match run { + HybridEncoded::Bitmap(values, length) => { + let run_length = length - offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, offset + length)); + } + + Some(FilteredHybridEncoded::Bitmap { + values, + offset, + length, + }) + }, + HybridEncoded::Repeated(is_set, run_length) => { + let run_length = run_length - offset; + + let length = limit.min(run_length); + + if length == run_length { + self.current = None; + } else { + self.current = Some((run, offset + length)); + } + + Some(FilteredHybridEncoded::Repeated { is_set, length }) + }, + } + } +} + +impl<'a> PageValidity<'a> for OptionalPageValidity<'a> { + fn next_limited(&mut self, limit: usize) -> Option> { + self.next_limited(limit) + } +} + +/// Extends a [`Pushable`] from an iterator of non-null values and an hybrid-rle decoder +pub(super) fn extend_from_decoder, I: Iterator>( + validity: &mut MutableBitmap, + page_validity: &mut dyn PageValidity, + limit: Option, + pushable: &mut P, + mut values_iter: I, +) { + let limit = limit.unwrap_or(usize::MAX); + + let mut runs = vec![]; + let mut remaining = limit; + let mut reserve_pushable = 0; + + // first do a scan so that we know how much to reserve up front + while remaining > 0 { + let run = page_validity.next_limited(remaining); + let run = if let Some(run) = run { run } else { break }; + + match run { + FilteredHybridEncoded::Bitmap { length, .. } => { + reserve_pushable += length; + remaining -= length; + }, + FilteredHybridEncoded::Repeated { length, .. } => { + reserve_pushable += length; + remaining -= length; + }, + _ => {}, + }; + runs.push(run) + } + pushable.reserve(reserve_pushable); + validity.reserve(reserve_pushable); + + // then a second loop to really fill the buffers + for run in runs { + match run { + FilteredHybridEncoded::Bitmap { + values, + offset, + length, + } => { + // consume `length` items + let iter = BitmapIter::new(values, offset, length); + let iter = Zip::new(iter, &mut values_iter); + + for item in iter { + if let Some(item) = item { + pushable.push(item) + } else { + pushable.push_null() + } + } + validity.extend_from_slice(values, offset, length); + }, + FilteredHybridEncoded::Repeated { is_set, length } => { + validity.extend_constant(length, is_set); + if is_set { + for v in (&mut values_iter).take(length) { + pushable.push(v) + } + } else { + pushable.extend_constant(length, T::default()); + } + }, + FilteredHybridEncoded::Skipped(valids) => for _ in values_iter.by_ref().take(valids) {}, + }; + } +} + +/// The state of a partially deserialized page +pub(super) trait PageState<'a>: std::fmt::Debug { + fn len(&self) -> usize; +} + +/// The state of a partially deserialized page +pub(super) trait DecodedState: std::fmt::Debug { + // the number of values that the state already has + fn len(&self) -> usize; +} + +/// A decoder that knows how to map `State` -> Array +pub(super) trait Decoder<'a> { + /// The state that this decoder derives from a [`DataPage`]. This is bound to the page. + type State: PageState<'a>; + /// The dictionary representation that the decoder uses + type Dict; + /// The target state that this Decoder decodes into. + type DecodedState: DecodedState; + + /// Creates a new `Self::State` + fn build_state( + &self, + page: &'a DataPage, + dict: Option<&'a Self::Dict>, + ) -> Result; + + /// Initializes a new [`Self::DecodedState`]. + fn with_capacity(&self, capacity: usize) -> Self::DecodedState; + + /// extends [`Self::DecodedState`] by deserializing items in [`Self::State`]. + /// It guarantees that the length of `decoded` is at most `decoded.len() + remaining`. + fn extend_from_state( + &self, + page: &mut Self::State, + decoded: &mut Self::DecodedState, + additional: usize, + ); + + /// Deserializes a [`DictPage`] into [`Self::Dict`]. + fn deserialize_dict(&self, page: &DictPage) -> Self::Dict; +} + +pub(super) fn extend_from_new_page<'a, T: Decoder<'a>>( + mut page: T::State, + chunk_size: Option, + items: &mut VecDeque, + remaining: &mut usize, + decoder: &T, +) { + let capacity = chunk_size.unwrap_or(0); + let chunk_size = chunk_size.unwrap_or(usize::MAX); + + let mut decoded = if let Some(decoded) = items.pop_back() { + decoded + } else { + // there is no state => initialize it + decoder.with_capacity(capacity) + }; + let existing = decoded.len(); + + let additional = (chunk_size - existing).min(*remaining); + + decoder.extend_from_state(&mut page, &mut decoded, additional); + *remaining -= decoded.len() - existing; + items.push_back(decoded); + + while page.len() > 0 && *remaining > 0 { + let additional = chunk_size.min(*remaining); + + let mut decoded = decoder.with_capacity(additional); + decoder.extend_from_state(&mut page, &mut decoded, additional); + *remaining -= decoded.len(); + items.push_back(decoded) + } +} + +/// Represents what happened when a new page was consumed +#[derive(Debug)] +pub enum MaybeNext

{ + /// Whether the page was sufficient to fill `chunk_size` + Some(P), + /// whether there are no more pages or intermediary decoded states + None, + /// Whether the page was insufficient to fill `chunk_size` and a new page is required + More, +} + +#[inline] +pub(super) fn next<'a, I: Pages, D: Decoder<'a>>( + iter: &'a mut I, + items: &'a mut VecDeque, + dict: &'a mut Option, + remaining: &'a mut usize, + chunk_size: Option, + decoder: &'a D, +) -> MaybeNext> { + // front[a1, a2, a3, ...]back + if items.len() > 1 { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if (items.len() == 1) && items.front().unwrap().len() == chunk_size.unwrap_or(usize::MAX) { + return MaybeNext::Some(Ok(items.pop_front().unwrap())); + } + if *remaining == 0 { + return match items.pop_front() { + Some(decoded) => MaybeNext::Some(Ok(decoded)), + None => MaybeNext::None, + }; + } + + match iter.next() { + Err(e) => MaybeNext::Some(Err(e.into())), + Ok(Some(page)) => { + let page = match page { + Page::Data(page) => page, + Page::Dict(dict_page) => { + *dict = Some(decoder.deserialize_dict(dict_page)); + return MaybeNext::More; + }, + }; + + // there is a new page => consume the page from the start + let maybe_page = decoder.build_state(page, dict.as_ref()); + let page = match maybe_page { + Ok(page) => page, + Err(e) => return MaybeNext::Some(Err(e)), + }; + + extend_from_new_page(page, chunk_size, items, remaining, decoder); + + if (items.len() == 1) && items.front().unwrap().len() < chunk_size.unwrap_or(usize::MAX) + { + MaybeNext::More + } else { + let decoded = items.pop_front().unwrap(); + MaybeNext::Some(Ok(decoded)) + } + }, + Ok(None) => { + if let Some(decoded) = items.pop_front() { + // we have a populated item and no more pages + // the only case where an item's length may be smaller than chunk_size + debug_assert!(decoded.len() <= chunk_size.unwrap_or(usize::MAX)); + MaybeNext::Some(Ok(decoded)) + } else { + MaybeNext::None + } + }, + } +} + +#[inline] +pub(super) fn dict_indices_decoder(page: &DataPage) -> Result { + let (_, _, indices_buffer) = split_buffer(page)?; + + // SPEC: Data page format: the bit width used to encode the entry ids stored as 1 byte (max bit width = 32), + // SPEC: followed by the values encoded using RLE/Bit packed described above (with the given bit width). + let bit_width = indices_buffer[0]; + let indices_buffer = &indices_buffer[1..]; + + hybrid_rle::HybridRleDecoder::try_new(indices_buffer, bit_width as u32, page.num_values()) + .map_err(Error::from) +} diff --git a/crates/nano-arrow/src/io/parquet/read/file.rs b/crates/nano-arrow/src/io/parquet/read/file.rs new file mode 100644 index 000000000000..750340c60ef7 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/file.rs @@ -0,0 +1,205 @@ +use std::io::{Read, Seek}; + +use parquet2::indexes::FilteredPage; + +use super::{RowGroupDeserializer, RowGroupMetaData}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::Result; +use crate::io::parquet::read::read_columns_many; + +/// An iterator of [`Chunk`]s coming from row groups of a parquet file. +/// +/// This can be thought of a flatten chain of [`Iterator`] - each row group is sequentially +/// mapped to an [`Iterator`] and each iterator is iterated upon until either the limit +/// or the last iterator ends. +/// # Implementation +/// This iterator is single threaded on both IO-bounded and CPU-bounded tasks, and mixes them. +pub struct FileReader { + row_groups: RowGroupReader, + remaining_rows: usize, + current_row_group: Option, +} + +impl FileReader { + /// Returns a new [`FileReader`]. + pub fn new( + reader: R, + row_groups: Vec, + schema: Schema, + chunk_size: Option, + limit: Option, + page_indexes: Option>>>>, + ) -> Self { + let row_groups = + RowGroupReader::new(reader, schema, row_groups, chunk_size, limit, page_indexes); + + Self { + row_groups, + remaining_rows: limit.unwrap_or(usize::MAX), + current_row_group: None, + } + } + + fn next_row_group(&mut self) -> Result> { + let result = self.row_groups.next().transpose()?; + + // If current_row_group is None, then there will be no elements to remove. + if self.current_row_group.is_some() { + self.remaining_rows = self.remaining_rows.saturating_sub( + result + .as_ref() + .map(|x| x.num_rows()) + .unwrap_or(self.remaining_rows), + ); + } + Ok(result) + } + + /// Returns the [`Schema`] associated to this file. + pub fn schema(&self) -> &Schema { + &self.row_groups.schema + } +} + +impl Iterator for FileReader { + type Item = Result>>; + + fn next(&mut self) -> Option { + if self.remaining_rows == 0 { + // reached the limit + return None; + } + + if let Some(row_group) = &mut self.current_row_group { + match row_group.next() { + // no more chunks in the current row group => try a new one + None => match self.next_row_group() { + Ok(Some(row_group)) => { + self.current_row_group = Some(row_group); + // new found => pull again + self.next() + }, + Ok(None) => { + self.current_row_group = None; + None + }, + Err(e) => Some(Err(e)), + }, + other => other, + } + } else { + match self.next_row_group() { + Ok(Some(row_group)) => { + self.current_row_group = Some(row_group); + self.next() + }, + Ok(None) => { + self.current_row_group = None; + None + }, + Err(e) => Some(Err(e)), + } + } + } +} + +/// An [`Iterator`] from row groups of a parquet file. +/// +/// # Implementation +/// Advancing this iterator is IO-bounded - each iteration reads all the column chunks from the file +/// to memory and attaches [`RowGroupDeserializer`] to them so that they can be iterated in chunks. +pub struct RowGroupReader { + reader: R, + schema: Schema, + row_groups: std::vec::IntoIter, + chunk_size: Option, + remaining_rows: usize, + page_indexes: Option>>>>, +} + +impl RowGroupReader { + /// Returns a new [`RowGroupReader`] + pub fn new( + reader: R, + schema: Schema, + row_groups: Vec, + chunk_size: Option, + limit: Option, + page_indexes: Option>>>>, + ) -> Self { + if let Some(pages) = &page_indexes { + assert_eq!(pages.len(), row_groups.len()) + } + Self { + reader, + schema, + row_groups: row_groups.into_iter(), + chunk_size, + remaining_rows: limit.unwrap_or(usize::MAX), + page_indexes: page_indexes.map(|pages| pages.into_iter()), + } + } + + #[inline] + fn _next(&mut self) -> Result> { + if self.schema.fields.is_empty() { + return Ok(None); + } + if self.remaining_rows == 0 { + // reached the limit + return Ok(None); + } + + let row_group = if let Some(row_group) = self.row_groups.next() { + row_group + } else { + return Ok(None); + }; + + let pages = self.page_indexes.as_mut().and_then(|iter| iter.next()); + + // the number of rows depends on whether indexes are selected or not. + let num_rows = pages + .as_ref() + .map(|x| { + // first field, first column within that field + x[0][0] + .iter() + .map(|page| { + page.selected_rows + .iter() + .map(|interval| interval.length) + .sum::() + }) + .sum() + }) + .unwrap_or_else(|| row_group.num_rows()); + + let column_chunks = read_columns_many( + &mut self.reader, + &row_group, + self.schema.fields.clone(), + self.chunk_size, + Some(self.remaining_rows), + pages, + )?; + + let result = RowGroupDeserializer::new(column_chunks, num_rows, Some(self.remaining_rows)); + self.remaining_rows = self.remaining_rows.saturating_sub(num_rows); + Ok(Some(result)) + } +} + +impl Iterator for RowGroupReader { + type Item = Result; + + fn next(&mut self) -> Option { + self._next().transpose() + } + + fn size_hint(&self) -> (usize, Option) { + self.row_groups.size_hint() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/binary.rs b/crates/nano-arrow/src/io/parquet/read/indexes/binary.rs new file mode 100644 index 000000000000..9a7c7c4ca90b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/binary.rs @@ -0,0 +1,40 @@ +use parquet2::indexes::PageIndex; + +use super::ColumnPageStatistics; +use crate::array::{Array, BinaryArray, PrimitiveArray, Utf8Array}; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Error; +use crate::trusted_len::TrustedLen; + +pub fn deserialize( + indexes: &[PageIndex>], + data_type: &DataType, +) -> Result { + Ok(ColumnPageStatistics { + min: deserialize_binary_iter(indexes.iter().map(|index| index.min.as_ref()), data_type)?, + max: deserialize_binary_iter(indexes.iter().map(|index| index.max.as_ref()), data_type)?, + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + }) +} + +fn deserialize_binary_iter<'a, I: TrustedLen>>>( + iter: I, + data_type: &DataType, +) -> Result, Error> { + match data_type.to_physical_type() { + PhysicalType::LargeBinary => Ok(Box::new(BinaryArray::::from_iter(iter))), + PhysicalType::Utf8 => { + let iter = iter.map(|x| x.map(|x| std::str::from_utf8(x)).transpose()); + Ok(Box::new(Utf8Array::::try_from_trusted_len_iter(iter)?)) + }, + PhysicalType::LargeUtf8 => { + let iter = iter.map(|x| x.map(|x| std::str::from_utf8(x)).transpose()); + Ok(Box::new(Utf8Array::::try_from_trusted_len_iter(iter)?)) + }, + _ => Ok(Box::new(BinaryArray::::from_iter(iter))), + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/boolean.rs b/crates/nano-arrow/src/io/parquet/read/indexes/boolean.rs new file mode 100644 index 000000000000..70977197d103 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/boolean.rs @@ -0,0 +1,20 @@ +use parquet2::indexes::PageIndex; + +use super::ColumnPageStatistics; +use crate::array::{BooleanArray, PrimitiveArray}; + +pub fn deserialize(indexes: &[PageIndex]) -> ColumnPageStatistics { + ColumnPageStatistics { + min: Box::new(BooleanArray::from_trusted_len_iter( + indexes.iter().map(|index| index.min), + )), + max: Box::new(BooleanArray::from_trusted_len_iter( + indexes.iter().map(|index| index.max), + )), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/fixed_len_binary.rs b/crates/nano-arrow/src/io/parquet/read/indexes/fixed_len_binary.rs new file mode 100644 index 000000000000..26002e5857d5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/fixed_len_binary.rs @@ -0,0 +1,67 @@ +use parquet2::indexes::PageIndex; + +use super::ColumnPageStatistics; +use crate::array::{Array, FixedSizeBinaryArray, MutableFixedSizeBinaryArray, PrimitiveArray}; +use crate::datatypes::{DataType, PhysicalType, PrimitiveType}; +use crate::trusted_len::TrustedLen; +use crate::types::{i256, NativeType}; + +pub fn deserialize(indexes: &[PageIndex>], data_type: DataType) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_binary_iter( + indexes.iter().map(|index| index.min.as_ref()), + data_type.clone(), + ), + max: deserialize_binary_iter(indexes.iter().map(|index| index.max.as_ref()), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +fn deserialize_binary_iter<'a, I: TrustedLen>>>( + iter: I, + data_type: DataType, +) -> Box { + match data_type.to_physical_type() { + PhysicalType::Primitive(PrimitiveType::Int128) => { + Box::new(PrimitiveArray::from_trusted_len_iter(iter.map(|v| { + v.map(|x| { + // Copy the fixed-size byte value to the start of a 16 byte stack + // allocated buffer, then use an arithmetic right shift to fill in + // MSBs, which accounts for leading 1's in negative (two's complement) + // values. + let n = x.len(); + let mut bytes = [0u8; 16]; + bytes[..n].copy_from_slice(x); + i128::from_be_bytes(bytes) >> (8 * (16 - n)) + }) + }))) + }, + PhysicalType::Primitive(PrimitiveType::Int256) => { + Box::new(PrimitiveArray::from_trusted_len_iter(iter.map(|v| { + v.map(|x| { + let n = x.len(); + let mut bytes = [0u8; 32]; + bytes[..n].copy_from_slice(x); + i256::from_be_bytes(bytes) + }) + }))) + }, + _ => { + let mut a = MutableFixedSizeBinaryArray::try_new( + data_type, + Vec::with_capacity(iter.size_hint().0), + None, + ) + .unwrap(); + for item in iter { + a.push(item); + } + let a: FixedSizeBinaryArray = a.into(); + Box::new(a) + }, + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/mod.rs b/crates/nano-arrow/src/io/parquet/read/indexes/mod.rs new file mode 100644 index 000000000000..b60b717ebfd5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/mod.rs @@ -0,0 +1,381 @@ +//! API to perform page-level filtering (also known as indexes) +use parquet2::error::Error as ParquetError; +use parquet2::indexes::{ + select_pages, BooleanIndex, ByteIndex, FixedLenByteIndex, Index as ParquetIndex, NativeIndex, + PageLocation, +}; +use parquet2::metadata::{ColumnChunkMetaData, RowGroupMetaData}; +use parquet2::read::{read_columns_indexes as _read_columns_indexes, read_pages_locations}; +use parquet2::schema::types::PhysicalType as ParquetPhysicalType; + +mod binary; +mod boolean; +mod fixed_len_binary; +mod primitive; + +use std::collections::VecDeque; +use std::io::{Read, Seek}; + +pub use parquet2::indexes::{FilteredPage, Interval}; + +use super::get_field_pages; +use crate::array::{Array, UInt64Array}; +use crate::datatypes::{DataType, Field, PhysicalType, PrimitiveType}; +use crate::error::Error; + +/// Page statistics of an Arrow field. +#[derive(Debug, PartialEq)] +pub enum FieldPageStatistics { + /// Variant used for fields with a single parquet column (e.g. primitives, dictionaries, list) + Single(ColumnPageStatistics), + /// Variant used for fields with multiple parquet columns (e.g. Struct, Map) + Multiple(Vec), +} + +impl From for FieldPageStatistics { + fn from(column: ColumnPageStatistics) -> Self { + Self::Single(column) + } +} + +/// [`ColumnPageStatistics`] contains the minimum, maximum, and null_count +/// of each page of a parquet column, as an [`Array`]. +/// This struct has the following invariants: +/// * `min`, `max` and `null_count` have the same length (equal to the number of pages in the column) +/// * `min`, `max` and `null_count` are guaranteed to be non-null +/// * `min` and `max` have the same logical type +#[derive(Debug, PartialEq)] +pub struct ColumnPageStatistics { + /// The minimum values in the pages + pub min: Box, + /// The maximum values in the pages + pub max: Box, + /// The number of null values in the pages. + pub null_count: UInt64Array, +} + +/// Given a sequence of [`ParquetIndex`] representing the page indexes of each column in the +/// parquet file, returns the page-level statistics as a [`FieldPageStatistics`]. +/// +/// This function maps timestamps, decimal types, etc. accordingly. +/// # Implementation +/// This function is CPU-bounded `O(P)` where `P` is the total number of pages on all columns. +/// # Error +/// This function errors iff the value is not deserializable to arrow (e.g. invalid utf-8) +fn deserialize( + indexes: &mut VecDeque<&Box>, + data_type: DataType, +) -> Result { + match data_type.to_physical_type() { + PhysicalType::Boolean => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + Ok(boolean::deserialize(&index.indexes).into()) + }, + PhysicalType::Primitive(PrimitiveType::Int128) => { + let index = indexes.pop_front().unwrap(); + match index.physical_type() { + ParquetPhysicalType::Int32 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) + }, + parquet2::schema::types::PhysicalType::Int64 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok( + primitive::deserialize_i64( + &index.indexes, + &index.primitive_type, + data_type, + ) + .into(), + ) + }, + parquet2::schema::types::PhysicalType::FixedLenByteArray(_) => { + let index = index.as_any().downcast_ref::().unwrap(); + Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) + }, + other => Err(Error::nyi(format!( + "Deserialize {other:?} to arrow's int64" + ))), + } + }, + PhysicalType::Primitive(PrimitiveType::Int256) => { + let index = indexes.pop_front().unwrap(); + match index.physical_type() { + ParquetPhysicalType::Int32 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) + }, + parquet2::schema::types::PhysicalType::Int64 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok( + primitive::deserialize_i64( + &index.indexes, + &index.primitive_type, + data_type, + ) + .into(), + ) + }, + parquet2::schema::types::PhysicalType::FixedLenByteArray(_) => { + let index = index.as_any().downcast_ref::().unwrap(); + Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) + }, + other => Err(Error::nyi(format!( + "Deserialize {other:?} to arrow's int64" + ))), + } + }, + PhysicalType::Primitive(PrimitiveType::UInt8) + | PhysicalType::Primitive(PrimitiveType::UInt16) + | PhysicalType::Primitive(PrimitiveType::UInt32) + | PhysicalType::Primitive(PrimitiveType::Int32) => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_i32(&index.indexes, data_type).into()) + }, + PhysicalType::Primitive(PrimitiveType::UInt64) + | PhysicalType::Primitive(PrimitiveType::Int64) => { + let index = indexes.pop_front().unwrap(); + match index.physical_type() { + ParquetPhysicalType::Int64 => { + let index = index.as_any().downcast_ref::>().unwrap(); + Ok( + primitive::deserialize_i64( + &index.indexes, + &index.primitive_type, + data_type, + ) + .into(), + ) + }, + parquet2::schema::types::PhysicalType::Int96 => { + let index = index + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_i96(&index.indexes, data_type).into()) + }, + other => Err(Error::nyi(format!( + "Deserialize {other:?} to arrow's int64" + ))), + } + }, + PhysicalType::Primitive(PrimitiveType::Float32) => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_id(&index.indexes, data_type).into()) + }, + PhysicalType::Primitive(PrimitiveType::Float64) => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::>() + .unwrap(); + Ok(primitive::deserialize_id(&index.indexes, data_type).into()) + }, + PhysicalType::Binary + | PhysicalType::LargeBinary + | PhysicalType::Utf8 + | PhysicalType::LargeUtf8 => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + binary::deserialize(&index.indexes, &data_type).map(|x| x.into()) + }, + PhysicalType::FixedSizeBinary => { + let index = indexes + .pop_front() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + Ok(fixed_len_binary::deserialize(&index.indexes, data_type).into()) + }, + PhysicalType::Dictionary(_) => { + if let DataType::Dictionary(_, inner, _) = data_type.to_logical_type() { + deserialize(indexes, (**inner).clone()) + } else { + unreachable!() + } + }, + PhysicalType::List => { + if let DataType::List(inner) = data_type.to_logical_type() { + deserialize(indexes, inner.data_type.clone()) + } else { + unreachable!() + } + }, + PhysicalType::LargeList => { + if let DataType::LargeList(inner) = data_type.to_logical_type() { + deserialize(indexes, inner.data_type.clone()) + } else { + unreachable!() + } + }, + PhysicalType::Map => { + if let DataType::Map(inner, _) = data_type.to_logical_type() { + deserialize(indexes, inner.data_type.clone()) + } else { + unreachable!() + } + }, + PhysicalType::Struct => { + let children_fields = if let DataType::Struct(children) = data_type.to_logical_type() { + children + } else { + unreachable!() + }; + let children = children_fields + .iter() + .map(|child| deserialize(indexes, child.data_type.clone())) + .collect::, Error>>()?; + + Ok(FieldPageStatistics::Multiple(children)) + }, + + other => Err(Error::nyi(format!( + "Deserialize into arrow's {other:?} page index" + ))), + } +} + +/// Checks whether the row group have page index information (page statistics) +pub fn has_indexes(row_group: &RowGroupMetaData) -> bool { + row_group + .columns() + .iter() + .all(|chunk| chunk.column_chunk().column_index_offset.is_some()) +} + +/// Reads the column indexes from the reader assuming a valid set of derived Arrow fields +/// for all parquet the columns in the file. +/// +/// It returns one [`FieldPageStatistics`] per field in `fields` +/// +/// This function is expected to be used to filter out parquet pages. +/// +/// # Implementation +/// This function is IO-bounded and calls `reader.read_exact` exactly once. +/// # Error +/// Errors iff the indexes can't be read or their deserialization to arrow is incorrect (e.g. invalid utf-8) +pub fn read_columns_indexes( + reader: &mut R, + chunks: &[ColumnChunkMetaData], + fields: &[Field], +) -> Result, Error> { + let indexes = _read_columns_indexes(reader, chunks)?; + + fields + .iter() + .map(|field| { + let indexes = get_field_pages(chunks, &indexes, &field.name); + let mut indexes = indexes.into_iter().collect(); + + deserialize(&mut indexes, field.data_type.clone()) + }) + .collect() +} + +/// Returns the set of (row) intervals of the pages. +pub fn compute_page_row_intervals( + locations: &[PageLocation], + num_rows: usize, +) -> Result, ParquetError> { + if locations.is_empty() { + return Ok(vec![]); + }; + + let last = (|| { + let start: usize = locations.last().unwrap().first_row_index.try_into()?; + let length = num_rows - start; + Result::<_, ParquetError>::Ok(Interval::new(start, length)) + })(); + + let pages_lengths = locations + .windows(2) + .map(|x| { + let start = usize::try_from(x[0].first_row_index)?; + let length = usize::try_from(x[1].first_row_index - x[0].first_row_index)?; + Ok(Interval::new(start, length)) + }) + .chain(std::iter::once(last)); + pages_lengths.collect() +} + +/// Reads all page locations and index locations (IO-bounded) and uses `predicate` to compute +/// the set of [`FilteredPage`] that fulfill the predicate. +/// +/// The non-trivial argument of this function is `predicate`, that controls which pages are selected. +/// Its signature contains 2 arguments: +/// * 0th argument (indexes): contains one [`ColumnPageStatistics`] (page statistics) per field. +/// Use it to evaluate the predicate against +/// * 1th argument (intervals): contains one [`Vec>`] (row positions) per field. +/// For each field, the outermost vector corresponds to each parquet column: +/// a primitive field contains 1 column, a struct field with 2 primitive fields contain 2 columns. +/// The inner `Vec` contains one [`Interval`] per page: its length equals the length of [`ColumnPageStatistics`]. +/// It returns a single [`Vec`] denoting the set of intervals that the predicate selects (over all columns). +/// +/// This returns one item per `field`. For each field, there is one item per column (for non-nested types it returns one column) +/// and finally [`Vec`], that corresponds to the set of selected pages. +pub fn read_filtered_pages< + R: Read + Seek, + F: Fn(&[FieldPageStatistics], &[Vec>]) -> Vec, +>( + reader: &mut R, + row_group: &RowGroupMetaData, + fields: &[Field], + predicate: F, + //is_intersection: bool, +) -> Result>>, Error> { + let num_rows = row_group.num_rows(); + + // one vec per column + let locations = read_pages_locations(reader, row_group.columns())?; + // one Vec> per field (non-nested contain a single entry on the first column) + let locations = fields + .iter() + .map(|field| get_field_pages(row_group.columns(), &locations, &field.name)) + .collect::>(); + + // one ColumnPageStatistics per field + let indexes = read_columns_indexes(reader, row_group.columns(), fields)?; + + let intervals = locations + .iter() + .map(|locations| { + locations + .iter() + .map(|locations| Ok(compute_page_row_intervals(locations, num_rows)?)) + .collect::, Error>>() + }) + .collect::, Error>>()?; + + let intervals = predicate(&indexes, &intervals); + + locations + .into_iter() + .map(|locations| { + locations + .into_iter() + .map(|locations| Ok(select_pages(&intervals, locations, num_rows)?)) + .collect::, Error>>() + }) + .collect() +} diff --git a/crates/nano-arrow/src/io/parquet/read/indexes/primitive.rs b/crates/nano-arrow/src/io/parquet/read/indexes/primitive.rs new file mode 100644 index 000000000000..90e52e4a4aaf --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/indexes/primitive.rs @@ -0,0 +1,222 @@ +use ethnum::I256; +use parquet2::indexes::PageIndex; +use parquet2::schema::types::{PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit}; +use parquet2::types::int96_to_i64_ns; + +use super::ColumnPageStatistics; +use crate::array::{Array, MutablePrimitiveArray, PrimitiveArray}; +use crate::datatypes::{DataType, TimeUnit}; +use crate::trusted_len::TrustedLen; +use crate::types::{i256, NativeType}; + +#[inline] +fn deserialize_int32>>( + iter: I, + data_type: DataType, +) -> Box { + use DataType::*; + match data_type.to_logical_type() { + UInt8 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u8))) + .to(data_type), + ) as _, + UInt16 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u16))) + .to(data_type), + ), + UInt32 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u32))) + .to(data_type), + ), + Decimal(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as i128))) + .to(data_type), + ), + Decimal256(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter( + iter.map(|x| x.map(|x| i256(I256::new(x.into())))), + ) + .to(data_type), + ) as _, + _ => Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)), + } +} + +#[inline] +fn timestamp( + array: &mut MutablePrimitiveArray, + time_unit: TimeUnit, + logical_type: Option, +) { + let unit = if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + unit + } else { + return; + }; + + match (unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + (ParquetTimeUnit::Microseconds, TimeUnit::Second) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000_000), + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000_000_000), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) => {}, + (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000_000), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x *= 1_000), + (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) => {}, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x *= 1_000_000), + (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => array + .values_mut_slice() + .iter_mut() + .for_each(|x| *x /= 1_000), + (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => {}, + } +} + +#[inline] +fn deserialize_int64>>( + iter: I, + primitive_type: &PrimitiveType, + data_type: DataType, +) -> Box { + use DataType::*; + match data_type.to_logical_type() { + UInt64 => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as u64))) + .to(data_type), + ) as _, + Decimal(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(|x| x as i128))) + .to(data_type), + ) as _, + Decimal256(_, _) => Box::new( + PrimitiveArray::::from_trusted_len_iter( + iter.map(|x| x.map(|x| i256(I256::new(x.into())))), + ) + .to(data_type), + ) as _, + Timestamp(time_unit, _) => { + let mut array = + MutablePrimitiveArray::::from_trusted_len_iter(iter).to(data_type.clone()); + + timestamp(&mut array, *time_unit, primitive_type.logical_type); + + let array: PrimitiveArray = array.into(); + + Box::new(array) + }, + _ => Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)), + } +} + +#[inline] +fn deserialize_int96>>( + iter: I, + data_type: DataType, +) -> Box { + Box::new( + PrimitiveArray::::from_trusted_len_iter(iter.map(|x| x.map(int96_to_i64_ns))) + .to(data_type), + ) +} + +#[inline] +fn deserialize_id_s>>( + iter: I, + data_type: DataType, +) -> Box { + Box::new(PrimitiveArray::::from_trusted_len_iter(iter).to(data_type)) +} + +pub fn deserialize_i32(indexes: &[PageIndex], data_type: DataType) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_int32(indexes.iter().map(|index| index.min), data_type.clone()), + max: deserialize_int32(indexes.iter().map(|index| index.max), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +pub fn deserialize_i64( + indexes: &[PageIndex], + primitive_type: &PrimitiveType, + data_type: DataType, +) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_int64( + indexes.iter().map(|index| index.min), + primitive_type, + data_type.clone(), + ), + max: deserialize_int64( + indexes.iter().map(|index| index.max), + primitive_type, + data_type, + ), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +pub fn deserialize_i96( + indexes: &[PageIndex<[u32; 3]>], + data_type: DataType, +) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_int96(indexes.iter().map(|index| index.min), data_type.clone()), + max: deserialize_int96(indexes.iter().map(|index| index.max), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} + +pub fn deserialize_id( + indexes: &[PageIndex], + data_type: DataType, +) -> ColumnPageStatistics { + ColumnPageStatistics { + min: deserialize_id_s(indexes.iter().map(|index| index.min), data_type.clone()), + max: deserialize_id_s(indexes.iter().map(|index| index.max), data_type), + null_count: PrimitiveArray::from_trusted_len_iter( + indexes + .iter() + .map(|index| index.null_count.map(|x| x as u64)), + ), + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/mod.rs b/crates/nano-arrow/src/io/parquet/read/mod.rs new file mode 100644 index 000000000000..52a4d07d922e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/mod.rs @@ -0,0 +1,95 @@ +//! APIs to read from Parquet format. +#![allow(clippy::type_complexity)] + +mod deserialize; +mod file; +pub mod indexes; +mod row_group; +pub mod schema; +pub mod statistics; + +use std::io::{Read, Seek}; + +pub use deserialize::{ + column_iter_to_arrays, create_list, create_map, get_page_iterator, init_nested, n_columns, + InitNested, NestedArrayIter, NestedState, StructIterator, +}; +pub use file::{FileReader, RowGroupReader}; +use futures::{AsyncRead, AsyncSeek}; +// re-exports of parquet2's relevant APIs +pub use parquet2::{ + error::Error as ParquetError, + fallible_streaming_iterator, + metadata::{ColumnChunkMetaData, ColumnDescriptor, RowGroupMetaData}, + page::{CompressedDataPage, DataPageHeader, Page}, + read::{ + decompress, get_column_iterator, get_page_stream, + read_columns_indexes as _read_columns_indexes, read_metadata as _read_metadata, + read_metadata_async as _read_metadata_async, read_pages_locations, BasicDecompressor, + Decompressor, MutStreamingIterator, PageFilter, PageReader, ReadColumnIterator, State, + }, + schema::types::{ + GroupLogicalType, ParquetType, PhysicalType, PrimitiveConvertedType, PrimitiveLogicalType, + TimeUnit as ParquetTimeUnit, + }, + types::int96_to_i64_ns, + FallibleStreamingIterator, +}; +pub use row_group::*; +pub use schema::{infer_schema, FileMetaData}; + +use crate::array::Array; +use crate::error::Result; +use crate::types::{i256, NativeType}; + +/// Trait describing a [`FallibleStreamingIterator`] of [`Page`] +pub trait Pages: + FallibleStreamingIterator + Send + Sync +{ +} + +impl + Send + Sync> Pages for I {} + +/// Type def for a sharable, boxed dyn [`Iterator`] of arrays +pub type ArrayIter<'a> = Box>> + Send + Sync + 'a>; + +/// Reads parquets' metadata synchronously. +pub fn read_metadata(reader: &mut R) -> Result { + Ok(_read_metadata(reader)?) +} + +/// Reads parquets' metadata asynchronously. +pub async fn read_metadata_async( + reader: &mut R, +) -> Result { + Ok(_read_metadata_async(reader).await?) +} + +fn convert_days_ms(value: &[u8]) -> crate::types::days_ms { + crate::types::days_ms( + i32::from_le_bytes(value[4..8].try_into().unwrap()), + i32::from_le_bytes(value[8..12].try_into().unwrap()), + ) +} + +fn convert_i128(value: &[u8], n: usize) -> i128 { + // Copy the fixed-size byte value to the start of a 16 byte stack + // allocated buffer, then use an arithmetic right shift to fill in + // MSBs, which accounts for leading 1's in negative (two's complement) + // values. + let mut bytes = [0u8; 16]; + bytes[..n].copy_from_slice(value); + i128::from_be_bytes(bytes) >> (8 * (16 - n)) +} + +fn convert_i256(value: &[u8]) -> i256 { + if value[0] >= 128 { + let mut neg_bytes = [255u8; 32]; + neg_bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(neg_bytes) + } else { + let mut bytes = [0u8; 32]; + bytes[32 - value.len()..].copy_from_slice(value); + i256::from_be_bytes(bytes) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/row_group.rs b/crates/nano-arrow/src/io/parquet/read/row_group.rs new file mode 100644 index 000000000000..0b72897c5ac6 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/row_group.rs @@ -0,0 +1,339 @@ +use std::io::{Read, Seek}; + +use futures::future::{try_join_all, BoxFuture}; +use futures::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; +use parquet2::indexes::FilteredPage; +use parquet2::metadata::ColumnChunkMetaData; +use parquet2::read::{BasicDecompressor, IndexedPageReader, PageMetaData, PageReader}; + +use super::{ArrayIter, RowGroupMetaData}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Field; +use crate::error::Result; +use crate::io::parquet::read::column_iter_to_arrays; + +/// An [`Iterator`] of [`Chunk`] that (dynamically) adapts a vector of iterators of [`Array`] into +/// an iterator of [`Chunk`]. +/// +/// This struct tracks advances each of the iterators individually and combines the +/// result in a single [`Chunk`]. +/// +/// # Implementation +/// This iterator is single-threaded and advancing it is CPU-bounded. +pub struct RowGroupDeserializer { + num_rows: usize, + remaining_rows: usize, + column_chunks: Vec>, +} + +impl RowGroupDeserializer { + /// Creates a new [`RowGroupDeserializer`]. + /// + /// # Panic + /// This function panics iff any of the `column_chunks` + /// do not return an array with an equal length. + pub fn new( + column_chunks: Vec>, + num_rows: usize, + limit: Option, + ) -> Self { + Self { + num_rows, + remaining_rows: limit.unwrap_or(usize::MAX).min(num_rows), + column_chunks, + } + } + + /// Returns the number of rows on this row group + pub fn num_rows(&self) -> usize { + self.num_rows + } +} + +impl Iterator for RowGroupDeserializer { + type Item = Result>>; + + fn next(&mut self) -> Option { + if self.remaining_rows == 0 { + return None; + } + let chunk = self + .column_chunks + .iter_mut() + .map(|iter| iter.next().unwrap()) + .collect::>>() + .and_then(Chunk::try_new); + self.remaining_rows = self.remaining_rows.saturating_sub( + chunk + .as_ref() + .map(|x| x.len()) + .unwrap_or(self.remaining_rows), + ); + + Some(chunk) + } +} + +/// Returns all [`ColumnChunkMetaData`] associated to `field_name`. +/// For non-nested parquet types, this returns a single column +pub fn get_field_columns<'a>( + columns: &'a [ColumnChunkMetaData], + field_name: &str, +) -> Vec<&'a ColumnChunkMetaData> { + columns + .iter() + .filter(|x| x.descriptor().path_in_schema[0] == field_name) + .collect() +} + +/// Returns all [`ColumnChunkMetaData`] associated to `field_name`. +/// For non-nested parquet types, this returns a single column +pub fn get_field_pages<'a, T>( + columns: &'a [ColumnChunkMetaData], + items: &'a [T], + field_name: &str, +) -> Vec<&'a T> { + columns + .iter() + .zip(items) + .filter(|(metadata, _)| metadata.descriptor().path_in_schema[0] == field_name) + .map(|(_, item)| item) + .collect() +} + +/// Reads all columns that are part of the parquet field `field_name` +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns associated to +/// the field (one for non-nested types) +pub fn read_columns<'a, R: Read + Seek>( + reader: &mut R, + columns: &'a [ColumnChunkMetaData], + field_name: &str, +) -> Result)>> { + get_field_columns(columns, field_name) + .into_iter() + .map(|meta| _read_single_column(reader, meta)) + .collect() +} + +fn _read_single_column<'a, R>( + reader: &mut R, + meta: &'a ColumnChunkMetaData, +) -> Result<(&'a ColumnChunkMetaData, Vec)> +where + R: Read + Seek, +{ + let (start, length) = meta.byte_range(); + reader.seek(std::io::SeekFrom::Start(start))?; + + let mut chunk = vec![]; + chunk.try_reserve(length as usize)?; + reader.by_ref().take(length).read_to_end(&mut chunk)?; + Ok((meta, chunk)) +} + +async fn _read_single_column_async<'b, R, F>( + reader_factory: F, + meta: &ColumnChunkMetaData, +) -> Result<(&ColumnChunkMetaData, Vec)> +where + R: AsyncRead + AsyncSeek + Send + Unpin, + F: Fn() -> BoxFuture<'b, std::io::Result>, +{ + let mut reader = reader_factory().await?; + let (start, length) = meta.byte_range(); + reader.seek(std::io::SeekFrom::Start(start)).await?; + + let mut chunk = vec![]; + chunk.try_reserve(length as usize)?; + reader.take(length).read_to_end(&mut chunk).await?; + Result::Ok((meta, chunk)) +} + +/// Reads all columns that are part of the parquet field `field_name` +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns associated to +/// the field (one for non-nested types) +/// +/// It does so asynchronously via a single `join_all` over all the necessary columns for +/// `field_name`. +pub async fn read_columns_async< + 'a, + 'b, + R: AsyncRead + AsyncSeek + Send + Unpin, + F: Fn() -> BoxFuture<'b, std::io::Result> + Clone, +>( + reader_factory: F, + columns: &'a [ColumnChunkMetaData], + field_name: &str, +) -> Result)>> { + let futures = get_field_columns(columns, field_name) + .into_iter() + .map(|meta| async { _read_single_column_async(reader_factory.clone(), meta).await }); + + try_join_all(futures).await +} + +type Pages = Box< + dyn Iterator> + + Sync + + Send, +>; + +/// Converts a vector of columns associated with the parquet field whose name is [`Field`] +/// to an iterator of [`Array`], [`ArrayIter`] of chunk size `chunk_size`. +pub fn to_deserializer<'a>( + columns: Vec<(&ColumnChunkMetaData, Vec)>, + field: Field, + num_rows: usize, + chunk_size: Option, + pages: Option>>, +) -> Result> { + let chunk_size = chunk_size.map(|c| c.min(num_rows)); + + let (columns, types) = if let Some(pages) = pages { + let (columns, types): (Vec<_>, Vec<_>) = columns + .into_iter() + .zip(pages) + .map(|((column_meta, chunk), mut pages)| { + // de-offset the start, since we read in chunks (and offset is from start of file) + let mut meta: PageMetaData = column_meta.into(); + pages + .iter_mut() + .for_each(|page| page.start -= meta.column_start); + meta.column_start = 0; + let pages = IndexedPageReader::new_with_page_meta( + std::io::Cursor::new(chunk), + meta, + pages, + vec![], + vec![], + ); + let pages = Box::new(pages) as Pages; + ( + BasicDecompressor::new(pages, vec![]), + &column_meta.descriptor().descriptor.primitive_type, + ) + }) + .unzip(); + + (columns, types) + } else { + let (columns, types): (Vec<_>, Vec<_>) = columns + .into_iter() + .map(|(column_meta, chunk)| { + let len = chunk.len(); + let pages = PageReader::new( + std::io::Cursor::new(chunk), + column_meta, + std::sync::Arc::new(|_, _| true), + vec![], + len * 2 + 1024, + ); + let pages = Box::new(pages) as Pages; + ( + BasicDecompressor::new(pages, vec![]), + &column_meta.descriptor().descriptor.primitive_type, + ) + }) + .unzip(); + + (columns, types) + }; + + column_iter_to_arrays(columns, types, field, chunk_size, num_rows) +} + +/// Returns a vector of iterators of [`Array`] ([`ArrayIter`]) corresponding to the top +/// level parquet fields whose name matches `fields`'s names. +/// +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns in the row group - +/// it reads all the columns to memory from the row group associated to the requested fields. +/// +/// This operation is single-threaded. For readers with stronger invariants +/// (e.g. implement [`Clone`]) you can use [`read_columns`] to read multiple columns at once +/// and convert them to [`ArrayIter`] via [`to_deserializer`]. +pub fn read_columns_many<'a, R: Read + Seek>( + reader: &mut R, + row_group: &RowGroupMetaData, + fields: Vec, + chunk_size: Option, + limit: Option, + pages: Option>>>, +) -> Result>> { + let num_rows = row_group.num_rows(); + let num_rows = limit.map(|limit| limit.min(num_rows)).unwrap_or(num_rows); + + // reads all the necessary columns for all fields from the row group + // This operation is IO-bounded `O(C)` where C is the number of columns in the row group + let field_columns = fields + .iter() + .map(|field| read_columns(reader, row_group.columns(), &field.name)) + .collect::>>()?; + + if let Some(pages) = pages { + field_columns + .into_iter() + .zip(fields) + .zip(pages) + .map(|((columns, field), pages)| { + to_deserializer(columns, field, num_rows, chunk_size, Some(pages)) + }) + .collect() + } else { + field_columns + .into_iter() + .zip(fields) + .map(|(columns, field)| to_deserializer(columns, field, num_rows, chunk_size, None)) + .collect() + } +} + +/// Returns a vector of iterators of [`Array`] corresponding to the top level parquet fields whose +/// name matches `fields`'s names. +/// +/// # Implementation +/// This operation is IO-bounded `O(C)` where C is the number of columns in the row group - +/// it reads all the columns to memory from the row group associated to the requested fields. +/// It does so asynchronously via `join_all` +pub async fn read_columns_many_async< + 'a, + 'b, + R: AsyncRead + AsyncSeek + Send + Unpin, + F: Fn() -> BoxFuture<'b, std::io::Result> + Clone, +>( + reader_factory: F, + row_group: &RowGroupMetaData, + fields: Vec, + chunk_size: Option, + limit: Option, + pages: Option>>>, +) -> Result>> { + let num_rows = row_group.num_rows(); + let num_rows = limit.map(|limit| limit.min(num_rows)).unwrap_or(num_rows); + + let futures = fields + .iter() + .map(|field| read_columns_async(reader_factory.clone(), row_group.columns(), &field.name)); + + let field_columns = try_join_all(futures).await?; + + if let Some(pages) = pages { + field_columns + .into_iter() + .zip(fields) + .zip(pages) + .map(|((columns, field), pages)| { + to_deserializer(columns, field, num_rows, chunk_size, Some(pages)) + }) + .collect() + } else { + field_columns + .into_iter() + .zip(fields.into_iter()) + .map(|(columns, field)| to_deserializer(columns, field, num_rows, chunk_size, None)) + .collect() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/schema/convert.rs b/crates/nano-arrow/src/io/parquet/read/schema/convert.rs new file mode 100644 index 000000000000..4ae50e05e8e0 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/schema/convert.rs @@ -0,0 +1,1049 @@ +//! This module has entry points, [`parquet_to_arrow_schema`] and the more configurable [`parquet_to_arrow_schema_with_options`]. +use parquet2::schema::types::{ + FieldInfo, GroupConvertedType, GroupLogicalType, IntegerType, ParquetType, PhysicalType, + PrimitiveConvertedType, PrimitiveLogicalType, PrimitiveType, TimeUnit as ParquetTimeUnit, +}; +use parquet2::schema::Repetition; + +use crate::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use crate::io::parquet::read::schema::SchemaInferenceOptions; + +/// Converts [`ParquetType`]s to a [`Field`], ignoring parquet fields that do not contain +/// any physical column. +pub fn parquet_to_arrow_schema(fields: &[ParquetType]) -> Vec { + parquet_to_arrow_schema_with_options(fields, &None) +} + +/// Like [`parquet_to_arrow_schema`] but with configurable options which affect the behavior of schema inference +pub fn parquet_to_arrow_schema_with_options( + fields: &[ParquetType], + options: &Option, +) -> Vec { + fields + .iter() + .filter_map(|f| to_field(f, options.as_ref().unwrap_or(&Default::default()))) + .collect::>() +} + +fn from_int32( + logical_type: Option, + converted_type: Option, +) -> DataType { + use PrimitiveLogicalType::*; + match (logical_type, converted_type) { + // handle logical types first + (Some(Integer(t)), _) => match t { + IntegerType::Int8 => DataType::Int8, + IntegerType::Int16 => DataType::Int16, + IntegerType::Int32 => DataType::Int32, + IntegerType::UInt8 => DataType::UInt8, + IntegerType::UInt16 => DataType::UInt16, + IntegerType::UInt32 => DataType::UInt32, + // The above are the only possible annotations for parquet's int32. Anything else + // is a deviation to the parquet specification and we ignore + _ => DataType::Int32, + }, + (Some(Decimal(precision, scale)), _) => DataType::Decimal(precision, scale), + (Some(Date), _) => DataType::Date32, + (Some(Time { unit, .. }), _) => match unit { + ParquetTimeUnit::Milliseconds => DataType::Time32(TimeUnit::Millisecond), + // MILLIS is the only possible annotation for parquet's int32. Anything else + // is a deviation to the parquet specification and we ignore + _ => DataType::Int32, + }, + // handle converted types: + (_, Some(PrimitiveConvertedType::Uint8)) => DataType::UInt8, + (_, Some(PrimitiveConvertedType::Uint16)) => DataType::UInt16, + (_, Some(PrimitiveConvertedType::Uint32)) => DataType::UInt32, + (_, Some(PrimitiveConvertedType::Int8)) => DataType::Int8, + (_, Some(PrimitiveConvertedType::Int16)) => DataType::Int16, + (_, Some(PrimitiveConvertedType::Int32)) => DataType::Int32, + (_, Some(PrimitiveConvertedType::Date)) => DataType::Date32, + (_, Some(PrimitiveConvertedType::TimeMillis)) => DataType::Time32(TimeUnit::Millisecond), + (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + DataType::Decimal(precision, scale) + }, + (_, _) => DataType::Int32, + } +} + +fn from_int64( + logical_type: Option, + converted_type: Option, +) -> DataType { + use PrimitiveLogicalType::*; + match (logical_type, converted_type) { + // handle logical types first + (Some(Integer(integer)), _) => match integer { + IntegerType::UInt64 => DataType::UInt64, + IntegerType::Int64 => DataType::Int64, + _ => DataType::Int64, + }, + ( + Some(Timestamp { + is_adjusted_to_utc, + unit, + }), + _, + ) => { + let timezone = if is_adjusted_to_utc { + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // A TIMESTAMP with isAdjustedToUTC=true is defined as [...] elapsed since the Unix epoch + Some("+00:00".to_string()) + } else { + // PARQUET: + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // A TIMESTAMP with isAdjustedToUTC=false represents [...] such + // timestamps should always be displayed the same way, regardless of the local time zone in effect + // ARROW: + // https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + // If the time zone is null or equal to an empty string, the data is "time + // zone naive" and shall be displayed *as is* to the user, not localized + // to the locale of the user. + None + }; + + match unit { + ParquetTimeUnit::Milliseconds => { + DataType::Timestamp(TimeUnit::Millisecond, timezone) + }, + ParquetTimeUnit::Microseconds => { + DataType::Timestamp(TimeUnit::Microsecond, timezone) + }, + ParquetTimeUnit::Nanoseconds => DataType::Timestamp(TimeUnit::Nanosecond, timezone), + } + }, + (Some(Time { unit, .. }), _) => match unit { + ParquetTimeUnit::Microseconds => DataType::Time64(TimeUnit::Microsecond), + ParquetTimeUnit::Nanoseconds => DataType::Time64(TimeUnit::Nanosecond), + // MILLIS is only possible for int32. Appearing in int64 is a deviation + // to parquet's spec, which we ignore + _ => DataType::Int64, + }, + (Some(Decimal(precision, scale)), _) => DataType::Decimal(precision, scale), + // handle converted types: + (_, Some(PrimitiveConvertedType::TimeMicros)) => DataType::Time64(TimeUnit::Microsecond), + (_, Some(PrimitiveConvertedType::TimestampMillis)) => { + DataType::Timestamp(TimeUnit::Millisecond, None) + }, + (_, Some(PrimitiveConvertedType::TimestampMicros)) => { + DataType::Timestamp(TimeUnit::Microsecond, None) + }, + (_, Some(PrimitiveConvertedType::Int64)) => DataType::Int64, + (_, Some(PrimitiveConvertedType::Uint64)) => DataType::UInt64, + (_, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + DataType::Decimal(precision, scale) + }, + + (_, _) => DataType::Int64, + } +} + +fn from_byte_array( + logical_type: &Option, + converted_type: &Option, +) -> DataType { + match (logical_type, converted_type) { + (Some(PrimitiveLogicalType::String), _) => DataType::Utf8, + (Some(PrimitiveLogicalType::Json), _) => DataType::Binary, + (Some(PrimitiveLogicalType::Bson), _) => DataType::Binary, + (Some(PrimitiveLogicalType::Enum), _) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Json)) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Bson)) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Enum)) => DataType::Binary, + (_, Some(PrimitiveConvertedType::Utf8)) => DataType::Utf8, + (_, _) => DataType::Binary, + } +} + +fn from_fixed_len_byte_array( + length: usize, + logical_type: Option, + converted_type: Option, +) -> DataType { + match (logical_type, converted_type) { + (Some(PrimitiveLogicalType::Decimal(precision, scale)), _) => { + DataType::Decimal(precision, scale) + }, + (None, Some(PrimitiveConvertedType::Decimal(precision, scale))) => { + DataType::Decimal(precision, scale) + }, + (None, Some(PrimitiveConvertedType::Interval)) => { + // There is currently no reliable way of determining which IntervalUnit + // to return. Thus without the original Arrow schema, the results + // would be incorrect if all 12 bytes of the interval are populated + DataType::Interval(IntervalUnit::DayTime) + }, + _ => DataType::FixedSizeBinary(length), + } +} + +/// Maps a [`PhysicalType`] with optional metadata to a [`DataType`] +fn to_primitive_type_inner( + primitive_type: &PrimitiveType, + options: &SchemaInferenceOptions, +) -> DataType { + match primitive_type.physical_type { + PhysicalType::Boolean => DataType::Boolean, + PhysicalType::Int32 => { + from_int32(primitive_type.logical_type, primitive_type.converted_type) + }, + PhysicalType::Int64 => { + from_int64(primitive_type.logical_type, primitive_type.converted_type) + }, + PhysicalType::Int96 => DataType::Timestamp(options.int96_coerce_to_timeunit, None), + PhysicalType::Float => DataType::Float32, + PhysicalType::Double => DataType::Float64, + PhysicalType::ByteArray => { + from_byte_array(&primitive_type.logical_type, &primitive_type.converted_type) + }, + PhysicalType::FixedLenByteArray(length) => from_fixed_len_byte_array( + length, + primitive_type.logical_type, + primitive_type.converted_type, + ), + } +} + +/// Entry point for converting parquet primitive type to arrow type. +/// +/// This function takes care of repetition. +fn to_primitive_type(primitive_type: &PrimitiveType, options: &SchemaInferenceOptions) -> DataType { + let base_type = to_primitive_type_inner(primitive_type, options); + + if primitive_type.field_info.repetition == Repetition::Repeated { + DataType::List(Box::new(Field::new( + &primitive_type.field_info.name, + base_type, + is_nullable(&primitive_type.field_info), + ))) + } else { + base_type + } +} + +fn non_repeated_group( + logical_type: &Option, + converted_type: &Option, + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + debug_assert!(!fields.is_empty()); + match (logical_type, converted_type) { + (Some(GroupLogicalType::List), _) => to_list(fields, parent_name, options), + (None, Some(GroupConvertedType::List)) => to_list(fields, parent_name, options), + (Some(GroupLogicalType::Map), _) => to_list(fields, parent_name, options), + (None, Some(GroupConvertedType::Map) | Some(GroupConvertedType::MapKeyValue)) => { + to_map(fields, options) + }, + _ => to_struct(fields, options), + } +} + +/// Converts a parquet group type to an arrow [`DataType::Struct`]. +/// Returns [`None`] if all its fields are empty +fn to_struct(fields: &[ParquetType], options: &SchemaInferenceOptions) -> Option { + let fields = fields + .iter() + .filter_map(|f| to_field(f, options)) + .collect::>(); + if fields.is_empty() { + None + } else { + Some(DataType::Struct(fields)) + } +} + +/// Converts a parquet group type to an arrow [`DataType::Struct`]. +/// Returns [`None`] if all its fields are empty +fn to_map(fields: &[ParquetType], options: &SchemaInferenceOptions) -> Option { + let inner = to_field(&fields[0], options)?; + Some(DataType::Map(Box::new(inner), false)) +} + +/// Entry point for converting parquet group type. +/// +/// This function takes care of logical type and repetition. +fn to_group_type( + field_info: &FieldInfo, + logical_type: &Option, + converted_type: &Option, + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + debug_assert!(!fields.is_empty()); + if field_info.repetition == Repetition::Repeated { + Some(DataType::List(Box::new(Field::new( + &field_info.name, + to_struct(fields, options)?, + is_nullable(field_info), + )))) + } else { + non_repeated_group(logical_type, converted_type, fields, parent_name, options) + } +} + +/// Checks whether this schema is nullable. +pub(crate) fn is_nullable(field_info: &FieldInfo) -> bool { + match field_info.repetition { + Repetition::Optional => true, + Repetition::Repeated => true, + Repetition::Required => false, + } +} + +/// Converts parquet schema to arrow field. +/// Returns `None` iff the parquet type has no associated primitive types, +/// i.e. if it is a column-less group type. +fn to_field(type_: &ParquetType, options: &SchemaInferenceOptions) -> Option { + Some(Field::new( + &type_.get_field_info().name, + to_data_type(type_, options)?, + is_nullable(type_.get_field_info()), + )) +} + +/// Converts a parquet list to arrow list. +/// +/// To fully understand this algorithm, please refer to +/// [parquet doc](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md). +fn to_list( + fields: &[ParquetType], + parent_name: &str, + options: &SchemaInferenceOptions, +) -> Option { + let item = fields.first().unwrap(); + + let item_type = match item { + ParquetType::PrimitiveType(primitive) => Some(to_primitive_type_inner(primitive, options)), + ParquetType::GroupType { fields, .. } => { + if fields.len() == 1 + && item.name() != "array" + && item.name() != format!("{parent_name}_tuple") + { + // extract the repetition field + let nested_item = fields.first().unwrap(); + to_data_type(nested_item, options) + } else { + to_struct(fields, options) + } + }, + }?; + + // Check that the name of the list child is "list", in which case we + // get the child nullability and name (normally "element") from the nested + // group type. + // Without this step, the child incorrectly inherits the parent's optionality + let (list_item_name, item_is_optional) = match item { + ParquetType::GroupType { + field_info, fields, .. + } if field_info.name == "list" && fields.len() == 1 => { + let field = fields.first().unwrap(); + ( + &field.get_field_info().name, + field.get_field_info().repetition != Repetition::Required, + ) + }, + _ => ( + &item.get_field_info().name, + item.get_field_info().repetition != Repetition::Required, + ), + }; + + Some(DataType::List(Box::new(Field::new( + list_item_name, + item_type, + item_is_optional, + )))) +} + +/// Converts parquet schema to arrow data type. +/// +/// This function discards schema name. +/// +/// If this schema is a primitive type and not included in the leaves, the result is +/// Ok(None). +/// +/// If this schema is a group type and none of its children is reserved in the +/// conversion, the result is Ok(None). +pub(crate) fn to_data_type( + type_: &ParquetType, + options: &SchemaInferenceOptions, +) -> Option { + match type_ { + ParquetType::PrimitiveType(primitive) => Some(to_primitive_type(primitive, options)), + ParquetType::GroupType { + field_info, + logical_type, + converted_type, + fields, + } => { + if fields.is_empty() { + None + } else { + to_group_type( + field_info, + logical_type, + converted_type, + fields, + &field_info.name, + options, + ) + } + }, + } +} + +#[cfg(test)] +mod tests { + use parquet2::metadata::SchemaDescriptor; + + use super::*; + use crate::datatypes::{DataType, Field, TimeUnit}; + use crate::error::Result; + + #[test] + fn test_flat_primitives() -> Result<()> { + let message = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 int16 (INT_16); + REQUIRED INT32 uint8 (INTEGER(8,false)); + REQUIRED INT32 uint16 (INTEGER(16,false)); + REQUIRED INT32 int32; + REQUIRED INT64 int64 ; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (UTF8); + OPTIONAL BINARY string_2 (STRING); + } + "; + let expected = &[ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + Field::new("int16", DataType::Int16, false), + Field::new("uint8", DataType::UInt8, false), + Field::new("uint16", DataType::UInt16, false), + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("double", DataType::Float64, true), + Field::new("float", DataType::Float32, true), + Field::new("string", DataType::Utf8, true), + Field::new("string_2", DataType::Utf8, true), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_byte_array_fields() -> Result<()> { + let message = " + message test_schema { + REQUIRED BYTE_ARRAY binary; + REQUIRED FIXED_LEN_BYTE_ARRAY (20) fixed_binary; + } + "; + let expected = vec![ + Field::new("binary", DataType::Binary, false), + Field::new("fixed_binary", DataType::FixedSizeBinary(20), false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_duplicate_fields() -> Result<()> { + let message = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + } + "; + let expected = &[ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(fields, expected); + Ok(()) + } + + #[test] + fn test_parquet_lists() -> Result<()> { + let mut arrow_fields = Vec::new(); + + // LIST encoding example taken from parquet-format/LogicalTypes.md + let message_type = " + message test_schema { + REQUIRED GROUP my_list (LIST) { + REPEATED GROUP list { + OPTIONAL BINARY element (UTF8); + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + OPTIONAL GROUP array_of_arrays (LIST) { + REPEATED GROUP list { + REQUIRED GROUP element (LIST) { + REPEATED GROUP list { + REQUIRED INT32 element; + } + } + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP element { + REQUIRED BINARY str (UTF8); + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED INT32 element; + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP element { + REQUIRED BINARY str (UTF8); + REQUIRED INT32 num; + } + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP array { + REQUIRED BINARY str (UTF8); + } + + } + OPTIONAL GROUP my_list (LIST) { + REPEATED GROUP my_list_tuple { + REQUIRED BINARY str (UTF8); + } + } + REPEATED INT32 name; + } + "; + + // // List (list non-null, elements nullable) + // required group my_list (LIST) { + // repeated group list { + // optional binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Utf8, true))), + false, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Utf8, false))), + true, + )); + } + + // Element types can be nested structures. For example, a list of lists: + // + // // List> + // optional group array_of_arrays (LIST) { + // repeated group list { + // required group element (LIST) { + // repeated group list { + // required int32 element; + // } + // } + // } + // } + { + let arrow_inner_list = + DataType::List(Box::new(Field::new("element", DataType::Int32, false))); + arrow_fields.push(Field::new( + "array_of_arrays", + DataType::List(Box::new(Field::new("element", arrow_inner_list, false))), + true, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // }; + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Utf8, true))), + true, + )); + } + + // // List (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + { + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", DataType::Int32, true))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + { + let arrow_struct = DataType::Struct(vec![ + Field::new("str", DataType::Utf8, false), + Field::new("num", DataType::Int32, false), + ]); + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("element", arrow_struct, true))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // Special case: group is named array + { + let arrow_struct = DataType::Struct(vec![Field::new("str", DataType::Utf8, false)]); + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("array", arrow_struct, true))), + true, + )); + } + + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // Special case: group named ends in _tuple + { + let arrow_struct = DataType::Struct(vec![Field::new("str", DataType::Utf8, false)]); + arrow_fields.push(Field::new( + "my_list", + DataType::List(Box::new(Field::new("my_list_tuple", arrow_struct, true))), + true, + )); + } + + // One-level encoding: Only allows required lists with required cells + // repeated value_type name + { + arrow_fields.push(Field::new( + "name", + DataType::List(Box::new(Field::new("name", DataType::Int32, true))), + true, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_parquet_list_nullable() -> Result<()> { + let mut arrow_fields = Vec::new(); + + let message_type = " + message test_schema { + REQUIRED GROUP my_list1 (LIST) { + REPEATED GROUP list { + OPTIONAL BINARY element (UTF8); + } + } + OPTIONAL GROUP my_list2 (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + REQUIRED GROUP my_list3 (LIST) { + REPEATED GROUP list { + REQUIRED BINARY element (UTF8); + } + } + } + "; + + // // List (list non-null, elements nullable) + // required group my_list1 (LIST) { + // repeated group list { + // optional binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list1", + DataType::List(Box::new(Field::new("element", DataType::Utf8, true))), + false, + )); + } + + // // List (list nullable, elements non-null) + // optional group my_list2 (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list2", + DataType::List(Box::new(Field::new("element", DataType::Utf8, false))), + true, + )); + } + + // // List (list non-null, elements non-null) + // repeated group my_list3 (LIST) { + // repeated group list { + // required binary element (UTF8); + // } + // } + { + arrow_fields.push(Field::new( + "my_list3", + DataType::List(Box::new(Field::new("element", DataType::Utf8, false))), + false, + )); + } + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_nested_schema() -> Result<()> { + let mut arrow_fields = Vec::new(); + { + let group1_fields = vec![ + Field::new("leaf1", DataType::Boolean, false), + Field::new("leaf2", DataType::Int32, false), + ]; + let group1_struct = Field::new("group1", DataType::Struct(group1_fields), false); + arrow_fields.push(group1_struct); + + let leaf3_field = Field::new("leaf3", DataType::Int64, false); + arrow_fields.push(leaf3_field); + } + + let message_type = " + message test_schema { + REQUIRED GROUP group1 { + REQUIRED BOOLEAN leaf1; + REQUIRED INT32 leaf2; + } + REQUIRED INT64 leaf3; + } + "; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_repeated_nested_schema() -> Result<()> { + let mut arrow_fields = Vec::new(); + { + arrow_fields.push(Field::new("leaf1", DataType::Int32, true)); + + let inner_group_list = Field::new( + "innerGroup", + DataType::List(Box::new(Field::new( + "innerGroup", + DataType::Struct(vec![Field::new("leaf3", DataType::Int32, true)]), + true, + ))), + true, + ); + + let outer_group_list = Field::new( + "outerGroup", + DataType::List(Box::new(Field::new( + "outerGroup", + DataType::Struct(vec![ + Field::new("leaf2", DataType::Int32, true), + inner_group_list, + ]), + true, + ))), + true, + ); + arrow_fields.push(outer_group_list); + } + + let message_type = " + message test_schema { + OPTIONAL INT32 leaf1; + REPEATED GROUP outerGroup { + OPTIONAL INT32 leaf2; + REPEATED GROUP innerGroup { + OPTIONAL INT32 leaf3; + } + } + } + "; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_column_desc_to_field() -> Result<()> { + let message_type = " + message test_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 uint8 (INTEGER(8,false)); + REQUIRED INT32 int16 (INT_16); + REQUIRED INT32 uint16 (INTEGER(16,false)); + REQUIRED INT32 int32; + REQUIRED INT64 int64; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (UTF8); + REPEATED BOOLEAN bools; + OPTIONAL INT32 date (DATE); + OPTIONAL INT32 time_milli (TIME_MILLIS); + OPTIONAL INT64 time_micro (TIME_MICROS); + OPTIONAL INT64 time_nano (TIME(NANOS,false)); + OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); + REQUIRED INT64 ts_micro (TIMESTAMP_MICROS); + REQUIRED INT64 ts_nano (TIMESTAMP(NANOS,true)); + } + "; + let arrow_fields = vec![ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + Field::new("uint8", DataType::UInt8, false), + Field::new("int16", DataType::Int16, false), + Field::new("uint16", DataType::UInt16, false), + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("double", DataType::Float64, true), + Field::new("float", DataType::Float32, true), + Field::new("string", DataType::Utf8, true), + Field::new( + "bools", + DataType::List(Box::new(Field::new("bools", DataType::Boolean, true))), + true, + ), + Field::new("date", DataType::Date32, true), + Field::new("time_milli", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micro", DataType::Time64(TimeUnit::Microsecond), true), + Field::new("time_nano", DataType::Time64(TimeUnit::Nanosecond), true), + Field::new( + "ts_milli", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micro", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "ts_nano", + DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".to_string())), + false, + ), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_field_to_column_desc() -> Result<()> { + let message_type = " + message arrow_schema { + REQUIRED BOOLEAN boolean; + REQUIRED INT32 int8 (INT_8); + REQUIRED INT32 int16 (INTEGER(16,true)); + REQUIRED INT32 int32; + REQUIRED INT64 int64; + OPTIONAL DOUBLE double; + OPTIONAL FLOAT float; + OPTIONAL BINARY string (STRING); + OPTIONAL GROUP bools (LIST) { + REPEATED GROUP list { + OPTIONAL BOOLEAN element; + } + } + REQUIRED GROUP bools_non_null (LIST) { + REPEATED GROUP list { + REQUIRED BOOLEAN element; + } + } + OPTIONAL INT32 date (DATE); + OPTIONAL INT32 time_milli (TIME(MILLIS,false)); + OPTIONAL INT64 time_micro (TIME_MICROS); + OPTIONAL INT64 ts_milli (TIMESTAMP_MILLIS); + REQUIRED INT64 ts_micro (TIMESTAMP(MICROS,false)); + REQUIRED GROUP struct { + REQUIRED BOOLEAN bools; + REQUIRED INT32 uint32 (INTEGER(32,false)); + REQUIRED GROUP int32 (LIST) { + REPEATED GROUP list { + OPTIONAL INT32 element; + } + } + } + REQUIRED BINARY dictionary_strings (STRING); + } + "; + + let arrow_fields = vec![ + Field::new("boolean", DataType::Boolean, false), + Field::new("int8", DataType::Int8, false), + Field::new("int16", DataType::Int16, false), + Field::new("int32", DataType::Int32, false), + Field::new("int64", DataType::Int64, false), + Field::new("double", DataType::Float64, true), + Field::new("float", DataType::Float32, true), + Field::new("string", DataType::Utf8, true), + Field::new( + "bools", + DataType::List(Box::new(Field::new("element", DataType::Boolean, true))), + true, + ), + Field::new( + "bools_non_null", + DataType::List(Box::new(Field::new("element", DataType::Boolean, false))), + false, + ), + Field::new("date", DataType::Date32, true), + Field::new("time_milli", DataType::Time32(TimeUnit::Millisecond), true), + Field::new("time_micro", DataType::Time64(TimeUnit::Microsecond), true), + Field::new( + "ts_milli", + DataType::Timestamp(TimeUnit::Millisecond, None), + true, + ), + Field::new( + "ts_micro", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "struct", + DataType::Struct(vec![ + Field::new("bools", DataType::Boolean, false), + Field::new("uint32", DataType::UInt32, false), + Field::new( + "int32", + DataType::List(Box::new(Field::new("element", DataType::Int32, true))), + false, + ), + ]), + false, + ), + Field::new("dictionary_strings", DataType::Utf8, false), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema(parquet_schema.fields()); + + assert_eq!(arrow_fields, fields); + Ok(()) + } + + #[test] + fn test_int96_options() -> Result<()> { + for tu in [ + TimeUnit::Second, + TimeUnit::Microsecond, + TimeUnit::Millisecond, + TimeUnit::Nanosecond, + ] { + let message_type = " + message arrow_schema { + REQUIRED INT96 int96_field; + OPTIONAL GROUP int96_list (LIST) { + REPEATED GROUP list { + OPTIONAL INT96 element; + } + } + REQUIRED GROUP int96_struct { + REQUIRED INT96 int96_field; + } + } + "; + let coerced_to = DataType::Timestamp(tu, None); + let arrow_fields = vec![ + Field::new("int96_field", coerced_to.clone(), false), + Field::new( + "int96_list", + DataType::List(Box::new(Field::new("element", coerced_to.clone(), true))), + true, + ), + Field::new( + "int96_struct", + DataType::Struct(vec![Field::new("int96_field", coerced_to.clone(), false)]), + false, + ), + ]; + + let parquet_schema = SchemaDescriptor::try_from_message(message_type)?; + let fields = parquet_to_arrow_schema_with_options( + parquet_schema.fields(), + &Some(SchemaInferenceOptions { + int96_coerce_to_timeunit: tu, + }), + ); + assert_eq!(arrow_fields, fields); + } + Ok(()) + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/schema/metadata.rs b/crates/nano-arrow/src/io/parquet/read/schema/metadata.rs new file mode 100644 index 000000000000..574ff08d1fd5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/schema/metadata.rs @@ -0,0 +1,55 @@ +use base64::engine::general_purpose; +use base64::Engine as _; +pub use parquet2::metadata::KeyValue; + +use super::super::super::ARROW_SCHEMA_META_KEY; +use crate::datatypes::{Metadata, Schema}; +use crate::error::{Error, Result}; +use crate::io::ipc::read::deserialize_schema; + +/// Reads an arrow schema from Parquet's file metadata. Returns `None` if no schema was found. +/// # Errors +/// Errors iff the schema cannot be correctly parsed. +pub fn read_schema_from_metadata(metadata: &mut Metadata) -> Result> { + metadata + .remove(ARROW_SCHEMA_META_KEY) + .map(|encoded| get_arrow_schema_from_metadata(&encoded)) + .transpose() +} + +/// Try to convert Arrow schema metadata into a schema +fn get_arrow_schema_from_metadata(encoded_meta: &str) -> Result { + let decoded = general_purpose::STANDARD.decode(encoded_meta); + match decoded { + Ok(bytes) => { + let slice = if bytes[0..4] == [255u8; 4] { + &bytes[8..] + } else { + bytes.as_slice() + }; + deserialize_schema(slice).map(|x| x.0) + }, + Err(err) => { + // The C++ implementation returns an error if the schema can't be parsed. + Err(Error::InvalidArgumentError(format!( + "Unable to decode the encoded schema stored in {ARROW_SCHEMA_META_KEY}, {err:?}" + ))) + }, + } +} + +pub(super) fn parse_key_value_metadata(key_value_metadata: &Option>) -> Metadata { + key_value_metadata + .as_ref() + .map(|key_values| { + key_values + .iter() + .filter_map(|kv| { + kv.value + .as_ref() + .map(|value| (kv.key.clone(), value.clone())) + }) + .collect() + }) + .unwrap_or_default() +} diff --git a/crates/nano-arrow/src/io/parquet/read/schema/mod.rs b/crates/nano-arrow/src/io/parquet/read/schema/mod.rs new file mode 100644 index 000000000000..8b2394684440 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/schema/mod.rs @@ -0,0 +1,58 @@ +//! APIs to handle Parquet <-> Arrow schemas. +use crate::datatypes::{Schema, TimeUnit}; +use crate::error::Result; + +mod convert; +mod metadata; + +pub(crate) use convert::*; +pub use convert::{parquet_to_arrow_schema, parquet_to_arrow_schema_with_options}; +pub use metadata::read_schema_from_metadata; +pub use parquet2::metadata::{FileMetaData, KeyValue, SchemaDescriptor}; +pub use parquet2::schema::types::ParquetType; + +use self::metadata::parse_key_value_metadata; + +/// Options when inferring schemas from Parquet +pub struct SchemaInferenceOptions { + /// When inferring schemas from the Parquet INT96 timestamp type, this is the corresponding TimeUnit + /// in the inferred Arrow Timestamp type. + /// + /// This defaults to `TimeUnit::Nanosecond`, but INT96 timestamps outside of the range of years 1678-2262, + /// will overflow when parsed as `Timestamp(TimeUnit::Nanosecond)`. Setting this to a lower resolution + /// (e.g. TimeUnit::Milliseconds) will result in loss of precision, but support a larger range of dates + /// without overflowing when parsing the data. + pub int96_coerce_to_timeunit: TimeUnit, +} + +impl Default for SchemaInferenceOptions { + fn default() -> Self { + SchemaInferenceOptions { + int96_coerce_to_timeunit: TimeUnit::Nanosecond, + } + } +} + +/// Infers a [`Schema`] from parquet's [`FileMetaData`]. This first looks for the metadata key +/// `"ARROW:schema"`; if it does not exist, it converts the parquet types declared in the +/// file's parquet schema to Arrow's equivalent. +/// # Error +/// This function errors iff the key `"ARROW:schema"` exists but is not correctly encoded, +/// indicating that that the file's arrow metadata was incorrectly written. +pub fn infer_schema(file_metadata: &FileMetaData) -> Result { + infer_schema_with_options(file_metadata, &None) +} + +/// Like [`infer_schema`] but with configurable options which affects the behavior of inference +pub fn infer_schema_with_options( + file_metadata: &FileMetaData, + options: &Option, +) -> Result { + let mut metadata = parse_key_value_metadata(file_metadata.key_value_metadata()); + + let schema = read_schema_from_metadata(&mut metadata)?; + Ok(schema.unwrap_or_else(|| { + let fields = parquet_to_arrow_schema_with_options(file_metadata.schema().fields(), options); + Schema { fields, metadata } + })) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/binary.rs b/crates/nano-arrow/src/io/parquet/read/statistics/binary.rs new file mode 100644 index 000000000000..aeb43a6b3e0b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/binary.rs @@ -0,0 +1,24 @@ +use parquet2::statistics::{BinaryStatistics, Statistics as ParquetStatistics}; + +use crate::array::{MutableArray, MutableBinaryArray}; +use crate::error::Result; +use crate::offset::Offset; + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + min.push(from.and_then(|s| s.min_value.as_ref())); + max.push(from.and_then(|s| s.max_value.as_ref())); + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/boolean.rs b/crates/nano-arrow/src/io/parquet/read/statistics/boolean.rs new file mode 100644 index 000000000000..ebb0ce3dade2 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/boolean.rs @@ -0,0 +1,23 @@ +use parquet2::statistics::{BooleanStatistics, Statistics as ParquetStatistics}; + +use crate::array::{MutableArray, MutableBooleanArray}; +use crate::error::Result; + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + min.push(from.and_then(|s| s.min_value)); + max.push(from.and_then(|s| s.max_value)); + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/dictionary.rs b/crates/nano-arrow/src/io/parquet/read/statistics/dictionary.rs new file mode 100644 index 000000000000..f6e2fdddcce9 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/dictionary.rs @@ -0,0 +1,69 @@ +use super::make_mutable; +use crate::array::*; +use crate::datatypes::{DataType, PhysicalType}; +use crate::error::Result; + +#[derive(Debug)] +pub struct DynMutableDictionary { + data_type: DataType, + pub inner: Box, +} + +impl DynMutableDictionary { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inner = if let DataType::Dictionary(_, inner, _) = &data_type { + inner.as_ref() + } else { + unreachable!() + }; + let inner = make_mutable(inner, capacity)?; + + Ok(Self { data_type, inner }) + } +} + +impl MutableArray for DynMutableDictionary { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + self.inner.validity() + } + + fn as_box(&mut self) -> Box { + let inner = self.inner.as_box(); + match self.data_type.to_physical_type() { + PhysicalType::Dictionary(key) => match_integer_type!(key, |$T| { + let keys: Vec<$T> = (0..inner.len() as $T).collect(); + let keys = PrimitiveArray::<$T>::from_vec(keys); + Box::new(DictionaryArray::<$T>::try_new(self.data_type.clone(), keys, inner).unwrap()) + }), + _ => todo!(), + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/fixlen.rs b/crates/nano-arrow/src/io/parquet/read/statistics/fixlen.rs new file mode 100644 index 000000000000..1f9db20d9c9a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/fixlen.rs @@ -0,0 +1,146 @@ +use ethnum::I256; +use parquet2::statistics::{FixedLenStatistics, Statistics as ParquetStatistics}; + +use super::super::{convert_days_ms, convert_i128}; +use crate::array::*; +use crate::error::Result; +use crate::io::parquet::read::convert_i256; +use crate::types::{days_ms, i256}; + +pub(super) fn push_i128( + from: Option<&dyn ParquetStatistics>, + n: usize, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(|x| convert_i128(x, n)))); + max.push(from.and_then(|s| s.max_value.as_deref().map(|x| convert_i128(x, n)))); + + Ok(()) +} + +pub(super) fn push_i256_with_i128( + from: Option<&dyn ParquetStatistics>, + n: usize, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| { + s.min_value + .as_deref() + .map(|x| i256(I256::new(convert_i128(x, n)))) + })); + max.push(from.and_then(|s| { + s.max_value + .as_deref() + .map(|x| i256(I256::new(convert_i128(x, n)))) + })); + + Ok(()) +} + +pub(super) fn push_i256( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(convert_i256))); + max.push(from.and_then(|s| s.max_value.as_deref().map(convert_i256))); + + Ok(()) +} + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + min.push(from.and_then(|s| s.min_value.as_ref())); + max.push(from.and_then(|s| s.max_value.as_ref())); + Ok(()) +} + +fn convert_year_month(value: &[u8]) -> i32 { + i32::from_le_bytes(value[..4].try_into().unwrap()) +} + +pub(super) fn push_year_month( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(convert_year_month))); + max.push(from.and_then(|s| s.max_value.as_deref().map(convert_year_month))); + + Ok(()) +} + +pub(super) fn push_days_ms( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push(from.and_then(|s| s.min_value.as_deref().map(convert_days_ms))); + max.push(from.and_then(|s| s.max_value.as_deref().map(convert_days_ms))); + + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/list.rs b/crates/nano-arrow/src/io/parquet/read/statistics/list.rs new file mode 100644 index 000000000000..cb22cbf7063a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/list.rs @@ -0,0 +1,85 @@ +use super::make_mutable; +use crate::array::*; +use crate::datatypes::DataType; +use crate::error::Result; +use crate::offset::Offsets; + +#[derive(Debug)] +pub struct DynMutableListArray { + data_type: DataType, + pub inner: Box, +} + +impl DynMutableListArray { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inner = match data_type.to_logical_type() { + DataType::List(inner) | DataType::LargeList(inner) => inner.data_type(), + _ => unreachable!(), + }; + let inner = make_mutable(inner, capacity)?; + + Ok(Self { data_type, inner }) + } +} + +impl MutableArray for DynMutableListArray { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + self.inner.validity() + } + + fn as_box(&mut self) -> Box { + let inner = self.inner.as_box(); + + match self.data_type.to_logical_type() { + DataType::List(_) => { + let offsets = + Offsets::try_from_lengths(std::iter::repeat(1).take(inner.len())).unwrap(); + Box::new(ListArray::::new( + self.data_type.clone(), + offsets.into(), + inner, + None, + )) + }, + DataType::LargeList(_) => { + let offsets = + Offsets::try_from_lengths(std::iter::repeat(1).take(inner.len())).unwrap(); + Box::new(ListArray::::new( + self.data_type.clone(), + offsets.into(), + inner, + None, + )) + }, + _ => unreachable!(), + } + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/map.rs b/crates/nano-arrow/src/io/parquet/read/statistics/map.rs new file mode 100644 index 000000000000..d6b2a73388f5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/map.rs @@ -0,0 +1,65 @@ +use super::make_mutable; +use crate::array::{Array, MapArray, MutableArray}; +use crate::datatypes::DataType; +use crate::error::Error; + +#[derive(Debug)] +pub struct DynMutableMapArray { + data_type: DataType, + pub inner: Box, +} + +impl DynMutableMapArray { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inner = match data_type.to_logical_type() { + DataType::Map(inner, _) => inner, + _ => unreachable!(), + }; + let inner = make_mutable(inner.data_type(), capacity)?; + + Ok(Self { data_type, inner }) + } +} + +impl MutableArray for DynMutableMapArray { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner.len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + Box::new(MapArray::new( + self.data_type.clone(), + vec![0, self.inner.len() as i32].try_into().unwrap(), + self.inner.as_box(), + None, + )) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/mod.rs b/crates/nano-arrow/src/io/parquet/read/statistics/mod.rs new file mode 100644 index 000000000000..3048952530a6 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/mod.rs @@ -0,0 +1,577 @@ +//! APIs exposing `parquet2`'s statistics as arrow's statistics. +use std::collections::VecDeque; +use std::sync::Arc; + +use ethnum::I256; +use parquet2::metadata::RowGroupMetaData; +use parquet2::schema::types::{ + PhysicalType as ParquetPhysicalType, PrimitiveType as ParquetPrimitiveType, +}; +use parquet2::statistics::{ + BinaryStatistics, BooleanStatistics, FixedLenStatistics, PrimitiveStatistics, + Statistics as ParquetStatistics, +}; +use parquet2::types::int96_to_i64_ns; + +use crate::array::*; +use crate::datatypes::{DataType, Field, IntervalUnit, PhysicalType}; +use crate::error::{Error, Result}; +use crate::types::i256; + +mod binary; +mod boolean; +mod dictionary; +mod fixlen; +mod list; +mod map; +mod null; +mod primitive; +mod struct_; +mod utf8; + +use self::list::DynMutableListArray; +use super::get_field_columns; + +/// Arrow-deserialized parquet Statistics of a file +#[derive(Debug, PartialEq)] +pub struct Statistics { + /// number of nulls. This is a [`UInt64Array`] for non-nested types + pub null_count: Box, + /// number of dictinct values. This is a [`UInt64Array`] for non-nested types + pub distinct_count: Box, + /// Minimum + pub min_value: Box, + /// Maximum + pub max_value: Box, +} + +/// Arrow-deserialized parquet Statistics of a file +#[derive(Debug)] +struct MutableStatistics { + /// number of nulls + pub null_count: Box, + /// number of dictinct values + pub distinct_count: Box, + /// Minimum + pub min_value: Box, + /// Maximum + pub max_value: Box, +} + +impl From for Statistics { + fn from(mut s: MutableStatistics) -> Self { + let null_count = if let PhysicalType::Struct = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::Map = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::List = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::LargeList = s.null_count.data_type().to_physical_type() { + s.null_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else { + s.null_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + }; + let distinct_count = if let PhysicalType::Struct = + s.distinct_count.data_type().to_physical_type() + { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::Map = s.distinct_count.data_type().to_physical_type() { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::List = s.distinct_count.data_type().to_physical_type() { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else if let PhysicalType::LargeList = s.distinct_count.data_type().to_physical_type() { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::>() + .unwrap() + .clone() + .boxed() + } else { + s.distinct_count + .as_box() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .boxed() + }; + Self { + null_count, + distinct_count, + min_value: s.min_value.as_box(), + max_value: s.max_value.as_box(), + } + } +} + +fn make_mutable(data_type: &DataType, capacity: usize) -> Result> { + Ok(match data_type.to_physical_type() { + PhysicalType::Boolean => { + Box::new(MutableBooleanArray::with_capacity(capacity)) as Box + }, + PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + Box::new(MutablePrimitiveArray::<$T>::with_capacity(capacity).to(data_type.clone())) + as Box + }), + PhysicalType::Binary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) as Box + }, + PhysicalType::LargeBinary => { + Box::new(MutableBinaryArray::::with_capacity(capacity)) as Box + }, + PhysicalType::Utf8 => { + Box::new(MutableUtf8Array::::with_capacity(capacity)) as Box + }, + PhysicalType::LargeUtf8 => { + Box::new(MutableUtf8Array::::with_capacity(capacity)) as Box + }, + PhysicalType::FixedSizeBinary => { + Box::new(MutableFixedSizeBinaryArray::try_new(data_type.clone(), vec![], None).unwrap()) + as _ + }, + PhysicalType::LargeList | PhysicalType::List => Box::new( + DynMutableListArray::try_with_capacity(data_type.clone(), capacity)?, + ) as Box, + PhysicalType::Dictionary(_) => Box::new( + dictionary::DynMutableDictionary::try_with_capacity(data_type.clone(), capacity)?, + ), + PhysicalType::Struct => Box::new(struct_::DynMutableStructArray::try_with_capacity( + data_type.clone(), + capacity, + )?), + PhysicalType::Map => Box::new(map::DynMutableMapArray::try_with_capacity( + data_type.clone(), + capacity, + )?), + PhysicalType::Null => { + Box::new(MutableNullArray::new(DataType::Null, 0)) as Box + }, + other => { + return Err(Error::NotYetImplemented(format!( + "Deserializing parquet stats from {other:?} is still not implemented" + ))) + }, + }) +} + +fn create_dt(data_type: &DataType) -> DataType { + if let DataType::Struct(fields) = data_type.to_logical_type() { + DataType::Struct( + fields + .iter() + .map(|f| Field::new(&f.name, create_dt(&f.data_type), f.is_nullable)) + .collect(), + ) + } else if let DataType::Map(f, ordered) = data_type.to_logical_type() { + DataType::Map( + Box::new(Field::new(&f.name, create_dt(&f.data_type), f.is_nullable)), + *ordered, + ) + } else if let DataType::List(f) = data_type.to_logical_type() { + DataType::List(Box::new(Field::new( + &f.name, + create_dt(&f.data_type), + f.is_nullable, + ))) + } else if let DataType::LargeList(f) = data_type.to_logical_type() { + DataType::LargeList(Box::new(Field::new( + &f.name, + create_dt(&f.data_type), + f.is_nullable, + ))) + } else { + DataType::UInt64 + } +} + +impl MutableStatistics { + fn try_new(field: &Field) -> Result { + let min_value = make_mutable(&field.data_type, 0)?; + let max_value = make_mutable(&field.data_type, 0)?; + + let dt = create_dt(&field.data_type); + Ok(Self { + null_count: make_mutable(&dt, 0)?, + distinct_count: make_mutable(&dt, 0)?, + min_value, + max_value, + }) + } +} + +fn push_others( + from: Option<&dyn ParquetStatistics>, + distinct_count: &mut UInt64Vec, + null_count: &mut UInt64Vec, +) { + let from = if let Some(from) = from { + from + } else { + distinct_count.push(None); + null_count.push(None); + return; + }; + let (distinct, null_count1) = match from.physical_type() { + ParquetPhysicalType::Boolean => { + let from = from.as_any().downcast_ref::().unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Int32 => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Int64 => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Int96 => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Float => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::Double => { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::ByteArray => { + let from = from.as_any().downcast_ref::().unwrap(); + (from.distinct_count, from.null_count) + }, + ParquetPhysicalType::FixedLenByteArray(_) => { + let from = from.as_any().downcast_ref::().unwrap(); + (from.distinct_count, from.null_count) + }, + }; + + distinct_count.push(distinct.map(|x| x as u64)); + null_count.push(null_count1.map(|x| x as u64)); +} + +fn push( + stats: &mut VecDeque<(Option>, ParquetPrimitiveType)>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, + distinct_count: &mut dyn MutableArray, + null_count: &mut dyn MutableArray, +) -> Result<()> { + match min.data_type().to_logical_type() { + List(_) | LargeList(_) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + return push( + stats, + min.inner.as_mut(), + max.inner.as_mut(), + distinct_count.inner.as_mut(), + null_count.inner.as_mut(), + ); + }, + Dictionary(_, _, _) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + return push( + stats, + min.inner.as_mut(), + max.inner.as_mut(), + distinct_count, + null_count, + ); + }, + Struct(_) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + return min + .inner + .iter_mut() + .zip(max.inner.iter_mut()) + .zip(distinct_count.inner.iter_mut()) + .zip(null_count.inner.iter_mut()) + .try_for_each(|(((min, max), distinct_count), null_count)| { + push( + stats, + min.as_mut(), + max.as_mut(), + distinct_count.as_mut(), + null_count.as_mut(), + ) + }); + }, + Map(_, _) => { + let min = min + .as_mut_any() + .downcast_mut::() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::() + .unwrap(); + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + return push( + stats, + min.inner.as_mut(), + max.inner.as_mut(), + distinct_count.inner.as_mut(), + null_count.inner.as_mut(), + ); + }, + _ => {}, + } + + let (from, type_) = stats.pop_front().unwrap(); + let from = from.as_deref(); + + let distinct_count = distinct_count + .as_mut_any() + .downcast_mut::() + .unwrap(); + let null_count = null_count.as_mut_any().downcast_mut::().unwrap(); + + push_others(from, distinct_count, null_count); + + let physical_type = &type_.physical_type; + + use DataType::*; + match min.data_type().to_logical_type() { + Boolean => boolean::push(from, min, max), + Int8 => primitive::push(from, min, max, |x: i32| Ok(x as i8)), + Int16 => primitive::push(from, min, max, |x: i32| Ok(x as i16)), + Date32 | Time32(_) => primitive::push::(from, min, max, Ok), + Interval(IntervalUnit::YearMonth) => fixlen::push_year_month(from, min, max), + Interval(IntervalUnit::DayTime) => fixlen::push_days_ms(from, min, max), + UInt8 => primitive::push(from, min, max, |x: i32| Ok(x as u8)), + UInt16 => primitive::push(from, min, max, |x: i32| Ok(x as u16)), + UInt32 => match physical_type { + // some implementations of parquet write arrow's u32 into i64. + ParquetPhysicalType::Int64 => primitive::push(from, min, max, |x: i64| Ok(x as u32)), + ParquetPhysicalType::Int32 => primitive::push(from, min, max, |x: i32| Ok(x as u32)), + other => Err(Error::NotYetImplemented(format!( + "Can't decode UInt32 type from parquet type {other:?}" + ))), + }, + Int32 => primitive::push::(from, min, max, Ok), + Date64 => match physical_type { + ParquetPhysicalType::Int64 => primitive::push::(from, min, max, Ok), + // some implementations of parquet write arrow's date64 into i32. + ParquetPhysicalType::Int32 => { + primitive::push(from, min, max, |x: i32| Ok(x as i64 * 86400000)) + }, + other => Err(Error::NotYetImplemented(format!( + "Can't decode Date64 type from parquet type {other:?}" + ))), + }, + Int64 | Time64(_) | Duration(_) => primitive::push::(from, min, max, Ok), + UInt64 => primitive::push(from, min, max, |x: i64| Ok(x as u64)), + Timestamp(time_unit, _) => { + let time_unit = *time_unit; + if physical_type == &ParquetPhysicalType::Int96 { + let from = from.map(|from| { + let from = from + .as_any() + .downcast_ref::>() + .unwrap(); + PrimitiveStatistics:: { + primitive_type: from.primitive_type.clone(), + null_count: from.null_count, + distinct_count: from.distinct_count, + min_value: from.min_value.map(int96_to_i64_ns), + max_value: from.max_value.map(int96_to_i64_ns), + } + }); + primitive::push( + from.as_ref().map(|x| x as &dyn ParquetStatistics), + min, + max, + |x: i64| { + Ok(primitive::timestamp( + type_.logical_type.as_ref(), + time_unit, + x, + )) + }, + ) + } else { + primitive::push(from, min, max, |x: i64| { + Ok(primitive::timestamp( + type_.logical_type.as_ref(), + time_unit, + x, + )) + }) + } + }, + Float32 => primitive::push::(from, min, max, Ok), + Float64 => primitive::push::(from, min, max, Ok), + Decimal(_, _) => match physical_type { + ParquetPhysicalType::Int32 => primitive::push(from, min, max, |x: i32| Ok(x as i128)), + ParquetPhysicalType::Int64 => primitive::push(from, min, max, |x: i64| Ok(x as i128)), + ParquetPhysicalType::FixedLenByteArray(n) if *n > 16 => Err(Error::NotYetImplemented( + format!("Can't decode Decimal128 type from Fixed Size Byte Array of len {n:?}"), + )), + ParquetPhysicalType::FixedLenByteArray(n) => fixlen::push_i128(from, *n, min, max), + _ => unreachable!(), + }, + Decimal256(_, _) => match physical_type { + ParquetPhysicalType::Int32 => { + primitive::push(from, min, max, |x: i32| Ok(i256(I256::new(x.into())))) + }, + ParquetPhysicalType::Int64 => { + primitive::push(from, min, max, |x: i64| Ok(i256(I256::new(x.into())))) + }, + ParquetPhysicalType::FixedLenByteArray(n) if *n <= 16 => { + fixlen::push_i256_with_i128(from, *n, min, max) + }, + ParquetPhysicalType::FixedLenByteArray(n) if *n > 32 => Err(Error::NotYetImplemented( + format!("Can't decode Decimal256 type from Fixed Size Byte Array of len {n:?}"), + )), + ParquetPhysicalType::FixedLenByteArray(_) => fixlen::push_i256(from, min, max), + _ => unreachable!(), + }, + Binary => binary::push::(from, min, max), + LargeBinary => binary::push::(from, min, max), + Utf8 => utf8::push::(from, min, max), + LargeUtf8 => utf8::push::(from, min, max), + FixedSizeBinary(_) => fixlen::push(from, min, max), + Null => null::push(min, max), + other => todo!("{:?}", other), + } +} + +/// Deserializes the statistics in the column chunks from all `row_groups` +/// into [`Statistics`] associated from `field`'s name. +/// +/// # Errors +/// This function errors if the deserialization of the statistics fails (e.g. invalid utf8) +pub fn deserialize(field: &Field, row_groups: &[RowGroupMetaData]) -> Result { + let mut statistics = MutableStatistics::try_new(field)?; + + // transpose + row_groups.iter().try_for_each(|group| { + let columns = get_field_columns(group.columns(), field.name.as_ref()); + let mut stats = columns + .into_iter() + .map(|column| { + Ok(( + column.statistics().transpose()?, + column.descriptor().descriptor.primitive_type.clone(), + )) + }) + .collect::, ParquetPrimitiveType)>>>()?; + push( + &mut stats, + statistics.min_value.as_mut(), + statistics.max_value.as_mut(), + statistics.distinct_count.as_mut(), + statistics.null_count.as_mut(), + ) + })?; + + Ok(statistics.into()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/null.rs b/crates/nano-arrow/src/io/parquet/read/statistics/null.rs new file mode 100644 index 000000000000..9102720ebc5c --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/null.rs @@ -0,0 +1,11 @@ +use crate::array::*; +use crate::error::Result; + +pub(super) fn push(min: &mut dyn MutableArray, max: &mut dyn MutableArray) -> Result<()> { + let min = min.as_mut_any().downcast_mut::().unwrap(); + let max = max.as_mut_any().downcast_mut::().unwrap(); + min.push_null(); + max.push_null(); + + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/primitive.rs b/crates/nano-arrow/src/io/parquet/read/statistics/primitive.rs new file mode 100644 index 000000000000..849363028ad1 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/primitive.rs @@ -0,0 +1,55 @@ +use parquet2::schema::types::{PrimitiveLogicalType, TimeUnit as ParquetTimeUnit}; +use parquet2::statistics::{PrimitiveStatistics, Statistics as ParquetStatistics}; +use parquet2::types::NativeType as ParquetNativeType; + +use crate::array::*; +use crate::datatypes::TimeUnit; +use crate::error::Result; +use crate::types::NativeType; + +pub fn timestamp(logical_type: Option<&PrimitiveLogicalType>, time_unit: TimeUnit, x: i64) -> i64 { + let unit = if let Some(PrimitiveLogicalType::Timestamp { unit, .. }) = logical_type { + unit + } else { + return x; + }; + + match (unit, time_unit) { + (ParquetTimeUnit::Milliseconds, TimeUnit::Second) => x / 1_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Second) => x / 1_000_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Second) => x * 1_000_000_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Millisecond) => x, + (ParquetTimeUnit::Microseconds, TimeUnit::Millisecond) => x / 1_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Millisecond) => x / 1_000_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Microsecond) => x * 1_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Microsecond) => x, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Microsecond) => x / 1_000, + + (ParquetTimeUnit::Milliseconds, TimeUnit::Nanosecond) => x * 1_000_000, + (ParquetTimeUnit::Microseconds, TimeUnit::Nanosecond) => x * 1_000, + (ParquetTimeUnit::Nanoseconds, TimeUnit::Nanosecond) => x, + } +} + +pub(super) fn push Result + Copy>( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, + map: F, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::>().unwrap()); + min.push(from.and_then(|s| s.min_value.map(map)).transpose()?); + max.push(from.and_then(|s| s.max_value.map(map)).transpose()?); + + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/struct_.rs b/crates/nano-arrow/src/io/parquet/read/statistics/struct_.rs new file mode 100644 index 000000000000..6aca0352701e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/struct_.rs @@ -0,0 +1,64 @@ +use super::make_mutable; +use crate::array::{Array, MutableArray, StructArray}; +use crate::datatypes::DataType; +use crate::error::Result; + +#[derive(Debug)] +pub struct DynMutableStructArray { + data_type: DataType, + pub inner: Vec>, +} + +impl DynMutableStructArray { + pub fn try_with_capacity(data_type: DataType, capacity: usize) -> Result { + let inners = match data_type.to_logical_type() { + DataType::Struct(inner) => inner, + _ => unreachable!(), + }; + let inner = inners + .iter() + .map(|f| make_mutable(f.data_type(), capacity)) + .collect::>>()?; + + Ok(Self { data_type, inner }) + } +} +impl MutableArray for DynMutableStructArray { + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn len(&self) -> usize { + self.inner[0].len() + } + + fn validity(&self) -> Option<&crate::bitmap::MutableBitmap> { + None + } + + fn as_box(&mut self) -> Box { + let inner = self.inner.iter_mut().map(|x| x.as_box()).collect(); + + Box::new(StructArray::new(self.data_type.clone(), inner, None)) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_mut_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn push_null(&mut self) { + todo!() + } + + fn reserve(&mut self, _: usize) { + todo!(); + } + + fn shrink_to_fit(&mut self) { + todo!() + } +} diff --git a/crates/nano-arrow/src/io/parquet/read/statistics/utf8.rs b/crates/nano-arrow/src/io/parquet/read/statistics/utf8.rs new file mode 100644 index 000000000000..da9fcb6e1119 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/read/statistics/utf8.rs @@ -0,0 +1,31 @@ +use parquet2::statistics::{BinaryStatistics, Statistics as ParquetStatistics}; + +use crate::array::{MutableArray, MutableUtf8Array}; +use crate::error::Result; +use crate::offset::Offset; + +pub(super) fn push( + from: Option<&dyn ParquetStatistics>, + min: &mut dyn MutableArray, + max: &mut dyn MutableArray, +) -> Result<()> { + let min = min + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let max = max + .as_mut_any() + .downcast_mut::>() + .unwrap(); + let from = from.map(|s| s.as_any().downcast_ref::().unwrap()); + + min.push( + from.and_then(|s| s.min_value.as_deref().map(simdutf8::basic::from_utf8)) + .transpose()?, + ); + max.push( + from.and_then(|s| s.max_value.as_deref().map(simdutf8::basic::from_utf8)) + .transpose()?, + ); + Ok(()) +} diff --git a/crates/nano-arrow/src/io/parquet/write/binary/basic.rs b/crates/nano-arrow/src/io/parquet/write/binary/basic.rs new file mode 100644 index 000000000000..de840e45fa5a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/binary/basic.rs @@ -0,0 +1,168 @@ +use parquet2::encoding::{delta_bitpacked, Encoding}; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, BinaryStatistics, ParquetStatistics, Statistics}; + +use super::super::{utils, WriteOptions}; +use crate::array::{Array, BinaryArray}; +use crate::bitmap::Bitmap; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::offset::Offset; + +pub(crate) fn encode_plain( + array: &BinaryArray, + is_optional: bool, + buffer: &mut Vec, +) { + // append the non-null values + if is_optional { + array.iter().for_each(|x| { + if let Some(x) = x { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x); + } + }) + } else { + array.values_iter().for_each(|x| { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x); + }) + } +} + +pub fn array_to_page( + array: &BinaryArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> Result { + let validity = array.validity(); + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + match encoding { + Encoding::Plain => encode_plain(array, is_optional, &mut buffer), + Encoding::DeltaLengthByteArray => encode_delta( + array.values(), + array.offsets().buffer(), + array.validity(), + is_optional, + &mut buffer, + ), + _ => { + return Err(Error::InvalidArgumentError(format!( + "Datatype {:?} cannot be encoded by {:?} encoding", + array.data_type(), + encoding + ))) + }, + } + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub(crate) fn build_statistics( + array: &BinaryArray, + primitive_type: PrimitiveType, +) -> ParquetStatistics { + let statistics = &BinaryStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + } as &dyn Statistics; + serialize_statistics(statistics) +} + +pub(crate) fn encode_delta( + values: &[u8], + offsets: &[O], + validity: Option<&Bitmap>, + is_optional: bool, + buffer: &mut Vec, +) { + if is_optional { + if let Some(validity) = validity { + let lengths = offsets + .windows(2) + .map(|w| (w[1] - w[0]).to_usize() as i64) + .zip(validity.iter()) + .flat_map(|(x, is_valid)| if is_valid { Some(x) } else { None }); + let length = offsets.len() - 1 - validity.unset_bits(); + let lengths = utils::ExactSizedIter::new(lengths, length); + + delta_bitpacked::encode(lengths, buffer); + } else { + let lengths = offsets.windows(2).map(|w| (w[1] - w[0]).to_usize() as i64); + delta_bitpacked::encode(lengths, buffer); + } + } else { + let lengths = offsets.windows(2).map(|w| (w[1] - w[0]).to_usize() as i64); + delta_bitpacked::encode(lengths, buffer); + } + + buffer.extend_from_slice( + &values[offsets.first().unwrap().to_usize()..offsets.last().unwrap().to_usize()], + ) +} + +/// Returns the ordering of two binary values. This corresponds to pyarrows' ordering +/// of statistics. +pub(crate) fn ord_binary<'a>(a: &'a [u8], b: &'a [u8]) -> std::cmp::Ordering { + use std::cmp::Ordering::*; + match (a.is_empty(), b.is_empty()) { + (true, true) => return Equal, + (true, false) => return Less, + (false, true) => return Greater, + (false, false) => {}, + } + + for (v1, v2) in a.iter().zip(b.iter()) { + match v1.cmp(v2) { + Equal => continue, + other => return other, + } + } + Equal +} diff --git a/crates/nano-arrow/src/io/parquet/write/binary/mod.rs b/crates/nano-arrow/src/io/parquet/write/binary/mod.rs new file mode 100644 index 000000000000..e942b4b69103 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/binary/mod.rs @@ -0,0 +1,7 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub(crate) use basic::{build_statistics, encode_plain}; +pub(super) use basic::{encode_delta, ord_binary}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/binary/nested.rs b/crates/nano-arrow/src/io/parquet/write/binary/nested.rs new file mode 100644 index 000000000000..11de9d9676a7 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/binary/nested.rs @@ -0,0 +1,48 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, BinaryArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; +use crate::offset::Offset; + +pub fn array_to_page( + array: &BinaryArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result +where + O: Offset, +{ + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, is_optional, &mut buffer); + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/boolean/basic.rs b/crates/nano-arrow/src/io/parquet/write/boolean/basic.rs new file mode 100644 index 000000000000..833bfab09e5a --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/boolean/basic.rs @@ -0,0 +1,92 @@ +use parquet2::encoding::hybrid_rle::bitpacked_encode; +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{ + serialize_statistics, BooleanStatistics, ParquetStatistics, Statistics, +}; + +use super::super::{utils, WriteOptions}; +use crate::array::*; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; + +fn encode(iterator: impl Iterator, buffer: &mut Vec) -> Result<()> { + // encode values using bitpacking + let len = buffer.len(); + let mut buffer = std::io::Cursor::new(buffer); + buffer.set_position(len as u64); + Ok(bitpacked_encode(&mut buffer, iterator)?) +} + +pub(super) fn encode_plain( + array: &BooleanArray, + is_optional: bool, + buffer: &mut Vec, +) -> Result<()> { + if is_optional { + let iter = array.iter().flatten().take( + array + .validity() + .as_ref() + .map(|x| x.len() - x.unset_bits()) + .unwrap_or_else(|| array.len()), + ); + encode(iter, buffer) + } else { + let iter = array.values().iter(); + encode(iter, buffer) + } +} + +pub fn array_to_page( + array: &BooleanArray, + options: WriteOptions, + type_: PrimitiveType, +) -> Result { + let is_optional = is_nullable(&type_.field_info); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + encode_plain(array, is_optional, &mut buffer)?; + + let statistics = if options.write_statistics { + Some(build_statistics(array)) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} + +pub(super) fn build_statistics(array: &BooleanArray) -> ParquetStatistics { + let statistics = &BooleanStatistics { + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array.iter().flatten().max(), + min_value: array.iter().flatten().min(), + } as &dyn Statistics; + serialize_statistics(statistics) +} diff --git a/crates/nano-arrow/src/io/parquet/write/boolean/mod.rs b/crates/nano-arrow/src/io/parquet/write/boolean/mod.rs new file mode 100644 index 000000000000..280e2ff9efb5 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/boolean/mod.rs @@ -0,0 +1,5 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/boolean/nested.rs b/crates/nano-arrow/src/io/parquet/write/boolean/nested.rs new file mode 100644 index 000000000000..656019100825 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/boolean/nested.rs @@ -0,0 +1,44 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, BooleanArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; + +pub fn array_to_page( + array: &BooleanArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result { + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, is_optional, &mut buffer)?; + + let statistics = if options.write_statistics { + Some(build_statistics(array)) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/dictionary.rs b/crates/nano-arrow/src/io/parquet/write/dictionary.rs new file mode 100644 index 000000000000..4ee0a5c37eac --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/dictionary.rs @@ -0,0 +1,281 @@ +use parquet2::encoding::hybrid_rle::encode_u32; +use parquet2::encoding::Encoding; +use parquet2::page::{DictPage, Page}; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, ParquetStatistics}; +use parquet2::write::DynIter; + +use super::binary::{ + build_statistics as binary_build_statistics, encode_plain as binary_encode_plain, +}; +use super::fixed_len_bytes::{ + build_statistics as fixed_binary_build_statistics, encode_plain as fixed_binary_encode_plain, +}; +use super::primitive::{ + build_statistics as primitive_build_statistics, encode_plain as primitive_encode_plain, +}; +use super::utf8::{build_statistics as utf8_build_statistics, encode_plain as utf8_encode_plain}; +use super::{nested, Nested, WriteOptions}; +use crate::array::{Array, DictionaryArray, DictionaryKey}; +use crate::bitmap::{Bitmap, MutableBitmap}; +use crate::datatypes::DataType; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::{slice_nested_leaf, utils}; + +fn serialize_def_levels_simple( + validity: Option<&Bitmap>, + length: usize, + is_optional: bool, + options: WriteOptions, + buffer: &mut Vec, +) -> Result<()> { + utils::write_def_levels(buffer, is_optional, validity, length, options.version) +} + +fn serialize_keys_values( + array: &DictionaryArray, + validity: Option<&Bitmap>, + buffer: &mut Vec, +) -> Result<()> { + let keys = array.keys_values_iter().map(|x| x as u32); + if let Some(validity) = validity { + // discard indices whose values are null. + let keys = keys + .zip(validity.iter()) + .filter(|&(_key, is_valid)| is_valid) + .map(|(key, _is_valid)| key); + let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64); + + let keys = utils::ExactSizedIter::new(keys, array.len() - validity.unset_bits()); + + // num_bits as a single byte + buffer.push(num_bits as u8); + + // followed by the encoded indices. + Ok(encode_u32(buffer, keys, num_bits)?) + } else { + let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64); + + // num_bits as a single byte + buffer.push(num_bits as u8); + + // followed by the encoded indices. + Ok(encode_u32(buffer, keys, num_bits)?) + } +} + +fn serialize_levels( + validity: Option<&Bitmap>, + length: usize, + type_: &PrimitiveType, + nested: &[Nested], + options: WriteOptions, + buffer: &mut Vec, +) -> Result<(usize, usize)> { + if nested.len() == 1 { + let is_optional = is_nullable(&type_.field_info); + serialize_def_levels_simple(validity, length, is_optional, options, buffer)?; + let definition_levels_byte_length = buffer.len(); + Ok((0, definition_levels_byte_length)) + } else { + nested::write_rep_and_def(options.version, nested, buffer) + } +} + +fn normalized_validity(array: &DictionaryArray) -> Option { + match (array.keys().validity(), array.values().validity()) { + (None, None) => None, + (None, rhs) => rhs.cloned(), + (lhs, None) => lhs.cloned(), + (Some(_), Some(rhs)) => { + let projected_validity = array + .keys_iter() + .map(|x| x.map(|x| rhs.get_bit(x)).unwrap_or(false)); + MutableBitmap::from_trusted_len_iter(projected_validity).into() + }, + } +} + +fn serialize_keys( + array: &DictionaryArray, + type_: PrimitiveType, + nested: &[Nested], + statistics: Option, + options: WriteOptions, +) -> Result { + let mut buffer = vec![]; + + // parquet only accepts a single validity - we "&" the validities into a single one + // and ignore keys whole _value_ is null. + let validity = normalized_validity(array); + let (start, len) = slice_nested_leaf(nested); + + let mut nested = nested.to_vec(); + let array = array.clone().sliced(start, len); + if let Some(Nested::Primitive(_, _, c)) = nested.last_mut() { + *c = len; + } else { + unreachable!("") + } + + let (repetition_levels_byte_length, definition_levels_byte_length) = serialize_levels( + validity.as_ref(), + array.len(), + &type_, + &nested, + options, + &mut buffer, + )?; + + serialize_keys_values(&array, validity.as_ref(), &mut buffer)?; + + let (num_values, num_rows) = if nested.len() == 1 { + (array.len(), array.len()) + } else { + (nested::num_values(&nested), nested[0].len()) + }; + + utils::build_plain_page( + buffer, + num_values, + num_rows, + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::RleDictionary, + ) + .map(Page::Data) +} + +macro_rules! dyn_prim { + ($from:ty, $to:ty, $array:expr, $options:expr, $type_:expr) => {{ + let values = $array.values().as_any().downcast_ref().unwrap(); + + let buffer = primitive_encode_plain::<$from, $to>(values, false, vec![]); + + let stats: Option = if $options.write_statistics { + let mut stats = primitive_build_statistics::<$from, $to>(values, $type_.clone()); + stats.null_count = Some($array.null_count() as i64); + let stats = serialize_statistics(&stats); + Some(stats) + } else { + None + }; + (DictPage::new(buffer, values.len(), false), stats) + }}; +} + +pub fn array_to_pages( + array: &DictionaryArray, + type_: PrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> Result>> { + match encoding { + Encoding::PlainDictionary | Encoding::RleDictionary => { + // write DictPage + let (dict_page, statistics): (_, Option) = + match array.values().data_type().to_logical_type() { + DataType::Int8 => dyn_prim!(i8, i32, array, options, type_), + DataType::Int16 => dyn_prim!(i16, i32, array, options, type_), + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { + dyn_prim!(i32, i32, array, options, type_) + }, + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => dyn_prim!(i64, i64, array, options, type_), + DataType::UInt8 => dyn_prim!(u8, i32, array, options, type_), + DataType::UInt16 => dyn_prim!(u16, i32, array, options, type_), + DataType::UInt32 => dyn_prim!(u32, i32, array, options, type_), + DataType::UInt64 => dyn_prim!(u64, i64, array, options, type_), + DataType::Float32 => dyn_prim!(f32, f32, array, options, type_), + DataType::Float64 => dyn_prim!(f64, f64, array, options, type_), + DataType::Utf8 => { + let array = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + utf8_encode_plain::(array, false, &mut buffer); + let stats = if options.write_statistics { + Some(utf8_build_statistics(array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + DataType::LargeUtf8 => { + let array = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + utf8_encode_plain::(array, false, &mut buffer); + let stats = if options.write_statistics { + Some(utf8_build_statistics(array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + DataType::Binary => { + let array = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + binary_encode_plain::(array, false, &mut buffer); + let stats = if options.write_statistics { + Some(binary_build_statistics(array, type_.clone())) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + DataType::LargeBinary => { + let values = array.values().as_any().downcast_ref().unwrap(); + + let mut buffer = vec![]; + binary_encode_plain::(values, false, &mut buffer); + let stats = if options.write_statistics { + let mut stats = binary_build_statistics(values, type_.clone()); + stats.null_count = Some(array.null_count() as i64); + Some(stats) + } else { + None + }; + (DictPage::new(buffer, values.len(), false), stats) + }, + DataType::FixedSizeBinary(_) => { + let mut buffer = vec![]; + let array = array.values().as_any().downcast_ref().unwrap(); + fixed_binary_encode_plain(array, false, &mut buffer); + let stats = if options.write_statistics { + let mut stats = fixed_binary_build_statistics(array, type_.clone()); + stats.null_count = Some(array.null_count() as i64); + Some(serialize_statistics(&stats)) + } else { + None + }; + (DictPage::new(buffer, array.len(), false), stats) + }, + other => { + return Err(Error::NotYetImplemented(format!( + "Writing dictionary arrays to parquet only support data type {other:?}" + ))) + }, + }; + let dict_page = Page::Dict(dict_page); + + // write DataPage pointing to DictPage + let data_page = serialize_keys(array, type_, nested, statistics, options)?; + + let iter = std::iter::once(Ok(dict_page)).chain(std::iter::once(Ok(data_page))); + Ok(DynIter::new(Box::new(iter))) + }, + _ => Err(Error::NotYetImplemented( + "Dictionary arrays only support dictionary encoding".to_string(), + )), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/file.rs b/crates/nano-arrow/src/io/parquet/write/file.rs new file mode 100644 index 000000000000..4ec37b941ad9 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/file.rs @@ -0,0 +1,95 @@ +use std::io::Write; + +use parquet2::metadata::{KeyValue, SchemaDescriptor}; +use parquet2::write::{RowGroupIter, WriteOptions as FileWriteOptions}; + +use super::schema::schema_to_metadata_key; +use super::{to_parquet_schema, ThriftFileMetaData, WriteOptions}; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; + +/// Attaches [`Schema`] to `key_value_metadata` +pub fn add_arrow_schema( + schema: &Schema, + key_value_metadata: Option>, +) -> Option> { + key_value_metadata + .map(|mut x| { + x.push(schema_to_metadata_key(schema)); + x + }) + .or_else(|| Some(vec![schema_to_metadata_key(schema)])) +} + +/// An interface to write a parquet to a [`Write`] +pub struct FileWriter { + writer: parquet2::write::FileWriter, + schema: Schema, + options: WriteOptions, +} + +// Accessors +impl FileWriter { + /// The options assigned to the file + pub fn options(&self) -> WriteOptions { + self.options + } + + /// The [`SchemaDescriptor`] assigned to this file + pub fn parquet_schema(&self) -> &SchemaDescriptor { + self.writer.schema() + } + + /// The [`Schema`] assigned to this file + pub fn schema(&self) -> &Schema { + &self.schema + } +} + +impl FileWriter { + /// Returns a new [`FileWriter`]. + /// # Error + /// If it is unable to derive a parquet schema from [`Schema`]. + pub fn try_new(writer: W, schema: Schema, options: WriteOptions) -> Result { + let parquet_schema = to_parquet_schema(&schema)?; + + let created_by = Some("Arrow2 - Native Rust implementation of Arrow".to_string()); + + Ok(Self { + writer: parquet2::write::FileWriter::new( + writer, + parquet_schema, + FileWriteOptions { + version: options.version, + write_statistics: options.write_statistics, + }, + created_by, + ), + schema, + options, + }) + } + + /// Writes a row group to the file. + pub fn write(&mut self, row_group: RowGroupIter<'_, Error>) -> Result<()> { + Ok(self.writer.write(row_group)?) + } + + /// Writes the footer of the parquet file. Returns the total size of the file. + pub fn end(&mut self, key_value_metadata: Option>) -> Result { + let key_value_metadata = add_arrow_schema(&self.schema, key_value_metadata); + Ok(self.writer.end(key_value_metadata)?) + } + + /// Consumes this writer and returns the inner writer + pub fn into_inner(self) -> W { + self.writer.into_inner() + } + + /// Returns the underlying writer and [`ThriftFileMetaData`] + /// # Panics + /// This function panics if [`Self::end`] has not yet been called + pub fn into_inner_and_metadata(self) -> (W, ThriftFileMetaData) { + self.writer.into_inner_and_metadata() + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/fixed_len_bytes.rs b/crates/nano-arrow/src/io/parquet/write/fixed_len_bytes.rs new file mode 100644 index 000000000000..86080ef7728f --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/fixed_len_bytes.rs @@ -0,0 +1,147 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, FixedLenStatistics}; + +use super::binary::ord_binary; +use super::{utils, WriteOptions}; +use crate::array::{Array, FixedSizeBinaryArray, PrimitiveArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::types::i256; + +pub(crate) fn encode_plain(array: &FixedSizeBinaryArray, is_optional: bool, buffer: &mut Vec) { + // append the non-null values + if is_optional { + array.iter().for_each(|x| { + if let Some(x) = x { + buffer.extend_from_slice(x); + } + }) + } else { + buffer.extend_from_slice(array.values()); + } +} + +pub fn array_to_page( + array: &FixedSizeBinaryArray, + options: WriteOptions, + type_: PrimitiveType, + statistics: Option, +) -> Result { + let is_optional = is_nullable(&type_.field_info); + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + encode_plain(array, is_optional, &mut buffer); + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics.map(|x| serialize_statistics(&x)), + type_, + options, + Encoding::Plain, + ) +} + +pub(super) fn build_statistics( + array: &FixedSizeBinaryArray, + primitive_type: PrimitiveType, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + min_value: array + .iter() + .flatten() + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + } +} + +pub(super) fn build_statistics_decimal( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max() + .map(|x| x.to_be_bytes()[16 - size..].to_vec()), + min_value: array + .iter() + .flatten() + .min() + .map(|x| x.to_be_bytes()[16 - size..].to_vec()), + } +} + +pub(super) fn build_statistics_decimal256_with_i128( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max() + .map(|x| x.0.low().to_be_bytes()[16 - size..].to_vec()), + min_value: array + .iter() + .flatten() + .min() + .map(|x| x.0.low().to_be_bytes()[16 - size..].to_vec()), + } +} + +pub(super) fn build_statistics_decimal256( + array: &PrimitiveArray, + primitive_type: PrimitiveType, + size: usize, +) -> FixedLenStatistics { + FixedLenStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .max() + .map(|x| x.0.to_be_bytes()[32 - size..].to_vec()), + min_value: array + .iter() + .flatten() + .min() + .map(|x| x.0.to_be_bytes()[32 - size..].to_vec()), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/mod.rs b/crates/nano-arrow/src/io/parquet/write/mod.rs new file mode 100644 index 000000000000..b74daea04d7e --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/mod.rs @@ -0,0 +1,876 @@ +//! APIs to write to Parquet format. +//! +//! # Arrow/Parquet Interoperability +//! As of [parquet-format v2.9](https://github.com/apache/parquet-format/blob/master/LogicalTypes.md) +//! there are Arrow [DataTypes](crate::datatypes::DataType) which do not have a parquet +//! representation. These include but are not limited to: +//! * `DataType::Timestamp(TimeUnit::Second, _)` +//! * `DataType::Int64` +//! * `DataType::Duration` +//! * `DataType::Date64` +//! * `DataType::Time32(TimeUnit::Second)` +//! +//! The use of these arrow types will result in no logical type being stored within a parquet file. + +mod binary; +mod boolean; +mod dictionary; +mod file; +mod fixed_len_bytes; +mod nested; +mod pages; +mod primitive; +mod row_group; +mod schema; +mod sink; +mod utf8; +mod utils; + +pub use nested::{num_values, write_rep_and_def}; +pub use pages::{to_leaves, to_nested, to_parquet_leaves}; +pub use parquet2::compression::{BrotliLevel, CompressionOptions, GzipLevel, ZstdLevel}; +pub use parquet2::encoding::Encoding; +pub use parquet2::metadata::{ + Descriptor, FileMetaData, KeyValue, SchemaDescriptor, ThriftFileMetaData, +}; +pub use parquet2::page::{CompressedDataPage, CompressedPage, Page}; +use parquet2::schema::types::PrimitiveType as ParquetPrimitiveType; +pub use parquet2::schema::types::{FieldInfo, ParquetType, PhysicalType as ParquetPhysicalType}; +pub use parquet2::write::{ + compress, write_metadata_sidecar, Compressor, DynIter, DynStreamingIterator, RowGroupIter, + Version, +}; +pub use parquet2::{fallible_streaming_iterator, FallibleStreamingIterator}; +pub use utils::write_def_levels; + +use crate::array::*; +use crate::datatypes::*; +use crate::error::{Error, Result}; +use crate::types::{days_ms, i256, NativeType}; + +/// Currently supported options to write to parquet +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct WriteOptions { + /// Whether to write statistics + pub write_statistics: bool, + /// The page and file version to use + pub version: Version, + /// The compression to apply to every page + pub compression: CompressionOptions, + /// The size to flush a page, defaults to 1024 * 1024 if None + pub data_pagesize_limit: Option, +} + +pub use file::FileWriter; +pub use pages::{array_to_columns, Nested}; +pub use row_group::{row_group_iter, RowGroupIterator}; +pub use schema::to_parquet_type; +pub use sink::FileSink; + +use crate::compute::aggregate::estimated_bytes_size; + +/// returns offset and length to slice the leaf values +pub fn slice_nested_leaf(nested: &[Nested]) -> (usize, usize) { + // find the deepest recursive dremel structure as that one determines how many values we must + // take + let mut out = (0, 0); + for nested in nested.iter().rev() { + match nested { + Nested::LargeList(l_nested) => { + let start = *l_nested.offsets.first(); + let end = *l_nested.offsets.last(); + return (start as usize, (end - start) as usize); + }, + Nested::List(l_nested) => { + let start = *l_nested.offsets.first(); + let end = *l_nested.offsets.last(); + return (start as usize, (end - start) as usize); + }, + Nested::Primitive(_, _, len) => out = (0, *len), + _ => {}, + } + } + out +} + +fn decimal_length_from_precision(precision: usize) -> usize { + // digits = floor(log_10(2^(8*n - 1) - 1)) + // ceil(digits) = log10(2^(8*n - 1) - 1) + // 10^ceil(digits) = 2^(8*n - 1) - 1 + // 10^ceil(digits) + 1 = 2^(8*n - 1) + // log2(10^ceil(digits) + 1) = (8*n - 1) + // log2(10^ceil(digits) + 1) + 1 = 8*n + // (log2(10^ceil(a) + 1) + 1) / 8 = n + (((10.0_f64.powi(precision as i32) + 1.0).log2() + 1.0) / 8.0).ceil() as usize +} + +/// Creates a parquet [`SchemaDescriptor`] from a [`Schema`]. +pub fn to_parquet_schema(schema: &Schema) -> Result { + let parquet_types = schema + .fields + .iter() + .map(to_parquet_type) + .collect::>>()?; + Ok(SchemaDescriptor::new("root".to_string(), parquet_types)) +} + +/// Checks whether the `data_type` can be encoded as `encoding`. +/// Note that this is whether this implementation supports it, which is a subset of +/// what the parquet spec allows. +pub fn can_encode(data_type: &DataType, encoding: Encoding) -> bool { + if let (Encoding::DeltaBinaryPacked, DataType::Decimal(p, _)) = + (encoding, data_type.to_logical_type()) + { + return *p <= 18; + }; + + matches!( + (encoding, data_type.to_logical_type()), + (Encoding::Plain, _) + | ( + Encoding::DeltaLengthByteArray, + DataType::Binary | DataType::LargeBinary | DataType::Utf8 | DataType::LargeUtf8, + ) + | (Encoding::RleDictionary, DataType::Dictionary(_, _, _)) + | (Encoding::PlainDictionary, DataType::Dictionary(_, _, _)) + | ( + Encoding::DeltaBinaryPacked, + DataType::Null + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Date32 + | DataType::Time32(_) + | DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + ) + ) +} + +/// Slices the [`Array`] to `Box` and `Vec`. +pub fn slice_parquet_array( + primitive_array: &mut dyn Array, + nested: &mut [Nested], + mut current_offset: usize, + mut current_length: usize, +) { + for nested in nested.iter_mut() { + match nested { + Nested::LargeList(l_nested) => { + l_nested.offsets.slice(current_offset, current_length + 1); + if let Some(validity) = l_nested.validity.as_mut() { + validity.slice(current_offset, current_length) + }; + + current_length = l_nested.offsets.range() as usize; + current_offset = *l_nested.offsets.first() as usize; + }, + Nested::List(l_nested) => { + l_nested.offsets.slice(current_offset, current_length + 1); + if let Some(validity) = l_nested.validity.as_mut() { + validity.slice(current_offset, current_length) + }; + + current_length = l_nested.offsets.range() as usize; + current_offset = *l_nested.offsets.first() as usize; + }, + Nested::Struct(validity, _, length) => { + *length = current_length; + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + }, + Nested::Primitive(validity, _, length) => { + *length = current_length; + if let Some(validity) = validity.as_mut() { + validity.slice(current_offset, current_length) + }; + primitive_array.slice(current_offset, current_length); + }, + } + } +} + +/// Get the length of [`Array`] that should be sliced. +pub fn get_max_length(nested: &[Nested]) -> usize { + let mut length = 0; + for nested in nested.iter() { + match nested { + Nested::LargeList(l_nested) => length += l_nested.offsets.range() as usize, + Nested::List(l_nested) => length += l_nested.offsets.range() as usize, + _ => {}, + } + } + length +} + +/// Returns an iterator of [`Page`]. +pub fn array_to_pages( + primitive_array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> Result>> { + if let DataType::Dictionary(key_type, _, _) = primitive_array.data_type().to_logical_type() { + return match_integer_type!(key_type, |$T| { + dictionary::array_to_pages::<$T>( + primitive_array.as_any().downcast_ref().unwrap(), + type_, + &nested, + options, + encoding, + ) + }); + }; + + let nested = nested.to_vec(); + let primitive_array = primitive_array.to_boxed(); + + let number_of_rows = nested[0].len(); + + // note: this is not correct if the array is sliced - the estimation should happen on the + // primitive after sliced for parquet + let byte_size = estimated_bytes_size(primitive_array.as_ref()); + + const DEFAULT_PAGE_SIZE: usize = 1024 * 1024; + let max_page_size = options.data_pagesize_limit.unwrap_or(DEFAULT_PAGE_SIZE); + let max_page_size = max_page_size.min(2usize.pow(31) - 2usize.pow(25)); // allowed maximum page size + let bytes_per_row = if number_of_rows == 0 { + 0 + } else { + ((byte_size as f64) / (number_of_rows as f64)) as usize + }; + let rows_per_page = (max_page_size / (bytes_per_row + 1)).max(1); + + let pages = (0..number_of_rows) + .step_by(rows_per_page) + .map(move |offset| { + let length = if offset + rows_per_page > number_of_rows { + number_of_rows - offset + } else { + rows_per_page + }; + + let mut right_array = primitive_array.clone(); + let mut right_nested = nested.clone(); + slice_parquet_array(right_array.as_mut(), &mut right_nested, offset, length); + + array_to_page( + right_array.as_ref(), + type_.clone(), + &right_nested, + options, + encoding, + ) + }); + + Ok(DynIter::new(pages)) +} + +/// Converts an [`Array`] to a [`CompressedPage`] based on options, descriptor and `encoding`. +pub fn array_to_page( + array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + encoding: Encoding, +) -> Result { + if nested.len() == 1 { + // special case where validity == def levels + return array_to_page_simple(array, type_, options, encoding); + } + array_to_page_nested(array, type_, nested, options, encoding) +} + +/// Converts an [`Array`] to a [`CompressedPage`] based on options, descriptor and `encoding`. +pub fn array_to_page_simple( + array: &dyn Array, + type_: ParquetPrimitiveType, + options: WriteOptions, + encoding: Encoding, +) -> Result { + let data_type = array.data_type(); + if !can_encode(data_type, encoding) { + return Err(Error::InvalidArgumentError(format!( + "The datatype {data_type:?} cannot be encoded by {encoding:?}" + ))); + } + + match data_type.to_logical_type() { + DataType::Boolean => { + boolean::array_to_page(array.as_any().downcast_ref().unwrap(), options, type_) + }, + // casts below MUST match the casts done at the metadata (field -> parquet type). + DataType::UInt8 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::UInt16 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::UInt32 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::UInt64 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Int8 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Int16 => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Int32 | DataType::Date32 | DataType::Time32(_) => { + primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ) + }, + DataType::Int64 + | DataType::Date64 + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) => primitive::array_to_page_integer::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Float32 => primitive::array_to_page_plain::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + ), + DataType::Float64 => primitive::array_to_page_plain::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + ), + DataType::Utf8 => utf8::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::LargeUtf8 => utf8::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Binary => binary::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::LargeBinary => binary::array_to_page::( + array.as_any().downcast_ref().unwrap(), + options, + type_, + encoding, + ), + DataType::Null => { + let array = Int32Array::new_null(DataType::Int32, array.len()); + primitive::array_to_page_plain::(&array, options, type_) + }, + DataType::Interval(IntervalUnit::YearMonth) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let mut values = Vec::::with_capacity(12 * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_le_bytes(); + values.extend_from_slice(bytes); + values.extend_from_slice(&[0; 8]); + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(12), + values.into(), + array.validity().cloned(), + ); + let statistics = if options.write_statistics { + Some(fixed_len_bytes::build_statistics(&array, type_.clone())) + } else { + None + }; + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + }, + DataType::Interval(IntervalUnit::DayTime) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let mut values = Vec::::with_capacity(12 * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_le_bytes(); + values.extend_from_slice(&[0; 4]); // months + values.extend_from_slice(bytes); // days and seconds + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(12), + values.into(), + array.validity().cloned(), + ); + let statistics = if options.write_statistics { + Some(fixed_len_bytes::build_statistics(&array, type_.clone())) + } else { + None + }; + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + }, + DataType::FixedSizeBinary(_) => { + let array = array.as_any().downcast_ref().unwrap(); + let statistics = if options.write_statistics { + Some(fixed_len_bytes::build_statistics(array, type_.clone())) + } else { + None + }; + + fixed_len_bytes::array_to_page(array, options, type_, statistics) + }, + DataType::Decimal256(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i32()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i64()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else if precision <= 38 { + let size = decimal_length_from_precision(precision); + let statistics = if options.write_statistics { + let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + array, + type_.clone(), + size, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.0.low().to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } else { + let size = 32; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal256(array, type_.clone(), size); + Some(stats) + } else { + None + }; + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes(); + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + DataType::Decimal(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::array_to_page_integer::(&array, options, type_, encoding) + } else { + let size = decimal_length_from_precision(precision); + + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal(array, type_.clone(), size); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + other => Err(Error::NotYetImplemented(format!( + "Writing parquet pages for data type {other:?}" + ))), + } + .map(Page::Data) +} + +fn array_to_page_nested( + array: &dyn Array, + type_: ParquetPrimitiveType, + nested: &[Nested], + options: WriteOptions, + _encoding: Encoding, +) -> Result { + use DataType::*; + match array.data_type().to_logical_type() { + Null => { + let array = Int32Array::new_null(DataType::Int32, array.len()); + primitive::nested_array_to_page::(&array, options, type_, nested) + }, + Boolean => { + let array = array.as_any().downcast_ref().unwrap(); + boolean::nested_array_to_page(array, options, type_, nested) + }, + Utf8 => { + let array = array.as_any().downcast_ref().unwrap(); + utf8::nested_array_to_page::(array, options, type_, nested) + }, + LargeUtf8 => { + let array = array.as_any().downcast_ref().unwrap(); + utf8::nested_array_to_page::(array, options, type_, nested) + }, + Binary => { + let array = array.as_any().downcast_ref().unwrap(); + binary::nested_array_to_page::(array, options, type_, nested) + }, + LargeBinary => { + let array = array.as_any().downcast_ref().unwrap(); + binary::nested_array_to_page::(array, options, type_, nested) + }, + UInt8 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt16 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt32 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + UInt64 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int8 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int16 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int32 | Date32 | Time32(_) => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Int64 | Date64 | Time64(_) | Timestamp(_, _) | Duration(_) => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Float32 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Float64 => { + let array = array.as_any().downcast_ref().unwrap(); + primitive::nested_array_to_page::(array, options, type_, nested) + }, + Decimal(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else { + let size = decimal_length_from_precision(precision); + + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal(array, type_.clone(), size); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + Decimal256(precision, _) => { + let precision = *precision; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + if precision <= 9 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i32()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int32, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 18 { + let values = array + .values() + .iter() + .map(|x| x.0.as_i64()) + .collect::>() + .into(); + + let array = + PrimitiveArray::::new(DataType::Int64, values, array.validity().cloned()); + primitive::nested_array_to_page::(&array, options, type_, nested) + } else if precision <= 38 { + let size = decimal_length_from_precision(precision); + let statistics = if options.write_statistics { + let stats = fixed_len_bytes::build_statistics_decimal256_with_i128( + array, + type_.clone(), + size, + ); + Some(stats) + } else { + None + }; + + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.0.low().to_be_bytes()[16 - size..]; + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } else { + let size = 32; + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let statistics = if options.write_statistics { + let stats = + fixed_len_bytes::build_statistics_decimal256(array, type_.clone(), size); + Some(stats) + } else { + None + }; + let mut values = Vec::::with_capacity(size * array.len()); + array.values().iter().for_each(|x| { + let bytes = &x.to_be_bytes(); + values.extend_from_slice(bytes) + }); + let array = FixedSizeBinaryArray::new( + DataType::FixedSizeBinary(size), + values.into(), + array.validity().cloned(), + ); + + fixed_len_bytes::array_to_page(&array, options, type_, statistics) + } + }, + other => Err(Error::NotYetImplemented(format!( + "Writing nested parquet pages for data type {other:?}" + ))), + } + .map(Page::Data) +} + +fn transverse_recursive T + Clone>( + data_type: &DataType, + map: F, + encodings: &mut Vec, +) { + use crate::datatypes::PhysicalType::*; + match data_type.to_physical_type() { + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | Dictionary(_) | LargeUtf8 => encodings.push(map(data_type)), + List | FixedSizeList | LargeList => { + let a = data_type.to_logical_type(); + if let DataType::List(inner) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else if let DataType::LargeList(inner) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else if let DataType::FixedSizeList(inner, _) = a { + transverse_recursive(&inner.data_type, map, encodings) + } else { + unreachable!() + } + }, + Struct => { + if let DataType::Struct(fields) = data_type.to_logical_type() { + for field in fields { + transverse_recursive(&field.data_type, map.clone(), encodings) + } + } else { + unreachable!() + } + }, + Map => { + if let DataType::Map(field, _) = data_type.to_logical_type() { + if let DataType::Struct(fields) = field.data_type.to_logical_type() { + for field in fields { + transverse_recursive(&field.data_type, map.clone(), encodings) + } + } else { + unreachable!() + } + } else { + unreachable!() + } + }, + Union => todo!(), + } +} + +/// Transverses the `data_type` up to its (parquet) columns and returns a vector of +/// items based on `map`. +/// This is used to assign an [`Encoding`] to every parquet column based on the columns' type (see example) +/// # Example +/// ``` +/// use arrow2::io::parquet::write::{transverse, Encoding}; +/// use arrow2::datatypes::{DataType, Field}; +/// +/// let dt = DataType::Struct(vec![ +/// Field::new("a", DataType::Int64, true), +/// Field::new("b", DataType::List(Box::new(Field::new("item", DataType::Int32, true))), true), +/// ]); +/// +/// let encodings = transverse(&dt, |dt| Encoding::Plain); +/// assert_eq!(encodings, vec![Encoding::Plain, Encoding::Plain]); +/// ``` +pub fn transverse T + Clone>(data_type: &DataType, map: F) -> Vec { + let mut encodings = vec![]; + transverse_recursive(data_type, map, &mut encodings); + encodings +} diff --git a/crates/nano-arrow/src/io/parquet/write/nested/def.rs b/crates/nano-arrow/src/io/parquet/write/nested/def.rs new file mode 100644 index 000000000000..02947dd5bef9 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/nested/def.rs @@ -0,0 +1,584 @@ +use super::super::pages::{ListNested, Nested}; +use super::rep::num_values; +use super::to_length; +use crate::bitmap::Bitmap; +use crate::offset::Offset; + +trait DebugIter: Iterator + std::fmt::Debug {} + +impl + std::fmt::Debug> DebugIter for A {} + +fn single_iter<'a>( + validity: &'a Option, + is_optional: bool, + length: usize, +) -> Box { + match (is_optional, validity) { + (false, _) => { + Box::new(std::iter::repeat((0u32, 1usize)).take(length)) as Box + }, + (true, None) => { + Box::new(std::iter::repeat((1u32, 1usize)).take(length)) as Box + }, + (true, Some(validity)) => { + Box::new(validity.iter().map(|v| (v as u32, 1usize)).take(length)) as Box + }, + } +} + +fn single_list_iter<'a, O: Offset>(nested: &'a ListNested) -> Box { + match (nested.is_optional, &nested.validity) { + (false, _) => Box::new( + std::iter::repeat(0u32) + .zip(to_length(&nested.offsets)) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + (true, None) => Box::new( + std::iter::repeat(1u32) + .zip(to_length(&nested.offsets)) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + (true, Some(validity)) => Box::new( + validity + .iter() + .map(|x| (x as u32)) + .zip(to_length(&nested.offsets)) + .map(|(a, b)| (a + (b != 0) as u32, b)), + ) as Box, + } +} + +fn iter<'a>(nested: &'a [Nested]) -> Vec> { + nested + .iter() + .map(|nested| match nested { + Nested::Primitive(validity, is_optional, length) => { + single_iter(validity, *is_optional, *length) + }, + Nested::List(nested) => single_list_iter(nested), + Nested::LargeList(nested) => single_list_iter(nested), + Nested::Struct(validity, is_optional, length) => { + single_iter(validity, *is_optional, *length) + }, + }) + .collect() +} + +/// Iterator adapter of parquet / dremel definition levels +#[derive(Debug)] +pub struct DefLevelsIter<'a> { + // iterators of validities and lengths. E.g. [[[None,b,c], None], None] -> [[(true, 2), (false, 0)], [(true, 3), (false, 0)], [(false, 1), (true, 1), (true, 1)]] + iter: Vec>, + // vector containing the remaining number of values of each iterator. + // e.g. the iters [[2, 2], [3, 4, 1, 2]] after the first iteration will return [2, 3], + // and remaining will be [2, 3]. + // on the second iteration, it will be `[2, 2]` (since iterations consume the last items) + remaining: Vec, /* < remaining.len() == iter.len() */ + validity: Vec, + // cache of the first `remaining` that is non-zero. Examples: + // * `remaining = [2, 2] => current_level = 2` + // * `remaining = [2, 0] => current_level = 1` + // * `remaining = [0, 0] => current_level = 0` + current_level: usize, /* < iter.len() */ + // the total definition level at any given point during the iteration + total: u32, /* < iter.len() */ + // the total number of items that this iterator will return + remaining_values: usize, +} + +impl<'a> DefLevelsIter<'a> { + pub fn new(nested: &'a [Nested]) -> Self { + let remaining_values = num_values(nested); + + let iter = iter(nested); + let remaining = vec![0; iter.len()]; + let validity = vec![0; iter.len()]; + + Self { + iter, + remaining, + validity, + total: 0, + current_level: 0, + remaining_values, + } + } +} + +impl<'a> Iterator for DefLevelsIter<'a> { + type Item = u32; + + fn next(&mut self) -> Option { + if self.remaining_values == 0 { + return None; + } + + if self.remaining.is_empty() { + self.remaining_values -= 1; + return Some(0); + } + + let mut empty_contrib = 0u32; + for ((iter, remaining), validity) in self + .iter + .iter_mut() + .zip(self.remaining.iter_mut()) + .zip(self.validity.iter_mut()) + .skip(self.current_level) + { + let (is_valid, length): (u32, usize) = iter.next()?; + *validity = is_valid; + self.total += is_valid; + + *remaining = length; + if length == 0 { + *validity = 0; + self.total -= is_valid; + empty_contrib = is_valid; + break; + } + self.current_level += 1; + } + + // track + if let Some(x) = self.remaining.get_mut(self.current_level.saturating_sub(1)) { + *x = x.saturating_sub(1) + } + + let r = Some(self.total + empty_contrib); + + for index in (1..self.current_level).rev() { + if self.remaining[index] == 0 { + self.current_level -= 1; + self.remaining[index - 1] -= 1; + self.total -= self.validity[index]; + } + } + if self.remaining[0] == 0 { + self.current_level = self.current_level.saturating_sub(1); + self.total -= self.validity[0]; + } + self.remaining_values -= 1; + r + } + + fn size_hint(&self) -> (usize, Option) { + let length = self.remaining_values; + (length, Some(length)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test(nested: Vec, expected: Vec) { + let mut iter = DefLevelsIter::new(&nested); + assert_eq!(iter.size_hint().0, expected.len()); + let result = iter.by_ref().collect::>(); + assert_eq!(result, expected); + assert_eq!(iter.size_hint().0, 0); + } + + #[test] + fn struct_optional() { + let b = [ + true, false, true, true, false, true, false, false, true, true, + ]; + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(Some(b.into()), true, 10), + ]; + let expected = vec![2, 1, 2, 2, 1, 2, 1, 1, 2, 2]; + + test(nested, expected) + } + + #[test] + fn nested_edge_simple() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 2), + ]; + let expected = vec![3, 3]; + + test(nested, expected) + } + + #[test] + fn struct_optional_1() { + let b = [ + true, false, true, true, false, true, false, false, true, true, + ]; + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(Some(b.into()), true, 10), + ]; + let expected = vec![2, 1, 2, 2, 1, 2, 1, 1, 2, 2]; + + test(nested, expected) + } + + #[test] + fn struct_optional_optional() { + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(None, true, 10), + ]; + let expected = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; + + test(nested, expected) + } + + #[test] + fn l1_required_required() { + let nested = vec![ + // [[0, 1], [], [2, 0, 3], [4, 5, 6], [], [7, 8, 9], [], [10]] + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 12), + ]; + let expected = vec![1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1]; + + test(nested, expected) + } + + #[test] + fn l1_optional_optional() { + // [[0, 1], None, [2, None, 3], [4, 5, 6], [], [7, 8, 9], None, [10]] + + let v0 = [true, false, true, true, true, true, false, true]; + let v1 = [ + true, true, //[0, 1] + true, false, true, //[2, None, 3] + true, true, true, //[4, 5, 6] + true, true, true, //[7, 8, 9] + true, //[10] + ]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(v0.into()), + }), + Nested::Primitive(Some(v1.into()), true, 12), + ]; + let expected = vec![3u32, 3, 0, 3, 2, 3, 3, 3, 3, 1, 3, 3, 3, 0, 3]; + + test(nested, expected) + } + + #[test] + fn l2_required_required_required() { + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + [ + [8], + [9, 10] + ] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 2]; + + test(nested, expected) + } + + #[test] + fn l2_optional_required_required() { + let a = [true, false, true, true]; + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + None, + [ + [8], + [], + [9, 10] + ] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 2, 5].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![3, 3, 3, 3, 3, 3, 3, 0, 1, 3, 2, 3, 3]; + + test(nested, expected) + } + + #[test] + fn l2_optional_optional_required() { + let a = [true, false, true]; + let b = [true, true, true, true, false]; + /* + [ + [ + [1,2,3], + [4,5,6,7], + ], + None, + [ + [8], + [], + None, + ], + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 7, 8, 8, 8].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(None, false, 8), + ]; + let expected = vec![4, 4, 4, 4, 4, 4, 4, 0, 4, 3, 2]; + + test(nested, expected) + } + + #[test] + fn l2_optional_optional_optional() { + let a = [true, false, true]; + let b = [true, true, true, false]; + let c = [true, true, true, true, false, true, true, true]; + /* + [ + [ + [1,2,3], + [4,None,6,7], + ], + None, + [ + [8], + None, + ], + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 4].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 7, 8, 8].try_into().unwrap(), + validity: Some(b.into()), + }), + Nested::Primitive(Some(c.into()), true, 8), + ]; + let expected = vec![5, 5, 5, 5, 4, 5, 5, 0, 5, 2]; + + test(nested, expected) + } + + /* + [{"a": "a"}, {"a": "b"}], + None, + [{"a": "b"}, None, {"a": "b"}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": "d"}, {"a": "d"}, {"a": "d"}], + None, + [{"a": "e"}], + */ + #[test] + fn nested_list_struct_nullable() { + let a = [ + true, true, true, false, true, false, false, false, true, true, true, true, + ]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let c = [true, false, true, true, true, true, false, true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Struct(Some(b.into()), true, 12), + Nested::Primitive(Some(a.into()), true, 12), + ]; + let expected = vec![4, 4, 0, 4, 2, 4, 3, 3, 3, 1, 4, 4, 4, 0, 4]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_nullable1() { + let c = [true, false]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Struct(None, true, 1), + Nested::Primitive(None, true, 1), + ]; + let expected = vec![4, 0]; + + test(nested, expected) + } + + #[test] + fn nested_struct_list_nullable() { + let a = [true, false, true, true, true, true, false, true]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let nested = vec![ + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Primitive(Some(b.into()), true, 12), + ]; + let expected = vec![4, 4, 1, 4, 3, 4, 4, 4, 4, 2, 4, 4, 4, 1, 4]; + + test(nested, expected) + } + + #[test] + fn nested_struct_list_nullable1() { + let a = [true, true, false]; + let nested = vec![ + Nested::Struct(None, true, 3), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1, 1].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Primitive(None, true, 1), + ]; + let expected = vec![4, 2, 1]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_list_nullable1() { + /* + [ + [{"a": ["b"]}, None], + ] + */ + + let a = [true]; + let b = [true, false]; + let c = [true, false]; + let d = [true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Struct(Some(b.into()), true, 2), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: Some(c.into()), + }), + Nested::Primitive(Some(d.into()), true, 1), + ]; + /* + 0 6 + 1 6 + 0 0 + 0 6 + 1 2 + */ + let expected = vec![6, 2]; + + test(nested, expected) + } + + #[test] + fn nested_list_struct_list_nullable() { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + None, + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": None}, {"a": None}, {"a": None}], + [], + [{"a": ["d"]}, {"a": [None]}, {"a": ["c", "d"]}], + None, + [{"a": []}], + ] + */ + let a = [true, false, true, true, true, true, false, true]; + let b = [ + true, true, true, false, true, true, true, true, true, true, true, true, + ]; + let c = [ + true, true, true, false, true, false, false, false, true, true, true, true, + ]; + let d = [true, true, true, true, true, false, true, true]; + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: Some(a.into()), + }), + Nested::Struct(Some(b.into()), true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8, 8] + .try_into() + .unwrap(), + validity: Some(c.into()), + }), + Nested::Primitive(Some(d.into()), true, 8), + ]; + let expected = vec![6, 6, 0, 6, 2, 6, 3, 3, 3, 1, 6, 5, 6, 6, 0, 4]; + + test(nested, expected) + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/nested/mod.rs b/crates/nano-arrow/src/io/parquet/write/nested/mod.rs new file mode 100644 index 000000000000..042d731c57de --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/nested/mod.rs @@ -0,0 +1,118 @@ +mod def; +mod rep; + +use parquet2::encoding::hybrid_rle::encode_u32; +use parquet2::read::levels::get_bit_width; +use parquet2::write::Version; +pub use rep::num_values; + +use super::Nested; +use crate::error::Result; +use crate::offset::Offset; + +fn write_levels_v1) -> Result<()>>( + buffer: &mut Vec, + encode: F, +) -> Result<()> { + buffer.extend_from_slice(&[0; 4]); + let start = buffer.len(); + + encode(buffer)?; + + let end = buffer.len(); + let length = end - start; + + // write the first 4 bytes as length + let length = (length as i32).to_le_bytes(); + (0..4).for_each(|i| buffer[start - 4 + i] = length[i]); + Ok(()) +} + +/// writes the rep levels to a `Vec`. +fn write_rep_levels(buffer: &mut Vec, nested: &[Nested], version: Version) -> Result<()> { + let max_level = max_rep_level(nested) as i16; + if max_level == 0 { + return Ok(()); + } + let num_bits = get_bit_width(max_level); + + let levels = rep::RepLevelsIter::new(nested); + + match version { + Version::V1 => { + write_levels_v1(buffer, |buffer: &mut Vec| { + encode_u32(buffer, levels, num_bits)?; + Ok(()) + })?; + }, + Version::V2 => { + encode_u32(buffer, levels, num_bits)?; + }, + } + + Ok(()) +} + +/// writes the rep levels to a `Vec`. +fn write_def_levels(buffer: &mut Vec, nested: &[Nested], version: Version) -> Result<()> { + let max_level = max_def_level(nested) as i16; + if max_level == 0 { + return Ok(()); + } + let num_bits = get_bit_width(max_level); + + let levels = def::DefLevelsIter::new(nested); + + match version { + Version::V1 => write_levels_v1(buffer, move |buffer: &mut Vec| { + encode_u32(buffer, levels, num_bits)?; + Ok(()) + }), + Version::V2 => Ok(encode_u32(buffer, levels, num_bits)?), + } +} + +fn max_def_level(nested: &[Nested]) -> usize { + nested + .iter() + .map(|nested| match nested { + Nested::Primitive(_, is_optional, _) => *is_optional as usize, + Nested::List(nested) => 1 + (nested.is_optional as usize), + Nested::LargeList(nested) => 1 + (nested.is_optional as usize), + Nested::Struct(_, is_optional, _) => *is_optional as usize, + }) + .sum() +} + +fn max_rep_level(nested: &[Nested]) -> usize { + nested + .iter() + .map(|nested| match nested { + Nested::LargeList(_) | Nested::List(_) => 1, + Nested::Primitive(_, _, _) | Nested::Struct(_, _, _) => 0, + }) + .sum() +} + +fn to_length( + offsets: &[O], +) -> impl Iterator + std::fmt::Debug + Clone + '_ { + offsets + .windows(2) + .map(|w| w[1].to_usize() - w[0].to_usize()) +} + +/// Write `repetition_levels` and `definition_levels` to buffer. +pub fn write_rep_and_def( + page_version: Version, + nested: &[Nested], + buffer: &mut Vec, +) -> Result<(usize, usize)> { + write_rep_levels(buffer, nested, page_version)?; + let repetition_levels_byte_length = buffer.len(); + + write_def_levels(buffer, nested, page_version)?; + let definition_levels_byte_length = buffer.len() - repetition_levels_byte_length; + + Ok((repetition_levels_byte_length, definition_levels_byte_length)) +} diff --git a/crates/nano-arrow/src/io/parquet/write/nested/rep.rs b/crates/nano-arrow/src/io/parquet/write/nested/rep.rs new file mode 100644 index 000000000000..2bfbe1ce24f4 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/nested/rep.rs @@ -0,0 +1,370 @@ +use super::super::pages::Nested; +use super::to_length; + +trait DebugIter: Iterator + std::fmt::Debug {} + +impl + std::fmt::Debug> DebugIter for A {} + +fn iter<'a>(nested: &'a [Nested]) -> Vec> { + nested + .iter() + .filter_map(|nested| match nested { + Nested::Primitive(_, _, _) => None, + Nested::List(nested) => { + Some(Box::new(to_length(&nested.offsets)) as Box) + }, + Nested::LargeList(nested) => { + Some(Box::new(to_length(&nested.offsets)) as Box) + }, + Nested::Struct(_, _, _) => None, + }) + .collect() +} + +/// return number values of the nested +pub fn num_values(nested: &[Nested]) -> usize { + let pr = match nested.last().unwrap() { + Nested::Primitive(_, _, len) => *len, + _ => todo!(), + }; + + iter(nested) + .into_iter() + .enumerate() + .map(|(_, lengths)| { + lengths + .map(|length| if length == 0 { 1 } else { 0 }) + .sum::() + }) + .sum::() + + pr +} + +/// Iterator adapter of parquet / dremel repetition levels +#[derive(Debug)] +pub struct RepLevelsIter<'a> { + // iterators of lengths. E.g. [[[a,b,c], [d,e,f,g]], [[h], [i,j]]] -> [[2, 2], [3, 4, 1, 2]] + iter: Vec>, + // vector containing the remaining number of values of each iterator. + // e.g. the iters [[2, 2], [3, 4, 1, 2]] after the first iteration will return [2, 3], + // and remaining will be [2, 3]. + // on the second iteration, it will be `[2, 2]` (since iterations consume the last items) + remaining: Vec, /* < remaining.len() == iter.len() */ + // cache of the first `remaining` that is non-zero. Examples: + // * `remaining = [2, 2] => current_level = 2` + // * `remaining = [2, 0] => current_level = 1` + // * `remaining = [0, 0] => current_level = 0` + current_level: usize, /* < iter.len() */ + // the number to discount due to being the first element of the iterators. + total: usize, /* < iter.len() */ + + // the total number of items that this iterator will return + remaining_values: usize, +} + +impl<'a> RepLevelsIter<'a> { + pub fn new(nested: &'a [Nested]) -> Self { + let remaining_values = num_values(nested); + + let iter = iter(nested); + let remaining = vec![0; iter.len()]; + + Self { + iter, + remaining, + total: 0, + current_level: 0, + remaining_values, + } + } +} + +impl<'a> Iterator for RepLevelsIter<'a> { + type Item = u32; + + fn next(&mut self) -> Option { + if self.remaining_values == 0 { + return None; + } + if self.remaining.is_empty() { + self.remaining_values -= 1; + return Some(0); + } + + for (iter, remaining) in self + .iter + .iter_mut() + .zip(self.remaining.iter_mut()) + .skip(self.current_level) + { + let length: usize = iter.next()?; + *remaining = length; + if length == 0 { + break; + } + self.current_level += 1; + self.total += 1; + } + + // track + if let Some(x) = self.remaining.get_mut(self.current_level.saturating_sub(1)) { + *x = x.saturating_sub(1) + } + let r = Some((self.current_level - self.total) as u32); + + // update + for index in (1..self.current_level).rev() { + if self.remaining[index] == 0 { + self.current_level -= 1; + self.remaining[index - 1] -= 1; + } + } + if self.remaining[0] == 0 { + self.current_level = self.current_level.saturating_sub(1); + } + self.total = 0; + self.remaining_values -= 1; + + r + } + + fn size_hint(&self) -> (usize, Option) { + let length = self.remaining_values; + (length, Some(length)) + } +} + +#[cfg(test)] +mod tests { + use super::super::super::pages::ListNested; + use super::*; + + fn test(nested: Vec, expected: Vec) { + let mut iter = RepLevelsIter::new(&nested); + assert_eq!(iter.size_hint().0, expected.len()); + assert_eq!(iter.by_ref().collect::>(), expected); + assert_eq!(iter.size_hint().0, 0); + } + + #[test] + fn struct_required() { + let nested = vec![ + Nested::Struct(None, false, 10), + Nested::Primitive(None, true, 10), + ]; + let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + test(nested, expected) + } + + #[test] + fn struct_optional() { + let nested = vec![ + Nested::Struct(None, true, 10), + Nested::Primitive(None, true, 10), + ]; + let expected = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + test(nested, expected) + } + + #[test] + fn l1() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 12), + ]; + let expected = vec![0u32, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0]; + + test(nested, expected) + } + + #[test] + fn l2() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 3, 7, 8, 10].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 10), + ]; + let expected = vec![0, 2, 2, 1, 2, 2, 2, 0, 0, 1, 2]; + + test(nested, expected) + } + + #[test] + fn list_of_struct() { + /* + [ + [{"a": "b"}],[{"a": "c"}] + ] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 2), + Nested::Primitive(None, true, 2), + ]; + let expected = vec![0, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 3].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 3), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 3, 6, 7].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 7), + ]; + let expected = vec![0, 2, 2, 1, 2, 2, 0]; + + test(nested, expected) + } + + #[test] + fn struct_list_optional() { + /* + {"f1": ["a", "b", None, "c"]} + */ + let nested = vec![ + Nested::Struct(None, true, 1), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 4].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 4), + ]; + let expected = vec![0, 1, 1, 1]; + + test(nested, expected) + } + + #[test] + fn l2_other() { + let nested = vec![ + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 1, 1, 3, 5, 5, 8, 8, 9].try_into().unwrap(), + validity: None, + }), + Nested::List(ListNested { + is_optional: false, + offsets: vec![0, 2, 4, 5, 7, 8, 9, 10, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, false, 12), + ]; + let expected = vec![0, 2, 0, 0, 2, 1, 0, 2, 1, 0, 0, 1, 1, 0, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_1() { + /* + [ + [{"a": ["a"]}, {"a": ["b"]}], + [], + [{"a": ["b"]}, None, {"a": ["b"]}], + [{"a": []}, {"a": []}, {"a": []}], + [], + [{"a": ["d"]}, {"a": ["a"]}, {"a": ["c", "d"]}], + [], + [{"a": []}], + ] + // reps: [0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 2, 0, 0] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 2, 2, 5, 8, 8, 11, 11, 12].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 2, 3, 3, 4, 4, 4, 4, 5, 6, 8].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 8), + ]; + let expected = vec![0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 2, 0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_2() { + /* + [ + [{"a": []}], + ] + // reps: [0] + */ + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 0), + ]; + let expected = vec![0]; + + test(nested, expected) + } + + #[test] + fn list_struct_list_3() { + let nested = vec![ + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 1, 1].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 12), + Nested::List(ListNested { + is_optional: true, + offsets: vec![0, 0].try_into().unwrap(), + validity: None, + }), + Nested::Primitive(None, true, 0), + ]; + let expected = vec![0, 0]; + // [1, 0], [0] + // pick last + + test(nested, expected) + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/pages.rs b/crates/nano-arrow/src/io/parquet/write/pages.rs new file mode 100644 index 000000000000..ce51bcdcda89 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/pages.rs @@ -0,0 +1,633 @@ +use std::fmt::Debug; + +use parquet2::page::Page; +use parquet2::schema::types::{ParquetType, PrimitiveType as ParquetPrimitiveType}; +use parquet2::write::DynIter; + +use super::{array_to_pages, Encoding, WriteOptions}; +use crate::array::{Array, ListArray, MapArray, StructArray}; +use crate::bitmap::Bitmap; +use crate::datatypes::PhysicalType; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::offset::{Offset, OffsetsBuffer}; + +#[derive(Debug, Clone, PartialEq)] +pub struct ListNested { + pub is_optional: bool, + pub offsets: OffsetsBuffer, + pub validity: Option, +} + +impl ListNested { + pub fn new(offsets: OffsetsBuffer, validity: Option, is_optional: bool) -> Self { + Self { + is_optional, + offsets, + validity, + } + } +} + +/// Descriptor of nested information of a field +#[derive(Debug, Clone, PartialEq)] +pub enum Nested { + /// a primitive (leaf or parquet column) + /// bitmap, _, length + Primitive(Option, bool, usize), + /// a list + List(ListNested), + /// a list + LargeList(ListNested), + /// a struct + Struct(Option, bool, usize), +} + +impl Nested { + /// Returns the length (number of rows) of the element + pub fn len(&self) -> usize { + match self { + Nested::Primitive(_, _, length) => *length, + Nested::List(nested) => nested.offsets.len_proxy(), + Nested::LargeList(nested) => nested.offsets.len_proxy(), + Nested::Struct(_, _, len) => *len, + } + } +} + +/// Constructs the necessary `Vec>` to write the rep and def levels of `array` to parquet +pub fn to_nested(array: &dyn Array, type_: &ParquetType) -> Result>> { + let mut nested = vec![]; + + to_nested_recursive(array, type_, &mut nested, vec![])?; + Ok(nested) +} + +fn to_nested_recursive( + array: &dyn Array, + type_: &ParquetType, + nested: &mut Vec>, + mut parents: Vec, +) -> Result<()> { + let is_optional = is_nullable(type_.get_field_info()); + + use PhysicalType::*; + match array.data_type().to_physical_type() { + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + let fields = if let ParquetType::GroupType { fields, .. } = type_ { + fields + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a struct array".to_string(), + )); + }; + + parents.push(Nested::Struct( + array.validity().cloned(), + is_optional, + array.len(), + )); + + for (type_, array) in fields.iter().zip(array.values()) { + to_nested_recursive(array.as_ref(), type_, nested, parents.clone())?; + } + }, + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + }; + + parents.push(Nested::List(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a list array".to_string(), + )); + }; + + parents.push(Nested::LargeList(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.values().as_ref(), type_, nested, parents)?; + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let type_ = if let ParquetType::GroupType { fields, .. } = type_ { + if let ParquetType::GroupType { fields, .. } = &fields[0] { + &fields[0] + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a map array".to_string(), + )); + } + } else { + return Err(Error::InvalidArgumentError( + "Parquet type must be a group for a map array".to_string(), + )); + }; + + parents.push(Nested::List(ListNested::new( + array.offsets().clone(), + array.validity().cloned(), + is_optional, + ))); + to_nested_recursive(array.field().as_ref(), type_, nested, parents)?; + }, + _ => { + parents.push(Nested::Primitive( + array.validity().cloned(), + is_optional, + array.len(), + )); + nested.push(parents) + }, + } + Ok(()) +} + +/// Convert [`Array`] to `Vec<&dyn Array>` leaves in DFS order. +pub fn to_leaves(array: &dyn Array) -> Vec<&dyn Array> { + let mut leaves = vec![]; + to_leaves_recursive(array, &mut leaves); + leaves +} + +fn to_leaves_recursive<'a>(array: &'a dyn Array, leaves: &mut Vec<&'a dyn Array>) { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + array + .values() + .iter() + .for_each(|a| to_leaves_recursive(a.as_ref(), leaves)); + }, + List => { + let array = array.as_any().downcast_ref::>().unwrap(); + to_leaves_recursive(array.values().as_ref(), leaves); + }, + LargeList => { + let array = array.as_any().downcast_ref::>().unwrap(); + to_leaves_recursive(array.values().as_ref(), leaves); + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + to_leaves_recursive(array.field().as_ref(), leaves); + }, + Null | Boolean | Primitive(_) | Binary | FixedSizeBinary | LargeBinary | Utf8 + | LargeUtf8 | Dictionary(_) => leaves.push(array), + other => todo!("Writing {:?} to parquet not yet implemented", other), + } +} + +/// Convert `ParquetType` to `Vec` leaves in DFS order. +pub fn to_parquet_leaves(type_: ParquetType) -> Vec { + let mut leaves = vec![]; + to_parquet_leaves_recursive(type_, &mut leaves); + leaves +} + +fn to_parquet_leaves_recursive(type_: ParquetType, leaves: &mut Vec) { + match type_ { + ParquetType::PrimitiveType(primitive) => leaves.push(primitive), + ParquetType::GroupType { fields, .. } => { + fields + .into_iter() + .for_each(|type_| to_parquet_leaves_recursive(type_, leaves)); + }, + } +} + +/// Returns a vector of iterators of [`Page`], one per leaf column in the array +pub fn array_to_columns + Send + Sync>( + array: A, + type_: ParquetType, + options: WriteOptions, + encoding: &[Encoding], +) -> Result>>> { + let array = array.as_ref(); + let nested = to_nested(array, &type_)?; + + let types = to_parquet_leaves(type_); + + let values = to_leaves(array); + + assert_eq!(encoding.len(), types.len()); + + values + .iter() + .zip(nested) + .zip(types) + .zip(encoding.iter()) + .map(|(((values, nested), type_), encoding)| { + array_to_pages(*values, type_, &nested, options, *encoding) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use parquet2::schema::types::{GroupLogicalType, PrimitiveConvertedType, PrimitiveLogicalType}; + use parquet2::schema::Repetition; + + use super::super::{FieldInfo, ParquetPhysicalType, ParquetPrimitiveType}; + use super::*; + use crate::array::*; + use crate::bitmap::Bitmap; + use crate::datatypes::*; + + #[test] + fn test_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + vec![ + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_struct_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let fields = vec![ + Field::new("b", array.data_type().clone(), true), + Field::new("c", array.data_type().clone(), true), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![Box::new(array.clone()), Box::new(array)], + None, + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_.clone(), type_], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + // a.b.b + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + // a.b.c + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + // a.c.b + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + // a.c.c + vec![ + Nested::Struct(None, false, 4), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_list_struct() { + let boolean = BooleanArray::from_slice([false, false, true, true]).boxed(); + let int = Int32Array::from_slice([42, 28, 19, 31]).boxed(); + + let fields = vec![ + Field::new("b", DataType::Boolean, false), + Field::new("c", DataType::Int32, false), + ]; + + let array = StructArray::new( + DataType::Struct(fields), + vec![boolean.clone(), int.clone()], + Some(Bitmap::from([true, true, false, true])), + ); + + let array = ListArray::new( + DataType::List(Box::new(Field::new("l", array.data_type().clone(), true))), + vec![0i32, 2, 4].try_into().unwrap(), + Box::new(array), + None, + ); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "a".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "b".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Boolean, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "c".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "l".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ParquetType::GroupType { + field_info: FieldInfo { + name: "list".to_string(), + repetition: Repetition::Repeated, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_], + }], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 4].try_into().unwrap(), + validity: None, + }), + Nested::Struct(Some(Bitmap::from([true, true, false, true])), true, 4), + Nested::Primitive(None, false, 4), + ], + ] + ); + } + + #[test] + fn test_map() { + let kv_type = DataType::Struct(vec![ + Field::new("k", DataType::Utf8, false), + Field::new("v", DataType::Int32, false), + ]); + let kv_field = Field::new("kv", kv_type.clone(), false); + let map_type = DataType::Map(Box::new(kv_field), false); + + let key_array = Utf8Array::::from_slice(["k1", "k2", "k3", "k4", "k5", "k6"]).boxed(); + let val_array = Int32Array::from_slice([42, 28, 19, 31, 21, 17]).boxed(); + let kv_array = StructArray::try_new(kv_type, vec![key_array, val_array], None) + .unwrap() + .boxed(); + let offsets = OffsetsBuffer::try_from(vec![0, 2, 3, 4, 6]).unwrap(); + + let array = MapArray::try_new(map_type, offsets, kv_array, None).unwrap(); + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "kv".to_string(), + repetition: Repetition::Optional, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![ + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "k".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: Some(PrimitiveLogicalType::String), + converted_type: Some(PrimitiveConvertedType::Utf8), + physical_type: ParquetPhysicalType::ByteArray, + }), + ParquetType::PrimitiveType(ParquetPrimitiveType { + field_info: FieldInfo { + name: "v".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: None, + converted_type: None, + physical_type: ParquetPhysicalType::Int32, + }), + ], + }; + + let type_ = ParquetType::GroupType { + field_info: FieldInfo { + name: "m".to_string(), + repetition: Repetition::Required, + id: None, + }, + logical_type: Some(GroupLogicalType::Map), + converted_type: None, + fields: vec![ParquetType::GroupType { + field_info: FieldInfo { + name: "map".to_string(), + repetition: Repetition::Repeated, + id: None, + }, + logical_type: None, + converted_type: None, + fields: vec![type_], + }], + }; + + let a = to_nested(&array, &type_).unwrap(); + + assert_eq!( + a, + vec![ + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 3, 4, 6].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 6), + Nested::Primitive(None, false, 6), + ], + vec![ + Nested::List(ListNested:: { + is_optional: false, + offsets: vec![0, 2, 3, 4, 6].try_into().unwrap(), + validity: None, + }), + Nested::Struct(None, true, 6), + Nested::Primitive(None, false, 6), + ], + ] + ); + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/primitive/basic.rs b/crates/nano-arrow/src/io/parquet/write/primitive/basic.rs new file mode 100644 index 000000000000..14d5f9077b49 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/primitive/basic.rs @@ -0,0 +1,192 @@ +use parquet2::encoding::delta_bitpacked::encode; +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, PrimitiveStatistics}; +use parquet2::types::NativeType as ParquetNativeType; + +use super::super::{utils, WriteOptions}; +use crate::array::{Array, PrimitiveArray}; +use crate::error::Error; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::utils::ExactSizedIter; +use crate::types::NativeType; + +pub(crate) fn encode_plain( + array: &PrimitiveArray, + is_optional: bool, + mut buffer: Vec, +) -> Vec +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + if is_optional { + buffer.reserve(std::mem::size_of::

() * (array.len() - array.null_count())); + // append the non-null values + array.iter().for_each(|x| { + if let Some(x) = x { + let parquet_native: P = x.as_(); + buffer.extend_from_slice(parquet_native.to_le_bytes().as_ref()) + } + }); + } else { + buffer.reserve(std::mem::size_of::

() * array.len()); + // append all values + array.values().iter().for_each(|x| { + let parquet_native: P = x.as_(); + buffer.extend_from_slice(parquet_native.to_le_bytes().as_ref()) + }); + } + buffer +} + +pub(crate) fn encode_delta( + array: &PrimitiveArray, + is_optional: bool, + mut buffer: Vec, +) -> Vec +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, + P: num_traits::AsPrimitive, +{ + if is_optional { + // append the non-null values + let iterator = array.iter().flatten().map(|x| { + let parquet_native: P = x.as_(); + let integer: i64 = parquet_native.as_(); + integer + }); + let iterator = ExactSizedIter::new(iterator, array.len() - array.null_count()); + encode(iterator, &mut buffer) + } else { + // append all values + let iterator = array.values().iter().map(|x| { + let parquet_native: P = x.as_(); + let integer: i64 = parquet_native.as_(); + integer + }); + encode(iterator, &mut buffer) + } + buffer +} + +pub fn array_to_page_plain( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, +) -> Result +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + array_to_page(array, options, type_, Encoding::Plain, encode_plain) +} + +pub fn array_to_page_integer( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> Result +where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, + P: num_traits::AsPrimitive, +{ + match encoding { + Encoding::DeltaBinaryPacked => array_to_page(array, options, type_, encoding, encode_delta), + Encoding::Plain => array_to_page(array, options, type_, encoding, encode_plain), + other => Err(Error::nyi(format!("Encoding integer as {other:?}"))), + } +} + +pub fn array_to_page, bool, Vec) -> Vec>( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, + encode: F, +) -> Result +where + T: NativeType, + P: ParquetNativeType, + // constraint required to build statistics + T: num_traits::AsPrimitive

, +{ + let is_optional = is_nullable(&type_.field_info); + + let validity = array.validity(); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + let buffer = encode(array, is_optional, buffer); + + let statistics = if options.write_statistics { + Some(serialize_statistics(&build_statistics( + array, + type_.clone(), + ))) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub fn build_statistics( + array: &PrimitiveArray, + primitive_type: PrimitiveType, +) -> PrimitiveStatistics

+where + T: NativeType, + P: ParquetNativeType, + T: num_traits::AsPrimitive

, +{ + PrimitiveStatistics::

{ + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .map(|x| { + let x: P = x.as_(); + x + }) + .max_by(|x, y| x.ord(y)), + min_value: array + .iter() + .flatten() + .map(|x| { + let x: P = x.as_(); + x + }) + .min_by(|x, y| x.ord(y)), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/primitive/mod.rs b/crates/nano-arrow/src/io/parquet/write/primitive/mod.rs new file mode 100644 index 000000000000..96318ab0a89b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/primitive/mod.rs @@ -0,0 +1,6 @@ +mod basic; +mod nested; + +pub use basic::{array_to_page_integer, array_to_page_plain}; +pub(crate) use basic::{build_statistics, encode_plain}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/primitive/nested.rs b/crates/nano-arrow/src/io/parquet/write/primitive/nested.rs new file mode 100644 index 000000000000..fe859013c96b --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/primitive/nested.rs @@ -0,0 +1,56 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::serialize_statistics; +use parquet2::types::NativeType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, PrimitiveArray}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; +use crate::types::NativeType as ArrowNativeType; + +pub fn array_to_page( + array: &PrimitiveArray, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result +where + T: ArrowNativeType, + R: NativeType, + T: num_traits::AsPrimitive, +{ + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + let buffer = encode_plain(array, is_optional, buffer); + + let statistics = if options.write_statistics { + Some(serialize_statistics(&build_statistics( + array, + type_.clone(), + ))) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/row_group.rs b/crates/nano-arrow/src/io/parquet/write/row_group.rs new file mode 100644 index 000000000000..d281b63cebda --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/row_group.rs @@ -0,0 +1,126 @@ +use parquet2::error::Error as ParquetError; +use parquet2::schema::types::ParquetType; +use parquet2::write::Compressor; +use parquet2::FallibleStreamingIterator; + +use super::{ + array_to_columns, to_parquet_schema, DynIter, DynStreamingIterator, Encoding, RowGroupIter, + SchemaDescriptor, WriteOptions, +}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::{Error, Result}; + +/// Maps a [`Chunk`] and parquet-specific options to an [`RowGroupIter`] used to +/// write to parquet +/// # Panics +/// Iff +/// * `encodings.len() != fields.len()` or +/// * `encodings.len() != chunk.arrays().len()` +pub fn row_group_iter + 'static + Send + Sync>( + chunk: Chunk, + encodings: Vec>, + fields: Vec, + options: WriteOptions, +) -> RowGroupIter<'static, Error> { + assert_eq!(encodings.len(), fields.len()); + assert_eq!(encodings.len(), chunk.arrays().len()); + DynIter::new( + chunk + .into_arrays() + .into_iter() + .zip(fields) + .zip(encodings) + .flat_map(move |((array, type_), encoding)| { + let encoded_columns = array_to_columns(array, type_, options, &encoding).unwrap(); + encoded_columns + .into_iter() + .map(|encoded_pages| { + let pages = encoded_pages; + + let pages = DynIter::new( + pages + .into_iter() + .map(|x| x.map_err(|e| ParquetError::OutOfSpec(e.to_string()))), + ); + + let compressed_pages = Compressor::new(pages, options.compression, vec![]) + .map_err(Error::from); + Ok(DynStreamingIterator::new(compressed_pages)) + }) + .collect::>() + }), + ) +} + +/// An iterator adapter that converts an iterator over [`Chunk`] into an iterator +/// of row groups. +/// Use it to create an iterator consumable by the parquet's API. +pub struct RowGroupIterator + 'static, I: Iterator>>> { + iter: I, + options: WriteOptions, + parquet_schema: SchemaDescriptor, + encodings: Vec>, +} + +impl + 'static, I: Iterator>>> RowGroupIterator { + /// Creates a new [`RowGroupIterator`] from an iterator over [`Chunk`]. + /// + /// # Errors + /// Iff + /// * the Arrow schema can't be converted to a valid Parquet schema. + /// * the length of the encodings is different from the number of fields in schema + pub fn try_new( + iter: I, + schema: &Schema, + options: WriteOptions, + encodings: Vec>, + ) -> Result { + if encodings.len() != schema.fields.len() { + return Err(Error::InvalidArgumentError( + "The number of encodings must equal the number of fields".to_string(), + )); + } + let parquet_schema = to_parquet_schema(schema)?; + + Ok(Self { + iter, + options, + parquet_schema, + encodings, + }) + } + + /// Returns the [`SchemaDescriptor`] of the [`RowGroupIterator`]. + pub fn parquet_schema(&self) -> &SchemaDescriptor { + &self.parquet_schema + } +} + +impl + 'static + Send + Sync, I: Iterator>>> Iterator + for RowGroupIterator +{ + type Item = Result>; + + fn next(&mut self) -> Option { + let options = self.options; + + self.iter.next().map(|maybe_chunk| { + let chunk = maybe_chunk?; + if self.encodings.len() != chunk.arrays().len() { + return Err(Error::InvalidArgumentError( + "The number of arrays in the chunk must equal the number of fields in the schema" + .to_string(), + )); + }; + let encodings = self.encodings.clone(); + Ok(row_group_iter( + chunk, + encodings, + self.parquet_schema.fields().to_vec(), + options, + )) + }) + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/schema.rs b/crates/nano-arrow/src/io/parquet/write/schema.rs new file mode 100644 index 000000000000..6f3ade5d46b3 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/schema.rs @@ -0,0 +1,379 @@ +use base64::engine::general_purpose; +use base64::Engine as _; +use parquet2::metadata::KeyValue; +use parquet2::schema::types::{ + GroupConvertedType, GroupLogicalType, IntegerType, ParquetType, PhysicalType, + PrimitiveConvertedType, PrimitiveLogicalType, TimeUnit as ParquetTimeUnit, +}; +use parquet2::schema::Repetition; + +use super::super::ARROW_SCHEMA_META_KEY; +use crate::datatypes::{DataType, Field, Schema, TimeUnit}; +use crate::error::{Error, Result}; +use crate::io::ipc::write::{default_ipc_fields, schema_to_bytes}; +use crate::io::parquet::write::decimal_length_from_precision; + +pub fn schema_to_metadata_key(schema: &Schema) -> KeyValue { + let serialized_schema = schema_to_bytes(schema, &default_ipc_fields(&schema.fields)); + + // manually prepending the length to the schema as arrow uses the legacy IPC format + // TODO: change after addressing ARROW-9777 + let schema_len = serialized_schema.len(); + let mut len_prefix_schema = Vec::with_capacity(schema_len + 8); + len_prefix_schema.extend_from_slice(&[255u8, 255, 255, 255]); + len_prefix_schema.extend_from_slice(&(schema_len as u32).to_le_bytes()); + len_prefix_schema.extend_from_slice(&serialized_schema); + + let encoded = general_purpose::STANDARD.encode(&len_prefix_schema); + + KeyValue { + key: ARROW_SCHEMA_META_KEY.to_string(), + value: Some(encoded), + } +} + +/// Creates a [`ParquetType`] from a [`Field`]. +pub fn to_parquet_type(field: &Field) -> Result { + let name = field.name.clone(); + let repetition = if field.is_nullable { + Repetition::Optional + } else { + Repetition::Required + }; + // create type from field + match field.data_type().to_logical_type() { + DataType::Null => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + Some(PrimitiveLogicalType::Unknown), + None, + )?), + DataType::Boolean => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Boolean, + repetition, + None, + None, + None, + )?), + DataType::Int32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + None, + None, + )?), + // DataType::Duration(_) has no parquet representation => do not apply any logical type + DataType::Int64 | DataType::Duration(_) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + DataType::Date64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + DataType::Float32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Float, + repetition, + None, + None, + None, + )?), + DataType::Float64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Double, + repetition, + None, + None, + None, + )?), + DataType::Binary | DataType::LargeBinary => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::ByteArray, + repetition, + None, + None, + None, + )?), + DataType::Utf8 | DataType::LargeUtf8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::ByteArray, + repetition, + Some(PrimitiveConvertedType::Utf8), + Some(PrimitiveLogicalType::String), + None, + )?), + DataType::Date32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Date), + Some(PrimitiveLogicalType::Date), + None, + )?), + DataType::Int8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Int8), + Some(PrimitiveLogicalType::Integer(IntegerType::Int8)), + None, + )?), + DataType::Int16 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Int16), + Some(PrimitiveLogicalType::Integer(IntegerType::Int16)), + None, + )?), + DataType::UInt8 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint8), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt8)), + None, + )?), + DataType::UInt16 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint16), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt16)), + None, + )?), + DataType::UInt32 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Uint32), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt32)), + None, + )?), + DataType::UInt64 => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + Some(PrimitiveConvertedType::Uint64), + Some(PrimitiveLogicalType::Integer(IntegerType::UInt64)), + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + DataType::Timestamp(TimeUnit::Second, _) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + None, + None, + )?), + DataType::Timestamp(time_unit, zone) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + None, + Some(PrimitiveLogicalType::Timestamp { + is_adjusted_to_utc: matches!(zone, Some(z) if !z.as_str().is_empty()), + unit: match time_unit { + TimeUnit::Second => unreachable!(), + TimeUnit::Millisecond => ParquetTimeUnit::Milliseconds, + TimeUnit::Microsecond => ParquetTimeUnit::Microseconds, + TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds, + }, + }), + None, + )?), + // no natural representation in parquet; leave it as is. + // arrow consumers MAY use the arrow schema in the metadata to parse them. + DataType::Time32(TimeUnit::Second) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + None, + None, + None, + )?), + DataType::Time32(TimeUnit::Millisecond) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::TimeMillis), + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: false, + unit: ParquetTimeUnit::Milliseconds, + }), + None, + )?), + DataType::Time64(time_unit) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + match time_unit { + TimeUnit::Microsecond => Some(PrimitiveConvertedType::TimeMicros), + TimeUnit::Nanosecond => None, + _ => unreachable!(), + }, + Some(PrimitiveLogicalType::Time { + is_adjusted_to_utc: false, + unit: match time_unit { + TimeUnit::Microsecond => ParquetTimeUnit::Microseconds, + TimeUnit::Nanosecond => ParquetTimeUnit::Nanoseconds, + _ => unreachable!(), + }, + }), + None, + )?), + DataType::Struct(fields) => { + if fields.is_empty() { + return Err(Error::InvalidArgumentError( + "Parquet does not support writing empty structs".to_string(), + )); + } + // recursively convert children to types/nodes + let fields = fields + .iter() + .map(to_parquet_type) + .collect::>>()?; + Ok(ParquetType::from_group( + name, repetition, None, None, fields, None, + )) + }, + DataType::Dictionary(_, value, _) => { + let dict_field = Field::new(name.as_str(), value.as_ref().clone(), field.is_nullable); + to_parquet_type(&dict_field) + }, + DataType::FixedSizeBinary(size) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(*size), + repetition, + None, + None, + None, + )?), + DataType::Decimal(precision, scale) => { + let precision = *precision; + let scale = *scale; + let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale)); + + let physical_type = if precision <= 9 { + PhysicalType::Int32 + } else if precision <= 18 { + PhysicalType::Int64 + } else { + let len = decimal_length_from_precision(precision); + PhysicalType::FixedLenByteArray(len) + }; + Ok(ParquetType::try_from_primitive( + name, + physical_type, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + }, + DataType::Decimal256(precision, scale) => { + let precision = *precision; + let scale = *scale; + let logical_type = Some(PrimitiveLogicalType::Decimal(precision, scale)); + + if precision <= 9 { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int32, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else if precision <= 18 { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::Int64, + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else if precision <= 38 { + let len = decimal_length_from_precision(precision); + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(len), + repetition, + Some(PrimitiveConvertedType::Decimal(precision, scale)), + logical_type, + None, + )?) + } else { + Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(32), + repetition, + None, + None, + None, + )?) + } + }, + DataType::Interval(_) => Ok(ParquetType::try_from_primitive( + name, + PhysicalType::FixedLenByteArray(12), + repetition, + Some(PrimitiveConvertedType::Interval), + None, + None, + )?), + DataType::List(f) | DataType::FixedSizeList(f, _) | DataType::LargeList(f) => { + Ok(ParquetType::from_group( + name, + repetition, + Some(GroupConvertedType::List), + Some(GroupLogicalType::List), + vec![ParquetType::from_group( + "list".to_string(), + Repetition::Repeated, + None, + None, + vec![to_parquet_type(f)?], + None, + )], + None, + )) + }, + DataType::Map(f, _) => Ok(ParquetType::from_group( + name, + repetition, + Some(GroupConvertedType::Map), + Some(GroupLogicalType::Map), + vec![ParquetType::from_group( + "map".to_string(), + Repetition::Repeated, + None, + None, + vec![to_parquet_type(f)?], + None, + )], + None, + )), + other => Err(Error::NotYetImplemented(format!( + "Writing the data type {other:?} is not yet implemented" + ))), + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/sink.rs b/crates/nano-arrow/src/io/parquet/write/sink.rs new file mode 100644 index 000000000000..d357d7b89c2d --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/sink.rs @@ -0,0 +1,236 @@ +use std::pin::Pin; +use std::task::Poll; + +use ahash::AHashMap; +use futures::future::BoxFuture; +use futures::{AsyncWrite, AsyncWriteExt, FutureExt, Sink, TryFutureExt}; +use parquet2::metadata::KeyValue; +use parquet2::write::{FileStreamer, WriteOptions as ParquetWriteOptions}; + +use super::file::add_arrow_schema; +use super::{Encoding, SchemaDescriptor, WriteOptions}; +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::Schema; +use crate::error::Error; + +/// Sink that writes array [`chunks`](Chunk) as a Parquet file. +/// +/// Any values in the sink's `metadata` field will be written to the file's footer +/// when the sink is closed. +/// +/// # Examples +/// +/// ``` +/// use futures::SinkExt; +/// use arrow2::array::{Array, Int32Array}; +/// use arrow2::datatypes::{DataType, Field, Schema}; +/// use arrow2::chunk::Chunk; +/// use arrow2::io::parquet::write::{Encoding, WriteOptions, CompressionOptions, Version}; +/// # use arrow2::io::parquet::write::FileSink; +/// # futures::executor::block_on(async move { +/// +/// let schema = Schema::from(vec![ +/// Field::new("values", DataType::Int32, true), +/// ]); +/// let encoding = vec![vec![Encoding::Plain]]; +/// let options = WriteOptions { +/// write_statistics: true, +/// compression: CompressionOptions::Uncompressed, +/// version: Version::V2, +/// data_pagesize_limit: None, +/// }; +/// +/// let mut buffer = vec![]; +/// let mut sink = FileSink::try_new( +/// &mut buffer, +/// schema, +/// encoding, +/// options, +/// )?; +/// +/// for i in 0..3 { +/// let values = Int32Array::from(&[Some(i), None]); +/// let chunk = Chunk::new(vec![values.boxed()]); +/// sink.feed(chunk).await?; +/// } +/// sink.metadata.insert(String::from("key"), Some(String::from("value"))); +/// sink.close().await?; +/// # arrow2::error::Result::Ok(()) +/// # }).unwrap(); +/// ``` +pub struct FileSink<'a, W: AsyncWrite + Send + Unpin> { + writer: Option>, + task: Option>, Error>>>, + options: WriteOptions, + encodings: Vec>, + schema: Schema, + parquet_schema: SchemaDescriptor, + /// Key-value metadata that will be written to the file on close. + pub metadata: AHashMap>, +} + +impl<'a, W> FileSink<'a, W> +where + W: AsyncWrite + Send + Unpin + 'a, +{ + /// Create a new sink that writes arrays to the provided `writer`. + /// + /// # Error + /// Iff + /// * the Arrow schema can't be converted to a valid Parquet schema. + /// * the length of the encodings is different from the number of fields in schema + pub fn try_new( + writer: W, + schema: Schema, + encodings: Vec>, + options: WriteOptions, + ) -> Result { + if encodings.len() != schema.fields.len() { + return Err(Error::InvalidArgumentError( + "The number of encodings must equal the number of fields".to_string(), + )); + } + + let parquet_schema = crate::io::parquet::write::to_parquet_schema(&schema)?; + let created_by = Some("Arrow2 - Native Rust implementation of Arrow".to_string()); + let writer = FileStreamer::new( + writer, + parquet_schema.clone(), + ParquetWriteOptions { + version: options.version, + write_statistics: options.write_statistics, + }, + created_by, + ); + Ok(Self { + writer: Some(writer), + task: None, + options, + schema, + encodings, + parquet_schema, + metadata: AHashMap::default(), + }) + } + + /// The Arrow [`Schema`] for the file. + pub fn schema(&self) -> &Schema { + &self.schema + } + + /// The Parquet [`SchemaDescriptor`] for the file. + pub fn parquet_schema(&self) -> &SchemaDescriptor { + &self.parquet_schema + } + + /// The write options for the file. + pub fn options(&self) -> &WriteOptions { + &self.options + } + + fn poll_complete( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if let Some(task) = &mut self.task { + match futures::ready!(task.poll_unpin(cx)) { + Ok(writer) => { + self.task = None; + self.writer = writer; + Poll::Ready(Ok(())) + }, + Err(error) => { + self.task = None; + Poll::Ready(Err(error)) + }, + } + } else { + Poll::Ready(Ok(())) + } + } +} + +impl<'a, W> Sink>> for FileSink<'a, W> +where + W: AsyncWrite + Send + Unpin + 'a, +{ + type Error = Error; + + fn start_send(self: Pin<&mut Self>, item: Chunk>) -> Result<(), Self::Error> { + if self.schema.fields.len() != item.arrays().len() { + return Err(Error::InvalidArgumentError( + "The number of arrays in the chunk must equal the number of fields in the schema" + .to_string(), + )); + } + let this = self.get_mut(); + if let Some(mut writer) = this.writer.take() { + let rows = crate::io::parquet::write::row_group_iter( + item, + this.encodings.clone(), + this.parquet_schema.fields().to_vec(), + this.options, + ); + this.task = Some(Box::pin(async move { + writer.write(rows).await?; + Ok(Some(writer)) + })); + Ok(()) + } else { + Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "writer closed".to_string(), + ))) + } + } + + fn poll_ready( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.get_mut().poll_complete(cx) + } + + fn poll_close( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let this = self.get_mut(); + match futures::ready!(this.poll_complete(cx)) { + Ok(()) => { + let writer = this.writer.take(); + if let Some(mut writer) = writer { + let meta = std::mem::take(&mut this.metadata); + let metadata = if meta.is_empty() { + None + } else { + Some( + meta.into_iter() + .map(|(k, v)| KeyValue::new(k, v)) + .collect::>(), + ) + }; + let kv_meta = add_arrow_schema(&this.schema, metadata); + + this.task = Some(Box::pin(async move { + writer.end(kv_meta).map_err(Error::from).await?; + writer.into_inner().close().map_err(Error::from).await?; + Ok(None) + })); + this.poll_complete(cx) + } else { + Poll::Ready(Ok(())) + } + }, + Err(error) => Poll::Ready(Err(error)), + } + } +} diff --git a/crates/nano-arrow/src/io/parquet/write/utf8/basic.rs b/crates/nano-arrow/src/io/parquet/write/utf8/basic.rs new file mode 100644 index 000000000000..39f9c157c988 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utf8/basic.rs @@ -0,0 +1,117 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::{serialize_statistics, BinaryStatistics, ParquetStatistics, Statistics}; + +use super::super::binary::{encode_delta, ord_binary}; +use super::super::{utils, WriteOptions}; +use crate::array::{Array, Utf8Array}; +use crate::error::{Error, Result}; +use crate::io::parquet::read::schema::is_nullable; +use crate::offset::Offset; + +pub(crate) fn encode_plain( + array: &Utf8Array, + is_optional: bool, + buffer: &mut Vec, +) { + if is_optional { + array.iter().for_each(|x| { + if let Some(x) = x { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x.as_bytes()); + } + }) + } else { + array.values_iter().for_each(|x| { + // BYTE_ARRAY: first 4 bytes denote length in littleendian. + let len = (x.len() as u32).to_le_bytes(); + buffer.extend_from_slice(&len); + buffer.extend_from_slice(x.as_bytes()); + }) + } +} + +pub fn array_to_page( + array: &Utf8Array, + options: WriteOptions, + type_: PrimitiveType, + encoding: Encoding, +) -> Result { + let validity = array.validity(); + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + utils::write_def_levels( + &mut buffer, + is_optional, + validity, + array.len(), + options.version, + )?; + + let definition_levels_byte_length = buffer.len(); + + match encoding { + Encoding::Plain => encode_plain(array, is_optional, &mut buffer), + Encoding::DeltaLengthByteArray => encode_delta( + array.values(), + array.offsets().buffer(), + array.validity(), + is_optional, + &mut buffer, + ), + _ => { + return Err(Error::InvalidArgumentError(format!( + "Datatype {:?} cannot be encoded by {:?} encoding", + array.data_type(), + encoding + ))) + }, + } + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + array.len(), + array.len(), + array.null_count(), + 0, + definition_levels_byte_length, + statistics, + type_, + options, + encoding, + ) +} + +pub(crate) fn build_statistics( + array: &Utf8Array, + primitive_type: PrimitiveType, +) -> ParquetStatistics { + let statistics = &BinaryStatistics { + primitive_type, + null_count: Some(array.null_count() as i64), + distinct_count: None, + max_value: array + .iter() + .flatten() + .map(|x| x.as_bytes()) + .max_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + min_value: array + .iter() + .flatten() + .map(|x| x.as_bytes()) + .min_by(|x, y| ord_binary(x, y)) + .map(|x| x.to_vec()), + } as &dyn Statistics; + serialize_statistics(statistics) +} diff --git a/crates/nano-arrow/src/io/parquet/write/utf8/mod.rs b/crates/nano-arrow/src/io/parquet/write/utf8/mod.rs new file mode 100644 index 000000000000..e4ef46599e2c --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utf8/mod.rs @@ -0,0 +1,6 @@ +mod basic; +mod nested; + +pub use basic::array_to_page; +pub(crate) use basic::{build_statistics, encode_plain}; +pub use nested::array_to_page as nested_array_to_page; diff --git a/crates/nano-arrow/src/io/parquet/write/utf8/nested.rs b/crates/nano-arrow/src/io/parquet/write/utf8/nested.rs new file mode 100644 index 000000000000..43767246d194 --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utf8/nested.rs @@ -0,0 +1,48 @@ +use parquet2::encoding::Encoding; +use parquet2::page::DataPage; +use parquet2::schema::types::PrimitiveType; + +use super::super::{nested, utils, WriteOptions}; +use super::basic::{build_statistics, encode_plain}; +use crate::array::{Array, Utf8Array}; +use crate::error::Result; +use crate::io::parquet::read::schema::is_nullable; +use crate::io::parquet::write::Nested; +use crate::offset::Offset; + +pub fn array_to_page( + array: &Utf8Array, + options: WriteOptions, + type_: PrimitiveType, + nested: &[Nested], +) -> Result +where + O: Offset, +{ + let is_optional = is_nullable(&type_.field_info); + + let mut buffer = vec![]; + let (repetition_levels_byte_length, definition_levels_byte_length) = + nested::write_rep_and_def(options.version, nested, &mut buffer)?; + + encode_plain(array, is_optional, &mut buffer); + + let statistics = if options.write_statistics { + Some(build_statistics(array, type_.clone())) + } else { + None + }; + + utils::build_plain_page( + buffer, + nested::num_values(nested), + nested[0].len(), + array.null_count(), + repetition_levels_byte_length, + definition_levels_byte_length, + statistics, + type_, + options, + Encoding::Plain, + ) +} diff --git a/crates/nano-arrow/src/io/parquet/write/utils.rs b/crates/nano-arrow/src/io/parquet/write/utils.rs new file mode 100644 index 000000000000..caaba98a07fe --- /dev/null +++ b/crates/nano-arrow/src/io/parquet/write/utils.rs @@ -0,0 +1,146 @@ +use parquet2::compression::CompressionOptions; +use parquet2::encoding::hybrid_rle::encode_bool; +use parquet2::encoding::Encoding; +use parquet2::metadata::Descriptor; +use parquet2::page::{DataPage, DataPageHeader, DataPageHeaderV1, DataPageHeaderV2}; +use parquet2::schema::types::PrimitiveType; +use parquet2::statistics::ParquetStatistics; + +use super::{Version, WriteOptions}; +use crate::bitmap::Bitmap; +use crate::error::Result; + +fn encode_iter_v1>(buffer: &mut Vec, iter: I) -> Result<()> { + buffer.extend_from_slice(&[0; 4]); + let start = buffer.len(); + encode_bool(buffer, iter)?; + let end = buffer.len(); + let length = end - start; + + // write the first 4 bytes as length + let length = (length as i32).to_le_bytes(); + (0..4).for_each(|i| buffer[start - 4 + i] = length[i]); + Ok(()) +} + +fn encode_iter_v2>(writer: &mut Vec, iter: I) -> Result<()> { + Ok(encode_bool(writer, iter)?) +} + +fn encode_iter>( + writer: &mut Vec, + iter: I, + version: Version, +) -> Result<()> { + match version { + Version::V1 => encode_iter_v1(writer, iter), + Version::V2 => encode_iter_v2(writer, iter), + } +} + +/// writes the def levels to a `Vec` and returns it. +pub fn write_def_levels( + writer: &mut Vec, + is_optional: bool, + validity: Option<&Bitmap>, + len: usize, + version: Version, +) -> Result<()> { + // encode def levels + match (is_optional, validity) { + (true, Some(validity)) => encode_iter(writer, validity.iter(), version), + (true, None) => encode_iter(writer, std::iter::repeat(true).take(len), version), + _ => Ok(()), // is required => no def levels + } +} + +#[allow(clippy::too_many_arguments)] +pub fn build_plain_page( + buffer: Vec, + num_values: usize, + num_rows: usize, + null_count: usize, + repetition_levels_byte_length: usize, + definition_levels_byte_length: usize, + statistics: Option, + type_: PrimitiveType, + options: WriteOptions, + encoding: Encoding, +) -> Result { + let header = match options.version { + Version::V1 => DataPageHeader::V1(DataPageHeaderV1 { + num_values: num_values as i32, + encoding: encoding.into(), + definition_level_encoding: Encoding::Rle.into(), + repetition_level_encoding: Encoding::Rle.into(), + statistics, + }), + Version::V2 => DataPageHeader::V2(DataPageHeaderV2 { + num_values: num_values as i32, + encoding: encoding.into(), + num_nulls: null_count as i32, + num_rows: num_rows as i32, + definition_levels_byte_length: definition_levels_byte_length as i32, + repetition_levels_byte_length: repetition_levels_byte_length as i32, + is_compressed: Some(options.compression != CompressionOptions::Uncompressed), + statistics, + }), + }; + Ok(DataPage::new( + header, + buffer, + Descriptor { + primitive_type: type_, + max_def_level: 0, + max_rep_level: 0, + }, + Some(num_rows), + )) +} + +/// Auxiliary iterator adapter to declare the size hint of an iterator. +pub(super) struct ExactSizedIter> { + iter: I, + remaining: usize, +} + +impl + Clone> Clone for ExactSizedIter { + fn clone(&self) -> Self { + Self { + iter: self.iter.clone(), + remaining: self.remaining, + } + } +} + +impl> ExactSizedIter { + pub fn new(iter: I, length: usize) -> Self { + Self { + iter, + remaining: length, + } + } +} + +impl> Iterator for ExactSizedIter { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + self.iter.next().map(|x| { + self.remaining -= 1; + x + }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +/// Returns the number of bits needed to bitpack `max` +#[inline] +pub fn get_bit_width(max: u64) -> u32 { + 64 - max.leading_zeros() +} diff --git a/crates/nano-arrow/src/lib.rs b/crates/nano-arrow/src/lib.rs new file mode 100644 index 000000000000..c26b3e1a0b28 --- /dev/null +++ b/crates/nano-arrow/src/lib.rs @@ -0,0 +1,42 @@ +// So that we have more control over what is `unsafe` inside an `unsafe` block +#![allow(unused_unsafe)] +// +#![allow(clippy::len_without_is_empty)] +// this landed on 1.60. Let's not force everyone to bump just yet +#![allow(clippy::unnecessary_lazy_evaluations)] +// Trait objects must be returned as a &Box so that they can be cloned +#![allow(clippy::borrowed_box)] +// Allow type complexity warning to avoid API break. +#![allow(clippy::type_complexity)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(feature = "simd", feature(portable_simd))] +#![cfg_attr(feature = "nightly_build", feature(build_hasher_simple_hash_one))] + +#[macro_use] +pub mod array; +pub mod bitmap; +pub mod buffer; +pub mod chunk; +pub mod error; +#[cfg(feature = "io_ipc")] +#[cfg_attr(docsrs, doc(cfg(feature = "io_ipc")))] +pub mod mmap; + +pub mod offset; +pub mod scalar; +pub mod trusted_len; +pub mod types; + +pub mod compute; +pub mod io; +pub mod temporal_conversions; + +pub mod datatypes; + +pub mod ffi; +pub mod util; + +// re-exported because we return `Either` in our public API +// re-exported to construct dictionaries +pub use ahash::AHashMap; +pub use either::Either; diff --git a/crates/nano-arrow/src/mmap/array.rs b/crates/nano-arrow/src/mmap/array.rs new file mode 100644 index 000000000000..8efd6afcd671 --- /dev/null +++ b/crates/nano-arrow/src/mmap/array.rs @@ -0,0 +1,568 @@ +use std::collections::VecDeque; +use std::sync::Arc; + +use crate::array::{Array, DictionaryKey, FixedSizeListArray, ListArray, StructArray}; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::ffi::mmap::create_array; +use crate::ffi::{export_array_to_c, try_from, ArrowArray, InternalArrowArray}; +use crate::io::ipc::read::{Dictionaries, IpcBuffer, Node, OutOfSpecKind}; +use crate::io::ipc::IpcField; +use crate::offset::Offset; +use crate::types::NativeType; + +fn get_buffer_bounds(buffers: &mut VecDeque) -> Result<(usize, usize), Error> { + let buffer = buffers + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + let offset: usize = buffer + .offset() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let length: usize = buffer + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok((offset, length)) +} + +fn get_buffer<'a, T: NativeType>( + data: &'a [u8], + block_offset: usize, + buffers: &mut VecDeque, + num_rows: usize, +) -> Result<&'a [u8], Error> { + let (offset, length) = get_buffer_bounds(buffers)?; + + // verify that they are in-bounds + let values = data + .get(block_offset + offset..block_offset + offset + length) + .ok_or_else(|| Error::OutOfSpec("buffer out of bounds".to_string()))?; + + // validate alignment + let v: &[T] = bytemuck::try_cast_slice(values) + .map_err(|_| Error::OutOfSpec("buffer not aligned for mmap".to_string()))?; + + if v.len() < num_rows { + return Err(Error::OutOfSpec( + "buffer's length is too small in mmap".to_string(), + )); + } + + Ok(values) +} + +fn get_validity<'a>( + data: &'a [u8], + block_offset: usize, + buffers: &mut VecDeque, + null_count: usize, +) -> Result, Error> { + let validity = get_buffer_bounds(buffers)?; + let (offset, length) = validity; + + Ok(if null_count > 0 { + // verify that they are in-bounds and get its pointer + Some( + data.get(block_offset + offset..block_offset + offset + length) + .ok_or_else(|| Error::OutOfSpec("buffer out of bounds".to_string()))?, + ) + } else { + None + }) +} + +fn mmap_binary>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let offsets = get_buffer::(data_ref, block_offset, buffers, num_rows + 1)?.as_ptr(); + let values = get_buffer::(data_ref, block_offset, buffers, 0)?.as_ptr(); + + // NOTE: offsets and values invariants are _not_ validated + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(offsets), Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_fixed_size_binary>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, + data_type: &DataType, +) -> Result { + let bytes_per_row = if let DataType::FixedSizeBinary(bytes_per_row) = data_type { + bytes_per_row + } else { + return Err(Error::from(OutOfSpecKind::InvalidDataType)); + }; + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + let values = + get_buffer::(data_ref, block_offset, buffers, num_rows * bytes_per_row)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_null>( + data: Arc, + node: &Node, + _block_offset: usize, + _buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_boolean>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer_bounds(buffers)?; + let (offset, length) = values; + + // verify that they are in-bounds and get its pointer + let values = data_ref[block_offset + offset..block_offset + offset + length].as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +fn mmap_primitive>( + data: Arc, + node: &Node, + block_offset: usize, + buffers: &mut VecDeque, +) -> Result { + let data_ref = data.as_ref().as_ref(); + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer::

(data_ref, block_offset, buffers, num_rows)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_list>( + data: Arc, + node: &Node, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let child = ListArray::::try_get_child(data_type)?.data_type(); + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let offsets = get_buffer::(data_ref, block_offset, buffers, num_rows + 1)?.as_ptr(); + + let values = get_array( + data.clone(), + block_offset, + child, + &ipc_field.fields[0], + dictionaries, + field_nodes, + buffers, + )?; + + // NOTE: offsets and values invariants are _not_ validated + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(offsets)].into_iter(), + [values].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_fixed_size_list>( + data: Arc, + node: &Node, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let child = FixedSizeListArray::try_child_and_size(data_type)? + .0 + .data_type(); + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_array( + data.clone(), + block_offset, + child, + &ipc_field.fields[0], + dictionaries, + field_nodes, + buffers, + )?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity].into_iter(), + [values].into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_struct>( + data: Arc, + node: &Node, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let children = StructArray::try_get_fields(data_type)?; + + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = children + .iter() + .map(|f| &f.data_type) + .zip(ipc_field.fields.iter()) + .map(|(child, ipc)| { + get_array( + data.clone(), + block_offset, + child, + ipc, + dictionaries, + field_nodes, + buffers, + ) + }) + .collect::, Error>>()?; + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity].into_iter(), + values.into_iter(), + None, + None, + ) + }) +} + +#[allow(clippy::too_many_arguments)] +fn mmap_dict>( + data: Arc, + node: &Node, + block_offset: usize, + _: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + _: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + let num_rows: usize = node + .length() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let null_count: usize = node + .null_count() + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let data_ref = data.as_ref().as_ref(); + + let dictionary = dictionaries + .get(&ipc_field.dictionary_id.unwrap()) + .ok_or_else(|| Error::oos("Missing dictionary"))? + .clone(); + + let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr()); + + let values = get_buffer::(data_ref, block_offset, buffers, num_rows)?.as_ptr(); + + Ok(unsafe { + create_array( + data, + num_rows, + null_count, + [validity, Some(values)].into_iter(), + [].into_iter(), + Some(export_array_to_c(dictionary)), + None, + ) + }) +} + +fn get_array>( + data: Arc, + block_offset: usize, + data_type: &DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result { + use crate::datatypes::PhysicalType::*; + let node = field_nodes + .pop_front() + .ok_or_else(|| Error::from(OutOfSpecKind::ExpectedBuffer))?; + + match data_type.to_physical_type() { + Null => mmap_null(data, &node, block_offset, buffers), + Boolean => mmap_boolean(data, &node, block_offset, buffers), + Primitive(p) => with_match_primitive_type!(p, |$T| { + mmap_primitive::<$T, _>(data, &node, block_offset, buffers) + }), + Utf8 | Binary => mmap_binary::(data, &node, block_offset, buffers), + FixedSizeBinary => mmap_fixed_size_binary(data, &node, block_offset, buffers, data_type), + LargeBinary | LargeUtf8 => mmap_binary::(data, &node, block_offset, buffers), + List => mmap_list::( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + LargeList => mmap_list::( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + FixedSizeList => mmap_fixed_size_list( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + Struct => mmap_struct( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + mmap_dict::<$T, _>( + data, + &node, + block_offset, + data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + ) + }), + _ => todo!(), + } +} + +/// Maps a memory region to an [`Array`]. +pub(crate) unsafe fn mmap>( + data: Arc, + block_offset: usize, + data_type: DataType, + ipc_field: &IpcField, + dictionaries: &Dictionaries, + field_nodes: &mut VecDeque, + buffers: &mut VecDeque, +) -> Result, Error> { + let array = get_array( + data, + block_offset, + &data_type, + ipc_field, + dictionaries, + field_nodes, + buffers, + )?; + // The unsafety comes from the fact that `array` is not necessarily valid - + // the IPC file may be corrupted (e.g. invalid offsets or non-utf8 data) + unsafe { try_from(InternalArrowArray::new(array, data_type)) } +} diff --git a/crates/nano-arrow/src/mmap/mod.rs b/crates/nano-arrow/src/mmap/mod.rs new file mode 100644 index 000000000000..58265892ea57 --- /dev/null +++ b/crates/nano-arrow/src/mmap/mod.rs @@ -0,0 +1,227 @@ +//! Memory maps regions defined on the IPC format into [`Array`]. +use std::collections::VecDeque; +use std::sync::Arc; + +mod array; + +use arrow_format::ipc::planus::ReadAsRoot; +use arrow_format::ipc::{Block, MessageRef, RecordBatchRef}; + +use crate::array::Array; +use crate::chunk::Chunk; +use crate::datatypes::{DataType, Field}; +use crate::error::Error; +use crate::io::ipc::read::file::{get_dictionary_batch, get_record_batch}; +use crate::io::ipc::read::{ + first_dict_field, Dictionaries, FileMetadata, IpcBuffer, Node, OutOfSpecKind, +}; +use crate::io::ipc::{IpcField, CONTINUATION_MARKER}; + +fn read_message( + mut bytes: &[u8], + block: arrow_format::ipc::Block, +) -> Result<(MessageRef, usize), Error> { + let offset: usize = block + .offset + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let block_length: usize = block + .meta_data_length + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + bytes = &bytes[offset..]; + let mut message_length = bytes[..4].try_into().unwrap(); + bytes = &bytes[4..]; + + if message_length == CONTINUATION_MARKER { + // continuation marker encountered, read message next + message_length = bytes[..4].try_into().unwrap(); + bytes = &bytes[4..]; + }; + + let message_length: usize = i32::from_le_bytes(message_length) + .try_into() + .map_err(|_| Error::from(OutOfSpecKind::NegativeFooterLength))?; + + let message = arrow_format::ipc::MessageRef::read_as_root(&bytes[..message_length]) + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferMessage(err)))?; + + Ok((message, offset + block_length)) +} + +fn get_buffers_nodes( + batch: RecordBatchRef, +) -> Result<(VecDeque, VecDeque), Error> { + let compression = batch.compression()?; + if compression.is_some() { + return Err(Error::nyi( + "mmap can only be done on uncompressed IPC files", + )); + } + + let buffers = batch + .buffers() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferBuffers(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageBuffers))?; + let buffers = buffers.iter().collect::>(); + + let field_nodes = batch + .nodes() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferNodes(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingMessageNodes))?; + let field_nodes = field_nodes.iter().collect::>(); + + Ok((buffers, field_nodes)) +} + +unsafe fn _mmap_record>( + fields: &[Field], + ipc_fields: &[IpcField], + data: Arc, + batch: RecordBatchRef, + offset: usize, + dictionaries: &Dictionaries, +) -> Result>, Error> { + let (mut buffers, mut field_nodes) = get_buffers_nodes(batch)?; + + fields + .iter() + .map(|f| &f.data_type) + .cloned() + .zip(ipc_fields) + .map(|(data_type, ipc_field)| { + array::mmap( + data.clone(), + offset, + data_type, + ipc_field, + dictionaries, + &mut field_nodes, + &mut buffers, + ) + }) + .collect::>() + .and_then(Chunk::try_new) +} + +unsafe fn _mmap_unchecked>( + fields: &[Field], + ipc_fields: &[IpcField], + data: Arc, + block: Block, + dictionaries: &Dictionaries, +) -> Result>, Error> { + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_record_batch(message)?; + _mmap_record( + fields, + ipc_fields, + data.clone(), + batch, + offset, + dictionaries, + ) +} + +/// Memory maps an record batch from an IPC file into a [`Chunk`]. +/// # Errors +/// This function errors when: +/// * The IPC file is not valid +/// * the buffers on the file are un-aligned with their corresponding data. This can happen when: +/// * the file was written with 8-bit alignment +/// * the file contains type decimal 128 or 256 +/// # Safety +/// The caller must ensure that `data` contains a valid buffers, for example: +/// * Offsets in variable-sized containers must be in-bounds and increasing +/// * Utf8 data is valid +pub unsafe fn mmap_unchecked>( + metadata: &FileMetadata, + dictionaries: &Dictionaries, + data: Arc, + chunk: usize, +) -> Result>, Error> { + let block = metadata.blocks[chunk]; + + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_record_batch(message)?; + _mmap_record( + &metadata.schema.fields, + &metadata.ipc_schema.fields, + data.clone(), + batch, + offset, + dictionaries, + ) +} + +unsafe fn mmap_dictionary>( + metadata: &FileMetadata, + data: Arc, + block: Block, + dictionaries: &mut Dictionaries, +) -> Result<(), Error> { + let (message, offset) = read_message(data.as_ref().as_ref(), block)?; + let batch = get_dictionary_batch(&message)?; + + let id = batch + .id() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferId(err)))?; + let (first_field, first_ipc_field) = + first_dict_field(id, &metadata.schema.fields, &metadata.ipc_schema.fields)?; + + let batch = batch + .data() + .map_err(|err| Error::from(OutOfSpecKind::InvalidFlatbufferData(err)))? + .ok_or_else(|| Error::from(OutOfSpecKind::MissingData))?; + + let value_type = + if let DataType::Dictionary(_, value_type, _) = first_field.data_type.to_logical_type() { + value_type.as_ref() + } else { + return Err(Error::from(OutOfSpecKind::InvalidIdDataType { + requested_id: id, + })); + }; + + // Make a fake schema for the dictionary batch. + let field = Field::new("", value_type.clone(), false); + + let chunk = _mmap_record( + &[field], + &[first_ipc_field.clone()], + data.clone(), + batch, + offset, + dictionaries, + )?; + + dictionaries.insert(id, chunk.into_arrays().pop().unwrap()); + + Ok(()) +} + +/// Memory maps dictionaries from an IPC file into +/// # Safety +/// The caller must ensure that `data` contains a valid buffers, for example: +/// * Offsets in variable-sized containers must be in-bounds and increasing +/// * Utf8 data is valid +pub unsafe fn mmap_dictionaries_unchecked>( + metadata: &FileMetadata, + data: Arc, +) -> Result { + let blocks = if let Some(blocks) = &metadata.dictionaries { + blocks + } else { + return Ok(Default::default()); + }; + + let mut dictionaries = Default::default(); + + blocks + .iter() + .cloned() + .try_for_each(|block| mmap_dictionary(metadata, data.clone(), block, &mut dictionaries))?; + Ok(dictionaries) +} diff --git a/crates/nano-arrow/src/offset.rs b/crates/nano-arrow/src/offset.rs new file mode 100644 index 000000000000..409e695ba66a --- /dev/null +++ b/crates/nano-arrow/src/offset.rs @@ -0,0 +1,543 @@ +//! Contains the declaration of [`Offset`] +use std::hint::unreachable_unchecked; + +use crate::buffer::Buffer; +use crate::error::Error; +pub use crate::types::Offset; + +/// A wrapper type of [`Vec`] representing the invariants of Arrow's offsets. +/// It is guaranteed to (sound to assume that): +/// * every element is `>= 0` +/// * element at position `i` is >= than element at position `i-1`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Offsets(Vec); + +impl Default for Offsets { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl TryFrom> for Offsets { + type Error = Error; + + #[inline] + fn try_from(offsets: Vec) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets)) + } +} + +impl TryFrom> for OffsetsBuffer { + type Error = Error; + + #[inline] + fn try_from(offsets: Buffer) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets)) + } +} + +impl TryFrom> for OffsetsBuffer { + type Error = Error; + + #[inline] + fn try_from(offsets: Vec) -> Result { + try_check_offsets(&offsets)?; + Ok(Self(offsets.into())) + } +} + +impl From> for OffsetsBuffer { + #[inline] + fn from(offsets: Offsets) -> Self { + Self(offsets.0.into()) + } +} + +impl Offsets { + /// Returns an empty [`Offsets`] (i.e. with a single element, the zero) + #[inline] + pub fn new() -> Self { + Self(vec![O::zero()]) + } + + /// Returns an [`Offsets`] whose all lengths are zero. + #[inline] + pub fn new_zeroed(length: usize) -> Self { + Self(vec![O::zero(); length + 1]) + } + + /// Creates a new [`Offsets`] from an iterator of lengths + #[inline] + pub fn try_from_iter>(iter: I) -> Result { + let iterator = iter.into_iter(); + let (lower, _) = iterator.size_hint(); + let mut offsets = Self::with_capacity(lower); + for item in iterator { + offsets.try_push_usize(item)? + } + Ok(offsets) + } + + /// Returns a new [`Offsets`] with a capacity, allocating at least `capacity + 1` entries. + pub fn with_capacity(capacity: usize) -> Self { + let mut offsets = Vec::with_capacity(capacity + 1); + offsets.push(O::zero()); + Self(offsets) + } + + /// Returns the capacity of [`Offsets`]. + pub fn capacity(&self) -> usize { + self.0.capacity() - 1 + } + + /// Reserves `additional` entries. + pub fn reserve(&mut self, additional: usize) { + self.0.reserve(additional); + } + + /// Shrinks the capacity of self to fit. + pub fn shrink_to_fit(&mut self) { + self.0.shrink_to_fit(); + } + + /// Pushes a new element with a given length. + /// # Error + /// This function errors iff the new last item is larger than what `O` supports. + /// # Panic + /// This function asserts that `length > 0`. + #[inline] + pub fn try_push(&mut self, length: O) -> Result<(), Error> { + let old_length = self.last(); + assert!(length >= O::zero()); + let new_length = old_length.checked_add(&length).ok_or(Error::Overflow)?; + self.0.push(new_length); + Ok(()) + } + + /// Pushes a new element with a given length. + /// # Error + /// This function errors iff the new last item is larger than what `O` supports. + /// # Implementation + /// This function: + /// * checks that this length does not overflow + #[inline] + pub fn try_push_usize(&mut self, length: usize) -> Result<(), Error> { + let length = O::from_usize(length).ok_or(Error::Overflow)?; + + let old_length = self.last(); + let new_length = old_length.checked_add(&length).ok_or(Error::Overflow)?; + self.0.push(new_length); + Ok(()) + } + + /// Returns [`Offsets`] assuming that `offsets` fulfills its invariants + /// # Safety + /// This is safe iff the invariants of this struct are guaranteed in `offsets`. + #[inline] + pub unsafe fn new_unchecked(offsets: Vec) -> Self { + Self(offsets) + } + + /// Returns the last offset of this container. + #[inline] + pub fn last(&self) -> &O { + match self.0.last() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len()` + #[inline] + pub fn start_end(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + assert!(index < self.len_proxy()); + unsafe { self.start_end_unchecked(index) } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Safety + /// `index` must be `< self.len()` + #[inline] + pub unsafe fn start_end_unchecked(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + let start = self.0.get_unchecked(index).to_usize(); + let end = self.0.get_unchecked(index + 1).to_usize(); + (start, end) + } + + /// Returns the length an array with these offsets would be. + #[inline] + pub fn len_proxy(&self) -> usize { + self.0.len() - 1 + } + + #[inline] + /// Returns the number of offsets in this container. + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[O] { + self.0.as_slice() + } + + /// Pops the last element + #[inline] + pub fn pop(&mut self) -> Option { + if self.len_proxy() == 0 { + None + } else { + self.0.pop() + } + } + + /// Extends itself with `additional` elements equal to the last offset. + /// This is useful to extend offsets with empty values, e.g. for null slots. + #[inline] + pub fn extend_constant(&mut self, additional: usize) { + let offset = *self.last(); + if additional == 1 { + self.0.push(offset) + } else { + self.0.resize(self.len() + additional, offset) + } + } + + /// Try to create a new [`Offsets`] from a sequence of `lengths` + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + #[inline] + pub fn try_from_lengths>(lengths: I) -> Result { + let mut self_ = Self::with_capacity(lengths.size_hint().0); + self_.try_extend_from_lengths(lengths)?; + Ok(self_) + } + + /// Try extend from an iterator of lengths + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + #[inline] + pub fn try_extend_from_lengths>( + &mut self, + lengths: I, + ) -> Result<(), Error> { + let mut total_length = 0; + let mut offset = *self.last(); + let original_offset = offset.to_usize(); + + let lengths = lengths.map(|length| { + total_length += length; + O::from_as_usize(length) + }); + + let offsets = lengths.map(|length| { + offset += length; // this may overflow, checked below + offset + }); + self.0.extend(offsets); + + let last_offset = original_offset + .checked_add(total_length) + .ok_or(Error::Overflow)?; + O::from_usize(last_offset).ok_or(Error::Overflow)?; + Ok(()) + } + + /// Extends itself from another [`Offsets`] + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + pub fn try_extend_from_self(&mut self, other: &Self) -> Result<(), Error> { + let mut length = *self.last(); + let other_length = *other.last(); + // check if the operation would overflow + length.checked_add(&other_length).ok_or(Error::Overflow)?; + + let lengths = other.as_slice().windows(2).map(|w| w[1] - w[0]); + let offsets = lengths.map(|new_length| { + length += new_length; + length + }); + self.0.extend(offsets); + Ok(()) + } + + /// Extends itself from another [`Offsets`] sliced by `start, length` + /// # Errors + /// This function errors iff this operation overflows for the maximum value of `O`. + pub fn try_extend_from_slice( + &mut self, + other: &OffsetsBuffer, + start: usize, + length: usize, + ) -> Result<(), Error> { + if length == 0 { + return Ok(()); + } + let other = &other.0[start..start + length + 1]; + let other_length = other.last().expect("Length to be non-zero"); + let mut length = *self.last(); + // check if the operation would overflow + length.checked_add(other_length).ok_or(Error::Overflow)?; + + let lengths = other.windows(2).map(|w| w[1] - w[0]); + let offsets = lengths.map(|new_length| { + length += new_length; + length + }); + self.0.extend(offsets); + Ok(()) + } + + /// Returns the inner [`Vec`]. + #[inline] + pub fn into_inner(self) -> Vec { + self.0 + } +} + +/// Checks that `offsets` is monotonically increasing. +fn try_check_offsets(offsets: &[O]) -> Result<(), Error> { + // this code is carefully constructed to auto-vectorize, don't change naively! + match offsets.first() { + None => Err(Error::oos("offsets must have at least one element")), + Some(first) => { + if *first < O::zero() { + return Err(Error::oos("offsets must be larger than 0")); + } + let mut previous = *first; + let mut any_invalid = false; + + // This loop will auto-vectorize because there is not any break, + // an invalid value will be returned once the whole offsets buffer is processed. + for offset in offsets { + if previous > *offset { + any_invalid = true + } + previous = *offset; + } + + if any_invalid { + Err(Error::oos("offsets must be monotonically increasing")) + } else { + Ok(()) + } + }, + } +} + +/// A wrapper type of [`Buffer`] that is guaranteed to: +/// * Always contain an element +/// * Every element is `>= 0` +/// * element at position `i` is >= than element at position `i-1`. +#[derive(Clone, PartialEq, Debug)] +pub struct OffsetsBuffer(Buffer); + +impl Default for OffsetsBuffer { + #[inline] + fn default() -> Self { + Self(vec![O::zero()].into()) + } +} + +impl OffsetsBuffer { + /// # Safety + /// This is safe iff the invariants of this struct are guaranteed in `offsets`. + #[inline] + pub unsafe fn new_unchecked(offsets: Buffer) -> Self { + Self(offsets) + } + + /// Returns an empty [`OffsetsBuffer`] (i.e. with a single element, the zero) + #[inline] + pub fn new() -> Self { + Self(vec![O::zero()].into()) + } + + /// Copy-on-write API to convert [`OffsetsBuffer`] into [`Offsets`]. + #[inline] + pub fn into_mut(self) -> either::Either> { + self.0 + .into_mut() + // Safety: Offsets and OffsetsBuffer share invariants + .map_right(|offsets| unsafe { Offsets::new_unchecked(offsets) }) + .map_left(Self) + } + + /// Returns a reference to its internal [`Buffer`]. + #[inline] + pub fn buffer(&self) -> &Buffer { + &self.0 + } + + /// Returns the length an array with these offsets would be. + #[inline] + pub fn len_proxy(&self) -> usize { + self.0.len() - 1 + } + + /// Returns the number of offsets in this container. + #[inline] + pub fn len(&self) -> usize { + self.0.len() + } + + /// Returns the byte slice stored in this buffer + #[inline] + pub fn as_slice(&self) -> &[O] { + self.0.as_slice() + } + + /// Returns the range of the offsets. + #[inline] + pub fn range(&self) -> O { + *self.last() - *self.first() + } + + /// Returns the first offset. + #[inline] + pub fn first(&self) -> &O { + match self.0.first() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns the last offset. + #[inline] + pub fn last(&self) -> &O { + match self.0.last() { + Some(element) => element, + None => unsafe { unreachable_unchecked() }, + } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Panic + /// This function panics iff `index >= self.len()` + #[inline] + pub fn start_end(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + assert!(index < self.len_proxy()); + unsafe { self.start_end_unchecked(index) } + } + + /// Returns a range (start, end) corresponding to the position `index` + /// # Safety + /// `index` must be `< self.len()` + #[inline] + pub unsafe fn start_end_unchecked(&self, index: usize) -> (usize, usize) { + // soundness: the invariant of the function + let start = self.0.get_unchecked(index).to_usize(); + let end = self.0.get_unchecked(index + 1).to_usize(); + (start, end) + } + + /// Slices this [`OffsetsBuffer`]. + /// # Panics + /// Panics if `offset + length` is larger than `len` + /// or `length == 0`. + #[inline] + pub fn slice(&mut self, offset: usize, length: usize) { + assert!(length > 0); + self.0.slice(offset, length); + } + + /// Slices this [`OffsetsBuffer`] starting at `offset`. + /// # Safety + /// The caller must ensure `offset + length <= self.len()` + #[inline] + pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) { + self.0.slice_unchecked(offset, length); + } + + /// Returns an iterator with the lengths of the offsets + #[inline] + pub fn lengths(&self) -> impl Iterator + '_ { + self.0.windows(2).map(|w| (w[1] - w[0]).to_usize()) + } + + /// Returns the inner [`Buffer`]. + #[inline] + pub fn into_inner(self) -> Buffer { + self.0 + } +} + +impl From<&OffsetsBuffer> for OffsetsBuffer { + fn from(offsets: &OffsetsBuffer) -> Self { + // this conversion is lossless and uphelds all invariants + Self( + offsets + .buffer() + .iter() + .map(|x| *x as i64) + .collect::>() + .into(), + ) + } +} + +impl TryFrom<&OffsetsBuffer> for OffsetsBuffer { + type Error = Error; + + fn try_from(offsets: &OffsetsBuffer) -> Result { + i32::try_from(*offsets.last()).map_err(|_| Error::Overflow)?; + + // this conversion is lossless and uphelds all invariants + Ok(Self( + offsets + .buffer() + .iter() + .map(|x| *x as i32) + .collect::>() + .into(), + )) + } +} + +impl From> for Offsets { + fn from(offsets: Offsets) -> Self { + // this conversion is lossless and uphelds all invariants + Self( + offsets + .as_slice() + .iter() + .map(|x| *x as i64) + .collect::>(), + ) + } +} + +impl TryFrom> for Offsets { + type Error = Error; + + fn try_from(offsets: Offsets) -> Result { + i32::try_from(*offsets.last()).map_err(|_| Error::Overflow)?; + + // this conversion is lossless and uphelds all invariants + Ok(Self( + offsets + .as_slice() + .iter() + .map(|x| *x as i32) + .collect::>(), + )) + } +} + +impl std::ops::Deref for OffsetsBuffer { + type Target = [O]; + + #[inline] + fn deref(&self) -> &[O] { + self.0.as_slice() + } +} diff --git a/crates/nano-arrow/src/scalar/README.md b/crates/nano-arrow/src/scalar/README.md new file mode 100644 index 000000000000..b780081b6131 --- /dev/null +++ b/crates/nano-arrow/src/scalar/README.md @@ -0,0 +1,28 @@ +# Scalar API + +Design choices: + +### `Scalar` is trait object + +There are three reasons: + +- a scalar should have a small memory footprint, which an enum would not ensure given the different physical types available. +- forward-compatibility: a new entry on an `enum` is backward-incompatible +- do not expose implementation details to users (reduce the surface of the public API) + +### `Scalar` MUST contain nullability information + +This is to be aligned with the general notion of arrow's `Array`. + +This API is a companion to the `Array`, and follows the same design as `Array`. +Specifically, a `Scalar` is a trait object that can be downcasted to concrete implementations. + +Like `Array`, `Scalar` implements + +- `data_type`, which is used to perform the correct downcast +- `is_valid`, to tell whether the scalar is null or not + +### There is one implementation per arrows' physical type + +- Reduces the number of `match` that users need to write +- Allows casting of logical types without changing the underlying physical type diff --git a/crates/nano-arrow/src/scalar/binary.rs b/crates/nano-arrow/src/scalar/binary.rs new file mode 100644 index 000000000000..0d33f6f8f7e4 --- /dev/null +++ b/crates/nano-arrow/src/scalar/binary.rs @@ -0,0 +1,55 @@ +use super::Scalar; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// The [`Scalar`] implementation of binary ([`Option>`]). +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BinaryScalar { + value: Option>, + phantom: std::marker::PhantomData, +} + +impl BinaryScalar { + /// Returns a new [`BinaryScalar`]. + #[inline] + pub fn new>>(value: Option

) -> Self { + Self { + value: value.map(|x| x.into()), + phantom: std::marker::PhantomData, + } + } + + /// Its value + #[inline] + pub fn value(&self) -> Option<&[u8]> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl>> From> for BinaryScalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for BinaryScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + if O::IS_LARGE { + &DataType::LargeBinary + } else { + &DataType::Binary + } + } +} diff --git a/crates/nano-arrow/src/scalar/boolean.rs b/crates/nano-arrow/src/scalar/boolean.rs new file mode 100644 index 000000000000..aa7d435859af --- /dev/null +++ b/crates/nano-arrow/src/scalar/boolean.rs @@ -0,0 +1,46 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// The [`Scalar`] implementation of a boolean. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BooleanScalar { + value: Option, +} + +impl BooleanScalar { + /// Returns a new [`BooleanScalar`] + #[inline] + pub fn new(value: Option) -> Self { + Self { value } + } + + /// The value + #[inline] + pub fn value(&self) -> Option { + self.value + } +} + +impl Scalar for BooleanScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + &DataType::Boolean + } +} + +impl From> for BooleanScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(v) + } +} diff --git a/crates/nano-arrow/src/scalar/dictionary.rs b/crates/nano-arrow/src/scalar/dictionary.rs new file mode 100644 index 000000000000..97e3e5916f52 --- /dev/null +++ b/crates/nano-arrow/src/scalar/dictionary.rs @@ -0,0 +1,54 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; + +/// The [`DictionaryArray`] equivalent of [`Array`] for [`Scalar`]. +#[derive(Debug, Clone)] +pub struct DictionaryScalar { + value: Option>, + phantom: std::marker::PhantomData, + data_type: DataType, +} + +impl PartialEq for DictionaryScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) && (self.value.as_ref() == other.value.as_ref()) + } +} + +impl DictionaryScalar { + /// returns a new [`DictionaryScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, value: Option>) -> Self { + Self { + value, + phantom: std::marker::PhantomData, + data_type, + } + } + + /// The values of the [`DictionaryScalar`] + pub fn value(&self) -> Option<&Box> { + self.value.as_ref() + } +} + +impl Scalar for DictionaryScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.value.is_some() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/equal.rs b/crates/nano-arrow/src/scalar/equal.rs new file mode 100644 index 000000000000..34f98d23640d --- /dev/null +++ b/crates/nano-arrow/src/scalar/equal.rs @@ -0,0 +1,57 @@ +use std::sync::Arc; + +use super::*; +use crate::datatypes::PhysicalType; + +impl PartialEq for dyn Scalar + '_ { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(self, that) + } +} + +impl PartialEq for Arc { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(&**self, that) + } +} + +impl PartialEq for Box { + fn eq(&self, that: &dyn Scalar) -> bool { + equal(&**self, that) + } +} + +macro_rules! dyn_eq { + ($ty:ty, $lhs:expr, $rhs:expr) => {{ + let lhs = $lhs.as_any().downcast_ref::<$ty>().unwrap(); + let rhs = $rhs.as_any().downcast_ref::<$ty>().unwrap(); + lhs == rhs + }}; +} + +fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { + if lhs.data_type() != rhs.data_type() { + return false; + } + + use PhysicalType::*; + match lhs.data_type().to_physical_type() { + Null => dyn_eq!(NullScalar, lhs, rhs), + Boolean => dyn_eq!(BooleanScalar, lhs, rhs), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + dyn_eq!(PrimitiveScalar<$T>, lhs, rhs) + }), + LargeUtf8 => dyn_eq!(Utf8Scalar, lhs, rhs), + LargeBinary => dyn_eq!(BinaryScalar, lhs, rhs), + LargeList => dyn_eq!(ListScalar, lhs, rhs), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + dyn_eq!(DictionaryScalar<$T>, lhs, rhs) + }), + Struct => dyn_eq!(StructScalar, lhs, rhs), + FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs), + FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), + Union => dyn_eq!(UnionScalar, lhs, rhs), + Map => dyn_eq!(MapScalar, lhs, rhs), + _ => unimplemented!(), + } +} diff --git a/crates/nano-arrow/src/scalar/fixed_size_binary.rs b/crates/nano-arrow/src/scalar/fixed_size_binary.rs new file mode 100644 index 000000000000..d8fbb96bac2c --- /dev/null +++ b/crates/nano-arrow/src/scalar/fixed_size_binary.rs @@ -0,0 +1,58 @@ +use super::Scalar; +use crate::datatypes::DataType; + +#[derive(Debug, Clone, PartialEq, Eq)] +/// The [`Scalar`] implementation of fixed size binary ([`Option>`]). +pub struct FixedSizeBinaryScalar { + value: Option>, + data_type: DataType, +} + +impl FixedSizeBinaryScalar { + /// Returns a new [`FixedSizeBinaryScalar`]. + /// # Panics + /// iff + /// * the `data_type` is not `FixedSizeBinary` + /// * the size of child binary is not equal + #[inline] + pub fn new>>(data_type: DataType, value: Option

) -> Self { + assert_eq!( + data_type.to_physical_type(), + crate::datatypes::PhysicalType::FixedSizeBinary + ); + Self { + value: value.map(|x| { + let x: Vec = x.into(); + assert_eq!( + data_type.to_logical_type(), + &DataType::FixedSizeBinary(x.len()) + ); + x.into_boxed_slice() + }), + data_type, + } + } + + /// Its value + #[inline] + pub fn value(&self) -> Option<&[u8]> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl Scalar for FixedSizeBinaryScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/fixed_size_list.rs b/crates/nano-arrow/src/scalar/fixed_size_list.rs new file mode 100644 index 000000000000..b8333c02c347 --- /dev/null +++ b/crates/nano-arrow/src/scalar/fixed_size_list.rs @@ -0,0 +1,60 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; + +/// The scalar equivalent of [`FixedSizeListArray`]. Like [`FixedSizeListArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct FixedSizeListScalar { + values: Option>, + data_type: DataType, +} + +impl PartialEq for FixedSizeListScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.values.is_some() == other.values.is_some()) + && ((self.values.is_none()) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl FixedSizeListScalar { + /// returns a new [`FixedSizeListScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `FixedSizeList` + /// * the child of the `data_type` is not equal to the `values` + /// * the size of child array is not equal + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let (field, size) = FixedSizeListArray::get_child_and_size(&data_type); + let inner_data_type = field.data_type(); + let values = values.map(|x| { + assert_eq!(inner_data_type, x.data_type()); + assert_eq!(size, x.len()); + x + }); + Self { values, data_type } + } + + /// The values of the [`FixedSizeListScalar`] + pub fn values(&self) -> Option<&Box> { + self.values.as_ref() + } +} + +impl Scalar for FixedSizeListScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.values.is_some() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/list.rs b/crates/nano-arrow/src/scalar/list.rs new file mode 100644 index 000000000000..d82bf02768bf --- /dev/null +++ b/crates/nano-arrow/src/scalar/list.rs @@ -0,0 +1,68 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// The scalar equivalent of [`ListArray`]. Like [`ListArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct ListScalar { + values: Box, + is_valid: bool, + phantom: std::marker::PhantomData, + data_type: DataType, +} + +impl PartialEq for ListScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl ListScalar { + /// returns a new [`ListScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `List` or `LargeList` (depending on this scalar's offset `O`) + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let inner_data_type = ListArray::::get_child_type(&data_type); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_data_type, values.data_type()); + (true, values) + }, + None => (false, new_empty_array(inner_data_type.clone())), + }; + Self { + values, + is_valid, + phantom: std::marker::PhantomData, + data_type, + } + } + + /// The values of the [`ListScalar`] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl Scalar for ListScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/map.rs b/crates/nano-arrow/src/scalar/map.rs new file mode 100644 index 000000000000..90145fb6a30f --- /dev/null +++ b/crates/nano-arrow/src/scalar/map.rs @@ -0,0 +1,66 @@ +use std::any::Any; + +use super::Scalar; +use crate::array::*; +use crate::datatypes::DataType; + +/// The scalar equivalent of [`MapArray`]. Like [`MapArray`], this struct holds a dynamically-typed +/// [`Array`]. The only difference is that this has only one element. +#[derive(Debug, Clone)] +pub struct MapScalar { + values: Box, + is_valid: bool, + data_type: DataType, +} + +impl PartialEq for MapScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values.as_ref() == other.values.as_ref())) + } +} + +impl MapScalar { + /// returns a new [`MapScalar`] + /// # Panics + /// iff + /// * the `data_type` is not `Map` + /// * the child of the `data_type` is not equal to the `values` + #[inline] + pub fn new(data_type: DataType, values: Option>) -> Self { + let inner_field = MapArray::try_get_field(&data_type).unwrap(); + let inner_data_type = inner_field.data_type(); + let (is_valid, values) = match values { + Some(values) => { + assert_eq!(inner_data_type, values.data_type()); + (true, values) + }, + None => (false, new_empty_array(inner_data_type.clone())), + }; + Self { + values, + is_valid, + data_type, + } + } + + /// The values of the [`MapScalar`] + pub fn values(&self) -> &Box { + &self.values + } +} + +impl Scalar for MapScalar { + fn as_any(&self) -> &dyn Any { + self + } + + fn is_valid(&self) -> bool { + self.is_valid + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/mod.rs b/crates/nano-arrow/src/scalar/mod.rs new file mode 100644 index 000000000000..7b78b93a44f2 --- /dev/null +++ b/crates/nano-arrow/src/scalar/mod.rs @@ -0,0 +1,187 @@ +//! contains the [`Scalar`] trait object representing individual items of [`Array`](crate::array::Array)s, +//! as well as concrete implementations such as [`BooleanScalar`]. +use std::any::Any; + +use crate::array::*; +use crate::datatypes::*; + +mod dictionary; +pub use dictionary::*; +mod equal; +mod primitive; +pub use primitive::*; +mod utf8; +pub use utf8::*; +mod binary; +pub use binary::*; +mod boolean; +pub use boolean::*; +mod list; +pub use list::*; +mod map; +pub use map::*; +mod null; +pub use null::*; +mod struct_; +pub use struct_::*; +mod fixed_size_list; +pub use fixed_size_list::*; +mod fixed_size_binary; +pub use fixed_size_binary::*; +mod union; +pub use union::UnionScalar; + +/// Trait object declaring an optional value with a [`DataType`]. +/// This strait is often used in APIs that accept multiple scalar types. +pub trait Scalar: std::fmt::Debug + Send + Sync + dyn_clone::DynClone + 'static { + /// convert itself to + fn as_any(&self) -> &dyn Any; + + /// whether it is valid + fn is_valid(&self) -> bool; + + /// the logical type. + fn data_type(&self) -> &DataType; +} + +dyn_clone::clone_trait_object!(Scalar); + +macro_rules! dyn_new_utf8 { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(Utf8Scalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_binary { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index)) + } else { + None + }; + Box::new(BinaryScalar::<$type>::new(value)) + }}; +} + +macro_rules! dyn_new_list { + ($array:expr, $index:expr, $type:ty) => {{ + let array = $array.as_any().downcast_ref::>().unwrap(); + let value = if array.is_valid($index) { + Some(array.value($index).into()) + } else { + None + }; + Box::new(ListScalar::<$type>::new(array.data_type().clone(), value)) + }}; +} + +/// creates a new [`Scalar`] from an [`Array`]. +pub fn new_scalar(array: &dyn Array, index: usize) -> Box { + use PhysicalType::*; + match array.data_type().to_physical_type() { + Null => Box::new(NullScalar::new()), + Boolean => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(BooleanScalar::new(value)) + }, + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(PrimitiveScalar::new(array.data_type().clone(), value)) + }), + Utf8 => dyn_new_utf8!(array, index, i32), + LargeUtf8 => dyn_new_utf8!(array, index, i64), + Binary => dyn_new_binary!(array, index, i32), + LargeBinary => dyn_new_binary!(array, index, i64), + List => dyn_new_list!(array, index, i32), + LargeList => dyn_new_list!(array, index, i64), + Struct => { + let array = array.as_any().downcast_ref::().unwrap(); + if array.is_valid(index) { + let values = array + .values() + .iter() + .map(|x| new_scalar(x.as_ref(), index)) + .collect(); + Box::new(StructScalar::new(array.data_type().clone(), Some(values))) + } else { + Box::new(StructScalar::new(array.data_type().clone(), None)) + } + }, + FixedSizeBinary => { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(FixedSizeBinaryScalar::new(array.data_type().clone(), value)) + }, + FixedSizeList => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(FixedSizeListScalar::new(array.data_type().clone(), value)) + }, + Union => { + let array = array.as_any().downcast_ref::().unwrap(); + Box::new(UnionScalar::new( + array.data_type().clone(), + array.types()[index], + array.value(index), + )) + }, + Map => { + let array = array.as_any().downcast_ref::().unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index)) + } else { + None + }; + Box::new(MapScalar::new(array.data_type().clone(), value)) + }, + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let value = if array.is_valid(index) { + Some(array.value(index).into()) + } else { + None + }; + Box::new(DictionaryScalar::<$T>::new( + array.data_type().clone(), + value, + )) + }), + } +} diff --git a/crates/nano-arrow/src/scalar/null.rs b/crates/nano-arrow/src/scalar/null.rs new file mode 100644 index 000000000000..2de7d7cde55b --- /dev/null +++ b/crates/nano-arrow/src/scalar/null.rs @@ -0,0 +1,37 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// The representation of a single entry of a [`crate::array::NullArray`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct NullScalar {} + +impl NullScalar { + /// A new [`NullScalar`] + #[inline] + pub fn new() -> Self { + Self {} + } +} + +impl Default for NullScalar { + fn default() -> Self { + Self::new() + } +} + +impl Scalar for NullScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + false + } + + #[inline] + fn data_type(&self) -> &DataType { + &DataType::Null + } +} diff --git a/crates/nano-arrow/src/scalar/primitive.rs b/crates/nano-arrow/src/scalar/primitive.rs new file mode 100644 index 000000000000..3288708f6755 --- /dev/null +++ b/crates/nano-arrow/src/scalar/primitive.rs @@ -0,0 +1,67 @@ +use super::Scalar; +use crate::datatypes::DataType; +use crate::error::Error; +use crate::types::NativeType; + +/// The implementation of [`Scalar`] for primitive, semantically equivalent to [`Option`] +/// with [`DataType`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrimitiveScalar { + value: Option, + data_type: DataType, +} + +impl PrimitiveScalar { + /// Returns a new [`PrimitiveScalar`]. + #[inline] + pub fn new(data_type: DataType, value: Option) -> Self { + if !data_type.to_physical_type().eq_primitive(T::PRIMITIVE) { + panic!( + "{:?}", + Error::InvalidArgumentError(format!( + "Type {} does not support logical type {:?}", + std::any::type_name::(), + data_type + )) + ) + } + Self { value, data_type } + } + + /// Returns the optional value. + #[inline] + pub fn value(&self) -> &Option { + &self.value + } + + /// Returns a new `PrimitiveScalar` with the same value but different [`DataType`] + /// # Panic + /// This function panics if the `data_type` is not valid for self's physical type `T`. + pub fn to(self, data_type: DataType) -> Self { + Self::new(data_type, self.value) + } +} + +impl From> for PrimitiveScalar { + #[inline] + fn from(v: Option) -> Self { + Self::new(T::PRIMITIVE.into(), v) + } +} + +impl Scalar for PrimitiveScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/struct_.rs b/crates/nano-arrow/src/scalar/struct_.rs new file mode 100644 index 000000000000..29c2c33ba295 --- /dev/null +++ b/crates/nano-arrow/src/scalar/struct_.rs @@ -0,0 +1,54 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// A single entry of a [`crate::array::StructArray`]. +#[derive(Debug, Clone)] +pub struct StructScalar { + values: Vec>, + is_valid: bool, + data_type: DataType, +} + +impl PartialEq for StructScalar { + fn eq(&self, other: &Self) -> bool { + (self.data_type == other.data_type) + && (self.is_valid == other.is_valid) + && ((!self.is_valid) | (self.values == other.values)) + } +} + +impl StructScalar { + /// Returns a new [`StructScalar`] + #[inline] + pub fn new(data_type: DataType, values: Option>>) -> Self { + let is_valid = values.is_some(); + Self { + values: values.unwrap_or_default(), + is_valid, + data_type, + } + } + + /// Returns the values irrespectively of the validity. + #[inline] + pub fn values(&self) -> &[Box] { + &self.values + } +} + +impl Scalar for StructScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.is_valid + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/union.rs b/crates/nano-arrow/src/scalar/union.rs new file mode 100644 index 000000000000..987e9f4e6044 --- /dev/null +++ b/crates/nano-arrow/src/scalar/union.rs @@ -0,0 +1,51 @@ +use super::Scalar; +use crate::datatypes::DataType; + +/// A single entry of a [`crate::array::UnionArray`]. +#[derive(Debug, Clone, PartialEq)] +pub struct UnionScalar { + value: Box, + type_: i8, + data_type: DataType, +} + +impl UnionScalar { + /// Returns a new [`UnionScalar`] + #[inline] + pub fn new(data_type: DataType, type_: i8, value: Box) -> Self { + Self { + value, + type_, + data_type, + } + } + + /// Returns the inner value + #[inline] + pub fn value(&self) -> &Box { + &self.value + } + + /// Returns the type of the union scalar + #[inline] + pub fn type_(&self) -> i8 { + self.type_ + } +} + +impl Scalar for UnionScalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + true + } + + #[inline] + fn data_type(&self) -> &DataType { + &self.data_type + } +} diff --git a/crates/nano-arrow/src/scalar/utf8.rs b/crates/nano-arrow/src/scalar/utf8.rs new file mode 100644 index 000000000000..ea08d30af578 --- /dev/null +++ b/crates/nano-arrow/src/scalar/utf8.rs @@ -0,0 +1,55 @@ +use super::Scalar; +use crate::datatypes::DataType; +use crate::offset::Offset; + +/// The implementation of [`Scalar`] for utf8, semantically equivalent to [`Option`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Utf8Scalar { + value: Option, + phantom: std::marker::PhantomData, +} + +impl Utf8Scalar { + /// Returns a new [`Utf8Scalar`] + #[inline] + pub fn new>(value: Option

) -> Self { + Self { + value: value.map(|x| x.into()), + phantom: std::marker::PhantomData, + } + } + + /// Returns the value irrespectively of the validity. + #[inline] + pub fn value(&self) -> Option<&str> { + self.value.as_ref().map(|x| x.as_ref()) + } +} + +impl> From> for Utf8Scalar { + #[inline] + fn from(v: Option

) -> Self { + Self::new(v) + } +} + +impl Scalar for Utf8Scalar { + #[inline] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline] + fn is_valid(&self) -> bool { + self.value.is_some() + } + + #[inline] + fn data_type(&self) -> &DataType { + if O::IS_LARGE { + &DataType::LargeUtf8 + } else { + &DataType::Utf8 + } + } +} diff --git a/crates/nano-arrow/src/temporal_conversions.rs b/crates/nano-arrow/src/temporal_conversions.rs new file mode 100644 index 000000000000..5058d1d887bd --- /dev/null +++ b/crates/nano-arrow/src/temporal_conversions.rs @@ -0,0 +1,543 @@ +//! Conversion methods for dates and times. + +use chrono::format::{parse, Parsed, StrftimeItems}; +use chrono::{Datelike, Duration, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; + +use crate::array::{PrimitiveArray, Utf8Array}; +use crate::datatypes::{DataType, TimeUnit}; +use crate::error::{Error, Result}; +use crate::offset::Offset; +use crate::types::months_days_ns; + +/// Number of seconds in a day +pub const SECONDS_IN_DAY: i64 = 86_400; +/// Number of milliseconds in a second +pub const MILLISECONDS: i64 = 1_000; +/// Number of microseconds in a second +pub const MICROSECONDS: i64 = 1_000_000; +/// Number of nanoseconds in a second +pub const NANOSECONDS: i64 = 1_000_000_000; +/// Number of milliseconds in a day +pub const MILLISECONDS_IN_DAY: i64 = SECONDS_IN_DAY * MILLISECONDS; +/// Number of days between 0001-01-01 and 1970-01-01 +pub const EPOCH_DAYS_FROM_CE: i32 = 719_163; + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime(v: i32) -> NaiveDateTime { + date32_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i32` representing a `date32` to [`NaiveDateTime`] +#[inline] +pub fn date32_to_datetime_opt(v: i32) -> Option { + NaiveDateTime::from_timestamp_opt(v as i64 * SECONDS_IN_DAY, 0) +} + +/// converts a `i32` representing a `date32` to [`NaiveDate`] +#[inline] +pub fn date32_to_date(days: i32) -> NaiveDate { + date32_to_date_opt(days).expect("out-of-range date") +} + +/// converts a `i32` representing a `date32` to [`NaiveDate`] +#[inline] +pub fn date32_to_date_opt(days: i32) -> Option { + NaiveDate::from_num_days_from_ce_opt(EPOCH_DAYS_FROM_CE + days) +} + +/// converts a `i64` representing a `date64` to [`NaiveDateTime`] +#[inline] +pub fn date64_to_datetime(v: i64) -> NaiveDateTime { + NaiveDateTime::from_timestamp_opt( + // extract seconds from milliseconds + v / MILLISECONDS, + // discard extracted seconds and convert milliseconds to nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) + .expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `date64` to [`NaiveDate`] +#[inline] +pub fn date64_to_date(milliseconds: i64) -> NaiveDate { + date64_to_datetime(milliseconds).date() +} + +/// converts a `i32` representing a `time32(s)` to [`NaiveTime`] +#[inline] +pub fn time32s_to_time(v: i32) -> NaiveTime { + NaiveTime::from_num_seconds_from_midnight_opt(v as u32, 0).expect("invalid time") +} + +/// converts a `i64` representing a `duration(s)` to [`Duration`] +#[inline] +pub fn duration_s_to_duration(v: i64) -> Duration { + Duration::seconds(v) +} + +/// converts a `i64` representing a `duration(ms)` to [`Duration`] +#[inline] +pub fn duration_ms_to_duration(v: i64) -> Duration { + Duration::milliseconds(v) +} + +/// converts a `i64` representing a `duration(us)` to [`Duration`] +#[inline] +pub fn duration_us_to_duration(v: i64) -> Duration { + Duration::microseconds(v) +} + +/// converts a `i64` representing a `duration(ns)` to [`Duration`] +#[inline] +pub fn duration_ns_to_duration(v: i64) -> Duration { + Duration::nanoseconds(v) +} + +/// converts a `i32` representing a `time32(ms)` to [`NaiveTime`] +#[inline] +pub fn time32ms_to_time(v: i32) -> NaiveTime { + let v = v as i64; + let seconds = v / MILLISECONDS; + + let milli_to_nano = 1_000_000; + let nano = (v - seconds * MILLISECONDS) * milli_to_nano; + NaiveTime::from_num_seconds_from_midnight_opt(seconds as u32, nano as u32) + .expect("invalid time") +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveTime`] +#[inline] +pub fn time64us_to_time(v: i64) -> NaiveTime { + time64us_to_time_opt(v).expect("invalid time") +} + +/// converts a `i64` representing a `time64(us)` to [`NaiveTime`] +#[inline] +pub fn time64us_to_time_opt(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from microseconds + (v / MICROSECONDS) as u32, + // discard extracted seconds and convert microseconds to + // nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveTime`] +#[inline] +pub fn time64ns_to_time(v: i64) -> NaiveTime { + time64ns_to_time_opt(v).expect("invalid time") +} + +/// converts a `i64` representing a `time64(ns)` to [`NaiveTime`] +#[inline] +pub fn time64ns_to_time_opt(v: i64) -> Option { + NaiveTime::from_num_seconds_from_midnight_opt( + // extract seconds from nanoseconds + (v / NANOSECONDS) as u32, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime(seconds: i64) -> NaiveDateTime { + timestamp_s_to_datetime_opt(seconds).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(s)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_s_to_datetime_opt(seconds: i64) -> Option { + NaiveDateTime::from_timestamp_opt(seconds, 0) +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime(v: i64) -> NaiveDateTime { + timestamp_ms_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(ms)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ms_to_datetime_opt(v: i64) -> Option { + if v >= 0 { + NaiveDateTime::from_timestamp_opt( + // extract seconds from milliseconds + v / MILLISECONDS, + // discard extracted seconds and convert milliseconds to nanoseconds + (v % MILLISECONDS * MICROSECONDS) as u32, + ) + } else { + let secs_rem = (v / MILLISECONDS, v % MILLISECONDS); + if secs_rem.1 == 0 { + // whole/integer seconds; no adjustment required + NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) + } else { + // negative values with fractional seconds require 'div_floor' rounding behaviour. + // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) + NaiveDateTime::from_timestamp_opt( + secs_rem.0 - 1, + (NANOSECONDS + (v % MILLISECONDS * MICROSECONDS)) as u32, + ) + } + } +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime(v: i64) -> NaiveDateTime { + timestamp_us_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(us)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_us_to_datetime_opt(v: i64) -> Option { + if v >= 0 { + NaiveDateTime::from_timestamp_opt( + // extract seconds from microseconds + v / MICROSECONDS, + // discard extracted seconds and convert microseconds to nanoseconds + (v % MICROSECONDS * MILLISECONDS) as u32, + ) + } else { + let secs_rem = (v / MICROSECONDS, v % MICROSECONDS); + if secs_rem.1 == 0 { + // whole/integer seconds; no adjustment required + NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) + } else { + // negative values with fractional seconds require 'div_floor' rounding behaviour. + // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) + NaiveDateTime::from_timestamp_opt( + secs_rem.0 - 1, + (NANOSECONDS + (v % MICROSECONDS * MILLISECONDS)) as u32, + ) + } + } +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime(v: i64) -> NaiveDateTime { + timestamp_ns_to_datetime_opt(v).expect("invalid or out-of-range datetime") +} + +/// converts a `i64` representing a `timestamp(ns)` to [`NaiveDateTime`] +#[inline] +pub fn timestamp_ns_to_datetime_opt(v: i64) -> Option { + if v >= 0 { + NaiveDateTime::from_timestamp_opt( + // extract seconds from nanoseconds + v / NANOSECONDS, + // discard extracted seconds + (v % NANOSECONDS) as u32, + ) + } else { + let secs_rem = (v / NANOSECONDS, v % NANOSECONDS); + if secs_rem.1 == 0 { + // whole/integer seconds; no adjustment required + NaiveDateTime::from_timestamp_opt(secs_rem.0, 0) + } else { + // negative values with fractional seconds require 'div_floor' rounding behaviour. + // (which isn't yet stabilised: https://github.com/rust-lang/rust/issues/88581) + NaiveDateTime::from_timestamp_opt( + secs_rem.0 - 1, + (NANOSECONDS + (v % NANOSECONDS)) as u32, + ) + } + } +} + +/// Converts a timestamp in `time_unit` and `timezone` into [`chrono::DateTime`]. +#[inline] +pub fn timestamp_to_naive_datetime(timestamp: i64, time_unit: TimeUnit) -> chrono::NaiveDateTime { + match time_unit { + TimeUnit::Second => timestamp_s_to_datetime(timestamp), + TimeUnit::Millisecond => timestamp_ms_to_datetime(timestamp), + TimeUnit::Microsecond => timestamp_us_to_datetime(timestamp), + TimeUnit::Nanosecond => timestamp_ns_to_datetime(timestamp), + } +} + +/// Converts a timestamp in `time_unit` and `timezone` into [`chrono::DateTime`]. +#[inline] +pub fn timestamp_to_datetime( + timestamp: i64, + time_unit: TimeUnit, + timezone: &T, +) -> chrono::DateTime { + timezone.from_utc_datetime(×tamp_to_naive_datetime(timestamp, time_unit)) +} + +/// Calculates the scale factor between two TimeUnits. The function returns the +/// scale that should multiply the TimeUnit "b" to have the same time scale as +/// the TimeUnit "a". +pub fn timeunit_scale(a: TimeUnit, b: TimeUnit) -> f64 { + match (a, b) { + (TimeUnit::Second, TimeUnit::Second) => 1.0, + (TimeUnit::Second, TimeUnit::Millisecond) => 0.001, + (TimeUnit::Second, TimeUnit::Microsecond) => 0.000_001, + (TimeUnit::Second, TimeUnit::Nanosecond) => 0.000_000_001, + (TimeUnit::Millisecond, TimeUnit::Second) => 1_000.0, + (TimeUnit::Millisecond, TimeUnit::Millisecond) => 1.0, + (TimeUnit::Millisecond, TimeUnit::Microsecond) => 0.001, + (TimeUnit::Millisecond, TimeUnit::Nanosecond) => 0.000_001, + (TimeUnit::Microsecond, TimeUnit::Second) => 1_000_000.0, + (TimeUnit::Microsecond, TimeUnit::Millisecond) => 1_000.0, + (TimeUnit::Microsecond, TimeUnit::Microsecond) => 1.0, + (TimeUnit::Microsecond, TimeUnit::Nanosecond) => 0.001, + (TimeUnit::Nanosecond, TimeUnit::Second) => 1_000_000_000.0, + (TimeUnit::Nanosecond, TimeUnit::Millisecond) => 1_000_000.0, + (TimeUnit::Nanosecond, TimeUnit::Microsecond) => 1_000.0, + (TimeUnit::Nanosecond, TimeUnit::Nanosecond) => 1.0, + } +} + +/// Parses an offset of the form `"+WX:YZ"` or `"UTC"` into [`FixedOffset`]. +/// # Errors +/// If the offset is not in any of the allowed forms. +pub fn parse_offset(offset: &str) -> Result { + if offset == "UTC" { + return Ok(FixedOffset::east_opt(0).expect("FixedOffset::east out of bounds")); + } + let error = "timezone offset must be of the form [-]00:00"; + + let mut a = offset.split(':'); + let first = a + .next() + .map(Ok) + .unwrap_or_else(|| Err(Error::InvalidArgumentError(error.to_string())))?; + let last = a + .next() + .map(Ok) + .unwrap_or_else(|| Err(Error::InvalidArgumentError(error.to_string())))?; + let hours: i32 = first + .parse() + .map_err(|_| Error::InvalidArgumentError(error.to_string()))?; + let minutes: i32 = last + .parse() + .map_err(|_| Error::InvalidArgumentError(error.to_string()))?; + + Ok(FixedOffset::east_opt(hours * 60 * 60 + minutes * 60) + .expect("FixedOffset::east out of bounds")) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. +/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). +#[inline] +pub fn utf8_to_timestamp_ns_scalar( + value: &str, + fmt: &str, + tz: &T, +) -> Option { + utf8_to_timestamp_scalar(value, fmt, tz, &TimeUnit::Nanosecond) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp with timezone. +/// `tz` must be built from `timezone` (either via [`parse_offset`] or `chrono-tz`). +/// Returns in scale `tz` of `TimeUnit`. +#[inline] +pub fn utf8_to_timestamp_scalar( + value: &str, + fmt: &str, + tz: &T, + tu: &TimeUnit, +) -> Option { + let mut parsed = Parsed::new(); + let fmt = StrftimeItems::new(fmt); + let r = parse(&mut parsed, value, fmt).ok(); + if r.is_some() { + parsed + .to_datetime() + .map(|x| x.naive_utc()) + .map(|x| tz.from_utc_datetime(&x)) + .map(|x| match tu { + TimeUnit::Second => x.timestamp(), + TimeUnit::Millisecond => x.timestamp_millis(), + TimeUnit::Microsecond => x.timestamp_micros(), + TimeUnit::Nanosecond => x.timestamp_nanos_opt().unwrap(), + }) + .ok() + } else { + None + } +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. +#[inline] +pub fn utf8_to_naive_timestamp_ns_scalar(value: &str, fmt: &str) -> Option { + utf8_to_naive_timestamp_scalar(value, fmt, &TimeUnit::Nanosecond) +} + +/// Parses `value` to `Option` consistent with the Arrow's definition of timestamp without timezone. +/// Returns in scale `tz` of `TimeUnit`. +#[inline] +pub fn utf8_to_naive_timestamp_scalar(value: &str, fmt: &str, tu: &TimeUnit) -> Option { + let fmt = StrftimeItems::new(fmt); + let mut parsed = Parsed::new(); + parse(&mut parsed, value, fmt.clone()).ok(); + parsed + .to_naive_datetime_with_offset(0) + .map(|x| match tu { + TimeUnit::Second => x.timestamp(), + TimeUnit::Millisecond => x.timestamp_millis(), + TimeUnit::Microsecond => x.timestamp_micros(), + TimeUnit::Nanosecond => x.timestamp_nanos_opt().unwrap(), + }) + .ok() +} + +fn utf8_to_timestamp_ns_impl( + array: &Utf8Array, + fmt: &str, + timezone: String, + tz: T, +) -> PrimitiveArray { + let iter = array + .iter() + .map(|x| x.and_then(|x| utf8_to_timestamp_ns_scalar(x, fmt, &tz))); + + PrimitiveArray::from_trusted_len_iter(iter) + .to(DataType::Timestamp(TimeUnit::Nanosecond, Some(timezone))) +} + +/// Parses `value` to a [`chrono_tz::Tz`] with the Arrow's definition of timestamp with a timezone. +#[cfg(feature = "chrono-tz")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] +pub fn parse_offset_tz(timezone: &str) -> Result { + timezone.parse::().map_err(|_| { + Error::InvalidArgumentError(format!("timezone \"{timezone}\" cannot be parsed")) + }) +} + +#[cfg(feature = "chrono-tz")] +#[cfg_attr(docsrs, doc(cfg(feature = "chrono-tz")))] +fn chrono_tz_utf_to_timestamp_ns( + array: &Utf8Array, + fmt: &str, + timezone: String, +) -> Result> { + let tz = parse_offset_tz(&timezone)?; + Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) +} + +#[cfg(not(feature = "chrono-tz"))] +fn chrono_tz_utf_to_timestamp_ns( + _: &Utf8Array, + _: &str, + timezone: String, +) -> Result> { + Err(Error::InvalidArgumentError(format!( + "timezone \"{timezone}\" cannot be parsed (feature chrono-tz is not active)", + ))) +} + +/// Parses a [`Utf8Array`] to a timeozone-aware timestamp, i.e. [`PrimitiveArray`] with type `Timestamp(Nanosecond, Some(timezone))`. +/// # Implementation +/// * parsed values with timezone other than `timezone` are converted to `timezone`. +/// * parsed values without timezone are null. Use [`utf8_to_naive_timestamp_ns`] to parse naive timezones. +/// * Null elements remain null; non-parsable elements are null. +/// The feature `"chrono-tz"` enables IANA and zoneinfo formats for `timezone`. +/// # Error +/// This function errors iff `timezone` is not parsable to an offset. +pub fn utf8_to_timestamp_ns( + array: &Utf8Array, + fmt: &str, + timezone: String, +) -> Result> { + let tz = parse_offset(timezone.as_str()); + + if let Ok(tz) = tz { + Ok(utf8_to_timestamp_ns_impl(array, fmt, timezone, tz)) + } else { + chrono_tz_utf_to_timestamp_ns(array, fmt, timezone) + } +} + +/// Parses a [`Utf8Array`] to naive timestamp, i.e. +/// [`PrimitiveArray`] with type `Timestamp(Nanosecond, None)`. +/// Timezones are ignored. +/// Null elements remain null; non-parsable elements are set to null. +pub fn utf8_to_naive_timestamp_ns( + array: &Utf8Array, + fmt: &str, +) -> PrimitiveArray { + let iter = array + .iter() + .map(|x| x.and_then(|x| utf8_to_naive_timestamp_ns_scalar(x, fmt))); + + PrimitiveArray::from_trusted_len_iter(iter).to(DataType::Timestamp(TimeUnit::Nanosecond, None)) +} + +fn add_month(year: i32, month: u32, months: i32) -> chrono::NaiveDate { + let new_year = (year * 12 + (month - 1) as i32 + months) / 12; + let new_month = (year * 12 + (month - 1) as i32 + months) % 12 + 1; + chrono::NaiveDate::from_ymd_opt(new_year, new_month as u32, 1) + .expect("invalid or out-of-range date") +} + +fn get_days_between_months(year: i32, month: u32, months: i32) -> i64 { + add_month(year, month, months) + .signed_duration_since( + chrono::NaiveDate::from_ymd_opt(year, month, 1).expect("invalid or out-of-range date"), + ) + .num_days() +} + +/// Adds an `interval` to a `timestamp` in `time_unit` units without timezone. +#[inline] +pub fn add_naive_interval(timestamp: i64, time_unit: TimeUnit, interval: months_days_ns) -> i64 { + // convert seconds to a DateTime of a given offset. + let datetime = match time_unit { + TimeUnit::Second => timestamp_s_to_datetime(timestamp), + TimeUnit::Millisecond => timestamp_ms_to_datetime(timestamp), + TimeUnit::Microsecond => timestamp_us_to_datetime(timestamp), + TimeUnit::Nanosecond => timestamp_ns_to_datetime(timestamp), + }; + + // compute the number of days in the interval, which depends on the particular year and month (leap days) + let delta_days = get_days_between_months(datetime.year(), datetime.month(), interval.months()) + + interval.days() as i64; + + // add; no leap hours are considered + let new_datetime_tz = datetime + + chrono::Duration::nanoseconds(delta_days * 24 * 60 * 60 * 1_000_000_000 + interval.ns()); + + // convert back to the target unit + match time_unit { + TimeUnit::Second => new_datetime_tz.timestamp_millis() / 1000, + TimeUnit::Millisecond => new_datetime_tz.timestamp_millis(), + TimeUnit::Microsecond => new_datetime_tz.timestamp_nanos_opt().unwrap() / 1000, + TimeUnit::Nanosecond => new_datetime_tz.timestamp_nanos_opt().unwrap(), + } +} + +/// Adds an `interval` to a `timestamp` in `time_unit` units and timezone `timezone`. +#[inline] +pub fn add_interval( + timestamp: i64, + time_unit: TimeUnit, + interval: months_days_ns, + timezone: &T, +) -> i64 { + // convert seconds to a DateTime of a given offset. + let datetime_tz = timestamp_to_datetime(timestamp, time_unit, timezone); + + // compute the number of days in the interval, which depends on the particular year and month (leap days) + let delta_days = + get_days_between_months(datetime_tz.year(), datetime_tz.month(), interval.months()) + + interval.days() as i64; + + // add; tz will take care of leap hours + let new_datetime_tz = datetime_tz + + chrono::Duration::nanoseconds(delta_days * 24 * 60 * 60 * 1_000_000_000 + interval.ns()); + + // convert back to the target unit + match time_unit { + TimeUnit::Second => new_datetime_tz.timestamp_millis() / 1000, + TimeUnit::Millisecond => new_datetime_tz.timestamp_millis(), + TimeUnit::Microsecond => new_datetime_tz.timestamp_nanos_opt().unwrap() / 1000, + TimeUnit::Nanosecond => new_datetime_tz.timestamp_nanos_opt().unwrap(), + } +} diff --git a/crates/nano-arrow/src/trusted_len.rs b/crates/nano-arrow/src/trusted_len.rs new file mode 100644 index 000000000000..a1c38bd51c71 --- /dev/null +++ b/crates/nano-arrow/src/trusted_len.rs @@ -0,0 +1,57 @@ +//! Declares [`TrustedLen`]. +use std::slice::Iter; + +/// An iterator of known, fixed size. +/// A trait denoting Rusts' unstable [TrustedLen](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). +/// This is re-defined here and implemented for some iterators until `std::iter::TrustedLen` +/// is stabilized. +/// +/// # Safety +/// This trait must only be implemented when the contract is upheld. +/// Consumers of this trait must inspect Iterator::size_hint()’s upper bound. +pub unsafe trait TrustedLen: Iterator {} + +unsafe impl TrustedLen for Iter<'_, T> {} + +unsafe impl B> TrustedLen for std::iter::Map {} + +unsafe impl<'a, I, T: 'a> TrustedLen for std::iter::Copied +where + I: TrustedLen, + T: Copy, +{ +} +unsafe impl<'a, I, T: 'a> TrustedLen for std::iter::Cloned +where + I: TrustedLen, + T: Clone, +{ +} + +unsafe impl TrustedLen for std::iter::Enumerate where I: TrustedLen {} + +unsafe impl TrustedLen for std::iter::Zip +where + A: TrustedLen, + B: TrustedLen, +{ +} + +unsafe impl TrustedLen for std::slice::ChunksExact<'_, T> {} + +unsafe impl TrustedLen for std::slice::Windows<'_, T> {} + +unsafe impl TrustedLen for std::iter::Chain +where + A: TrustedLen, + B: TrustedLen, +{ +} + +unsafe impl TrustedLen for std::iter::Once {} + +unsafe impl TrustedLen for std::vec::IntoIter {} + +unsafe impl TrustedLen for std::iter::Repeat {} +unsafe impl A> TrustedLen for std::iter::RepeatWith {} +unsafe impl TrustedLen for std::iter::Take {} diff --git a/crates/nano-arrow/src/types/bit_chunk.rs b/crates/nano-arrow/src/types/bit_chunk.rs new file mode 100644 index 000000000000..ef4b25fd28a2 --- /dev/null +++ b/crates/nano-arrow/src/types/bit_chunk.rs @@ -0,0 +1,161 @@ +use std::fmt::Binary; +use std::ops::{BitAndAssign, Not, Shl, ShlAssign, ShrAssign}; + +use num_traits::PrimInt; + +use super::NativeType; + +/// A chunk of bits. This is used to create masks of a given length +/// whose width is `1` bit. In `portable_simd` notation, this corresponds to `m1xY`. +/// +/// This (sealed) trait is implemented for [`u8`], [`u16`], [`u32`] and [`u64`]. +pub trait BitChunk: + super::private::Sealed + + PrimInt + + NativeType + + Binary + + ShlAssign + + Not + + ShrAssign + + ShlAssign + + Shl + + BitAndAssign +{ + /// convert itself into bytes. + fn to_ne_bytes(self) -> Self::Bytes; + /// convert itself from bytes. + fn from_ne_bytes(v: Self::Bytes) -> Self; +} + +macro_rules! bit_chunk { + ($ty:ty) => { + impl BitChunk for $ty { + #[inline(always)] + fn to_ne_bytes(self) -> Self::Bytes { + self.to_ne_bytes() + } + + #[inline(always)] + fn from_ne_bytes(v: Self::Bytes) -> Self { + Self::from_ne_bytes(v) + } + } + }; +} + +bit_chunk!(u8); +bit_chunk!(u16); +bit_chunk!(u32); +bit_chunk!(u64); + +/// An [`Iterator`] over a [`BitChunk`]. This iterator is often +/// compiled to SIMD. +/// The [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit) corresponds +/// to the first slot, as defined by the arrow specification. +/// # Example +/// ``` +/// use arrow2::types::BitChunkIter; +/// let a = 0b00010000u8; +/// let iter = BitChunkIter::new(a, 7); +/// let r = iter.collect::>(); +/// assert_eq!(r, vec![false, false, false, false, true, false, false]); +/// ``` +pub struct BitChunkIter { + value: T, + mask: T, + remaining: usize, +} + +impl BitChunkIter { + /// Creates a new [`BitChunkIter`] with `len` bits. + #[inline] + pub fn new(value: T, len: usize) -> Self { + assert!(len <= std::mem::size_of::() * 8); + Self { + value, + remaining: len, + mask: T::one(), + } + } +} + +impl Iterator for BitChunkIter { + type Item = bool; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + }; + let result = Some(self.value & self.mask != T::zero()); + self.remaining -= 1; + self.mask <<= 1; + result + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +// # Safety +// a mathematical invariant of this iterator +unsafe impl crate::trusted_len::TrustedLen for BitChunkIter {} + +/// An [`Iterator`] over a [`BitChunk`] returning the index of each bit set in the chunk +/// See for details +/// # Example +/// ``` +/// use arrow2::types::BitChunkOnes; +/// let a = 0b00010000u8; +/// let iter = BitChunkOnes::new(a); +/// let r = iter.collect::>(); +/// assert_eq!(r, vec![4]); +/// ``` +pub struct BitChunkOnes { + value: T, + remaining: usize, +} + +impl BitChunkOnes { + /// Creates a new [`BitChunkOnes`] with `len` bits. + #[inline] + pub fn new(value: T) -> Self { + Self { + value, + remaining: value.count_ones() as usize, + } + } + + #[inline] + #[cfg(feature = "compute_filter")] + pub(crate) fn from_known_count(value: T, remaining: usize) -> Self { + Self { value, remaining } + } +} + +impl Iterator for BitChunkOnes { + type Item = usize; + + #[inline] + fn next(&mut self) -> Option { + if self.remaining == 0 { + return None; + } + let v = self.value.trailing_zeros() as usize; + self.value &= self.value - T::one(); + + self.remaining -= 1; + Some(v) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + (self.remaining, Some(self.remaining)) + } +} + +// # Safety +// a mathematical invariant of this iterator +unsafe impl crate::trusted_len::TrustedLen for BitChunkOnes {} diff --git a/crates/nano-arrow/src/types/index.rs b/crates/nano-arrow/src/types/index.rs new file mode 100644 index 000000000000..0aedea008fa3 --- /dev/null +++ b/crates/nano-arrow/src/types/index.rs @@ -0,0 +1,103 @@ +use std::convert::TryFrom; + +use super::NativeType; +use crate::trusted_len::TrustedLen; + +/// Sealed trait describing the subset of [`NativeType`] (`i32`, `i64`, `u32` and `u64`) +/// that can be used to index a slot of an array. +pub trait Index: + NativeType + + std::ops::AddAssign + + std::ops::Sub + + num_traits::One + + num_traits::Num + + num_traits::CheckedAdd + + PartialOrd + + Ord +{ + /// Convert itself to [`usize`]. + fn to_usize(&self) -> usize; + /// Convert itself from [`usize`]. + fn from_usize(index: usize) -> Option; + + /// Convert itself from [`usize`]. + fn from_as_usize(index: usize) -> Self; + + /// An iterator from (inclusive) `start` to (exclusive) `end`. + fn range(start: usize, end: usize) -> Option> { + let start = Self::from_usize(start); + let end = Self::from_usize(end); + match (start, end) { + (Some(start), Some(end)) => Some(IndexRange::new(start, end)), + _ => None, + } + } +} + +macro_rules! index { + ($t:ty) => { + impl Index for $t { + #[inline] + fn to_usize(&self) -> usize { + *self as usize + } + + #[inline] + fn from_usize(value: usize) -> Option { + Self::try_from(value).ok() + } + + #[inline] + fn from_as_usize(value: usize) -> Self { + value as $t + } + } + }; +} + +index!(i8); +index!(i16); +index!(i32); +index!(i64); +index!(u8); +index!(u16); +index!(u32); +index!(u64); + +/// Range of [`Index`], equivalent to `(a..b)`. +/// `Step` is unstable in Rust, which does not allow us to implement (a..b) for [`Index`]. +pub struct IndexRange { + start: I, + end: I, +} + +impl IndexRange { + /// Returns a new [`IndexRange`]. + pub fn new(start: I, end: I) -> Self { + assert!(end >= start); + Self { start, end } + } +} + +impl Iterator for IndexRange { + type Item = I; + + #[inline] + fn next(&mut self) -> Option { + if self.start == self.end { + return None; + } + let old = self.start; + self.start += I::one(); + Some(old) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = (self.end - self.start).to_usize(); + (len, Some(len)) + } +} + +/// Safety: a range is always of known length +unsafe impl TrustedLen for IndexRange {} diff --git a/crates/nano-arrow/src/types/mod.rs b/crates/nano-arrow/src/types/mod.rs new file mode 100644 index 000000000000..2ba57b4d784a --- /dev/null +++ b/crates/nano-arrow/src/types/mod.rs @@ -0,0 +1,89 @@ +//! Sealed traits and implementations to handle all _physical types_ used in this crate. +//! +//! Most physical types used in this crate are native Rust types, such as `i32`. +//! The trait [`NativeType`] describes the interfaces required by this crate to be conformant +//! with Arrow. +//! +//! Every implementation of [`NativeType`] has an associated variant in [`PrimitiveType`], +//! available via [`NativeType::PRIMITIVE`]. +//! Combined, these allow structs generic over [`NativeType`] to be trait objects downcastable +//! to concrete implementations based on the matched [`NativeType::PRIMITIVE`] variant. +//! +//! Another important trait in this module is [`Offset`], the subset of [`NativeType`] that can +//! be used in Arrow offsets (`i32` and `i64`). +//! +//! Another important trait in this module is [`BitChunk`], describing types that can be used to +//! represent chunks of bits (e.g. 8 bits via `u8`, 16 via `u16`), and [`BitChunkIter`], +//! that can be used to iterate over bitmaps in [`BitChunk`]s according to +//! Arrow's definition of bitmaps. +//! +//! Finally, this module contains traits used to compile code based on [`NativeType`] optimized +//! for SIMD, at [`mod@simd`]. + +mod bit_chunk; +pub use bit_chunk::{BitChunk, BitChunkIter, BitChunkOnes}; +mod index; +pub mod simd; +pub use index::*; +mod native; +pub use native::*; +mod offset; +pub use offset::*; +#[cfg(feature = "serde_types")] +use serde_derive::{Deserialize, Serialize}; + +/// The set of all implementations of the sealed trait [`NativeType`]. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde_types", derive(Serialize, Deserialize))] +pub enum PrimitiveType { + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// A signed 128-bit integer. + Int128, + /// A signed 256-bit integer. + Int256, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, + /// A 16-bit floating point number. + Float16, + /// A 32-bit floating point number. + Float32, + /// A 64-bit floating point number. + Float64, + /// Two i32 representing days and ms + DaysMs, + /// months_days_ns(i32, i32, i64) + MonthDayNano, +} + +mod private { + pub trait Sealed {} + + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for i128 {} + impl Sealed for super::i256 {} + impl Sealed for super::f16 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for super::days_ms {} + impl Sealed for super::months_days_ns {} +} diff --git a/crates/nano-arrow/src/types/native.rs b/crates/nano-arrow/src/types/native.rs new file mode 100644 index 000000000000..6e50a1454ead --- /dev/null +++ b/crates/nano-arrow/src/types/native.rs @@ -0,0 +1,639 @@ +use std::convert::TryFrom; +use std::ops::Neg; +use std::panic::RefUnwindSafe; + +use bytemuck::{Pod, Zeroable}; + +use super::PrimitiveType; + +/// Sealed trait implemented by all physical types that can be allocated, +/// serialized and deserialized by this crate. +/// All O(N) allocations in this crate are done for this trait alone. +pub trait NativeType: + super::private::Sealed + + Pod + + Send + + Sync + + Sized + + RefUnwindSafe + + std::fmt::Debug + + std::fmt::Display + + PartialEq + + Default +{ + /// The corresponding variant of [`PrimitiveType`]. + const PRIMITIVE: PrimitiveType; + + /// Type denoting its representation as bytes. + /// This is `[u8; N]` where `N = size_of::`. + type Bytes: AsRef<[u8]> + + std::ops::Index + + std::ops::IndexMut + + for<'a> TryFrom<&'a [u8]> + + std::fmt::Debug + + Default; + + /// To bytes in little endian + fn to_le_bytes(&self) -> Self::Bytes; + + /// To bytes in big endian + fn to_be_bytes(&self) -> Self::Bytes; + + /// From bytes in little endian + fn from_le_bytes(bytes: Self::Bytes) -> Self; + + /// From bytes in big endian + fn from_be_bytes(bytes: Self::Bytes) -> Self; +} + +macro_rules! native_type { + ($type:ty, $primitive_type:expr) => { + impl NativeType for $type { + const PRIMITIVE: PrimitiveType = $primitive_type; + + type Bytes = [u8; std::mem::size_of::()]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + Self::to_le_bytes(*self) + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + Self::to_be_bytes(*self) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self::from_le_bytes(bytes) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self::from_be_bytes(bytes) + } + } + }; +} + +native_type!(u8, PrimitiveType::UInt8); +native_type!(u16, PrimitiveType::UInt16); +native_type!(u32, PrimitiveType::UInt32); +native_type!(u64, PrimitiveType::UInt64); +native_type!(i8, PrimitiveType::Int8); +native_type!(i16, PrimitiveType::Int16); +native_type!(i32, PrimitiveType::Int32); +native_type!(i64, PrimitiveType::Int64); +native_type!(f32, PrimitiveType::Float32); +native_type!(f64, PrimitiveType::Float64); +native_type!(i128, PrimitiveType::Int128); + +/// The in-memory representation of the DayMillisecond variant of arrow's "Interval" logical type. +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct days_ms(pub i32, pub i32); + +impl days_ms { + /// A new [`days_ms`]. + #[inline] + pub fn new(days: i32, milliseconds: i32) -> Self { + Self(days, milliseconds) + } + + /// The number of days + #[inline] + pub fn days(&self) -> i32 { + self.0 + } + + /// The number of milliseconds + #[inline] + pub fn milliseconds(&self) -> i32 { + self.1 + } +} + +impl NativeType for days_ms { + const PRIMITIVE: PrimitiveType = PrimitiveType::DaysMs; + type Bytes = [u8; 8]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let days = self.0.to_le_bytes(); + let ms = self.1.to_le_bytes(); + let mut result = [0; 8]; + result[0] = days[0]; + result[1] = days[1]; + result[2] = days[2]; + result[3] = days[3]; + result[4] = ms[0]; + result[5] = ms[1]; + result[6] = ms[2]; + result[7] = ms[3]; + result + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let days = self.0.to_be_bytes(); + let ms = self.1.to_be_bytes(); + let mut result = [0; 8]; + result[0] = days[0]; + result[1] = days[1]; + result[2] = days[2]; + result[3] = days[3]; + result[4] = ms[0]; + result[5] = ms[1]; + result[6] = ms[2]; + result[7] = ms[3]; + result + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut days = [0; 4]; + days[0] = bytes[0]; + days[1] = bytes[1]; + days[2] = bytes[2]; + days[3] = bytes[3]; + let mut ms = [0; 4]; + ms[0] = bytes[4]; + ms[1] = bytes[5]; + ms[2] = bytes[6]; + ms[3] = bytes[7]; + Self(i32::from_le_bytes(days), i32::from_le_bytes(ms)) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let mut days = [0; 4]; + days[0] = bytes[0]; + days[1] = bytes[1]; + days[2] = bytes[2]; + days[3] = bytes[3]; + let mut ms = [0; 4]; + ms[0] = bytes[4]; + ms[1] = bytes[5]; + ms[2] = bytes[6]; + ms[3] = bytes[7]; + Self(i32::from_be_bytes(days), i32::from_be_bytes(ms)) + } +} + +/// The in-memory representation of the MonthDayNano variant of the "Interval" logical type. +#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, Hash, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct months_days_ns(pub i32, pub i32, pub i64); + +impl months_days_ns { + /// A new [`months_days_ns`]. + #[inline] + pub fn new(months: i32, days: i32, nanoseconds: i64) -> Self { + Self(months, days, nanoseconds) + } + + /// The number of months + #[inline] + pub fn months(&self) -> i32 { + self.0 + } + + /// The number of days + #[inline] + pub fn days(&self) -> i32 { + self.1 + } + + /// The number of nanoseconds + #[inline] + pub fn ns(&self) -> i64 { + self.2 + } +} + +impl NativeType for months_days_ns { + const PRIMITIVE: PrimitiveType = PrimitiveType::MonthDayNano; + type Bytes = [u8; 16]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let months = self.months().to_le_bytes(); + let days = self.days().to_le_bytes(); + let ns = self.ns().to_le_bytes(); + let mut result = [0; 16]; + result[0] = months[0]; + result[1] = months[1]; + result[2] = months[2]; + result[3] = months[3]; + result[4] = days[0]; + result[5] = days[1]; + result[6] = days[2]; + result[7] = days[3]; + (0..8).for_each(|i| { + result[8 + i] = ns[i]; + }); + result + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let months = self.months().to_be_bytes(); + let days = self.days().to_be_bytes(); + let ns = self.ns().to_be_bytes(); + let mut result = [0; 16]; + result[0] = months[0]; + result[1] = months[1]; + result[2] = months[2]; + result[3] = months[3]; + result[4] = days[0]; + result[5] = days[1]; + result[6] = days[2]; + result[7] = days[3]; + (0..8).for_each(|i| { + result[8 + i] = ns[i]; + }); + result + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let mut months = [0; 4]; + months[0] = bytes[0]; + months[1] = bytes[1]; + months[2] = bytes[2]; + months[3] = bytes[3]; + let mut days = [0; 4]; + days[0] = bytes[4]; + days[1] = bytes[5]; + days[2] = bytes[6]; + days[3] = bytes[7]; + let mut ns = [0; 8]; + (0..8).for_each(|i| { + ns[i] = bytes[8 + i]; + }); + Self( + i32::from_le_bytes(months), + i32::from_le_bytes(days), + i64::from_le_bytes(ns), + ) + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let mut months = [0; 4]; + months[0] = bytes[0]; + months[1] = bytes[1]; + months[2] = bytes[2]; + months[3] = bytes[3]; + let mut days = [0; 4]; + days[0] = bytes[4]; + days[1] = bytes[5]; + days[2] = bytes[6]; + days[3] = bytes[7]; + let mut ns = [0; 8]; + (0..8).for_each(|i| { + ns[i] = bytes[8 + i]; + }); + Self( + i32::from_be_bytes(months), + i32::from_be_bytes(days), + i64::from_be_bytes(ns), + ) + } +} + +impl std::fmt::Display for days_ms { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}d {}ms", self.days(), self.milliseconds()) + } +} + +impl std::fmt::Display for months_days_ns { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}m {}d {}ns", self.months(), self.days(), self.ns()) + } +} + +impl Neg for days_ms { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self::new(-self.days(), -self.milliseconds()) + } +} + +impl Neg for months_days_ns { + type Output = Self; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self::new(-self.months(), -self.days(), -self.ns()) + } +} + +/// Type representation of the Float16 physical type +#[derive(Copy, Clone, Default, Zeroable, Pod)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct f16(pub u16); + +impl PartialEq for f16 { + #[inline] + fn eq(&self, other: &f16) -> bool { + if self.is_nan() || other.is_nan() { + false + } else { + (self.0 == other.0) || ((self.0 | other.0) & 0x7FFFu16 == 0) + } + } +} + +// see https://github.com/starkat99/half-rs/blob/main/src/binary16.rs +impl f16 { + /// The difference between 1.0 and the next largest representable number. + pub const EPSILON: f16 = f16(0x1400u16); + + #[inline] + #[must_use] + pub(crate) const fn is_nan(self) -> bool { + self.0 & 0x7FFFu16 > 0x7C00u16 + } + + /// Casts from u16. + #[inline] + pub const fn from_bits(bits: u16) -> f16 { + f16(bits) + } + + /// Casts to u16. + #[inline] + pub const fn to_bits(self) -> u16 { + self.0 + } + + /// Casts this `f16` to `f32` + pub fn to_f32(self) -> f32 { + let i = self.0; + // Check for signed zero + if i & 0x7FFFu16 == 0 { + return f32::from_bits((i as u32) << 16); + } + + let half_sign = (i & 0x8000u16) as u32; + let half_exp = (i & 0x7C00u16) as u32; + let half_man = (i & 0x03FFu16) as u32; + + // Check for an infinity or NaN when all exponent bits set + if half_exp == 0x7C00u32 { + // Check for signed infinity if mantissa is zero + if half_man == 0 { + let number = (half_sign << 16) | 0x7F80_0000u32; + return f32::from_bits(number); + } else { + // NaN, keep current mantissa but also set most significiant mantissa bit + let number = (half_sign << 16) | 0x7FC0_0000u32 | (half_man << 13); + return f32::from_bits(number); + } + } + + // Calculate single-precision components with adjusted exponent + let sign = half_sign << 16; + // Unbias exponent + let unbiased_exp = ((half_exp as i32) >> 10) - 15; + + // Check for subnormals, which will be normalized by adjusting exponent + if half_exp == 0 { + // Calculate how much to adjust the exponent by + let e = (half_man as u16).leading_zeros() - 6; + + // Rebias and adjust exponent + let exp = (127 - 15 - e) << 23; + let man = (half_man << (14 + e)) & 0x7F_FF_FFu32; + return f32::from_bits(sign | exp | man); + } + + // Rebias exponent for a normalized normal + let exp = ((unbiased_exp + 127) as u32) << 23; + let man = (half_man & 0x03FFu32) << 13; + f32::from_bits(sign | exp | man) + } + + /// Casts an `f32` into `f16` + pub fn from_f32(value: f32) -> Self { + let x: u32 = value.to_bits(); + + // Extract IEEE754 components + let sign = x & 0x8000_0000u32; + let exp = x & 0x7F80_0000u32; + let man = x & 0x007F_FFFFu32; + + // Check for all exponent bits being set, which is Infinity or NaN + if exp == 0x7F80_0000u32 { + // Set mantissa MSB for NaN (and also keep shifted mantissa bits) + let nan_bit = if man == 0 { 0 } else { 0x0200u32 }; + return f16(((sign >> 16) | 0x7C00u32 | nan_bit | (man >> 13)) as u16); + } + + // The number is normalized, start assembling half precision version + let half_sign = sign >> 16; + // Unbias the exponent, then bias for half precision + let unbiased_exp = ((exp >> 23) as i32) - 127; + let half_exp = unbiased_exp + 15; + + // Check for exponent overflow, return +infinity + if half_exp >= 0x1F { + return f16((half_sign | 0x7C00u32) as u16); + } + + // Check for underflow + if half_exp <= 0 { + // Check mantissa for what we can do + if 14 - half_exp > 24 { + // No rounding possibility, so this is a full underflow, return signed zero + return f16(half_sign as u16); + } + // Don't forget about hidden leading mantissa bit when assembling mantissa + let man = man | 0x0080_0000u32; + let mut half_man = man >> (14 - half_exp); + // Check for rounding (see comment above functions) + let round_bit = 1 << (13 - half_exp); + if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { + half_man += 1; + } + // No exponent for subnormals + return f16((half_sign | half_man) as u16); + } + + // Rebias the exponent + let half_exp = (half_exp as u32) << 10; + let half_man = man >> 13; + // Check for rounding (see comment above functions) + let round_bit = 0x0000_1000u32; + if (man & round_bit) != 0 && (man & (3 * round_bit - 1)) != 0 { + // Round it + f16(((half_sign | half_exp | half_man) + 1) as u16) + } else { + f16((half_sign | half_exp | half_man) as u16) + } + } +} + +impl std::fmt::Debug for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.to_f32()) + } +} + +impl std::fmt::Display for f16 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl NativeType for f16 { + const PRIMITIVE: PrimitiveType = PrimitiveType::Float16; + type Bytes = [u8; 2]; + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + self.0.to_le_bytes() + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + self.0.to_be_bytes() + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_be_bytes(bytes)) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + Self(u16::from_le_bytes(bytes)) + } +} + +/// Physical representation of a decimal +#[derive(Clone, Copy, Default, Eq, Hash, PartialEq, PartialOrd, Ord)] +#[allow(non_camel_case_types)] +#[repr(C)] +pub struct i256(pub ethnum::I256); + +impl i256 { + /// Returns a new [`i256`] from two `i128`. + pub fn from_words(hi: i128, lo: i128) -> Self { + Self(ethnum::I256::from_words(hi, lo)) + } +} + +impl Neg for i256 { + type Output = Self; + + #[inline] + fn neg(self) -> Self::Output { + let (a, b) = self.0.into_words(); + Self(ethnum::I256::from_words(-a, b)) + } +} + +impl std::fmt::Debug for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl std::fmt::Display for i256 { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +unsafe impl Pod for i256 {} +unsafe impl Zeroable for i256 {} + +impl NativeType for i256 { + const PRIMITIVE: PrimitiveType = PrimitiveType::Int256; + + type Bytes = [u8; 32]; + + #[inline] + fn to_le_bytes(&self) -> Self::Bytes { + let mut bytes = [0u8; 32]; + let (a, b) = self.0.into_words(); + let a = a.to_le_bytes(); + (0..16).for_each(|i| { + bytes[i] = a[i]; + }); + + let b = b.to_le_bytes(); + (0..16).for_each(|i| { + bytes[i + 16] = b[i]; + }); + + bytes + } + + #[inline] + fn to_be_bytes(&self) -> Self::Bytes { + let mut bytes = [0u8; 32]; + let (a, b) = self.0.into_words(); + + let a = a.to_be_bytes(); + (0..16).for_each(|i| { + bytes[i] = a[i]; + }); + + let b = b.to_be_bytes(); + (0..16).for_each(|i| { + bytes[i + 16] = b[i]; + }); + + bytes + } + + #[inline] + fn from_be_bytes(bytes: Self::Bytes) -> Self { + let (a, b) = bytes.split_at(16); + let a: [u8; 16] = a.try_into().unwrap(); + let b: [u8; 16] = b.try_into().unwrap(); + let a = i128::from_be_bytes(a); + let b = i128::from_be_bytes(b); + Self(ethnum::I256::from_words(a, b)) + } + + #[inline] + fn from_le_bytes(bytes: Self::Bytes) -> Self { + let (b, a) = bytes.split_at(16); + let a: [u8; 16] = a.try_into().unwrap(); + let b: [u8; 16] = b.try_into().unwrap(); + let a = i128::from_le_bytes(a); + let b = i128::from_le_bytes(b); + Self(ethnum::I256::from_words(a, b)) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_f16_to_f32() { + let f = f16::from_f32(7.0); + assert_eq!(f.to_f32(), 7.0f32); + + // 7.1 is NOT exactly representable in 16-bit, it's rounded + let f = f16::from_f32(7.1); + let diff = (f.to_f32() - 7.1f32).abs(); + // diff must be <= 4 * EPSILON, as 7 has two more significant bits than 1 + assert!(diff <= 4.0 * f16::EPSILON.to_f32()); + + assert_eq!(f16(0x0000_0001).to_f32(), 2.0f32.powi(-24)); + assert_eq!(f16(0x0000_0005).to_f32(), 5.0 * 2.0f32.powi(-24)); + + assert_eq!(f16(0x0000_0001), f16::from_f32(2.0f32.powi(-24))); + assert_eq!(f16(0x0000_0005), f16::from_f32(5.0 * 2.0f32.powi(-24))); + + assert_eq!(format!("{}", f16::from_f32(7.0)), "7".to_string()); + assert_eq!(format!("{:?}", f16::from_f32(7.0)), "7.0".to_string()); + } +} diff --git a/crates/nano-arrow/src/types/offset.rs b/crates/nano-arrow/src/types/offset.rs new file mode 100644 index 000000000000..e68bb7ceb6bd --- /dev/null +++ b/crates/nano-arrow/src/types/offset.rs @@ -0,0 +1,16 @@ +use super::Index; + +/// Sealed trait describing the subset (`i32` and `i64`) of [`Index`] that can be used +/// as offsets of variable-length Arrow arrays. +pub trait Offset: super::private::Sealed + Index { + /// Whether it is `i32` (false) or `i64` (true). + const IS_LARGE: bool; +} + +impl Offset for i32 { + const IS_LARGE: bool = false; +} + +impl Offset for i64 { + const IS_LARGE: bool = true; +} diff --git a/crates/nano-arrow/src/types/simd/mod.rs b/crates/nano-arrow/src/types/simd/mod.rs new file mode 100644 index 000000000000..d906c9d25e95 --- /dev/null +++ b/crates/nano-arrow/src/types/simd/mod.rs @@ -0,0 +1,167 @@ +//! Contains traits and implementations of multi-data used in SIMD. +//! The actual representation is driven by the feature flag `"simd"`, which, if set, +//! uses [`std::simd`]. +use super::{days_ms, f16, i256, months_days_ns, BitChunk, BitChunkIter, NativeType}; + +/// Describes the ability to convert itself from a [`BitChunk`]. +pub trait FromMaskChunk { + /// Convert itself from a slice. + fn from_chunk(v: T) -> Self; +} + +/// A struct that lends itself well to be compiled leveraging SIMD +/// # Safety +/// The `NativeType` and the `NativeSimd` must have possible a matching alignment. +/// e.g. slicing `&[NativeType]` by `align_of()` must be properly aligned/safe. +pub unsafe trait NativeSimd: Sized + Default + Copy { + /// Number of lanes + const LANES: usize; + /// The [`NativeType`] of this struct. E.g. `f32` for a `NativeSimd = f32x16`. + type Native: NativeType; + /// The type holding bits for masks. + type Chunk: BitChunk; + /// Type used for masking. + type Mask: FromMaskChunk; + + /// Sets values to `default` based on `mask`. + fn select(self, mask: Self::Mask, default: Self) -> Self; + + /// Convert itself from a slice. + /// # Panics + /// * iff `v.len()` != `T::LANES` + fn from_chunk(v: &[Self::Native]) -> Self; + + /// creates a new Self from `v` by populating items from `v` up to its length. + /// Items from `v` at positions larger than the number of lanes are ignored; + /// remaining items are populated with `remaining`. + fn from_incomplete_chunk(v: &[Self::Native], remaining: Self::Native) -> Self; + + /// Returns a tuple of 3 items whose middle item is itself, and the remaining + /// are the head and tail of the un-aligned parts. + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]); +} + +/// Trait implemented by some [`NativeType`] that have a SIMD representation. +pub trait Simd: NativeType { + /// The SIMD type associated with this trait. + /// This type supports SIMD operations + type Simd: NativeSimd; +} + +#[cfg(not(feature = "simd"))] +mod native; +#[cfg(not(feature = "simd"))] +pub use native::*; +#[cfg(feature = "simd")] +mod packed; +#[cfg(feature = "simd")] +pub use packed::*; + +macro_rules! native_simd { + ($name:tt, $type:ty, $lanes:expr, $mask:ty) => { + /// Multi-Data correspondence of the native type + #[allow(non_camel_case_types)] + #[derive(Copy, Clone)] + pub struct $name(pub [$type; $lanes]); + + unsafe impl NativeSimd for $name { + const LANES: usize = $lanes; + type Native = $type; + type Chunk = $mask; + type Mask = $mask; + + #[inline] + fn select(self, mask: $mask, default: Self) -> Self { + let mut reduced = default; + let iter = BitChunkIter::new(mask, Self::LANES); + for (i, b) in (0..Self::LANES).zip(iter) { + reduced[i] = if b { self[i] } else { reduced[i] }; + } + reduced + } + + #[inline] + fn from_chunk(v: &[$type]) -> Self { + ($name)(v.try_into().unwrap()) + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; $lanes]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + Self(a) + } + + #[inline] + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]) { + unsafe { values.align_to::() } + } + } + + impl std::ops::Index for $name { + type Output = $type; + + #[inline] + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } + } + + impl std::ops::IndexMut for $name { + #[inline] + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } + } + + impl Default for $name { + #[inline] + fn default() -> Self { + ($name)([<$type>::default(); $lanes]) + } + } + }; +} + +pub(super) use native_simd; + +// Types do not have specific intrinsics and thus SIMD can't be specialized. +// Therefore, we can declare their MD representation as `[$t; 8]` irrespectively +// of how they are represented in the different channels. +native_simd!(f16x32, f16, 32, u32); +native_simd!(days_msx8, days_ms, 8, u8); +native_simd!(months_days_nsx8, months_days_ns, 8, u8); +native_simd!(i128x8, i128, 8, u8); +native_simd!(i256x8, i256, 8, u8); + +// In the native implementation, a mask is 1 bit wide, as per AVX512. +impl FromMaskChunk for T { + #[inline] + fn from_chunk(v: T) -> Self { + v + } +} + +macro_rules! native { + ($type:ty, $simd:ty) => { + impl Simd for $type { + type Simd = $simd; + } + }; +} + +native!(u8, u8x64); +native!(u16, u16x32); +native!(u32, u32x16); +native!(u64, u64x8); +native!(i8, i8x64); +native!(i16, i16x32); +native!(i32, i32x16); +native!(i64, i64x8); +native!(f16, f16x32); +native!(f32, f32x16); +native!(f64, f64x8); +native!(i128, i128x8); +native!(i256, i256x8); +native!(days_ms, days_msx8); +native!(months_days_ns, months_days_nsx8); diff --git a/crates/nano-arrow/src/types/simd/native.rs b/crates/nano-arrow/src/types/simd/native.rs new file mode 100644 index 000000000000..af31b8b26bc0 --- /dev/null +++ b/crates/nano-arrow/src/types/simd/native.rs @@ -0,0 +1,16 @@ +use std::convert::TryInto; + +use super::*; +use crate::types::BitChunkIter; + +native_simd!(u8x64, u8, 64, u64); +native_simd!(u16x32, u16, 32, u32); +native_simd!(u32x16, u32, 16, u16); +native_simd!(u64x8, u64, 8, u8); +native_simd!(i8x64, i8, 64, u64); +native_simd!(i16x32, i16, 32, u32); +native_simd!(i32x16, i32, 16, u16); +native_simd!(i64x8, i64, 8, u8); +native_simd!(f16x32, f16, 32, u32); +native_simd!(f32x16, f32, 16, u16); +native_simd!(f64x8, f64, 8, u8); diff --git a/crates/nano-arrow/src/types/simd/packed.rs b/crates/nano-arrow/src/types/simd/packed.rs new file mode 100644 index 000000000000..0d95b68882aa --- /dev/null +++ b/crates/nano-arrow/src/types/simd/packed.rs @@ -0,0 +1,197 @@ +pub use std::simd::{ + f32x16, f32x8, f64x8, i16x32, i16x8, i32x16, i32x8, i64x8, i8x64, i8x8, mask32x16 as m32x16, + mask64x8 as m64x8, mask8x64 as m8x64, u16x32, u16x8, u32x16, u32x8, u64x8, u8x64, u8x8, + SimdPartialEq, +}; + +/// Vector of 32 16-bit masks +#[allow(non_camel_case_types)] +pub type m16x32 = std::simd::Mask; + +use super::*; + +macro_rules! simd { + ($name:tt, $type:ty, $lanes:expr, $chunk:ty, $mask:tt) => { + unsafe impl NativeSimd for $name { + const LANES: usize = $lanes; + type Native = $type; + type Chunk = $chunk; + type Mask = $mask; + + #[inline] + fn select(self, mask: $mask, default: Self) -> Self { + mask.select(self, default) + } + + #[inline] + fn from_chunk(v: &[$type]) -> Self { + <$name>::from_slice(v) + } + + #[inline] + fn from_incomplete_chunk(v: &[$type], remaining: $type) -> Self { + let mut a = [remaining; $lanes]; + a.iter_mut().zip(v.iter()).for_each(|(a, b)| *a = *b); + <$name>::from_chunk(a.as_ref()) + } + + #[inline] + fn align(values: &[Self::Native]) -> (&[Self::Native], &[Self], &[Self::Native]) { + unsafe { values.align_to::() } + } + } + }; +} + +simd!(u8x64, u8, 64, u64, m8x64); +simd!(u16x32, u16, 32, u32, m16x32); +simd!(u32x16, u32, 16, u16, m32x16); +simd!(u64x8, u64, 8, u8, m64x8); +simd!(i8x64, i8, 64, u64, m8x64); +simd!(i16x32, i16, 32, u32, m16x32); +simd!(i32x16, i32, 16, u16, m32x16); +simd!(i64x8, i64, 8, u8, m64x8); +simd!(f32x16, f32, 16, u16, m32x16); +simd!(f64x8, f64, 8, u8, m64x8); + +macro_rules! chunk_macro { + ($type:ty, $chunk:ty, $simd:ty, $mask:tt, $m:expr) => { + impl FromMaskChunk<$chunk> for $mask { + #[inline] + fn from_chunk(chunk: $chunk) -> Self { + ($m)(chunk) + } + } + }; +} + +chunk_macro!(u8, u64, u8x64, m8x64, from_chunk_u64); +chunk_macro!(u16, u32, u16x32, m16x32, from_chunk_u32); +chunk_macro!(u32, u16, u32x16, m32x16, from_chunk_u16); +chunk_macro!(u64, u8, u64x8, m64x8, from_chunk_u8); + +#[inline] +fn from_chunk_u8(chunk: u8) -> m64x8 { + let idx = u64x8::from_array([1, 2, 4, 8, 16, 32, 64, 128]); + let vecmask = u64x8::splat(chunk as u64); + + (idx & vecmask).simd_eq(idx) +} + +#[inline] +fn from_chunk_u16(chunk: u16) -> m32x16 { + let idx = u32x16::from_array([ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, + ]); + let vecmask = u32x16::splat(chunk as u32); + + (idx & vecmask).simd_eq(idx) +} + +#[inline] +fn from_chunk_u32(chunk: u32) -> m16x32 { + let idx = u16x32::from_array([ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 1, 2, 4, 8, + 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, + ]); + let left = u16x32::from_chunk(&[ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ]); + let right = u16x32::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, + 1024, 2048, 4096, 8192, 16384, 32768, + ]); + + let a = chunk.to_ne_bytes(); + let a1 = u16::from_ne_bytes([a[2], a[3]]); + let a2 = u16::from_ne_bytes([a[0], a[1]]); + + let vecmask1 = u16x32::splat(a1); + let vecmask2 = u16x32::splat(a2); + + (idx & left & vecmask1).simd_eq(idx) | (idx & right & vecmask2).simd_eq(idx) +} + +#[inline] +fn from_chunk_u64(chunk: u64) -> m8x64 { + let idx = u8x64::from_array([ + 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, + 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, + 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128, + ]); + let idxs = [ + u8x64::from_chunk(&[ + 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, + 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128, + 0, 0, 0, 0, 0, 0, 0, 0, + ]), + u8x64::from_chunk(&[ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, + 4, 8, 16, 32, 64, 128, + ]), + ]; + + let a = chunk.to_ne_bytes(); + + let mut result = m8x64::default(); + for i in 0..8 { + result |= (idxs[i] & u8x64::splat(a[i])).simd_eq(idx) + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic1() { + let a = 0b00000001000000010000000100000001u32; + let a = from_chunk_u32(a); + for i in 0..32 { + assert_eq!(a.test(i), i % 8 == 0) + } + } + + #[test] + fn test_basic2() { + let a = 0b0000000100000001000000010000000100000001000000010000000100000001u64; + let a = from_chunk_u64(a); + for i in 0..64 { + assert_eq!(a.test(i), i % 8 == 0) + } + } +} diff --git a/crates/nano-arrow/src/util/bench_util.rs b/crates/nano-arrow/src/util/bench_util.rs new file mode 100644 index 000000000000..59fb88b198fc --- /dev/null +++ b/crates/nano-arrow/src/util/bench_util.rs @@ -0,0 +1,99 @@ +//! Utilities for benchmarking + +use rand::distributions::{Alphanumeric, Distribution, Standard}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; + +use crate::array::*; +use crate::offset::Offset; +use crate::types::NativeType; + +/// Returns fixed seedable RNG +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +/// Creates an random (but fixed-seeded) array of a given size and null density +pub fn create_primitive_array(size: usize, null_density: f32) -> PrimitiveArray +where + T: NativeType, + Standard: Distribution, +{ + let mut rng = seedable_rng(); + + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.gen()) + } + }) + .collect::>() +} + +/// Creates a new [`PrimitiveArray`] from random values with a pre-set seed. +pub fn create_primitive_array_with_seed( + size: usize, + null_density: f32, + seed: u64, +) -> PrimitiveArray +where + T: NativeType, + Standard: Distribution, +{ + let mut rng = StdRng::seed_from_u64(seed); + + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.gen()) + } + }) + .collect::>() +} + +/// Creates an random (but fixed-seeded) array of a given size and null density +pub fn create_boolean_array(size: usize, null_density: f32, true_density: f32) -> BooleanArray +where + Standard: Distribution, +{ + let mut rng = seedable_rng(); + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + let value = rng.gen::() < true_density; + Some(value) + } + }) + .collect() +} + +/// Creates an random (but fixed-seeded) [`Utf8Array`] of a given length, number of characters and null density. +pub fn create_string_array( + length: usize, + size: usize, + null_density: f32, + seed: u64, +) -> Utf8Array { + let mut rng = StdRng::seed_from_u64(seed); + + (0..length) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + let value = (&mut rng) + .sample_iter(&Alphanumeric) + .take(size) + .map(char::from) + .collect::(); + Some(value) + } + }) + .collect() +} diff --git a/crates/nano-arrow/src/util/lexical.rs b/crates/nano-arrow/src/util/lexical.rs new file mode 100644 index 000000000000..047986cbbedd --- /dev/null +++ b/crates/nano-arrow/src/util/lexical.rs @@ -0,0 +1,42 @@ +/// Converts numeric type to a `String` +#[inline] +pub fn lexical_to_bytes(n: N) -> Vec { + let mut buf = Vec::::with_capacity(N::FORMATTED_SIZE_DECIMAL); + lexical_to_bytes_mut(n, &mut buf); + buf +} + +/// Converts numeric type to a `String` +#[inline] +pub fn lexical_to_bytes_mut(n: N, buf: &mut Vec) { + buf.clear(); + buf.reserve(N::FORMATTED_SIZE_DECIMAL); + unsafe { + // JUSTIFICATION + // Benefit + // Allows using the faster serializer lexical core and convert to string + // Soundness + // Length of buf is set as written length afterwards. lexical_core + // creates a valid string, so doesn't need to be checked. + let slice = std::slice::from_raw_parts_mut(buf.as_mut_ptr(), buf.capacity()); + + // Safety: + // Omits an unneeded bound check as we just ensured that we reserved `N::FORMATTED_SIZE_DECIMAL` + #[cfg(debug_assertions)] + { + let len = lexical_core::write(n, slice).len(); + buf.set_len(len); + } + #[cfg(not(debug_assertions))] + { + let len = lexical_core::write_unchecked(n, slice).len(); + buf.set_len(len); + } + } +} + +/// Converts numeric type to a `String` +#[inline] +pub fn lexical_to_string(n: N) -> String { + unsafe { String::from_utf8_unchecked(lexical_to_bytes(n)) } +} diff --git a/crates/nano-arrow/src/util/mod.rs b/crates/nano-arrow/src/util/mod.rs new file mode 100644 index 000000000000..90642b151a1a --- /dev/null +++ b/crates/nano-arrow/src/util/mod.rs @@ -0,0 +1,24 @@ +//! Misc utilities used in different places in the crate. + +#[cfg(any( + feature = "compute", + feature = "io_csv_write", + feature = "io_csv_read", + feature = "io_json", + feature = "io_json_write", + feature = "compute_cast" +))] +mod lexical; +#[cfg(any( + feature = "compute", + feature = "io_csv_write", + feature = "io_csv_read", + feature = "io_json", + feature = "io_json_write", + feature = "compute_cast" +))] +pub use lexical::*; + +#[cfg(feature = "benchmarks")] +#[cfg_attr(docsrs, doc(cfg(feature = "benchmarks")))] +pub mod bench_util; diff --git a/crates/polars-algo/Cargo.toml b/crates/polars-algo/Cargo.toml index 569cccec807e..ac589731c96f 100644 --- a/crates/polars-algo/Cargo.toml +++ b/crates/polars-algo/Cargo.toml @@ -9,9 +9,9 @@ repository = { workspace = true } description = "Algorithms built upon Polars primitives" [dependencies] -polars-core = { version = "0.32.0", path = "../polars-core", features = ["dtype-categorical", "asof_join"], default-features = false } -polars-lazy = { version = "0.32.0", path = "../polars-lazy", features = ["asof_join", "concat_str", "strings"] } -polars-ops = { version = "0.32.0", path = "../polars-ops", features = ["dtype-categorical", "asof_join"], default-features = false } +polars-core = { workspace = true, features = ["dtype-categorical", "asof_join"] } +polars-lazy = { workspace = true, features = ["asof_join", "concat_str", "strings"], default-features = true } +polars-ops = { workspace = true, features = ["dtype-categorical", "asof_join"] } [package.metadata.docs.rs] all-features = true diff --git a/crates/polars-arrow/Cargo.toml b/crates/polars-arrow/Cargo.toml index 6b47d0cecf96..ac44783ddd46 100644 --- a/crates/polars-arrow/Cargo.toml +++ b/crates/polars-arrow/Cargo.toml @@ -9,7 +9,7 @@ repository = { workspace = true } description = "Arrow interfaces for Polars DataFrame library" [dependencies] -polars-error = { version = "0.32.0", path = "../polars-error" } +polars-error = { workspace = true } arrow = { workspace = true } atoi = { workspace = true, optional = true } @@ -34,6 +34,5 @@ compute = ["arrow/compute_cast"] temporal = ["arrow/compute_temporal"] bigidx = [] performant = [] -like = ["arrow/compute_like"] timezones = ["chrono-tz", "chrono"] simd = [] diff --git a/crates/polars-arrow/src/trusted_len/push_unchecked.rs b/crates/polars-arrow/src/trusted_len/push_unchecked.rs index 5d268d070777..f3d830f76fa1 100644 --- a/crates/polars-arrow/src/trusted_len/push_unchecked.rs +++ b/crates/polars-arrow/src/trusted_len/push_unchecked.rs @@ -1,13 +1,13 @@ use super::*; pub trait TrustedLenPush { - /// Will push an item and not check if there is enough capacity + /// Will push an item and not check if there is enough capacity. /// /// # Safety /// Caller must ensure the array has enough capacity to hold `T`. unsafe fn push_unchecked(&mut self, value: T); - /// Extend the array with an iterator who's length can be trusted + /// Extend the array with an iterator who's length can be trusted. fn extend_trusted_len, J: TrustedLen>( &mut self, iter: I, @@ -16,9 +16,16 @@ pub trait TrustedLenPush { } /// # Safety - /// Caller must ensure the iterators reported length is correct + /// Caller must ensure the iterators reported length is correct. unsafe fn extend_trusted_len_unchecked>(&mut self, iter: I); + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn try_extend_trusted_len_unchecked>>( + &mut self, + iter: I, + ) -> Result<(), E>; + fn from_trusted_len_iter, J: TrustedLen>( iter: I, ) -> Self @@ -28,8 +35,28 @@ pub trait TrustedLenPush { unsafe { Self::from_trusted_len_iter_unchecked(iter) } } /// # Safety - /// Caller must ensure the iterators reported length is correct + /// Caller must ensure the iterators reported length is correct. unsafe fn from_trusted_len_iter_unchecked>(iter: I) -> Self; + + fn try_from_trusted_len_iter< + E, + I: IntoIterator, IntoIter = J>, + J: TrustedLen, + >( + iter: I, + ) -> Result + where + Self: Sized, + { + unsafe { Self::try_from_trusted_len_iter_unchecked(iter) } + } + /// # Safety + /// Caller must ensure the iterators reported length is correct. + unsafe fn try_from_trusted_len_iter_unchecked>>( + iter: I, + ) -> Result + where + Self: Sized; } impl TrustedLenPush for Vec { @@ -55,10 +82,38 @@ impl TrustedLenPush for Vec { self.set_len(self.len() + upper) } + unsafe fn try_extend_trusted_len_unchecked>>( + &mut self, + iter: I, + ) -> Result<(), E> { + let iter = iter.into_iter(); + let upper = iter.size_hint().1.expect("must have an upper bound"); + self.reserve(upper); + + let mut dst = self.as_mut_ptr().add(self.len()); + for value in iter { + std::ptr::write(dst, value?); + dst = dst.add(1) + } + self.set_len(self.len() + upper); + Ok(()) + } + #[inline] unsafe fn from_trusted_len_iter_unchecked>(iter: I) -> Self { let mut v = vec![]; v.extend_trusted_len_unchecked(iter); v } + + unsafe fn try_from_trusted_len_iter_unchecked>>( + iter: I, + ) -> Result + where + Self: Sized, + { + let mut v = vec![]; + v.try_extend_trusted_len_unchecked(iter)?; + Ok(v) + } } diff --git a/crates/polars-core/Cargo.toml b/crates/polars-core/Cargo.toml index 11458b58d9d7..d4bddc3922f8 100644 --- a/crates/polars-core/Cargo.toml +++ b/crates/polars-core/Cargo.toml @@ -9,10 +9,10 @@ repository = { workspace = true } description = "Core of the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", features = ["compute"] } -polars-error = { version = "0.32.0", path = "../polars-error" } -polars-row = { version = "0.32.0", path = "../polars-row" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true, features = ["compute"] } +polars-error = { workspace = true } +polars-row = { workspace = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } @@ -26,7 +26,6 @@ indexmap = { workspace = true } itoap = { version = "1", optional = true, features = ["simd"] } ndarray = { version = "0.15", optional = true, default_features = false } num-traits = { workspace = true } -object_store = { workspace = true, optional = true } once_cell = { workspace = true } rand = { workspace = true, optional = true, features = ["small_rng", "std"] } rand_distr = { version = "0.4", optional = true } @@ -37,7 +36,6 @@ serde = { workspace = true, features = ["derive"], optional = true } serde_json = { workspace = true, optional = true } smartstring = { workspace = true } thiserror = { workspace = true } -url = { workspace = true, optional = true } xxhash-rust = { workspace = true } [dev-dependencies] @@ -54,7 +52,9 @@ avx512 = [] docs = [] temporal = ["regex", "chrono", "polars-error/regex"] random = ["rand", "rand_distr"] -default = ["docs", "temporal"] +algorithm_join = [] +algorithm_group_by = [] +default = ["algorithm_join", "algorithm_group_by"] lazy = [] # ~40% faster collect, needed until trustedlength iter stabilizes @@ -62,7 +62,7 @@ lazy = [] performant = ["polars-arrow/performant", "reinterpret"] # extra utilities for Utf8Chunked -strings = ["regex", "polars-arrow/strings", "arrow/compute_substring", "polars-error/regex"] +strings = ["regex", "polars-arrow/strings", "polars-error/regex"] # support for ObjectChunked (downcastable Series of any type) object = ["serde_json"] @@ -81,10 +81,10 @@ zip_with = [] round_series = [] checked_arithmetic = [] repeat_by = [] -is_first = [] -is_last = [] -asof_join = [] -cross_join = [] +is_first_distinct = [] +is_last_distinct = [] +asof_join = ["algorithm_join"] +cross_join = ["algorithm_join"] dot_product = [] concat_str = [] row_hash = [] @@ -97,7 +97,7 @@ group_by_list = [] cum_agg = [] # rolling window functions rolling_window = [] -rank = [] +rank = ["rand"] diff = [] pct_change = ["diff"] moment = [] @@ -109,7 +109,7 @@ dataframe_arithmetic = [] product = [] unique_counts = [] partition_by = [] -semi_anti_join = [] +semi_anti_join = ["algorithm_join"] chunked_ids = [] describe = [] timezones = ["chrono-tz", "arrow/chrono-tz", "polars-arrow/timezones"] @@ -132,7 +132,7 @@ dtype-struct = [] parquet = ["arrow/io_parquet"] # scale to terabytes? -bigidx = ["polars-arrow/bigidx"] +bigidx = ["polars-arrow/bigidx", "polars-utils/bigidx"] python = [] serde = ["dep:serde", "smartstring/serde", "bitflags/serde"] @@ -151,8 +151,8 @@ docs-selection = [ "round_series", "checked_arithmetic", "repeat_by", - "is_first", - "is_last", + "is_first_distinct", + "is_last_distinct", "asof_join", "cross_join", "dot_product", @@ -176,14 +176,10 @@ docs-selection = [ "chunked_ids", "semi_anti_join", "partition_by", + "algorithm_join", + "algorithm_group_by", ] -# Cloud support. -"async" = ["url", "object_store"] -"aws" = ["async", "object_store/aws"] -"azure" = ["async", "object_store/azure"] -"gcp" = ["async", "object_store/gcp"] - [package.metadata.docs.rs] # not all because arrow 4.3 does not compile with simd # all-features = true diff --git a/crates/polars-core/src/chunked_array/builder/mod.rs b/crates/polars-core/src/chunked_array/builder/mod.rs index c00521f125b8..1db996fe618f 100644 --- a/crates/polars-core/src/chunked_array/builder/mod.rs +++ b/crates/polars-core/src/chunked_array/builder/mod.rs @@ -241,8 +241,8 @@ mod test { // Test list collect. let out = [&s1, &s2].iter().copied().collect::(); - assert_eq!(out.get(0).unwrap().len(), 6); - assert_eq!(out.get(1).unwrap().len(), 3); + assert_eq!(out.get_as_series(0).unwrap().len(), 6); + assert_eq!(out.get_as_series(1).unwrap().len(), 3); let mut builder = ListPrimitiveChunkedBuilder::::new("a", 10, 5, DataType::Int32); diff --git a/crates/polars-core/src/chunked_array/collect.rs b/crates/polars-core/src/chunked_array/collect.rs new file mode 100644 index 000000000000..739cc0c6f5c8 --- /dev/null +++ b/crates/polars-core/src/chunked_array/collect.rs @@ -0,0 +1,171 @@ +//! Methods for collecting into a ChunkedArray. +//! +//! For types that don't have dtype parameters: +//! iter.(try_)collect_ca(_trusted) (name) +//! +//! For all types: +//! iter.(try_)collect_ca(_trusted)_like (other_df) Copies name/dtype from other_df +//! iter.(try_)collect_ca(_trusted)_with_dtype (name, df) +//! +//! The try variants work on iterators of Results, the trusted variants do not +//! check the length of the iterator. + +use std::sync::Arc; + +use polars_arrow::trusted_len::TrustedLen; + +use crate::chunked_array::ChunkedArray; +use crate::datatypes::{ + ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype, DataType, Field, PolarsDataType, +}; + +pub trait ChunkedCollectIterExt: Iterator + Sized { + #[inline] + fn collect_ca_with_dtype(self, name: &str, dtype: DataType) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.collect_arr_with_dtype(dtype); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_like(self, name_dtype_src: &ChunkedArray) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.collect_arr_with_dtype(field.dtype.clone()); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted_with_dtype(self, name: &str, dtype: DataType) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + Self: TrustedLen, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.collect_arr_trusted_with_dtype(dtype); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted_like(self, name_dtype_src: &ChunkedArray) -> ChunkedArray + where + T::Array: ArrayFromIterDtype, + Self: TrustedLen, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.collect_arr_trusted_with_dtype(field.dtype.clone()); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn try_collect_ca_with_dtype( + self, + name: &str, + dtype: DataType, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator>, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.try_collect_arr_with_dtype(dtype)?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_like( + self, + name_dtype_src: &ChunkedArray, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator>, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.try_collect_arr_with_dtype(field.dtype.clone())?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted_with_dtype( + self, + name: &str, + dtype: DataType, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + let field = Arc::new(Field::new(name, dtype.clone())); + let arr = self.try_collect_arr_trusted_with_dtype(dtype)?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted_like( + self, + name_dtype_src: &ChunkedArray, + ) -> Result, E> + where + T::Array: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + let field = Arc::clone(&name_dtype_src.field); + let arr = self.try_collect_arr_trusted_with_dtype(field.dtype.clone())?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } +} + +impl ChunkedCollectIterExt for I {} + +pub trait ChunkedCollectInferIterExt: Iterator + Sized { + #[inline] + fn collect_ca(self, name: &str) -> ChunkedArray + where + T::Array: ArrayFromIter, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.collect_arr(); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn collect_ca_trusted(self, name: &str) -> ChunkedArray + where + T::Array: ArrayFromIter, + Self: TrustedLen, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.collect_arr_trusted(); + ChunkedArray::from_chunk_iter_and_field(field, [arr]) + } + + #[inline] + fn try_collect_ca(self, name: &str) -> Result, E> + where + T::Array: ArrayFromIter, + Self: Iterator>, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.try_collect_arr()?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } + + #[inline] + fn try_collect_ca_trusted(self, name: &str) -> Result, E> + where + T::Array: ArrayFromIter, + Self: Iterator> + TrustedLen, + { + let field = Arc::new(Field::new(name, T::get_dtype())); + let arr = self.try_collect_arr_trusted()?; + Ok(ChunkedArray::from_chunk_iter_and_field(field, [arr])) + } +} + +impl ChunkedCollectInferIterExt for I {} diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 3515cd6e4678..fcfff9804dc6 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -687,7 +687,7 @@ where { match (lhs.len(), rhs.len()) { (_, 1) => { - let right = rhs.get(0).map(|s| s.with_name("")); + let right = rhs.get_as_series(0).map(|s| s.with_name("")); // SAFETY: values within iterator do not outlive the iterator itself unsafe { lhs.amortized_iter() @@ -696,7 +696,7 @@ where } }, (1, _) => { - let left = lhs.get(0).map(|s| s.with_name("")); + let left = lhs.get_as_series(0).map(|s| s.with_name("")); // SAFETY: values within iterator do not outlive the iterator itself unsafe { rhs.amortized_iter() diff --git a/crates/polars-core/src/chunked_array/from.rs b/crates/polars-core/src/chunked_array/from.rs index 9ad53faacbde..652b61dd858a 100644 --- a/crates/polars-core/src/chunked_array/from.rs +++ b/crates/polars-core/src/chunked_array/from.rs @@ -1,6 +1,6 @@ use super::*; -#[allow(clippy::ptr_arg)] +#[allow(clippy::all)] fn from_chunks_list_dtype(chunks: &mut Vec, dtype: DataType) -> DataType { // ensure we don't get List let dtype = if let Some(arr) = chunks.get(0) { @@ -105,6 +105,19 @@ where unsafe { Self::from_chunks(name, chunks) } } + pub fn from_chunk_iter_like(ca: &Self, iter: I) -> Self + where + I: IntoIterator, + T: PolarsDataType::Item>, + ::Item: Array, + { + let chunks = iter + .into_iter() + .map(|x| Box::new(x) as Box) + .collect(); + unsafe { Self::from_chunks_and_dtype_unchecked(ca.name(), chunks, ca.dtype().clone()) } + } + pub fn try_from_chunk_iter(name: &str, iter: I) -> Result where I: IntoIterator>, @@ -118,6 +131,35 @@ where unsafe { Ok(Self::from_chunks(name, chunks?)) } } + pub(crate) fn from_chunk_iter_and_field(field: Arc, chunks: I) -> Self + where + I: IntoIterator, + T: PolarsDataType::Item>, + ::Item: Array, + { + assert_eq!( + std::mem::discriminant(&T::get_dtype()), + std::mem::discriminant(&field.dtype) + ); + + let mut length = 0; + let chunks = chunks + .into_iter() + .map(|x| { + length += x.len(); + Box::new(x) as Box + }) + .collect(); + + ChunkedArray { + field, + chunks, + phantom: PhantomData, + bit_settings: Default::default(), + length: length.try_into().unwrap(), + } + } + /// Create a new [`ChunkedArray`] from existing chunks. /// /// # Safety @@ -191,29 +233,7 @@ where } out } -} - -impl ListChunked { - pub(crate) unsafe fn from_chunks_and_dtype_unchecked( - name: &str, - chunks: Vec, - dtype: DataType, - ) -> Self { - let field = Arc::new(Field::new(name, dtype)); - let mut out = ChunkedArray { - field, - chunks, - phantom: PhantomData, - bit_settings: Default::default(), - length: 0, - }; - out.compute_len(); - out - } -} -#[cfg(feature = "dtype-array")] -impl ArrayChunked { pub(crate) unsafe fn from_chunks_and_dtype_unchecked( name: &str, chunks: Vec, diff --git a/crates/polars-core/src/chunked_array/kernels/mod.rs b/crates/polars-core/src/chunked_array/kernels/mod.rs deleted file mode 100644 index 66d56923fa2b..000000000000 --- a/crates/polars-core/src/chunked_array/kernels/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub(crate) mod take; diff --git a/crates/polars-core/src/chunked_array/kernels/take.rs b/crates/polars-core/src/chunked_array/kernels/take.rs deleted file mode 100644 index fc31e5450084..000000000000 --- a/crates/polars-core/src/chunked_array/kernels/take.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::convert::TryFrom; - -use polars_arrow::compute::take::bitmap::take_bitmap_unchecked; -use polars_arrow::compute::take::take_value_indices_from_list; -use polars_arrow::utils::combine_validities_and; - -use crate::prelude::*; - -/// Take kernel for multiple chunks. We directly return a [`ChunkedArray`] because that path chooses the fastest collection path. -pub(crate) fn take_primitive_iter_n_chunks>( - ca: &ChunkedArray, - indices: I, -) -> ChunkedArray { - let taker = ca.take_rand(); - indices.into_iter().map(|idx| taker.get(idx)).collect() -} - -/// Take kernel for multiple chunks where an iterator can produce None values. -/// Used in join operations. We directly return a [`ChunkedArray`] because that path chooses the fastest collection path. -pub(crate) fn take_primitive_opt_iter_n_chunks< - T: PolarsNumericType, - I: IntoIterator>, ->( - ca: &ChunkedArray, - indices: I, -) -> ChunkedArray { - let taker = ca.take_rand(); - indices - .into_iter() - .map(|opt_idx| opt_idx.and_then(|idx| taker.get(idx))) - .collect() -} - -/// This is faster because it does no bounds checks and allocates directly into aligned memory. -/// -/// # Safety -/// No bounds checks -pub(crate) unsafe fn take_list_unchecked( - values: &ListArray, - indices: &IdxArr, -) -> ListArray { - // Taking the whole list or a contiguous sublist. - let (list_indices, offsets) = take_value_indices_from_list(values, indices); - - // Temporary series so that we can take primitives from it. - let s = Series::try_from(("", values.values().clone() as ArrayRef)).unwrap(); - let taken = s.take_unchecked(&list_indices.into()).unwrap(); - - let taken = taken.array_ref(0).clone(); - let validity = if let Some(validity) = values.validity() { - let validity = take_bitmap_unchecked(validity, indices.values().as_slice()); - combine_validities_and(Some(&validity), indices.validity()) - } else { - indices.validity().cloned() - }; - - let dtype = ListArray::::default_datatype(taken.data_type().clone()); - // SAFETY: offsets are monotonically increasing. - ListArray::new(dtype, offsets.into(), taken, validity) -} diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 8bf43fcba21f..6bf9ed2f60d6 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -173,14 +173,19 @@ impl ListChunked { where V: PolarsDataType, F: FnMut(Option>) -> Option + Copy, - K: ArrayFromElementIter, + V::Array: ArrayFromIter>, { // TODO! make an amortized iter that does not flatten + // SAFETY: unstable series never lives longer than the iterator. + unsafe { self.amortized_iter().map(f).collect_ca(self.name()) } + } + pub fn for_each_amortized<'a, F>(&'a self, f: F) + where + F: FnMut(Option>), + { // SAFETY: unstable series never lives longer than the iterator. - let element_iter = unsafe { self.amortized_iter().map(f) }; - let array = K::array_from_iter(element_iter); - ChunkedArray::from_chunk_iter(self.name(), std::iter::once(array)) + unsafe { self.amortized_iter().for_each(f) } } /// Apply a closure `F` elementwise. diff --git a/crates/polars-core/src/chunked_array/list/mod.rs b/crates/polars-core/src/chunked_array/list/mod.rs index 2d0da0ef5e28..3a7a92d8192d 100644 --- a/crates/polars-core/src/chunked_array/list/mod.rs +++ b/crates/polars-core/src/chunked_array/list/mod.rs @@ -29,13 +29,6 @@ impl ListChunked { self.bit_settings.contains(Settings::FAST_EXPLODE_LIST) } - pub(crate) fn is_nested(&self) -> bool { - match self.dtype() { - DataType::List(inner) => matches!(&**inner, DataType::List(_)), - _ => unreachable!(), - } - } - /// Set the logical type of the [`ListChunked`]. pub fn to_logical(&mut self, inner_dtype: DataType) { debug_assert_eq!(inner_dtype.to_physical(), self.inner_dtype()); diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs index 56ad826e8595..ff7b6fefd34c 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -6,7 +6,7 @@ use hashbrown::hash_map::{Entry, RawEntryMut}; use polars_arrow::trusted_len::TrustedLenPush; use crate::datatypes::PlHashMap; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; +use crate::hashing::HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::{using_string_cache, StringCache, POOL}; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 6841d6da5ba1..ec6cd04704ca 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -7,7 +7,6 @@ pub mod stringcache; use bitflags::bitflags; pub use builder::*; pub(crate) use merge::*; -pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal}; use polars_utils::sync::SyncPtr; use super::*; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs index 91f3e293e202..759628b322cb 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/mod.rs @@ -1,10 +1,8 @@ mod append; mod full; -mod take_random; +#[cfg(feature = "algorithm_group_by")] mod unique; #[cfg(feature = "zip_with")] mod zip; -pub(crate) use take_random::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal}; - use super::*; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/take_random.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/take_random.rs deleted file mode 100644 index 222ab4c77f51..000000000000 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/take_random.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::cmp::Ordering; - -use arrow::array::Utf8Array; - -use crate::datatypes::UInt32Type; -use crate::prelude::compare_inner::PartialOrdInner; -use crate::prelude::{ - CategoricalChunked, IntoTakeRandom, PlHashMap, RevMapping, TakeRandBranch3, TakeRandom, - TakeRandomArray, TakeRandomArrayValues, TakeRandomChunked, -}; - -type TakeCats<'a> = TakeRandBranch3< - TakeRandomArrayValues<'a, UInt32Type>, - TakeRandomArray<'a, UInt32Type>, - TakeRandomChunked<'a, UInt32Type>, ->; - -pub(crate) struct CategoricalTakeRandomLocal<'a> { - rev_map: &'a Utf8Array, - cats: TakeCats<'a>, -} - -impl<'a> CategoricalTakeRandomLocal<'a> { - pub(crate) fn new(ca: &'a CategoricalChunked) -> Self { - // should be rechunked upstream - assert_eq!(ca.logical.chunks.len(), 1, "implementation error"); - if let RevMapping::Local(rev_map) = &**ca.get_rev_map() { - let cats = ca.logical().take_rand(); - Self { rev_map, cats } - } else { - unreachable!() - } - } -} - -impl PartialOrdInner for CategoricalTakeRandomLocal<'_> { - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { - let a = self - .cats - .get_unchecked(idx_a) - .map(|cat| self.rev_map.value_unchecked(cat as usize)); - let b = self - .cats - .get_unchecked(idx_b) - .map(|cat| self.rev_map.value_unchecked(cat as usize)); - a.partial_cmp(&b).unwrap() - } -} - -pub(crate) struct CategoricalTakeRandomGlobal<'a> { - rev_map_part_1: &'a PlHashMap, - rev_map_part_2: &'a Utf8Array, - cats: TakeCats<'a>, -} -impl<'a> CategoricalTakeRandomGlobal<'a> { - pub(crate) fn new(ca: &'a CategoricalChunked) -> Self { - // should be rechunked upstream - assert_eq!(ca.logical.chunks.len(), 1, "implementation error"); - if let RevMapping::Global(rev_map_part_1, rev_map_part_2, _) = &**ca.get_rev_map() { - let cats = ca.logical().take_rand(); - Self { - rev_map_part_1, - rev_map_part_2, - cats, - } - } else { - unreachable!() - } - } -} - -impl PartialOrdInner for CategoricalTakeRandomGlobal<'_> { - unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering { - let a = self.cats.get_unchecked(idx_a).map(|cat| { - let idx = self.rev_map_part_1.get(&cat).unwrap(); - self.rev_map_part_2.value_unchecked(*idx as usize) - }); - let b = self.cats.get_unchecked(idx_b).map(|cat| { - let idx = self.rev_map_part_1.get(&cat).unwrap(); - self.rev_map_part_2.value_unchecked(*idx as usize) - }); - a.partial_cmp(&b).unwrap() - } -} diff --git a/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs b/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs index 195579e1392b..8c7a206925a3 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs @@ -8,7 +8,7 @@ use once_cell::sync::Lazy; use smartstring::{LazyCompact, SmartString}; use crate::datatypes::PlIdHashMap; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; +use crate::hashing::HASHMAP_INIT_SIZE; use crate::prelude::InitHashMaps; /// We use atomic reference counting diff --git a/crates/polars-core/src/chunked_array/logical/date.rs b/crates/polars-core/src/chunked_array/logical/date.rs index 0151361e13ca..38b5593a92c2 100644 --- a/crates/polars-core/src/chunked_array/logical/date.rs +++ b/crates/polars-core/src/chunked_array/logical/date.rs @@ -43,6 +43,7 @@ impl LogicalType for DateChunked { .into_datetime(*tu, tz.clone()) .into_series()) }, + #[cfg(feature = "dtype-time")] (Date, Time) => Ok(Int64Chunked::full(self.name(), 0i64, self.len()) .into_time() .into_series()), diff --git a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs index 2a523ecf856d..c3e27d1438db 100644 --- a/crates/polars-core/src/chunked_array/logical/struct_/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/struct_/mod.rs @@ -11,7 +11,7 @@ use smartstring::alias::String as SmartString; use super::*; use crate::datatypes::*; -use crate::utils::index_to_chunked_index2; +use crate::utils::index_to_chunked_index; /// This is logical type [`StructChunked`] that /// dispatches most logic to the `fields` implementations @@ -425,7 +425,7 @@ impl LogicalType for StructChunked { } unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> { - let (chunk_idx, idx) = index_to_chunked_index2(&self.chunks, i); + let (chunk_idx, idx) = index_to_chunked_index(self.chunks.iter().map(|c| c.len()), i); if let DataType::Struct(flds) = self.dtype() { // safety: we already have a single chunk and we are // guarded by the type system. diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index 60784cac93f1..d7518ecec4b7 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -16,10 +16,10 @@ pub mod ops; pub mod arithmetic; pub mod builder; pub mod cast; +pub mod collect; pub mod comparison; pub mod float; pub mod iterator; -pub mod kernels; #[cfg(feature = "ndarray")] pub(crate) mod ndarray; @@ -370,6 +370,82 @@ impl ChunkedArray { } } +impl ChunkedArray +where + T: PolarsDataType, +{ + #[inline] + pub fn get(&self, idx: usize) -> Option> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + let arr = self.downcast_get(chunk_idx)?; + + // SAFETY: if index_to_chunked_index returns a valid chunk_idx, we know + // that arr_idx < arr.len(). + unsafe { arr.get_unchecked(arr_idx) } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + pub unsafe fn get_unchecked(&self, idx: usize) -> Option> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + + unsafe { + // SAFETY: up to the caller to make sure the index is valid. + self.downcast_get_unchecked(chunk_idx) + .get_unchecked(arr_idx) + } + } + + /// # Safety + /// It is the callers responsibility that the `idx < self.len()`. + #[inline] + pub unsafe fn value_unchecked(&self, idx: usize) -> T::Physical<'_> { + let (chunk_idx, arr_idx) = self.index_to_chunked_index(idx); + + unsafe { + // SAFETY: up to the caller to make sure the index is valid. + self.downcast_get_unchecked(chunk_idx) + .value_unchecked(arr_idx) + } + } + + #[inline] + pub fn last(&self) -> Option> { + unsafe { + let arr = self.downcast_get_unchecked(self.chunks.len().checked_sub(1)?); + arr.get_unchecked(arr.len().checked_sub(1)?) + } + } +} + +impl ListChunked { + #[inline] + pub fn get_as_series(&self, idx: usize) -> Option { + unsafe { + Some(Series::from_chunks_and_dtype_unchecked( + self.name(), + vec![self.get(idx)?], + &self.inner_dtype().to_physical(), + )) + } + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayChunked { + #[inline] + pub fn get_as_series(&self, idx: usize) -> Option { + unsafe { + Some(Series::from_chunks_and_dtype_unchecked( + self.name(), + vec![self.get(idx)?], + &self.inner_dtype().to_physical(), + )) + } + } +} + impl ChunkedArray where T: PolarsDataType, @@ -663,7 +739,7 @@ pub(crate) mod test { #[test] fn take() { let a = get_chunked_array(); - let new = a.take([0usize, 1].iter().copied().into()).unwrap(); + let new = a.take(&[0 as IdxSize, 1]).unwrap(); assert_eq!(new.len(), 2) } diff --git a/crates/polars-core/src/chunked_array/ops/append.rs b/crates/polars-core/src/chunked_array/ops/append.rs index 4a1503826743..1bef4f8a4096 100644 --- a/crates/polars-core/src/chunked_array/ops/append.rs +++ b/crates/polars-core/src/chunked_array/ops/append.rs @@ -15,8 +15,7 @@ pub(super) fn update_sorted_flag_before_append<'a, T>( other: &'a ChunkedArray, ) where T: PolarsDataType, - &'a ChunkedArray: TakeRandom, - <&'a ChunkedArray as TakeRandom>::Item: PartialOrd, + T::Physical<'a>: PartialOrd, { let get_start_end = || { let end = { diff --git a/crates/polars-core/src/chunked_array/ops/apply.rs b/crates/polars-core/src/chunked_array/ops/apply.rs index b78cc9461bdc..3ef019176f3a 100644 --- a/crates/polars-core/src/chunked_array/ops/apply.rs +++ b/crates/polars-core/src/chunked_array/ops/apply.rs @@ -1,13 +1,9 @@ //! Implementations of the ChunkApply Trait. use std::borrow::Cow; use std::convert::TryFrom; -use std::error::Error; use arrow::array::{BooleanArray, PrimitiveArray}; use arrow::bitmap::utils::{get_bit_unchecked, set_bit_unchecked}; -use arrow::bitmap::Bitmap; -use arrow::trusted_len::TrustedLen; -use arrow::types::NativeType; use polars_arrow::bitmap::unary_mut; use crate::prelude::*; @@ -18,48 +14,96 @@ impl ChunkedArray where T: PolarsDataType, { - pub fn apply_values_generic<'a, U, K, F>(&'a self, op: F) -> ChunkedArray + // Applies a function to all elements , regardless of whether they + // are null or not, after which the null mask is copied from the + // original array. + pub fn apply_values_generic<'a, U, K, F>(&'a self, mut op: F) -> ChunkedArray where U: PolarsDataType, - F: FnMut(T::Physical<'a>) -> K + Copy, - K: ArrayFromElementIter, + F: FnMut(T::Physical<'a>) -> K, + U::Array: ArrayFromIter, { let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.values_iter().map(op); - let array = K::array_from_values_iter(element_iter); - array.with_validity_typed(arr.validity().cloned()) + let out: U::Array = arr.values_iter().map(&mut op).collect_arr(); + out.with_validity_typed(arr.validity().cloned()) }); ChunkedArray::from_chunk_iter(self.name(), iter) } - pub fn try_apply_values_generic<'a, U, K, F, E>(&'a self, op: F) -> Result, E> + /// Applies a function to all elements, regardless of whether they + /// are null or not, after which the null mask is copied from the + /// original array. + pub fn try_apply_values_generic<'a, U, K, F, E>( + &'a self, + mut op: F, + ) -> Result, E> where U: PolarsDataType, - F: FnMut(T::Physical<'a>) -> Result + Copy, - K: ArrayFromElementIter, - E: Error, + F: FnMut(T::Physical<'a>) -> Result, + U::Array: ArrayFromIter, { let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.values_iter().map(op); - let array = K::try_array_from_values_iter(element_iter)?; + let element_iter = arr.values_iter().map(&mut op); + let array: U::Array = element_iter.try_collect_arr()?; Ok(array.with_validity_typed(arr.validity().cloned())) }); ChunkedArray::try_from_chunk_iter(self.name(), iter) } - pub fn try_apply_generic<'a, U, K, F, E>(&'a self, op: F) -> Result, E> + /// Applies a function only to the non-null elements, propagating nulls. + pub fn apply_nonnull_values_generic<'a, U, K, F>( + &'a self, + dtype: DataType, + mut op: F, + ) -> ChunkedArray where U: PolarsDataType, - F: FnMut(Option>) -> Result, E> + Copy, - K: ArrayFromElementIter, - E: Error, + F: FnMut(T::Physical<'a>) -> K, + U::Array: ArrayFromIterDtype + ArrayFromIterDtype>, { let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.iter().map(op); - let array = K::try_array_from_iter(element_iter)?; - Ok(array.with_validity_typed(arr.validity().cloned())) + if arr.null_count() == 0 { + let out: U::Array = arr + .values_iter() + .map(&mut op) + .collect_arr_with_dtype(dtype.clone()); + out.with_validity_typed(arr.validity().cloned()) + } else { + let out: U::Array = arr + .iter() + .map(|opt| opt.map(&mut op)) + .collect_arr_with_dtype(dtype.clone()); + out.with_validity_typed(arr.validity().cloned()) + } + }); + + ChunkedArray::from_chunk_iter(self.name(), iter) + } + + /// Applies a function only to the non-null elements, propagating nulls. + pub fn try_apply_nonnull_values_generic<'a, U, K, F, E>( + &'a self, + mut op: F, + ) -> Result, E> + where + U: PolarsDataType, + F: FnMut(T::Physical<'a>) -> Result, + U::Array: ArrayFromIter + ArrayFromIter>, + { + let iter = self.downcast_iter().map(|arr| { + let arr = if arr.null_count() == 0 { + let out: U::Array = arr.values_iter().map(&mut op).try_collect_arr()?; + out.with_validity_typed(arr.validity().cloned()) + } else { + let out: U::Array = arr + .iter() + .map(|opt| opt.map(&mut op).transpose()) + .try_collect_arr()?; + out.with_validity_typed(arr.validity().cloned()) + }; + Ok(arr) }); ChunkedArray::try_from_chunk_iter(self.name(), iter) @@ -69,55 +113,34 @@ where where U: PolarsDataType, F: FnMut(Option>) -> Option, - K: ArrayFromElementIter, + U::Array: ArrayFromIter>, { if self.null_count() == 0 { - let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.values_iter().map(|x| op(Some(x))); - K::array_from_iter(element_iter) - }); + let iter = self + .downcast_iter() + .map(|arr| arr.values_iter().map(|x| op(Some(x))).collect_arr()); ChunkedArray::from_chunk_iter(self.name(), iter) } else { - let iter = self.downcast_iter().map(|arr| { - let element_iter = arr.iter().map(&mut op); - K::array_from_iter(element_iter) - }); + let iter = self + .downcast_iter() + .map(|arr| arr.iter().map(&mut op).collect_arr()); ChunkedArray::from_chunk_iter(self.name(), iter) } } -} -fn collect_array>( - iter: I, - validity: Option, -) -> PrimitiveArray { - PrimitiveArray::from_trusted_len_values_iter(iter).with_validity(validity) -} - -macro_rules! try_apply { - ($self:expr, $f:expr) => {{ - if !$self.has_validity() { - $self.into_no_null_iter().map($f).collect() - } else { - $self - .into_iter() - .map(|opt_v| opt_v.map($f).transpose()) - .collect() - } - }}; -} + pub fn try_apply_generic<'a, U, K, F, E>(&'a self, op: F) -> Result, E> + where + U: PolarsDataType, + F: FnMut(Option>) -> Result, E> + Copy, + U::Array: ArrayFromIter>, + { + let iter = self.downcast_iter().map(|arr| { + let array: U::Array = arr.iter().map(op).try_collect_arr()?; + Ok(array.with_validity_typed(arr.validity().cloned())) + }); -macro_rules! apply { - ($self:expr, $f:expr) => {{ - if !$self.has_validity() { - $self.into_no_null_iter().map($f).collect_trusted() - } else { - $self - .into_iter() - .map(|opt_v| opt_v.map($f)) - .collect_trusted() - } - }}; + ChunkedArray::try_from_chunk_iter(self.name(), iter) + } } fn apply_in_place_impl(name: &str, chunks: Vec, f: F) -> ChunkedArray @@ -215,7 +238,8 @@ where .data_views() .zip(self.iter_validities()) .map(|(slice, validity)| { - collect_array(slice.iter().copied().map(f), validity.cloned()) + let arr: T::Array = slice.iter().copied().map(f).collect_arr(); + arr.with_validity(validity.cloned()) }); ChunkedArray::from_chunk_iter(self.name(), chunks) } @@ -370,6 +394,21 @@ impl Utf8Chunked { }); Utf8Chunked::from_chunk_iter(self.name(), chunks) } + + /// Utility that reuses an string buffer to amortize allocations. + /// Prefer this over an `apply` that returns an owned `String`. + pub fn apply_to_buffer<'a, F>(&'a self, mut f: F) -> Self + where + F: FnMut(&'a str, &mut String), + { + let mut buf = String::new(); + let outer = |s: &'a str| { + buf.clear(); + f(s, &mut buf); + unsafe { std::mem::transmute::<&str, &'a str>(buf.as_str()) } + }; + self.apply_mut(outer) + } } impl BinaryChunked { @@ -555,7 +594,17 @@ impl<'a> ChunkApply<'a, Series> for ListChunked { } out }; - let mut ca: ListChunked = apply!(self, &mut function); + let mut ca: ListChunked = { + if !self.has_validity() { + self.into_no_null_iter() + .map(&mut function) + .collect_trusted() + } else { + self.into_iter() + .map(|opt_v| opt_v.map(&mut function)) + .collect_trusted() + } + }; if fast_explode { ca.set_fast_explode() } @@ -580,7 +629,15 @@ impl<'a> ChunkApply<'a, Series> for ListChunked { } out }; - let ca: PolarsResult = try_apply!(self, &mut function); + let ca: PolarsResult = { + if !self.has_validity() { + self.into_no_null_iter().map(&mut function).collect() + } else { + self.into_iter() + .map(|opt_v| opt_v.map(&mut function).transpose()) + .collect() + } + }; let mut ca = ca?; if fast_explode { ca.set_fast_explode() diff --git a/crates/polars-core/src/chunked_array/ops/arity.rs b/crates/polars-core/src/chunked_array/ops/arity.rs index 2613c49b4008..6bd63a488c10 100644 --- a/crates/polars-core/src/chunked_array/ops/arity.rs +++ b/crates/polars-core/src/chunked_array/ops/arity.rs @@ -3,9 +3,9 @@ use std::error::Error; use arrow::array::Array; use polars_arrow::utils::combine_validities_and; -use crate::datatypes::{ArrayFromElementIter, PolarsNumericType, StaticArray}; +use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray}; use crate::prelude::{ChunkedArray, PolarsDataType}; -use crate::utils::align_chunks_binary; +use crate::utils::{align_chunks_binary, align_chunks_ternary}; #[inline] pub fn binary_elementwise( @@ -18,7 +18,7 @@ where U: PolarsDataType, V: PolarsDataType, F: for<'a> FnMut(Option>, Option>) -> Option, - K: ArrayFromElementIter, + V::Array: ArrayFromIter>, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -29,11 +29,29 @@ where .iter() .zip(rhs_arr.iter()) .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); - K::array_from_iter(element_iter) + element_iter.collect_arr() }); ChunkedArray::from_chunk_iter(lhs.name(), iter) } +#[inline] +pub fn binary_elementwise_for_each(lhs: &ChunkedArray, rhs: &ChunkedArray, mut op: F) +where + T: PolarsDataType, + U: PolarsDataType, + F: for<'a> FnMut(Option>, Option>), +{ + let (lhs, rhs) = align_chunks_binary(lhs, rhs); + lhs.downcast_iter() + .zip(rhs.downcast_iter()) + .for_each(|(lhs_arr, rhs_arr)| { + lhs_arr + .iter() + .zip(rhs_arr.iter()) + .for_each(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); + }) +} + #[inline] pub fn try_binary_elementwise( lhs: &ChunkedArray, @@ -45,8 +63,7 @@ where U: PolarsDataType, V: PolarsDataType, F: for<'a> FnMut(Option>, Option>) -> Result, E>, - K: ArrayFromElementIter, - E: Error, + V::Array: ArrayFromIter>, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -57,7 +74,7 @@ where .iter() .zip(rhs_arr.iter()) .map(|(lhs_opt_val, rhs_opt_val)| op(lhs_opt_val, rhs_opt_val)); - K::try_array_from_iter(element_iter) + element_iter.try_collect_arr() }); ChunkedArray::try_from_chunk_iter(lhs.name(), iter) } @@ -71,9 +88,9 @@ pub fn binary_elementwise_values( where T: PolarsDataType, U: PolarsDataType, - V: PolarsNumericType, + V: PolarsDataType, F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K, - K: ArrayFromElementIter, + V::Array: ArrayFromIter, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -87,7 +104,7 @@ where .zip(rhs_arr.values_iter()) .map(|(lhs_val, rhs_val)| op(lhs_val, rhs_val)); - let array = K::array_from_values_iter(element_iter); + let array: V::Array = element_iter.collect_arr(); array.with_validity_typed(validity) }); ChunkedArray::from_chunk_iter(lhs.name(), iter) @@ -102,10 +119,9 @@ pub fn try_binary_elementwise_values( where T: PolarsDataType, U: PolarsDataType, - V: PolarsNumericType, + V: PolarsDataType, F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result, - K: ArrayFromElementIter, - E: Error, + V::Array: ArrayFromIter, { let (lhs, rhs) = align_chunks_binary(lhs, rhs); let iter = lhs @@ -119,7 +135,7 @@ where .zip(rhs_arr.values_iter()) .map(|(lhs_val, rhs_val)| op(lhs_val, rhs_val)); - let array = K::try_array_from_values_iter(element_iter)?; + let array: V::Array = element_iter.try_collect_arr()?; Ok(array.with_validity_typed(validity)) }); ChunkedArray::try_from_chunk_iter(lhs.name(), iter) @@ -238,3 +254,38 @@ where .collect::, E>>()?; Ok(lhs.copy_with_chunks(chunks, keep_sorted, keep_fast_explode)) } + +#[inline] +pub fn try_ternary_elementwise( + ca1: &ChunkedArray, + ca2: &ChunkedArray, + ca3: &ChunkedArray, + mut op: F, +) -> Result, E> +where + T: PolarsDataType, + U: PolarsDataType, + V: PolarsDataType, + G: PolarsDataType, + F: for<'a> FnMut( + Option>, + Option>, + Option>, + ) -> Result, E>, + V::Array: ArrayFromIter>, +{ + let (ca1, ca2, ca3) = align_chunks_ternary(ca1, ca2, ca3); + let iter = ca1 + .downcast_iter() + .zip(ca2.downcast_iter()) + .zip(ca3.downcast_iter()) + .map(|((ca1_arr, ca2_arr), ca3_arr)| { + let element_iter = ca1_arr.iter().zip(ca2_arr.iter()).zip(ca3_arr.iter()).map( + |((ca1_opt_val, ca2_opt_val), ca3_opt_val)| { + op(ca1_opt_val, ca2_opt_val, ca3_opt_val) + }, + ); + element_iter.try_collect_arr() + }); + ChunkedArray::try_from_chunk_iter(ca1.name(), iter) +} diff --git a/crates/polars-core/src/chunked_array/ops/bit_repr.rs b/crates/polars-core/src/chunked_array/ops/bit_repr.rs index e89bf87a9ee6..f26c68716af6 100644 --- a/crates/polars-core/src/chunked_array/ops/bit_repr.rs +++ b/crates/polars-core/src/chunked_array/ops/bit_repr.rs @@ -24,6 +24,7 @@ fn reinterpret_chunked_array( /// Reinterprets the type of a [`ListChunked`]. T and U must have the same size /// and alignment. +#[cfg(feature = "reinterpret")] fn reinterpret_list_chunked( ca: &ListChunked, ) -> ListChunked { diff --git a/crates/polars-core/src/chunked_array/ops/compare_inner.rs b/crates/polars-core/src/chunked_array/ops/compare_inner.rs index 3352706b05ae..f34aa304e0c8 100644 --- a/crates/polars-core/src/chunked_array/ops/compare_inner.rs +++ b/crates/polars-core/src/chunked_array/ops/compare_inner.rs @@ -2,11 +2,45 @@ use std::cmp::{Ordering, PartialEq}; -use crate::chunked_array::ops::take::take_random::{ - TakeRandomArray, TakeRandomArrayValues, TakeRandomChunked, -}; +use crate::chunked_array::ChunkedArrayLayout; use crate::prelude::*; +#[repr(transparent)] +struct NonNull(T); + +trait GetInner { + type Item; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item; +} + +impl<'a, T: PolarsDataType> GetInner for &'a ChunkedArray { + type Item = Option>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + ChunkedArray::get_unchecked(self, idx) + } +} + +impl<'a, T: StaticArray> GetInner for &'a T { + type Item = Option>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + ::get_unchecked(self, idx) + } +} + +impl<'a, T: PolarsDataType> GetInner for NonNull<&'a ChunkedArray> { + type Item = T::Physical<'a>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + self.0.value_unchecked(idx) + } +} + +impl<'a, T: StaticArray> GetInner for NonNull<&'a T> { + type Item = T::ValueT<'a>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + self.0.value_unchecked(idx) + } +} + pub trait PartialEqInner: Send + Sync { /// # Safety /// Does not do any bound checks. @@ -21,7 +55,7 @@ pub trait PartialOrdInner: Send + Sync { impl PartialEqInner for T where - T: TakeRandom + Send + Sync, + T: GetInner + Send + Sync, T::Item: PartialEq, { #[inline] @@ -43,22 +77,11 @@ where T::Physical<'a>: PartialEq, { fn into_partial_eq_inner(self) -> Box { - let mut chunks = self.downcast_iter(); - - if self.chunks.len() == 1 { - let arr = chunks.next().unwrap(); - - if !self.has_validity() { - Box::new(TakeRandomArrayValues:: { arr }) - } else { - Box::new(TakeRandomArray:: { arr }) - } - } else { - let t = TakeRandomChunked:: { - chunks: chunks.collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) + match self.layout() { + ChunkedArrayLayout::SingleNoNull(arr) => Box::new(NonNull(arr)), + ChunkedArrayLayout::Single(arr) => Box::new(arr), + ChunkedArrayLayout::MultiNoNull(ca) => Box::new(NonNull(ca)), + ChunkedArrayLayout::Multi(ca) => Box::new(ca), } } } @@ -66,9 +89,8 @@ where // Partial ordering implementations. #[inline] fn fallback(a: T) -> Ordering { - // nan != nan - // this is a simple way to check if it is nan - // without convincing the compiler we deal with floats + // This is a simple way to check if it is nan + // without convincing the compiler we deal with floats. #[allow(clippy::eq_op)] if a != a { Ordering::Less @@ -79,7 +101,7 @@ fn fallback(a: T) -> Ordering { impl PartialOrdInner for T where - T: TakeRandom + Send + Sync, + T: GetInner + Send + Sync, T::Item: PartialOrd, { #[inline] @@ -96,39 +118,60 @@ pub(crate) trait IntoPartialOrdInner<'a> { fn into_partial_ord_inner(self) -> Box; } -/// We use a trait object because we want to call this from Series and cannot use a typed enum. impl<'a, T> IntoPartialOrdInner<'a> for &'a ChunkedArray where T: PolarsDataType, T::Physical<'a>: PartialOrd, { fn into_partial_ord_inner(self) -> Box { - let mut chunks = self.downcast_iter(); - - if self.chunks.len() == 1 { - let arr = chunks.next().unwrap(); - - if !self.has_validity() { - Box::new(TakeRandomArrayValues:: { arr }) - } else { - Box::new(TakeRandomArray:: { arr }) - } - } else { - let t = TakeRandomChunked:: { - chunks: chunks.collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - Box::new(t) + match self.layout() { + ChunkedArrayLayout::SingleNoNull(arr) => Box::new(NonNull(arr)), + ChunkedArrayLayout::Single(arr) => Box::new(arr), + ChunkedArrayLayout::MultiNoNull(ca) => Box::new(NonNull(ca)), + ChunkedArrayLayout::Multi(ca) => Box::new(ca), } } } +#[cfg(feature = "dtype-categorical")] +struct LocalCategorical<'a> { + rev_map: &'a Utf8Array, + cats: &'a UInt32Chunked, +} + +#[cfg(feature = "dtype-categorical")] +impl<'a> GetInner for LocalCategorical<'a> { + type Item = Option<&'a str>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + let cat = self.cats.get_unchecked(idx)?; + Some(self.rev_map.value_unchecked(cat as usize)) + } +} + +#[cfg(feature = "dtype-categorical")] +struct GlobalCategorical<'a> { + p1: &'a PlHashMap, + p2: &'a Utf8Array, + cats: &'a UInt32Chunked, +} + +#[cfg(feature = "dtype-categorical")] +impl<'a> GetInner for GlobalCategorical<'a> { + type Item = Option<&'a str>; + unsafe fn get_unchecked(&self, idx: usize) -> Self::Item { + let cat = self.cats.get_unchecked(idx)?; + let idx = self.p1.get(&cat).unwrap(); + Some(self.p2.value_unchecked(*idx as usize)) + } +} + #[cfg(feature = "dtype-categorical")] impl<'a> IntoPartialOrdInner<'a> for &'a CategoricalChunked { fn into_partial_ord_inner(self) -> Box { + let cats = self.logical(); match &**self.get_rev_map() { - RevMapping::Local(_) => Box::new(CategoricalTakeRandomLocal::new(self)), - RevMapping::Global(_, _, _) => Box::new(CategoricalTakeRandomGlobal::new(self)), + RevMapping::Global(p1, p2, _) => Box::new(GlobalCategorical { p1, p2, cats }), + RevMapping::Local(rev_map) => Box::new(LocalCategorical { rev_map, cats }), } } } diff --git a/crates/polars-core/src/chunked_array/ops/filter.rs b/crates/polars-core/src/chunked_array/ops/filter.rs index 7543cff66583..8a50021147a1 100644 --- a/crates/polars-core/src/chunked_array/ops/filter.rs +++ b/crates/polars-core/src/chunked_array/ops/filter.rs @@ -104,6 +104,7 @@ impl ChunkFilter for ListChunked { )), }; } + check_filter_len!(self, filter); Ok(unsafe { arity::binary_unchecked_same_type( self, @@ -129,6 +130,7 @@ impl ChunkFilter for ArrayChunked { )), }; } + check_filter_len!(self, filter); Ok(unsafe { arity::binary_unchecked_same_type( self, @@ -157,7 +159,7 @@ where _ => Ok(ObjectChunked::new_empty(self.name())), }; } - polars_ensure!(!self.is_empty(), NoData: "cannot filter empty object array"); + check_filter_len!(self, filter); let chunks = self.downcast_iter().collect::>(); let mut builder = ObjectChunkedBuilder::::new(self.name(), self.len()); for (idx, mask) in filter.into_iter().enumerate() { diff --git a/crates/polars-core/src/chunked_array/ops/for_each.rs b/crates/polars-core/src/chunked_array/ops/for_each.rs new file mode 100644 index 000000000000..42713e0cdff2 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/for_each.rs @@ -0,0 +1,15 @@ +use crate::prelude::*; + +impl ChunkedArray +where + T: PolarsDataType, +{ + pub fn for_each<'a, F>(&'a self, mut op: F) + where + F: FnMut(Option>), + { + self.downcast_iter().for_each(|arr| { + arr.iter().for_each(&mut op); + }) + } +} diff --git a/crates/polars-core/src/chunked_array/ops/gather.rs b/crates/polars-core/src/chunked_array/ops/gather.rs new file mode 100644 index 000000000000..eef8f5fbe042 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/gather.rs @@ -0,0 +1,159 @@ +use arrow::array::Array; +use polars_error::{polars_bail, polars_ensure, PolarsResult}; + +use crate::chunked_array::ops::{ChunkTake, ChunkTakeUnchecked}; +use crate::chunked_array::ChunkedArray; +use crate::datatypes::{IdxCa, PolarsDataType, StaticArray}; +use crate::prelude::*; +use crate::utils::index_to_chunked_index; + +impl + ?Sized> ChunkTake for ChunkedArray +where + ChunkedArray: ChunkTakeUnchecked, +{ + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &I) -> PolarsResult { + let len = self.len(); + let all_valid = indices.as_ref().iter().all(|i| (*i as usize) < len); + polars_ensure!(all_valid, ComputeError: "invalid index in gather"); + + // SAFETY: we just checked the indices are valid. + Ok(unsafe { self.take_unchecked(indices) }) + } +} + +impl ChunkTake for ChunkedArray +where + ChunkedArray: ChunkTakeUnchecked, +{ + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &IdxCa) -> PolarsResult { + let len = self.len(); + let all_valid = indices.downcast_iter().all(|a| { + if a.null_count() == 0 { + a.values_iter().all(|i| (*i as usize) < len) + } else { + a.iter().flatten().all(|i| (*i as usize) < len) + } + }); + polars_ensure!(all_valid, ComputeError: "take indices are out of bounds"); + + // SAFETY: we just checked the indices are valid. + Ok(unsafe { self.take_unchecked(indices) }) + } +} + +unsafe fn target_value_unchecked<'a, A: StaticArray>( + targets: &[&'a A], + idx: IdxSize, +) -> A::ValueT<'a> { + let (chunk_idx, arr_idx) = + index_to_chunked_index(targets.iter().map(|a| a.len()), idx as usize); + let arr = targets.get_unchecked(chunk_idx); + arr.value_unchecked(arr_idx) +} + +unsafe fn target_get_unchecked<'a, A: StaticArray>( + targets: &[&'a A], + idx: IdxSize, +) -> Option> { + let (chunk_idx, arr_idx) = + index_to_chunked_index(targets.iter().map(|a| a.len()), idx as usize); + let arr = targets.get_unchecked(chunk_idx); + arr.get_unchecked(arr_idx) +} + +unsafe fn gather_idx_array_unchecked( + dtype: DataType, + targets: &[&A], + has_nulls: bool, + indices: &[IdxSize], +) -> A { + let it = indices.iter().copied(); + if targets.len() == 1 { + let arr = targets.iter().next().unwrap(); + if has_nulls { + it.map(|i| arr.get_unchecked(i as usize)) + .collect_arr_with_dtype(dtype) + } else { + it.map(|i| arr.value_unchecked(i as usize)) + .collect_arr_with_dtype(dtype) + } + } else if has_nulls { + it.map(|i| target_get_unchecked(targets, i)) + .collect_arr_with_dtype(dtype) + } else { + it.map(|i| target_value_unchecked(targets, i)) + .collect_arr_with_dtype(dtype) + } +} + +impl + ?Sized> ChunkTakeUnchecked for ChunkedArray { + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &I) -> Self { + let targets: Vec<_> = self.downcast_iter().collect(); + let arr = gather_idx_array_unchecked( + self.dtype().clone(), + &targets, + self.null_count() > 0, + indices.as_ref(), + ); + ChunkedArray::from_chunk_iter_like(self, [arr]) + } +} + +impl ChunkTakeUnchecked for ChunkedArray { + /// Gather values from ChunkedArray by index. + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self { + let targets_have_nulls = self.null_count() > 0; + let targets: Vec<_> = self.downcast_iter().collect(); + + let chunks = indices.downcast_iter().map(|idx_arr| { + if idx_arr.null_count() == 0 { + gather_idx_array_unchecked( + self.dtype().clone(), + &targets, + targets_have_nulls, + idx_arr.values(), + ) + } else if targets.len() == 1 { + let target = targets.first().unwrap(); + if targets_have_nulls { + idx_arr + .iter() + .map(|i| target.get_unchecked(*i? as usize)) + .collect_arr_with_dtype(self.dtype().clone()) + } else { + idx_arr + .iter() + .map(|i| Some(target.value_unchecked(*i? as usize))) + .collect_arr_with_dtype(self.dtype().clone()) + } + } else if targets_have_nulls { + idx_arr + .iter() + .map(|i| target_get_unchecked(&targets, *i?)) + .collect_arr_with_dtype(self.dtype().clone()) + } else { + idx_arr + .iter() + .map(|i| Some(target_value_unchecked(&targets, *i?))) + .collect_arr_with_dtype(self.dtype().clone()) + } + }); + + let mut out = ChunkedArray::from_chunk_iter_like(self, chunks); + + use crate::series::IsSorted::*; + let sorted_flag = match (self.is_sorted_flag(), indices.is_sorted_flag()) { + (_, Not) => Not, + (Not, _) => Not, + (Ascending, Ascending) => Ascending, + (Ascending, Descending) => Descending, + (Descending, Ascending) => Descending, + (Descending, Descending) => Ascending, + }; + out.set_sorted_flag(sorted_flag); + out + } +} diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 8107b287c05f..b1cb0ed9ec22 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -2,7 +2,6 @@ use arrow::offset::OffsetsBuffer; use polars_arrow::prelude::QuantileInterpolOptions; -pub use self::take::*; #[cfg(feature = "object")] use crate::datatypes::ObjectType; use crate::prelude::*; @@ -29,7 +28,9 @@ mod explode_and_offsets; mod extend; mod fill_null; mod filter; +mod for_each; pub mod full; +pub mod gather; #[cfg(feature = "interpolate")] mod interpolate; mod len; @@ -46,6 +47,7 @@ mod shift; pub mod sort; pub(crate) mod take; mod tile; +#[cfg(feature = "algorithm_group_by")] pub(crate) mod unique; #[cfg(feature = "zip_with")] pub mod zip; @@ -144,86 +146,19 @@ pub trait ChunkRollApply: AsRefDataType { } } -/// Random access -pub trait TakeRandom { - type Item; - - /// Get a nullable value by index. - /// - /// # Panics - /// Panics if `index >= self.len()` - fn get(&self, index: usize) -> Option; - - /// Get a value by index and ignore the null bit. - /// - /// # Safety - /// - /// Does not do bound checks. - unsafe fn get_unchecked(&self, index: usize) -> Option +pub trait ChunkTake: ChunkTakeUnchecked { + /// Gather values from ChunkedArray by index. + fn take(&self, indices: &Idx) -> PolarsResult where - Self: Sized, - { - self.get(index) - } - - /// This is much faster if we have many chunks as we don't have to compute the index - /// # Panics - /// Panics if `index >= self.len()` - fn last(&self) -> Option; -} -// Utility trait because associated type needs a lifetime -pub trait TakeRandomUtf8 { - type Item; - - /// Get a nullable value by index. - /// - /// # Panics - /// Panics if `index >= self.len()` - fn get(self, index: usize) -> Option; - - /// Get a value by index and ignore the null bit. - /// - /// # Safety - /// - /// Does not do bound checks. - unsafe fn get_unchecked(self, index: usize) -> Option - where - Self: Sized, - { - self.get(index) - } - - /// This is much faster if we have many chunks - /// # Panics - /// Panics if `index >= self.len()` - fn last(&self) -> Option; -} - -/// Fast access by index. -pub trait ChunkTake: ChunkTakeUnchecked { - /// Take values from ChunkedArray by index. - /// Note that the iterator will be cloned, so prefer an iterator that takes the owned memory - /// by reference. - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: Sized, - I: TakeIterator, - INulls: TakeIteratorNulls; + Self: Sized; } -/// Fast access by index. -pub trait ChunkTakeUnchecked { - /// Take values from ChunkedArray by index. +pub trait ChunkTakeUnchecked { + /// Gather values from ChunkedArray by index. /// /// # Safety - /// - /// Doesn't do any bound checking. - #[must_use] - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: Sized, - I: TakeIterator, - INulls: TakeIteratorNulls; + /// The non-null indices must be valid. + unsafe fn take_unchecked(&self, indices: &Idx) -> Self; } /// Create a `ChunkedArray` with new values by index or by boolean mask. @@ -594,7 +529,7 @@ macro_rules! impl_chunk_expand { impl ChunkExpandAtIndex for ChunkedArray where - ChunkedArray: ChunkFull + TakeRandom, + ChunkedArray: ChunkFull, { fn new_from_index(&self, index: usize, length: usize) -> ChunkedArray { let mut out = impl_chunk_expand!(self, length, index); @@ -629,7 +564,7 @@ impl ChunkExpandAtIndex for BinaryChunked { impl ChunkExpandAtIndex for ListChunked { fn new_from_index(&self, index: usize, length: usize) -> ListChunked { - let opt_val = self.get(index); + let opt_val = self.get_as_series(index); match opt_val { Some(val) => { let mut ca = ListChunked::full(self.name(), &val, length); @@ -644,7 +579,7 @@ impl ChunkExpandAtIndex for ListChunked { #[cfg(feature = "dtype-array")] impl ChunkExpandAtIndex for ArrayChunked { fn new_from_index(&self, index: usize, length: usize) -> ArrayChunked { - let opt_val = self.get(index); + let opt_val = self.get_as_series(index); match opt_val { Some(val) => { let mut ca = ArrayChunked::full(self.name(), &val, length); @@ -723,19 +658,19 @@ pub trait RepeatBy { } } -#[cfg(feature = "is_first")] +#[cfg(feature = "is_first_distinct")] /// Mask the first unique values as `true` -pub trait IsFirst { - fn is_first(&self) -> PolarsResult { - polars_bail!(opq = is_first, T::get_dtype()); +pub trait IsFirstDistinct { + fn is_first_distinct(&self) -> PolarsResult { + polars_bail!(opq = is_first_distinct, T::get_dtype()); } } -#[cfg(feature = "is_last")] +#[cfg(feature = "is_last_distinct")] /// Mask the last unique values as `true` -pub trait IsLast { - fn is_last(&self) -> PolarsResult { - polars_bail!(opq = is_last, T::get_dtype()); +pub trait IsLastDistinct { + fn is_last_distinct(&self) -> PolarsResult { + polars_bail!(opq = is_last_distinct, T::get_dtype()); } } diff --git a/crates/polars-core/src/chunked_array/ops/reverse.rs b/crates/polars-core/src/chunked_array/ops/reverse.rs index 4f950a8b78d0..d658991058fa 100644 --- a/crates/polars-core/src/chunked_array/ops/reverse.rs +++ b/crates/polars-core/src/chunked_array/ops/reverse.rs @@ -82,8 +82,7 @@ impl ChunkReverse for ArrayChunked { #[cfg(feature = "object")] impl ChunkReverse for ObjectChunked { fn reverse(&self) -> Self { - // Safety - // we we know we don't get out of bounds - unsafe { self.take_unchecked((0..self.len()).rev().into()) } + // SAFETY: we know we don't go out of bounds. + unsafe { self.take_unchecked(&(0..self.len() as IdxSize).rev().collect_ca("")) } } } diff --git a/crates/polars-core/src/chunked_array/ops/take/mod.rs b/crates/polars-core/src/chunked_array/ops/take/mod.rs index 75fb48b7424d..ccb11d118ba3 100644 --- a/crates/polars-core/src/chunked_array/ops/take/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/take/mod.rs @@ -1,506 +1,9 @@ //! Traits to provide fast Random access to ChunkedArrays data. //! This prevents downcasting every iteration. -//! IntoTakeRandom provides structs that implement the TakeRandom trait. -//! There are several structs that implement the fastest path for random access. -//! -use std::borrow::Cow; -use polars_arrow::compute::take::*; -pub use take_random::*; -pub use traits::*; - -use crate::chunked_array::kernels::take::*; use crate::prelude::*; use crate::utils::NoNull; mod take_chunked; -mod take_every; -pub(crate) mod take_random; -pub(crate) mod take_single; -mod traits; #[cfg(feature = "chunked_ids")] pub(crate) use take_chunked::*; - -macro_rules! take_iter_n_chunks_unchecked { - ($cat:ty, $ca:expr, $indices:expr) => {{ - let taker = $ca.take_rand(); - $indices - .into_iter() - .map(|idx| taker.get_unchecked(idx)) - .collect::<$cat>() - }}; -} - -macro_rules! take_opt_iter_n_chunks_unchecked { - ($cat:ty, $ca:expr, $indices:expr) => {{ - let taker = $ca.take_rand(); - $indices - .into_iter() - .map(|opt_idx| taker.get_unchecked(opt_idx?)) - .collect::<$cat>() - }}; -} - -impl ChunkedArray -where - T: PolarsDataType, -{ - fn finish_from_array(&self, array: Box) -> Self { - let keep_fast_explode = array.null_count() == 0; - unsafe { self.copy_with_chunks(vec![array], false, keep_fast_explode) } - } -} - -impl ChunkTake for ChunkedArray -where - ChunkedArray: ChunkTakeUnchecked, -{ - fn take(&self, indices: TakeIdx) -> PolarsResult - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - indices.check_bounds(self.len())?; - // SAFETY: just checked bounds. - Ok(unsafe { self.take_unchecked(indices) }) - } -} - -impl ChunkTakeUnchecked for ChunkedArray -where - T: PolarsNumericType, -{ - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let mut chunks = self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - let array = match (self.null_count(), self.chunks.len()) { - (0, 1) => { - take_no_null_primitive_unchecked::(chunks.next().unwrap(), array) - as ArrayRef - }, - (_, 1) => take_primitive_unchecked::(chunks.next().unwrap(), array) - as ArrayRef, - _ => { - return if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - let mut ca = take_primitive_iter_n_chunks(self, iter); - ca.rename(self.name()); - ca - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let mut ca = take_primitive_opt_iter_n_chunks(self, iter); - ca.rename(self.name()); - ca - } - }, - }; - self.finish_from_array(array) - }, - TakeIdx::Iter(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => take_no_null_primitive_iter_unchecked::( - chunks.next().unwrap(), - iter, - ) as ArrayRef, - (_, 1) => { - take_primitive_iter_unchecked::(chunks.next().unwrap(), iter) - as ArrayRef - }, - _ => { - let mut ca = take_primitive_iter_n_chunks(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - TakeIdx::IterNulls(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => take_no_null_primitive_opt_iter_unchecked::( - chunks.next().unwrap(), - iter, - ) as ArrayRef, - (_, 1) => take_primitive_opt_iter_unchecked::( - chunks.next().unwrap(), - iter, - ) as ArrayRef, - _ => { - let mut ca = take_primitive_opt_iter_n_chunks(self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - } - } -} - -impl ChunkTakeUnchecked for BooleanChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let mut chunks = self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - let array = match self.chunks.len() { - 1 => take::take_unchecked(chunks.next().unwrap(), array), - _ => { - return if !array.has_validity() { - let taker = self.take_rand(); - array - .values_iter() - .map(|i| taker.get(*i as usize)) - .collect::() - .with_name(self.name()) - } else { - let taker = self.take_rand(); - array - .into_iter() - .map(|opt_idx| taker.get(*(opt_idx?) as usize)) - .collect::() - .with_name(self.name()) - } - }, - }; - self.finish_from_array(array) - }, - TakeIdx::Iter(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_bool_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - (_, 1) => take_bool_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef, - _ => { - return take_iter_n_chunks_unchecked!(Self, self, iter) - .with_name(self.name()); - }, - }; - self.finish_from_array(array) - }, - TakeIdx::IterNulls(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_bool_opt_iter_unchecked(chunks.next().unwrap(), iter) - as ArrayRef - }, - (_, 1) => { - take_bool_opt_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - _ => { - let mut ca: BooleanChunked = - take_opt_iter_n_chunks_unchecked!(Self, self, iter); - ca.rename(self.name()); - return ca; - }, - }; - self.finish_from_array(array) - }, - } - } -} - -impl ChunkTakeUnchecked for Utf8Chunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - self.as_binary().take_unchecked(indices).to_utf8() - } -} - -impl ChunkTakeUnchecked for BinaryChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let mut chunks = self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - let array = match self.chunks.len() { - 1 => take_binary_unchecked(chunks.next().unwrap(), array) as ArrayRef, - _ => { - return if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - take_iter_n_chunks_unchecked!(Self, self, iter).with_name(self.name()) - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - take_opt_iter_n_chunks_unchecked!(Self, self, iter) - .with_name(self.name()) - } - }, - }; - self.finish_from_array(array) - }, - TakeIdx::Iter(iter) => { - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_binary_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - (_, 1) => take_binary_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef, - _ => { - return take_iter_n_chunks_unchecked!(Self, self, iter) - .with_name(self.name()); - }, - }; - self.finish_from_array(array) - }, - TakeIdx::IterNulls(iter) => { - let array = match (self.has_validity(), self.chunks.len()) { - (false, 1) => { - take_no_null_binary_opt_iter_unchecked(chunks.next().unwrap(), iter) - as ArrayRef - }, - (_, 1) => { - take_binary_opt_iter_unchecked(chunks.next().unwrap(), iter) as ArrayRef - }, - _ => { - return take_opt_iter_n_chunks_unchecked!(Self, self, iter) - .with_name(self.name()); - }, - }; - self.finish_from_array(array) - }, - } - } -} - -impl ChunkTakeUnchecked for ListChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let ca_self = if self.is_nested() { - Cow::Owned(self.rechunk()) - } else { - Cow::Borrowed(self) - }; - let mut chunks = ca_self.downcast_iter(); - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null_with_dtype( - self.name(), - array.len(), - &self.inner_dtype(), - ); - } - let array = match ca_self.chunks.len() { - 1 => Box::new(take_list_unchecked(chunks.next().unwrap(), array)) as ArrayRef, - _ => { - if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - let mut ca = - take_iter_n_chunks_unchecked!(Self, ca_self.as_ref(), iter); - ca.chunks.pop().unwrap() - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let mut ca = - take_opt_iter_n_chunks_unchecked!(Self, ca_self.as_ref(), iter); - ca.chunks.pop().unwrap() - } - }, - }; - self.finish_from_array(array) - }, - // todo! fast path for single chunk - TakeIdx::Iter(iter) => { - if ca_self.chunks.len() == 1 { - let idx: NoNull = iter.map(|v| v as IdxSize).collect(); - ca_self.take_unchecked((&idx.into_inner()).into()) - } else { - let mut ca = take_iter_n_chunks_unchecked!(Self, ca_self.as_ref(), iter); - self.finish_from_array(ca.chunks.pop().unwrap()) - } - }, - TakeIdx::IterNulls(iter) => { - if ca_self.chunks.len() == 1 { - let idx: IdxCa = iter.map(|v| v.map(|v| v as IdxSize)).collect(); - ca_self.take_unchecked((&idx).into()) - } else { - let mut ca = take_opt_iter_n_chunks_unchecked!(Self, ca_self.as_ref(), iter); - self.finish_from_array(ca.chunks.pop().unwrap()) - } - }, - } - } -} - -#[cfg(feature = "dtype-array")] -impl ChunkTakeUnchecked for ArrayChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - let ca_self = self.rechunk(); - match indices { - TakeIdx::Array(idx_array) => { - if idx_array.null_count() == idx_array.len() { - return Self::full_null_with_dtype( - self.name(), - idx_array.len(), - &self.inner_dtype(), - ca_self.width(), - ); - } - let arr = self.chunks[0].as_ref(); - let arr = take_unchecked(arr, idx_array); - self.finish_from_array(arr) - }, - TakeIdx::Iter(iter) => { - let idx: NoNull = iter.map(|v| v as IdxSize).collect(); - ca_self.take_unchecked((&idx.into_inner()).into()) - }, - TakeIdx::IterNulls(iter) => { - let idx: IdxCa = iter.map(|v| v.map(|v| v as IdxSize)).collect(); - ca_self.take_unchecked((&idx).into()) - }, - } - } -} - -#[cfg(feature = "object")] -impl ChunkTakeUnchecked for ObjectChunked { - unsafe fn take_unchecked(&self, indices: TakeIdx) -> Self - where - Self: std::marker::Sized, - I: TakeIterator, - INulls: TakeIteratorNulls, - { - // current implementation is suboptimal, every iterator is allocated to UInt32Array - match indices { - TakeIdx::Array(array) => { - if array.null_count() == array.len() { - return Self::full_null(self.name(), array.len()); - } - - match self.chunks.len() { - 1 => { - let values = self.downcast_chunks().get(0).unwrap().values(); - - let mut ca: Self = array - .into_iter() - .map(|opt_idx| { - opt_idx.map(|idx| values.get_unchecked(*idx as usize).clone()) - }) - .collect(); - ca.rename(self.name()); - ca - }, - _ => { - return if !array.has_validity() { - let iter = array.values().iter().map(|i| *i as usize); - - let taker = self.take_rand(); - let mut ca: ObjectChunked = - iter.map(|idx| taker.get_unchecked(idx).cloned()).collect(); - ca.rename(self.name()); - ca - } else { - let iter = array - .into_iter() - .map(|opt_idx| opt_idx.map(|idx| *idx as usize)); - let taker = self.take_rand(); - - let mut ca: ObjectChunked = iter - .map(|opt_idx| { - opt_idx.and_then(|idx| taker.get_unchecked(idx).cloned()) - }) - .collect(); - - ca.rename(self.name()); - ca - } - }, - } - }, - TakeIdx::Iter(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - - let taker = self.take_rand(); - let mut ca: ObjectChunked = - iter.map(|idx| taker.get_unchecked(idx).cloned()).collect(); - ca.rename(self.name()); - ca - }, - TakeIdx::IterNulls(iter) => { - if self.is_empty() { - return Self::full_null(self.name(), iter.size_hint().0); - } - let taker = self.take_rand(); - - let mut ca: ObjectChunked = iter - .map(|opt_idx| opt_idx.and_then(|idx| taker.get(idx).cloned())) - .collect(); - - ca.rename(self.name()); - ca - }, - } - } -} - -#[cfg(test)] -mod test { - use crate::prelude::*; - - #[test] - fn test_take_random() { - let ca = Int32Chunked::from_slice("a", &[1, 2, 3]); - assert_eq!(ca.get(0), Some(1)); - assert_eq!(ca.get(1), Some(2)); - assert_eq!(ca.get(2), Some(3)); - - let ca = Utf8Chunked::from_slice("a", &["a", "b", "c"]); - assert_eq!(ca.get(0), Some("a")); - assert_eq!(ca.get(1), Some("b")); - assert_eq!(ca.get(2), Some("c")); - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/take_every.rs b/crates/polars-core/src/chunked_array/ops/take/take_every.rs deleted file mode 100644 index 6401eb2f133d..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/take_every.rs +++ /dev/null @@ -1,11 +0,0 @@ -use crate::prelude::*; - -impl Series { - /// Traverse and collect every nth element in a new array. - pub fn take_every(&self, n: usize) -> Series { - let mut idx = (0..self.len()).step_by(n); - - // safety: we are in bounds - unsafe { self.take_iter_unchecked(&mut idx) } - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/take_random.rs b/crates/polars-core/src/chunked_array/ops/take/take_random.rs deleted file mode 100644 index 996b44e0c5ac..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/take_random.rs +++ /dev/null @@ -1,183 +0,0 @@ -use arrow::array::Array; - -use crate::prelude::*; -use crate::utils::index_to_chunked_index; - -/// Create a type that implements a faster `TakeRandom`. -pub trait IntoTakeRandom<'a> { - type Item; - type TakeRandom; - /// Create a type that implements `TakeRandom`. - fn take_rand(&self) -> Self::TakeRandom; -} - -pub enum TakeRandBranch3 { - SingleNoNull(N), - Single(S), - Multi(M), -} - -impl TakeRandom for TakeRandBranch3 -where - N: TakeRandom, - S: TakeRandom, - M: TakeRandom, -{ - type Item = I; - - #[inline] - fn get(&self, index: usize) -> Option { - match self { - Self::SingleNoNull(s) => s.get(index), - Self::Single(s) => s.get(index), - Self::Multi(m) => m.get(index), - } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - match self { - Self::SingleNoNull(s) => s.get_unchecked(index), - Self::Single(s) => s.get_unchecked(index), - Self::Multi(m) => m.get_unchecked(index), - } - } - - fn last(&self) -> Option { - match self { - Self::SingleNoNull(s) => s.last(), - Self::Single(s) => s.last(), - Self::Multi(m) => m.last(), - } - } -} - -#[allow(clippy::type_complexity)] -impl<'a, T> IntoTakeRandom<'a> for &'a ChunkedArray -where - T: PolarsDataType, -{ - type Item = T::Physical<'a>; - type TakeRandom = TakeRandBranch3< - TakeRandomArrayValues<'a, T>, - TakeRandomArray<'a, T>, - TakeRandomChunked<'a, T>, - >; - - #[inline] - fn take_rand(&self) -> Self::TakeRandom { - let mut chunks = self.downcast_iter(); - if self.chunks.len() == 1 { - let arr = chunks.next().unwrap(); - - if !self.has_validity() { - let t = TakeRandomArrayValues { arr }; - TakeRandBranch3::SingleNoNull(t) - } else { - let t = TakeRandomArray { arr }; - TakeRandBranch3::Single(t) - } - } else { - let t = TakeRandomChunked { - chunks: chunks.collect(), - chunk_lens: self.chunks.iter().map(|a| a.len() as IdxSize).collect(), - }; - TakeRandBranch3::Multi(t) - } - } -} - -pub struct TakeRandomChunked<'a, T> -where - T: PolarsDataType, -{ - pub(crate) chunks: Vec<&'a T::Array>, - pub(crate) chunk_lens: Vec, -} - -impl<'a, T> TakeRandom for TakeRandomChunked<'a, T> -where - T: PolarsDataType, -{ - type Item = T::Physical<'a>; - - #[inline] - fn get(&self, index: usize) -> Option { - let (chunk_idx, arr_idx) = - index_to_chunked_index(self.chunk_lens.iter().copied(), index as IdxSize); - let arr = self.chunks.get(chunk_idx as usize)?; - - // SAFETY: if index_to_chunked_index returns a valid chunk_idx, we know - // that arr_idx < arr.len(). - unsafe { arr.get_unchecked(arr_idx as usize) } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - let (chunk_idx, arr_idx) = - index_to_chunked_index(self.chunk_lens.iter().copied(), index as IdxSize); - - unsafe { - // SAFETY: up to the caller to make sure the index is valid. - self.chunks - .get_unchecked(chunk_idx as usize) - .get_unchecked(arr_idx as usize) - } - } - - fn last(&self) -> Option { - self.chunks - .last() - .and_then(|arr| arr.get(arr.len().saturating_sub(1))) - } -} - -pub struct TakeRandomArrayValues<'a, T: PolarsDataType> { - pub(crate) arr: &'a T::Array, -} - -impl<'a, T> TakeRandom for TakeRandomArrayValues<'a, T> -where - T: PolarsDataType, -{ - type Item = T::Physical<'a>; - - #[inline] - fn get(&self, index: usize) -> Option { - Some(self.arr.value(index)) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - Some(self.arr.value_unchecked(index)) - } - - fn last(&self) -> Option { - self.arr.last() - } -} - -pub struct TakeRandomArray<'a, T: PolarsDataType> { - pub(crate) arr: &'a T::Array, -} - -impl<'a, T> TakeRandom for TakeRandomArray<'a, T> -where - T: PolarsDataType, -{ - type Item = T::Physical<'a>; - - #[inline] - fn get(&self, index: usize) -> Option { - self.arr.get(index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - self.arr.get_unchecked(index) - } - - fn last(&self) -> Option { - self.arr.last() - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/take_single.rs b/crates/polars-core/src/chunked_array/ops/take/take_single.rs deleted file mode 100644 index b5494d337471..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/take_single.rs +++ /dev/null @@ -1,338 +0,0 @@ -use arrow::array::*; -use polars_arrow::is_valid::IsValid; - -#[cfg(feature = "object")] -use crate::chunked_array::object::ObjectArray; -use crate::prelude::*; - -macro_rules! impl_take_random_get { - ($self:ident, $index:ident, $array_type:ty) => {{ - assert!($index < $self.len()); - let (chunk_idx, idx) = $self.index_to_chunked_index($index); - // Safety: - // bounds are checked above - let arr = $self.chunks.get_unchecked(chunk_idx); - - // Safety: - // caller should give right array type - let arr = &*(arr as *const ArrayRef as *const Box<$array_type>); - - // Safety: - // index should be in bounds - if arr.is_valid(idx) { - Some(arr.value_unchecked(idx)) - } else { - None - } - }}; -} - -macro_rules! impl_take_random_get_unchecked { - ($self:ident, $index:ident, $array_type:ty) => {{ - let (chunk_idx, idx) = $self.index_to_chunked_index($index); - debug_assert!(chunk_idx < $self.chunks.len()); - // Safety: - // bounds are checked above - let arr = $self.chunks.get_unchecked(chunk_idx); - - // Safety: - // caller should give right array type - let arr = &*(&**arr as *const dyn Array as *const $array_type); - - // Safety: - // index should be in bounds - debug_assert!(idx < arr.len()); - if arr.is_valid_unchecked(idx) { - Some(arr.value_unchecked(idx)) - } else { - None - } - }}; -} - -impl TakeRandom for ChunkedArray -where - T: PolarsNumericType, -{ - type Item = T::Native; - - #[inline] - fn get(&self, index: usize) -> Option { - unsafe { impl_take_random_get!(self, index, PrimitiveArray) } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, PrimitiveArray) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a, T> TakeRandom for &'a ChunkedArray -where - T: PolarsNumericType, -{ - type Item = T::Native; - - #[inline] - fn get(&self, index: usize) -> Option { - (*self).get(index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - (*self).get_unchecked(index) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl TakeRandom for BooleanChunked { - type Item = bool; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, BooleanArray) } - } - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, BooleanArray) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a> TakeRandom for &'a BooleanChunked { - type Item = bool; - - #[inline] - fn get(&self, index: usize) -> Option { - (*self).get(index) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - (*self).get_unchecked(index) - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a> TakeRandom for &'a Utf8Chunked { - type Item = &'a str; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, LargeStringArray) } - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl<'a> TakeRandom for &'a BinaryChunked { - type Item = &'a [u8]; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, LargeBinaryArray) } - } - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -// extra trait such that it also works without extra reference. -// Autoref will insert the reference and -impl<'a> TakeRandomUtf8 for &'a Utf8Chunked { - type Item = &'a str; - - #[inline] - fn get(self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, LargeStringArray) } - } - - #[inline] - unsafe fn get_unchecked(self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, LargeStringArray) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -#[cfg(feature = "object")] -impl<'a, T: PolarsObject> TakeRandom for &'a ObjectChunked { - type Item = &'a T; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - unsafe { impl_take_random_get!(self, index, ObjectArray) } - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - impl_take_random_get_unchecked!(self, index, ObjectArray) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1) - } else { - None - } - } -} - -impl TakeRandom for ListChunked { - type Item = Series; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - let opt_arr = unsafe { impl_take_random_get!(self, index, LargeListArray) }; - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - let opt_arr = impl_take_random_get_unchecked!(self, index, LargeListArray); - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1).map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } else { - None - } - } -} - -#[cfg(feature = "dtype-array")] -impl TakeRandom for ArrayChunked { - type Item = Series; - - #[inline] - fn get(&self, index: usize) -> Option { - // Safety: - // Out of bounds is checked and downcast is of correct type - let opt_arr = unsafe { impl_take_random_get!(self, index, FixedSizeListArray) }; - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - #[inline] - unsafe fn get_unchecked(&self, index: usize) -> Option { - let opt_arr = impl_take_random_get_unchecked!(self, index, FixedSizeListArray); - opt_arr.map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } - - fn last(&self) -> Option { - let chunks = self.downcast_chunks(); - let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap(); - if arr.len() > 0 { - arr.get(arr.len() - 1).map(|arr| unsafe { - Series::from_chunks_and_dtype_unchecked( - self.name(), - vec![arr], - &self.inner_dtype().to_physical(), - ) - }) - } else { - None - } - } -} diff --git a/crates/polars-core/src/chunked_array/ops/take/traits.rs b/crates/polars-core/src/chunked_array/ops/take/traits.rs deleted file mode 100644 index 818681f831e6..000000000000 --- a/crates/polars-core/src/chunked_array/ops/take/traits.rs +++ /dev/null @@ -1,210 +0,0 @@ -//! Traits that indicate the allowed arguments in a ChunkedArray::take operation. -use crate::frame::group_by::GroupsProxyIter; -use crate::prelude::*; - -// Utility traits -pub trait TakeIterator: Iterator + TrustedLen { - fn check_bounds(&self, bound: usize) -> PolarsResult<()>; - // a sort of clone - fn boxed_clone(&self) -> Box; -} -pub trait TakeIteratorNulls: Iterator> + TrustedLen { - fn check_bounds(&self, bound: usize) -> PolarsResult<()>; - - fn boxed_clone(&self) -> Box; -} - -unsafe impl TrustedLen for &mut dyn TakeIterator {} -unsafe impl TrustedLen for &mut dyn TakeIteratorNulls {} -unsafe impl TrustedLen for GroupsProxyIter<'_> {} - -// Implement for the ref as well -impl TakeIterator for &mut dyn TakeIterator { - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - (**self).check_bounds(bound) - } - - fn boxed_clone(&self) -> Box { - (**self).boxed_clone() - } -} -impl TakeIteratorNulls for &mut dyn TakeIteratorNulls { - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - (**self).check_bounds(bound) - } - - fn boxed_clone(&self) -> Box { - (**self).boxed_clone() - } -} - -// Clonable iterators may implement the traits above -impl TakeIterator for I -where - I: Iterator + Clone + Sized + TrustedLen, -{ - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - // clone so that the iterator can be used again. - let iter = self.clone(); - let mut inbounds = true; - - for i in iter { - if i >= bound { - // we will not break here as that prevents SIMD - inbounds = false; - } - } - polars_ensure!(inbounds, ComputeError: "take indices are out of bounds"); - Ok(()) - } - - fn boxed_clone(&self) -> Box { - Box::new(self.clone()) - } -} -impl TakeIteratorNulls for I -where - I: Iterator> + Clone + Sized + TrustedLen, -{ - fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - // clone so that the iterator can be used again. - let iter = self.clone(); - let mut inbounds = true; - - for i in iter.flatten() { - if i >= bound { - // we will not break here as that prevents SIMD - inbounds = false; - } - } - polars_ensure!(inbounds, ComputeError: "take indices are out of bounds"); - Ok(()) - } - - fn boxed_clone(&self) -> Box { - Box::new(self.clone()) - } -} - -/// One of the three arguments allowed in unchecked_take -pub enum TakeIdx<'a, I, INulls> -where - I: TakeIterator, - INulls: TakeIteratorNulls, -{ - Array(&'a IdxArr), - Iter(I), - // will return a null where None - IterNulls(INulls), -} - -impl<'a, I, INulls> TakeIdx<'a, I, INulls> -where - I: TakeIterator, - INulls: TakeIteratorNulls, -{ - pub(crate) fn check_bounds(&self, bound: usize) -> PolarsResult<()> { - match self { - TakeIdx::Iter(i) => i.check_bounds(bound), - TakeIdx::IterNulls(i) => i.check_bounds(bound), - TakeIdx::Array(arr) => { - let values = arr.values().as_slice(); - let mut inbounds = true; - let len = bound as IdxSize; - if arr.null_count() == 0 { - for &i in values { - // we will not break here as that prevents SIMD - if i >= len { - inbounds = false; - } - } - } else { - for opt_v in *arr { - match opt_v { - Some(&v) if v >= len => { - inbounds = false; - }, - _ => {}, - } - } - } - polars_ensure!(inbounds, ComputeError: "take indices are out of bounds"); - Ok(()) - }, - } - } -} - -/// Dummy type, we need to instantiate all generic types, so we fill one with a dummy. -pub type Dummy = std::iter::Once; - -// Below the conversions from -// * UInt32Chunked -// * Iterator -// * Iterator> -// -// To the checked and unchecked TakeIdx enums - -// Unchecked conversions - -/// Conversion from UInt32Chunked to Unchecked TakeIdx -impl<'a> From<&'a IdxCa> for TakeIdx<'a, Dummy, Dummy>> { - fn from(ca: &'a IdxCa) -> Self { - if ca.chunks.len() == 1 { - TakeIdx::Array(ca.downcast_iter().next().unwrap()) - } else { - panic!("implementation error, should be transformed to an iterator by the caller") - } - } -} - -/// Conversion from Iterator to Unchecked TakeIdx -impl<'a, I> From for TakeIdx<'a, I, Dummy>> -where - I: TakeIterator, -{ - fn from(iter: I) -> Self { - TakeIdx::Iter(iter) - } -} - -/// Conversion from [`Iterator>`] to Unchecked [`TakeIdx`] -impl<'a, I> From for TakeIdx<'a, Dummy, I> -where - I: TakeIteratorNulls, -{ - fn from(iter: I) -> Self { - TakeIdx::IterNulls(iter) - } -} - -#[inline] -fn to_usize(idx: &IdxSize) -> usize { - *idx as usize -} - -/// Conversion from `&[IdxSize]` to Unchecked TakeIdx -impl<'a> From<&'a [IdxSize]> - for TakeIdx< - 'a, - std::iter::Map, fn(&IdxSize) -> usize>, - Dummy>, - > -{ - fn from(slice: &'a [IdxSize]) -> Self { - TakeIdx::Iter(slice.iter().map(to_usize)) - } -} - -/// Conversion from `&[IdxSize]` to Unchecked TakeIdx -impl<'a> From<&'a Vec> - for TakeIdx< - 'a, - std::iter::Map, fn(&IdxSize) -> usize>, - Dummy>, - > -{ - fn from(slice: &'a Vec) -> Self { - (&**slice).into() - } -} diff --git a/crates/polars-core/src/chunked_array/ops/unique/mod.rs b/crates/polars-core/src/chunked_array/ops/unique/mod.rs index 217e7a5494b0..c8954e691e61 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/mod.rs @@ -8,10 +8,10 @@ use arrow::bitmap::MutableBitmap; #[cfg(feature = "object")] use crate::datatypes::ObjectType; use crate::datatypes::PlHashSet; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; use crate::frame::group_by::GroupsProxy; #[cfg(feature = "mode")] use crate::frame::group_by::IntoGroupsProxy; +use crate::hashing::HASHMAP_INIT_SIZE; use crate::prelude::*; use crate::series::IsSorted; @@ -112,7 +112,7 @@ fn mode_indices(groups: GroupsProxy) -> Vec { #[cfg(feature = "mode")] fn mode(ca: &ChunkedArray) -> ChunkedArray where - ChunkedArray: IntoGroupsProxy + ChunkTake, + ChunkedArray: IntoGroupsProxy + ChunkTake<[IdxSize]>, { if ca.is_empty() { return ca.clone(); @@ -122,7 +122,7 @@ where // Safety: // group indices are in bounds - unsafe { ca.take_unchecked(idx.as_slice().into()) } + unsafe { ca.take_unchecked(idx.as_slice()) } } macro_rules! arg_unique_ca { diff --git a/crates/polars-core/src/chunked_array/ops/unique/rank.rs b/crates/polars-core/src/chunked_array/ops/unique/rank.rs index 3e532404506f..d77152a56a03 100644 --- a/crates/polars-core/src/chunked_array/ops/unique/rank.rs +++ b/crates/polars-core/src/chunked_array/ops/unique/rank.rs @@ -133,7 +133,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Optio Random => { // Safety: // in bounds - let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() }; + let arr = unsafe { s.take_unchecked(&sort_idx_ca) }; let not_consecutive_same = arr .slice(1, len - 1) .not_equal(&arr.slice(0, len - 1)) @@ -185,7 +185,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Optio _ => { let inv_ca = IdxCa::from_vec(s.name(), inv); // SAFETY: in bounds. - let arr = unsafe { s.take_unchecked(&sort_idx_ca).unwrap() }; + let arr = unsafe { s.take_unchecked(&sort_idx_ca) }; let validity = arr.chunks()[0].validity().cloned(); let not_consecutive_same = arr .slice(1, len - 1) @@ -228,7 +228,7 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Optio let dense = IdxCa::with_chunk(s.name(), arr); // SAFETY: in bounds. - let dense = unsafe { dense.take_unchecked((&inv_ca).into()) }; + let dense = unsafe { dense.take_unchecked(&inv_ca) }; if let RankMethod::Dense = method { return if s.null_count() == 0 { @@ -280,18 +280,18 @@ pub(crate) fn rank(s: &Series, method: RankMethod, descending: bool, seed: Optio match method { Max => { // SAFETY: in bounds. - unsafe { count.take_unchecked((&dense).into()).into_series() } + unsafe { count.take_unchecked(&dense).into_series() } }, Min => { // SAFETY: in bounds. - unsafe { (count.take_unchecked((&dense).into()) + 1).into_series() } + unsafe { (count.take_unchecked(&dense) + 1).into_series() } }, Average => { // SAFETY: in bounds. - let a = unsafe { count.take_unchecked((&dense).into()) } + let a = unsafe { count.take_unchecked(&dense) } .cast(&DataType::Float64) .unwrap(); - let b = unsafe { count.take_unchecked((&(dense - 1)).into()) } + let b = unsafe { count.take_unchecked(&(dense - 1)) } .cast(&DataType::Float64) .unwrap() + 1.0; diff --git a/crates/polars-core/src/chunked_array/random.rs b/crates/polars-core/src/chunked_array/random.rs index 61fb1433d5e9..99641dc0b2a9 100644 --- a/crates/polars-core/src/chunked_array/random.rs +++ b/crates/polars-core/src/chunked_array/random.rs @@ -85,15 +85,15 @@ impl Series { match with_replacement { true => { let idx = create_rand_index_with_replacement(n, len, seed); - // Safety we know that we never go out of bounds + // SAFETY: we know that we never go out of bounds. debug_assert_eq!(len, self.len()); - unsafe { self.take_unchecked(&idx) } + unsafe { Ok(self.take_unchecked(&idx)) } }, false => { let idx = create_rand_index_no_replacement(n, len, seed, shuffle); - // Safety we know that we never go out of bounds + // SAFETY: we know that we never go out of bounds. debug_assert_eq!(len, self.len()); - unsafe { self.take_unchecked(&idx) } + unsafe { Ok(self.take_unchecked(&idx)) } }, } } @@ -116,14 +116,14 @@ impl Series { let idx = create_rand_index_no_replacement(n, len, seed, true); // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { self.take_unchecked(&idx).unwrap() } + unsafe { self.take_unchecked(&idx) } } } impl ChunkedArray where T: PolarsDataType, - ChunkedArray: ChunkTake, + ChunkedArray: ChunkTake, { /// Sample n datapoints from this [`ChunkedArray`]. pub fn sample_n( @@ -141,13 +141,13 @@ where let idx = create_rand_index_with_replacement(n, len, seed); // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { Ok(self.take_unchecked((&idx).into())) } + unsafe { Ok(self.take_unchecked(&idx)) } }, false => { let idx = create_rand_index_no_replacement(n, len, seed, shuffle); // Safety we know that we never go out of bounds debug_assert_eq!(len, self.len()); - unsafe { Ok(self.take_unchecked((&idx).into())) } + unsafe { Ok(self.take_unchecked(&idx)) } }, } } diff --git a/crates/polars-core/src/chunked_array/temporal/conversion.rs b/crates/polars-core/src/chunked_array/temporal/conversion.rs index 1657e6a754f3..34baa7c7533e 100644 --- a/crates/polars-core/src/chunked_array/temporal/conversion.rs +++ b/crates/polars-core/src/chunked_array/temporal/conversion.rs @@ -36,7 +36,7 @@ impl From<&AnyValue<'_>> for NaiveTime { // Used by lazy for literal conversion pub fn datetime_to_timestamp_ns(v: NaiveDateTime) -> i64 { - v.timestamp_nanos() + v.timestamp_nanos_opt().unwrap() } // Used by lazy for literal conversion diff --git a/crates/polars-core/src/chunked_array/temporal/datetime.rs b/crates/polars-core/src/chunked_array/temporal/datetime.rs index cca5597e4cf7..8f7dde50e0fa 100644 --- a/crates/polars-core/src/chunked_array/temporal/datetime.rs +++ b/crates/polars-core/src/chunked_array/temporal/datetime.rs @@ -4,6 +4,7 @@ use arrow::temporal_conversions::{ timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_us_to_datetime, }; use chrono::format::{DelayedFormat, StrftimeItems}; +use chrono::NaiveDate; #[cfg(feature = "timezones")] use chrono::TimeZone as TimeZoneTrait; #[cfg(feature = "timezones")] diff --git a/crates/polars-core/src/chunked_array/upstream_traits.rs b/crates/polars-core/src/chunked_array/upstream_traits.rs index 458b7fbf1298..af24444fdf14 100644 --- a/crates/polars-core/src/chunked_array/upstream_traits.rs +++ b/crates/polars-core/src/chunked_array/upstream_traits.rs @@ -40,31 +40,8 @@ where T: PolarsNumericType, { fn from_iter>>(iter: I) -> Self { - let iter = iter.into_iter(); - - let arr: PrimitiveArray = match iter.size_hint() { - (a, Some(b)) if a == b => { - // 2021-02-07: ~40% faster than builder. - // It is unsafe because we cannot be certain that the iterators length can be trusted. - // For most iterators that report the same upper bound as lower bound it is, but still - // somebody can create an iterator that incorrectly gives those bounds. - // This will not lead to UB, but will panic. - #[cfg(feature = "performant")] - unsafe { - let arr = PrimitiveArray::from_trusted_len_iter_unchecked(iter) - .to(T::get_dtype().to_arrow()); - assert_eq!(arr.len(), a); - arr - } - #[cfg(not(feature = "performant"))] - iter.collect::>() - .to(T::get_dtype().to_arrow()) - }, - _ => iter - .collect::>() - .to(T::get_dtype().to_arrow()), - }; - arr.into() + // TODO: eliminate this FromIterator implementation entirely. + iter.into_iter().collect_ca("") } } diff --git a/crates/polars-core/src/datatypes/aliases.rs b/crates/polars-core/src/datatypes/aliases.rs index 1d7e930ddc44..87cb707da2c4 100644 --- a/crates/polars-core/src/datatypes/aliases.rs +++ b/crates/polars-core/src/datatypes/aliases.rs @@ -1,12 +1,15 @@ +pub use polars_arrow::index::{IdxArr, IdxSize}; + use super::*; +use crate::hashing::IdBuildHasher; + +/// [ChunkIdx, DfIdx] +pub type ChunkId = [IdxSize; 2]; #[cfg(not(feature = "bigidx"))] pub type IdxCa = UInt32Chunked; #[cfg(feature = "bigidx")] pub type IdxCa = UInt64Chunked; -pub use polars_arrow::index::{IdxArr, IdxSize}; - -use crate::hashing::IdBuildHasher; #[cfg(not(feature = "bigidx"))] pub const IDX_DTYPE: DataType = DataType::UInt32; diff --git a/crates/polars-core/src/datatypes/from_values.rs b/crates/polars-core/src/datatypes/from_values.rs deleted file mode 100644 index 07341355caa9..000000000000 --- a/crates/polars-core/src/datatypes/from_values.rs +++ /dev/null @@ -1,185 +0,0 @@ -use std::borrow::Cow; -use std::error::Error; - -use arrow::array::{ - BinaryArray, BooleanArray, MutableBinaryArray, MutableBinaryValuesArray, MutablePrimitiveArray, - MutableUtf8Array, MutableUtf8ValuesArray, PrimitiveArray, Utf8Array, -}; -use arrow::bitmap::Bitmap; -use polars_arrow::array::utf8::{BinaryFromIter, Utf8FromIter}; -use polars_arrow::prelude::FromData; -use polars_arrow::trusted_len::TrustedLen; - -use crate::datatypes::NumericNative; -use crate::prelude::StaticArray; - -pub trait ArrayFromElementIter -where - Self: Sized, -{ - type ArrayType: StaticArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType; - - fn array_from_values_iter>(iter: I) -> Self::ArrayType; - - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result; - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result; -} - -impl ArrayFromElementIter for bool { - type ArrayType = BooleanArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BooleanArray::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BooleanArray::from_trusted_len_values_iter_unchecked(iter) } - } - - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BooleanArray::try_from_trusted_len_iter_unchecked(iter) } - } - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - // SAFETY: guarded by `TrustedLen` trait - let values = unsafe { Bitmap::try_from_trusted_len_iter_unchecked(iter) }?; - Ok(BooleanArray::from_data_default(values, None)) - } -} - -impl ArrayFromElementIter for T -where - T: NumericNative, -{ - type ArrayType = PrimitiveArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { PrimitiveArray::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { PrimitiveArray::from_trusted_len_values_iter_unchecked(iter) } - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - // SAFETY: guarded by `TrustedLen` trait - unsafe { Ok(MutablePrimitiveArray::try_from_trusted_len_iter_unchecked(iter)?.into()) } - } - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let values: Vec<_> = iter.collect::, _>>()?; - Ok(PrimitiveArray::from_vec(values)) - } -} - -impl ArrayFromElementIter for &str { - type ArrayType = Utf8Array; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { Utf8Array::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - let len = iter.size_hint().0; - Utf8Array::from_values_iter(iter, len, len * 24) - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8Array::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8ValuesArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } -} - -impl ArrayFromElementIter for Cow<'_, str> { - type ArrayType = Utf8Array; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { Utf8Array::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - let len = iter.size_hint().0; - Utf8Array::from_values_iter(iter, len, len * 24) - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8Array::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableUtf8ValuesArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } -} - -impl ArrayFromElementIter for Cow<'_, [u8]> { - type ArrayType = BinaryArray; - - fn array_from_iter>>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - unsafe { BinaryArray::from_trusted_len_iter_unchecked(iter) } - } - - fn array_from_values_iter>(iter: I) -> Self::ArrayType { - // SAFETY: guarded by `TrustedLen` trait - let len = iter.size_hint().0; - BinaryArray::from_values_iter(iter, len, len * 24) - } - fn try_array_from_iter, E>>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableBinaryArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } - - fn try_array_from_values_iter>>( - iter: I, - ) -> Result { - let len = iter.size_hint().0; - let mut mutable = MutableBinaryValuesArray::::with_capacities(len, len * 24); - mutable.extend_fallible(iter)?; - Ok(mutable.into()) - } -} diff --git a/crates/polars-core/src/datatypes/mod.rs b/crates/polars-core/src/datatypes/mod.rs index 37d2d2d24a6a..25a55ae07210 100644 --- a/crates/polars-core/src/datatypes/mod.rs +++ b/crates/polars-core/src/datatypes/mod.rs @@ -12,8 +12,8 @@ mod aliases; mod any_value; mod dtype; mod field; -mod from_values; mod static_array; +mod static_array_collect; mod time_unit; use std::cmp::Ordering; @@ -32,7 +32,6 @@ use arrow::types::simd::Simd; use arrow::types::NativeType; pub use dtype::*; pub use field::*; -pub use from_values::ArrayFromElementIter; use num_traits::{Bounded, FromPrimitive, Num, NumCast, One, Zero}; use polars_arrow::data_types::IsFloat; #[cfg(feature = "serde")] @@ -42,6 +41,7 @@ use serde::{Deserialize, Serialize}; #[cfg(any(feature = "serde", feature = "serde-lazy"))] use serde::{Deserializer, Serializer}; pub use static_array::StaticArray; +pub use static_array_collect::{ArrayCollectIterExt, ArrayFromIter, ArrayFromIterDtype}; pub use time_unit::*; use crate::chunked_array::arithmetic::ArrayArithmetics; @@ -270,6 +270,3 @@ impl NumericNative for f32 { impl NumericNative for f64 { type PolarsType = Float64Type; } - -// Provide options to cloud providers (credentials, region). -pub type CloudOptions = PlHashMap; diff --git a/crates/polars-core/src/datatypes/static_array.rs b/crates/polars-core/src/datatypes/static_array.rs index 162bbd88bfbd..d4cc0aa61960 100644 --- a/crates/polars-core/src/datatypes/static_array.rs +++ b/crates/polars-core/src/datatypes/static_array.rs @@ -3,9 +3,14 @@ use arrow::bitmap::Bitmap; #[cfg(feature = "object")] use crate::chunked_array::object::{ObjectArray, ObjectValueIter}; +use crate::datatypes::static_array_collect::ArrayFromIterDtype; use crate::prelude::*; -pub trait StaticArray: Array { +pub trait StaticArray: + Array + + for<'a> ArrayFromIterDtype> + + for<'a> ArrayFromIterDtype>> +{ type ValueT<'a> where Self: 'a; @@ -55,6 +60,10 @@ pub trait StaticArray: Array { fn with_validity_typed(self, validity: Option) -> Self; } +pub trait ParameterFreeDtypeStaticArray: StaticArray { + fn get_dtype() -> DataType; +} + impl StaticArray for PrimitiveArray { type ValueT<'a> = T; type ValueIterT<'a> = std::iter::Copied>; @@ -77,6 +86,12 @@ impl StaticArray for PrimitiveArray { } } +impl ParameterFreeDtypeStaticArray for PrimitiveArray { + fn get_dtype() -> DataType { + T::PolarsType::get_dtype() + } +} + impl StaticArray for BooleanArray { type ValueT<'a> = bool; type ValueIterT<'a> = BitmapIter<'a>; @@ -99,6 +114,12 @@ impl StaticArray for BooleanArray { } } +impl ParameterFreeDtypeStaticArray for BooleanArray { + fn get_dtype() -> DataType { + DataType::Boolean + } +} + impl StaticArray for Utf8Array { type ValueT<'a> = &'a str; type ValueIterT<'a> = Utf8ValuesIter<'a, i64>; @@ -121,6 +142,12 @@ impl StaticArray for Utf8Array { } } +impl ParameterFreeDtypeStaticArray for Utf8Array { + fn get_dtype() -> DataType { + DataType::Utf8 + } +} + impl StaticArray for BinaryArray { type ValueT<'a> = &'a [u8]; type ValueIterT<'a> = BinaryValueIter<'a, i64>; @@ -143,6 +170,12 @@ impl StaticArray for BinaryArray { } } +impl ParameterFreeDtypeStaticArray for BinaryArray { + fn get_dtype() -> DataType { + DataType::Binary + } +} + impl StaticArray for ListArray { type ValueT<'a> = Box; type ValueIterT<'a> = ListValuesIter<'a, i64>; diff --git a/crates/polars-core/src/datatypes/static_array_collect.rs b/crates/polars-core/src/datatypes/static_array_collect.rs new file mode 100644 index 000000000000..63b15e5afc38 --- /dev/null +++ b/crates/polars-core/src/datatypes/static_array_collect.rs @@ -0,0 +1,881 @@ +use std::borrow::Cow; +use std::sync::Arc; + +#[cfg(feature = "dtype-array")] +use arrow::array::FixedSizeListArray; +use arrow::array::{ + Array, BinaryArray, BooleanArray, ListArray, MutableBinaryArray, MutableBinaryValuesArray, + PrimitiveArray, Utf8Array, +}; +use arrow::bitmap::Bitmap; +#[cfg(feature = "dtype-array")] +use polars_arrow::prelude::fixed_size_list::AnonymousBuilder as AnonymousFixedSizeListArrayBuilder; +use polars_arrow::prelude::list::AnonymousBuilder as AnonymousListArrayBuilder; +use polars_arrow::trusted_len::{TrustedLen, TrustedLenPush}; + +#[cfg(feature = "object")] +use crate::chunked_array::object::{ObjectArray, PolarsObject}; +use crate::datatypes::static_array::ParameterFreeDtypeStaticArray; +use crate::datatypes::{DataType, NumericNative, PolarsDataType, StaticArray}; + +pub trait ArrayFromIterDtype: Sized { + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self; + + #[inline(always)] + fn arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter_with_dtype(dtype, iter) + } + + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result; + + #[inline(always)] + fn try_arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::try_arr_from_iter_with_dtype(dtype, iter) + } +} + +pub trait ArrayFromIter: Sized { + fn arr_from_iter>(iter: I) -> Self; + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + Self::arr_from_iter(iter) + } + + fn try_arr_from_iter>>(iter: I) -> Result; + + #[inline(always)] + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + Self::try_arr_from_iter(iter) + } +} + +impl> ArrayFromIterDtype for A { + #[inline(always)] + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::arr_from_iter(iter) + } + + #[inline(always)] + fn arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::arr_from_iter_with_dtype(dtype, iter) + } + + #[inline(always)] + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::try_arr_from_iter(iter) + } + + #[inline(always)] + fn try_arr_from_iter_trusted_with_dtype(dtype: DataType, iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + debug_assert!(std::mem::discriminant(&dtype) == std::mem::discriminant(&A::get_dtype())); + Self::try_arr_from_iter_with_dtype(dtype, iter) + } +} + +pub trait ArrayCollectIterExt: Iterator + Sized { + #[inline(always)] + fn collect_arr(self) -> A + where + A: ArrayFromIter, + { + A::arr_from_iter(self) + } + + #[inline(always)] + fn collect_arr_trusted(self) -> A + where + A: ArrayFromIter, + Self: TrustedLen, + { + A::arr_from_iter_trusted(self) + } + + #[inline(always)] + fn try_collect_arr(self) -> Result + where + A: ArrayFromIter, + Self: Iterator>, + { + A::try_arr_from_iter(self) + } + + #[inline(always)] + fn try_collect_arr_trusted(self) -> Result + where + A: ArrayFromIter, + Self: Iterator> + TrustedLen, + { + A::try_arr_from_iter_trusted(self) + } + + #[inline(always)] + fn collect_arr_with_dtype(self, dtype: DataType) -> A + where + A: ArrayFromIterDtype, + { + A::arr_from_iter_with_dtype(dtype, self) + } + + #[inline(always)] + fn collect_arr_trusted_with_dtype(self, dtype: DataType) -> A + where + A: ArrayFromIterDtype, + Self: TrustedLen, + { + A::arr_from_iter_trusted_with_dtype(dtype, self) + } + + #[inline(always)] + fn try_collect_arr_with_dtype(self, dtype: DataType) -> Result + where + A: ArrayFromIterDtype, + Self: Iterator>, + { + A::try_arr_from_iter_with_dtype(dtype, self) + } + + #[inline(always)] + fn try_collect_arr_trusted_with_dtype(self, dtype: DataType) -> Result + where + A: ArrayFromIterDtype, + Self: Iterator> + TrustedLen, + { + A::try_arr_from_iter_trusted_with_dtype(dtype, self) + } +} + +impl ArrayCollectIterExt for I {} + +// --------------- +// Implementations +// --------------- +macro_rules! impl_collect_vec_validity { + ($iter: ident, $x:ident, $unpack:expr) => {{ + let mut iter = $iter.into_iter(); + let mut buf: Vec = Vec::new(); + let mut bitmap: Vec = Vec::new(); + let lo = iter.size_hint().0; + buf.reserve(8 + lo); + bitmap.reserve(8 + 8 * (lo / 64)); + + let mut nonnull_count = 0; + let mut mask = 0u8; + 'exhausted: loop { + unsafe { + // SAFETY: when we enter this loop we always have at least one + // capacity in bitmap, and at least 8 in buf. + for i in 0..8 { + let Some($x) = iter.next() else { + break 'exhausted; + }; + #[allow(clippy::all)] + // #[allow(clippy::redundant_locals)] Clippy lint too new + let x = $unpack; + let nonnull = x.is_some(); + mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + buf.push_unchecked(x.unwrap_or_default()); + } + + bitmap.push_unchecked(mask); + mask = 0; + } + + buf.reserve(8); + if bitmap.len() == bitmap.capacity() { + bitmap.reserve(8); // Waste some space to make branch more predictable. + } + } + + unsafe { + // SAFETY: when we broke to 'exhausted we had capacity by the loop invariant. + // It's also no problem if we make the mask bigger than strictly necessary. + bitmap.push_unchecked(mask); + } + + let null_count = buf.len() - nonnull_count; + let arrow_bitmap = if null_count > 0 { + unsafe { + // SAFETY: we made sure the null_count is correct. + Some(Bitmap::from_inner(Arc::new(bitmap.into()), 0, buf.len(), null_count).unwrap()) + } + } else { + None + }; + + (buf, arrow_bitmap) + }}; +} + +macro_rules! impl_trusted_collect_vec_validity { + ($iter: ident, $x:ident, $unpack:expr) => {{ + let mut iter = $iter.into_iter(); + let mut buf: Vec = Vec::new(); + let mut bitmap: Vec = Vec::new(); + let n = iter.size_hint().1.expect("must have an upper bound"); + buf.reserve(n); + bitmap.reserve(8 + 8 * (n / 64)); + + let mut nonnull_count = 0; + while buf.len() + 8 <= n { + unsafe { + let mut mask = 0u8; + for i in 0..8 { + let $x = iter.next().unwrap_unchecked(); + #[allow(clippy::all)] + // #[allow(clippy::redundant_locals)] Clippy lint too new + let x = $unpack; + let nonnull = x.is_some(); + mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + buf.push_unchecked(x.unwrap_or_default()); + } + bitmap.push_unchecked(mask); + } + } + + if buf.len() < n { + unsafe { + let mut mask = 0u8; + for i in 0..n - buf.len() { + let $x = iter.next().unwrap_unchecked(); + let x = $unpack; + let nonnull = x.is_some(); + mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + buf.push_unchecked(x.unwrap_or_default()); + } + bitmap.push_unchecked(mask); + } + } + + let null_count = buf.len() - nonnull_count; + let arrow_bitmap = if null_count > 0 { + unsafe { + // SAFETY: we made sure the null_count is correct. + Some(Bitmap::from_inner(Arc::new(bitmap.into()), 0, buf.len(), null_count).unwrap()) + } + } else { + None + }; + + (buf, arrow_bitmap) + }}; +} + +impl ArrayFromIter for PrimitiveArray { + fn arr_from_iter>(iter: I) -> Self { + PrimitiveArray::from_vec(iter.into_iter().collect()) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + PrimitiveArray::from_vec(Vec::from_trusted_len_iter(iter)) + } + + fn try_arr_from_iter>>(iter: I) -> Result { + let v: Result, E> = iter.into_iter().collect(); + Ok(PrimitiveArray::from_vec(v?)) + } + + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + let v = Vec::try_from_trusted_len_iter(iter); + Ok(PrimitiveArray::from_vec(v?)) + } +} + +impl ArrayFromIter> for PrimitiveArray { + fn arr_from_iter>>(iter: I) -> Self { + let (buf, validity) = impl_collect_vec_validity!(iter, x, x); + PrimitiveArray::new(T::PolarsType::get_dtype().to_arrow(), buf.into(), validity) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + let (buf, validity) = impl_trusted_collect_vec_validity!(iter, x, x); + PrimitiveArray::new(T::PolarsType::get_dtype().to_arrow(), buf.into(), validity) + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let (buf, validity) = impl_collect_vec_validity!(iter, x, x?); + Ok(PrimitiveArray::new( + T::PolarsType::get_dtype().to_arrow(), + buf.into(), + validity, + )) + } + + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator, E>>, + I::IntoIter: TrustedLen, + { + let (buf, validity) = impl_trusted_collect_vec_validity!(iter, x, x?); + Ok(PrimitiveArray::new( + T::PolarsType::get_dtype().to_arrow(), + buf.into(), + validity, + )) + } +} + +// We don't use AsRef here because it leads to problems with conflicting implementations, +// as Rust considers that AsRef<[u8]> for Option<&[u8]> could be implemented. +trait IntoBytes { + type AsRefT: AsRef<[u8]>; + fn into_bytes(self) -> Self::AsRefT; +} +trait TrivialIntoBytes: AsRef<[u8]> {} +impl IntoBytes for T { + type AsRefT = Self; + fn into_bytes(self) -> Self { + self + } +} +impl TrivialIntoBytes for Vec {} +impl<'a> TrivialIntoBytes for Cow<'a, [u8]> {} +impl<'a> TrivialIntoBytes for &'a [u8] {} +impl TrivialIntoBytes for String {} +impl<'a> TrivialIntoBytes for &'a str {} +impl<'a> IntoBytes for Cow<'a, str> { + type AsRefT = Cow<'a, [u8]>; + fn into_bytes(self) -> Cow<'a, [u8]> { + match self { + Cow::Borrowed(a) => Cow::Borrowed(a.as_bytes()), + Cow::Owned(s) => Cow::Owned(s.into_bytes()), + } + } +} + +impl ArrayFromIter for BinaryArray { + fn arr_from_iter>(iter: I) -> Self { + BinaryArray::from_iter_values(iter.into_iter().map(|s| s.into_bytes())) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: our iterator is TrustedLen. + MutableBinaryArray::from_trusted_len_values_iter_unchecked( + iter.into_iter().map(|s| s.into_bytes()), + ) + .into() + } + } + + fn try_arr_from_iter>>(iter: I) -> Result { + // No built-in for this? + let mut arr = MutableBinaryValuesArray::new(); + let mut iter = iter.into_iter(); + arr.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| -> Result<(), E> { + arr.push(x?.into_bytes()); + Ok(()) + })?; + Ok(arr.into()) + } + + // No faster implementation than this available, fall back to default. + // fn try_arr_from_iter_trusted(iter: I) -> Result +} + +impl ArrayFromIter> for BinaryArray { + fn arr_from_iter>>(iter: I) -> Self { + BinaryArray::from_iter(iter.into_iter().map(|s| Some(s?.into_bytes()))) + } + + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: the iterator is TrustedLen. + BinaryArray::from_trusted_len_iter_unchecked( + iter.into_iter().map(|s| Some(s?.into_bytes())), + ) + } + } + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + // No built-in for this? + let mut arr = MutableBinaryArray::new(); + let mut iter = iter.into_iter(); + arr.reserve(iter.size_hint().0, 0); + iter.try_for_each(|x| -> Result<(), E> { + arr.push(x?.map(|s| s.into_bytes())); + Ok(()) + })?; + Ok(arr.into()) + } + + fn try_arr_from_iter_trusted(iter: I) -> Result + where + I: IntoIterator, E>>, + I::IntoIter: TrustedLen, + { + unsafe { + // SAFETY: the iterator is TrustedLen. + BinaryArray::try_from_trusted_len_iter_unchecked( + iter.into_iter().map(|s| s.map(|s| Some(s?.into_bytes()))), + ) + } + } +} + +/// We use this to re-use the binary collect implementation for strings. +/// # Safety +/// The array must be valid UTF-8. +unsafe fn into_utf8array(arr: BinaryArray) -> Utf8Array { + unsafe { + let (_dt, offsets, values, validity) = arr.into_inner(); + let dt = arrow::datatypes::DataType::LargeUtf8; + Utf8Array::try_new_unchecked(dt, offsets, values, validity).unwrap_unchecked() + } +} + +trait StrIntoBytes: IntoBytes {} +impl StrIntoBytes for String {} +impl<'a> StrIntoBytes for &'a str {} +impl<'a> StrIntoBytes for Cow<'a, str> {} + +impl ArrayFromIter for Utf8Array { + #[inline(always)] + fn arr_from_iter>(iter: I) -> Self { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator, + I::IntoIter: TrustedLen, + { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn try_arr_from_iter>>(iter: I) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } + + #[inline(always)] + fn try_arr_from_iter_trusted>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } +} + +impl ArrayFromIter> for Utf8Array { + #[inline(always)] + fn arr_from_iter>>(iter: I) -> Self { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn arr_from_iter_trusted(iter: I) -> Self + where + I: IntoIterator>, + I::IntoIter: TrustedLen, + { + unsafe { into_utf8array(iter.into_iter().collect_arr()) } + } + + #[inline(always)] + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } + + #[inline(always)] + fn try_arr_from_iter_trusted, E>>>( + iter: I, + ) -> Result { + let arr = iter.into_iter().try_collect_arr()?; + unsafe { Ok(into_utf8array(arr)) } + } +} + +macro_rules! impl_collect_bool_validity { + ($iter: ident, $x:ident, $unpack:expr, $truth:expr, $nullity:expr, $with_valid:literal) => {{ + let mut iter = $iter.into_iter(); + let mut buf: Vec = Vec::new(); + let mut validity: Vec = Vec::new(); + let lo = iter.size_hint().0; + buf.reserve(8 + 8 * (lo / 64)); + if $with_valid { + validity.reserve(8 + 8 * (lo / 64)); + } + + let mut len = 0; + let mut buf_mask = 0u8; + let mut true_count = 0; + let mut valid_mask = 0u8; + let mut nonnull_count = 0; + 'exhausted: loop { + unsafe { + for i in 0..8 { + let Some($x) = iter.next() else { + break 'exhausted; + }; + #[allow(clippy::all)] + // #[allow(clippy::redundant_locals)] Clippy lint too new + let $x = $unpack; + let is_true: bool = $truth; + buf_mask |= (is_true as u8) << i; + true_count += is_true as usize; + if $with_valid { + let nonnull: bool = $nullity; + valid_mask |= (nonnull as u8) << i; + nonnull_count += nonnull as usize; + } + len += 1; + } + + buf.push_unchecked(buf_mask); + buf_mask = 0; + if $with_valid { + validity.push_unchecked(valid_mask); + valid_mask = 0; + } + } + + if buf.len() == buf.capacity() { + buf.reserve(8); // Waste some space to make branch more predictable. + if $with_valid { + validity.reserve(8); + } + } + } + + unsafe { + // SAFETY: when we broke to 'exhausted we had capacity by the loop invariant. + // It's also no problem if we make the mask bigger than strictly necessary. + buf.push_unchecked(buf_mask); + if $with_valid { + validity.push_unchecked(valid_mask); + } + } + + let false_count = len - true_count; + let values = + unsafe { Bitmap::from_inner(Arc::new(buf.into()), 0, len, false_count).unwrap() }; + + let null_count = len - nonnull_count; + let validity_bitmap = if $with_valid && null_count > 0 { + unsafe { + // SAFETY: we made sure the null_count is correct. + Some(Bitmap::from_inner(Arc::new(validity.into()), 0, len, null_count).unwrap()) + } + } else { + None + }; + + (values, validity_bitmap) + }}; +} + +impl ArrayFromIter for BooleanArray { + fn arr_from_iter>(iter: I) -> Self { + let dt = arrow::datatypes::DataType::Boolean; + let (values, _valid) = impl_collect_bool_validity!(iter, x, x, x, false, false); + BooleanArray::new(dt, values, None) + } + + // TODO: are efficient trusted collects for booleans worth it? + // fn arr_from_iter_trusted(iter: I) -> Self + + fn try_arr_from_iter>>(iter: I) -> Result { + let dt = arrow::datatypes::DataType::Boolean; + let (values, _valid) = impl_collect_bool_validity!(iter, x, x?, x, false, false); + Ok(BooleanArray::new(dt, values, None)) + } + + // fn try_arr_from_iter_trusted>>( +} + +impl ArrayFromIter> for BooleanArray { + fn arr_from_iter>>(iter: I) -> Self { + let dt = arrow::datatypes::DataType::Boolean; + let (values, valid) = + impl_collect_bool_validity!(iter, x, x, x.unwrap_or(false), x.is_some(), true); + BooleanArray::new(dt, values, valid) + } + + // fn arr_from_iter_trusted(iter: I) -> Self + + fn try_arr_from_iter, E>>>( + iter: I, + ) -> Result { + let dt = arrow::datatypes::DataType::Boolean; + let (values, valid) = + impl_collect_bool_validity!(iter, x, x?, x.unwrap_or(false), x.is_some(), true); + Ok(BooleanArray::new(dt, values, valid)) + } + + // fn try_arr_from_iter_trusted, E>>>( +} + +// We don't use AsRef here because it leads to problems with conflicting implementations, +// as Rust considers that AsRef for Option<&dyn Array> could be implemented. +trait AsArray { + fn as_array(&self) -> &dyn Array; + fn into_boxed_array(self) -> Box; // Prevents unnecessary re-boxing. +} +impl AsArray for Box { + fn as_array(&self) -> &dyn Array { + self.as_ref() + } + fn into_boxed_array(self) -> Box { + self + } +} +impl<'a> AsArray for &'a dyn Array { + fn as_array(&self) -> &'a dyn Array { + *self + } + fn into_boxed_array(self) -> Box { + self.to_boxed() + } +} + +// TODO: more efficient (fixed size) list collect routines. +impl ArrayFromIterDtype for ListArray { + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self { + let iter_values: Vec = iter.into_iter().collect(); + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push(arr.as_array()); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +impl ArrayFromIterDtype> for ListArray { + fn arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Self { + let iter_values: Vec> = iter.into_iter().collect(); + let mut builder = AnonymousListArrayBuilder::new(iter_values.len()); + for arr in &iter_values { + builder.push_opt(arr.as_ref().map(|a| a.as_array())); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype, E>>>( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayFromIterDtype> for FixedSizeListArray { + fn arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Self { + let DataType::Array(_, width) = &dtype else { + panic!("FixedSizeListArray::arr_from_iter_with_dtype called with non-Array dtype"); + }; + let iter_values: Vec<_> = iter.into_iter().collect(); + let mut builder = AnonymousFixedSizeListArrayBuilder::new(iter_values.len(), *width); + for arr in iter_values { + builder.push(arr.into_boxed_array()); + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype, E>>>( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +#[cfg(feature = "dtype-array")] +impl ArrayFromIterDtype>> for FixedSizeListArray { + fn arr_from_iter_with_dtype>>>( + dtype: DataType, + iter: I, + ) -> Self { + let DataType::Array(_, width) = &dtype else { + panic!("FixedSizeListArray::arr_from_iter_with_dtype called with non-Array dtype"); + }; + let iter_values: Vec<_> = iter.into_iter().collect(); + let mut builder = AnonymousFixedSizeListArrayBuilder::new(iter_values.len(), *width); + for arr in iter_values { + match arr { + Some(a) => builder.push(a.into_boxed_array()), + None => builder.push_null(), + } + } + let inner = dtype + .inner_dtype() + .expect("expected nested type in ListArray collect"); + builder + .finish(Some(&inner.to_physical().to_arrow())) + .unwrap() + } + + fn try_arr_from_iter_with_dtype< + E, + I: IntoIterator>, E>>, + >( + dtype: DataType, + iter: I, + ) -> Result { + let iter_values = iter.into_iter().collect::, E>>()?; + Ok(Self::arr_from_iter_with_dtype(dtype, iter_values)) + } +} + +// TODO: more efficient implementations, I really took the short path here. +#[cfg(feature = "object")] +impl<'a, T: PolarsObject> ArrayFromIterDtype<&'a T> for ObjectArray { + fn arr_from_iter_with_dtype>(dtype: DataType, iter: I) -> Self { + Self::try_arr_from_iter_with_dtype( + dtype, + iter.into_iter().map(|o| -> Result<_, ()> { Ok(Some(o)) }), + ) + .unwrap() + } + + fn try_arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Result { + Self::try_arr_from_iter_with_dtype(dtype, iter.into_iter().map(|o| Ok(Some(o?)))) + } +} + +#[cfg(feature = "object")] +impl<'a, T: PolarsObject> ArrayFromIterDtype> for ObjectArray { + fn arr_from_iter_with_dtype>>( + dtype: DataType, + iter: I, + ) -> Self { + Self::try_arr_from_iter_with_dtype( + dtype, + iter.into_iter().map(|o| -> Result<_, ()> { Ok(o) }), + ) + .unwrap() + } + + fn try_arr_from_iter_with_dtype, E>>>( + _dtype: DataType, + iter: I, + ) -> Result { + let iter = iter.into_iter(); + let size = iter.size_hint().0; + + let mut null_mask_builder = arrow::bitmap::MutableBitmap::with_capacity(size); + let values: Vec = iter + .map(|value| match value? { + Some(value) => { + null_mask_builder.push(true); + Ok(value.clone()) + }, + None => { + null_mask_builder.push(false); + Ok(T::default()) + }, + }) + .collect::, E>>()?; + + let null_bit_buffer: Option = null_mask_builder.into(); + let null_bitmap = null_bit_buffer; + let len = values.len(); + Ok(ObjectArray { + values: Arc::new(values), + null_bitmap, + offset: 0, + len, + }) + } +} diff --git a/crates/polars-core/src/doc/changelog/mod.rs b/crates/polars-core/src/doc/changelog/mod.rs deleted file mode 100644 index 40f167264afc..000000000000 --- a/crates/polars-core/src/doc/changelog/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod v0_10_0_11; -pub mod v0_3; -pub mod v0_4; -pub mod v0_5; -pub mod v0_6; -pub mod v0_7; -pub mod v0_8; -pub mod v0_9; diff --git a/crates/polars-core/src/doc/changelog/v0_10_0_11.rs b/crates/polars-core/src/doc/changelog/v0_10_0_11.rs deleted file mode 100644 index 8136f24f8f80..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_10_0_11.rs +++ /dev/null @@ -1,21 +0,0 @@ -//! # Changelog v0.10 / v0.11 -//! -//! * CSV Read IO -//! - Parallel csv reader -//! * Sample DataFrames/ Series -//! * Performance increase in take kernel -//! * Performance increase in ChunkedArray builders -//! * Join operation on multiple columns. -//! * ~3.5 x performance increase in group_by operations (measured on db-benchmark), -//! due to embarrassingly parallel grouping and better branch prediction (tight loops). -//! * Performance increase on join operation due to better branch prediction. -//! * Categorical datatype and global string cache (BETA). -//! -//! * Lazy -//! - Lot's of bug fixes in optimizer. -//! - Parallel execution of Physical plan -//! - Partition window function -//! - More simplify expression optimizations. -//! - Caching -//! - Alpha release of Aggregate pushdown optimization. -//! * Start of general Object type in ChunkedArray/DataFrames/Series diff --git a/crates/polars-core/src/doc/changelog/v0_3.rs b/crates/polars-core/src/doc/changelog/v0_3.rs deleted file mode 100644 index 738021313cdc..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_3.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! # Changelog v0.3 -//! -//! * Utf8 type is nullable [#37](https://github.com/pola-rs/polars/issues/37) -//! * Support all ARROW numeric types [#40](https://github.com/pola-rs/polars/issues/40) -//! * Support all ARROW temporal types [#46](https://github.com/pola-rs/polars/issues/46) -//! * ARROW IPC Reader/ Writer [#50](https://github.com/pola-rs/polars/issues/50) -//! * Implement DoubleEndedIterator trait for ChunkedArray's [#34](https://github.com/pola-rs/polars/issues/34) -//! diff --git a/crates/polars-core/src/doc/changelog/v0_4.rs b/crates/polars-core/src/doc/changelog/v0_4.rs deleted file mode 100644 index d357526134ef..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_4.rs +++ /dev/null @@ -1,9 +0,0 @@ -//! # Changelog v0.4 -//! -//! * median aggregation added to `ChunkedArray` -//! * Arrow LargeList datatype support (and group_by aggregation into LargeList). -//! * Shift operation. -//! * Fill None operation. -//! * Buffered serialization (less memory requirements) -//! * Temporal utilities -//! diff --git a/crates/polars-core/src/doc/changelog/v0_5.rs b/crates/polars-core/src/doc/changelog/v0_5.rs deleted file mode 100644 index 7d82f3271cc0..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_5.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! # Changelog v0.5 -//! -//! * `DataFrame.column` returns `Result<_>` **breaking change**. -//! * Define idiomatic way to do inplace operations on a `DataFrame` with `apply`, `try_apply` and `ChunkSet` -//! * `ChunkSet` Trait. -//! * `Groupby` aggregations can be done on a selection of multiple columns. -//! * `Groupby` operation can be done on multiple keys. -//! * `Groupby` `first` operation. -//! * `Pivot` operation. -//! * Random access to `ChunkedArray` types via `.get` and `.get_unchecked`. -//! diff --git a/crates/polars-core/src/doc/changelog/v0_6.rs b/crates/polars-core/src/doc/changelog/v0_6.rs deleted file mode 100644 index 23e38f3d2369..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_6.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! # Changelog v0.6 -//! -//! * Add more distributions for random sampling. -//! * Fix float aggregations with NaNs. -//! * Comparisons are more performant. -//! * Outer join is more performant. -//! * Start with parallel iterator support for ChunkedArrays. -//! * Remove crossbeam dependency. diff --git a/crates/polars-core/src/doc/changelog/v0_7.rs b/crates/polars-core/src/doc/changelog/v0_7.rs deleted file mode 100644 index 55996f2fcaa5..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_7.rs +++ /dev/null @@ -1,32 +0,0 @@ -//! # Changelog v0.7 -//! -//! * More group by aggregations: -//! - n_unique -//! - quantile -//! - median -//! - last -//! - group indexes -//! - agg (combined aggregations) -//! * explode operation -//! * melt operation -//! * df! macro -//! * Rem trait implemented for Series and ChunkedArrays -//! * ChunkedArrays broadcasting arithmetic -//! * ChunkedArray/Series `zip_with` operation -//! * ChunkedArray/Series `new_from_index` operation -//! * laziness api initiated. -//! - Predicate pushdown optimizer -//! - Projection pushdown optimizer -//! - Type coercion optimizer -//! - Selection (filter, where clause) -//! - Projection (select foo from bar) -//! - Aggregation (group_by) -//! - all eager aggregations supported -//! - Joins -//! - WithColumn operation -//! - DSL -//! * (col, lit, lt, lt_eq, alias, etc.) -//! * arithmetic -//! * when / then /otherwise -//! * 1.3-1.7 performance increase of filter -//! * ChunkedArray/ Series creation speedup: No nulls: 10X speedup, Nulls: 1.1-2.2x speedup. diff --git a/crates/polars-core/src/doc/changelog/v0_8.rs b/crates/polars-core/src/doc/changelog/v0_8.rs deleted file mode 100644 index 3d7c6fdabb8f..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_8.rs +++ /dev/null @@ -1,24 +0,0 @@ -//! # Changelog v0.8 -//! -//! * Upgrade to Arrow 2.0 -//! * Add quantile aggregation to `ChunkedArray` -//! * Option to stop reading CSV after n rows. -//! * Read parquet file in a single batch reducing reading time. -//! * Faster kernel for zip_with and set_with operation -//! * String utilities -//! - Utf8Chunked::str_lengths method -//! - Utf8Chunked::contains method -//! - Utf8Chunked::replace method -//! - Utf8Chunked::replace_all method -//! * Temporal utilities -//! - Utf8Chunked to dat32 / datetime -//! * Lazy -//! - fill_null expression -//! - shift expression -//! - Series aggregations -//! - aggregations on DataFrame level -//! - aggregate to largelist -//! - a lot of bugs fixed in optimizers -//! - UDF's / closures in lazy dsl -//! - DataFrame reverse operation -//! diff --git a/crates/polars-core/src/doc/changelog/v0_9.rs b/crates/polars-core/src/doc/changelog/v0_9.rs deleted file mode 100644 index f0ece2b79bcf..000000000000 --- a/crates/polars-core/src/doc/changelog/v0_9.rs +++ /dev/null @@ -1,19 +0,0 @@ -//! # Changelog v0.9 -//! -//! * CSV Read IO -//! - large performance increase -//! - skip_rows -//! - ignore parser errors -//! * Overall performance increase by using aHash in favor of FNV. -//! * Groupby floating point keys -//! * DataFrame operations -//! - drop_nulls -//! - drop duplicate rows -//! * Temporal handling -//! * Lazy -//! - a lot of bug fixes in the optimizer -//! - start of optimizer framework -//! - start of simplify expression optimizer -//! - csv scan -//! - various operations -//! * Start of general Object type in ChunkedArray/DataFrames/Series diff --git a/crates/polars-core/src/doc/mod.rs b/crates/polars-core/src/doc/mod.rs deleted file mode 100644 index 18169f152474..000000000000 --- a/crates/polars-core/src/doc/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -//! Other documentation -pub mod changelog; diff --git a/crates/polars-core/src/frame/asof_join/groups.rs b/crates/polars-core/src/frame/asof_join/groups.rs index 9c980c935b3b..616d74b92992 100644 --- a/crates/polars-core/src/frame/asof_join/groups.rs +++ b/crates/polars-core/src/frame/asof_join/groups.rs @@ -9,13 +9,12 @@ use rayon::prelude::*; use smartstring::alias::String as SmartString; use super::*; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; #[cfg(feature = "dtype-categorical")] use crate::frame::hash_join::_check_categorical_src; use crate::frame::hash_join::{ build_tables, get_hash_tbl_threaded_join_partitioned, multiple_keys as mk, prepare_bytes, }; -use crate::hashing::{df_rows_to_hashes_threaded_vertical, AsU64}; +use crate::hashing::{df_rows_to_hashes_threaded_vertical, AsU64, HASHMAP_INIT_SIZE}; use crate::utils::{split_ca, split_df}; use crate::POOL; @@ -846,15 +845,9 @@ impl DataFrame { right_join_tuples = slice_slice(right_join_tuples, offset, len); } - // Safety: - // join tuples are in bounds - let right_df = unsafe { - other.take_opt_iter_unchecked( - right_join_tuples - .iter() - .map(|opt_idx| opt_idx.map(|idx| idx as usize)), - ) - }; + // SAFETY: join tuples are in bounds. + let right_df = + unsafe { other.take_unchecked(&right_join_tuples.iter().copied().collect_ca("")) }; _finish_join(left, right_df, suffix) } diff --git a/crates/polars-core/src/frame/asof_join/mod.rs b/crates/polars-core/src/frame/asof_join/mod.rs index c496c670d696..bc5ca09acb99 100644 --- a/crates/polars-core/src/frame/asof_join/mod.rs +++ b/crates/polars-core/src/frame/asof_join/mod.rs @@ -8,6 +8,7 @@ use num_traits::Bounded; use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; +use crate::frame::hash_join::_finish_join; use crate::prelude::*; use crate::utils::{ensure_sorted_arg, slice_slice}; @@ -194,15 +195,8 @@ impl DataFrame { take_idx = slice_slice(take_idx, offset, len); } - // Safety: - // join tuples are in bounds - let right_df = unsafe { - other.take_opt_iter_unchecked( - take_idx - .iter() - .map(|opt_idx| opt_idx.map(|idx| idx as usize)), - ) - }; + // SAFETY: join tuples are in bounds. + let right_df = unsafe { other.take_unchecked(&take_idx.iter().copied().collect_ca("")) }; _finish_join(left, right_df, suffix.as_deref()) } diff --git a/crates/polars-core/src/frame/cross_join.rs b/crates/polars-core/src/frame/cross_join.rs index 46d73cb85a74..bdd514a15bd7 100644 --- a/crates/polars-core/src/frame/cross_join.rs +++ b/crates/polars-core/src/frame/cross_join.rs @@ -1,5 +1,6 @@ use smartstring::alias::String as SmartString; +use crate::frame::hash_join::_finish_join; use crate::prelude::*; use crate::series::IsSorted; use crate::utils::{concat_df_unchecked, slice_offsets, CustomIterTools, NoNull}; diff --git a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs index b636e1883f0d..caeea68b8fef 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/agg_list.rs @@ -157,7 +157,7 @@ impl AggList for BooleanChunked { let mut builder = ListBooleanChunkedBuilder::new(self.name(), groups.len(), self.len()); for idx in groups.all().iter() { - let ca = { self.take_unchecked(idx.into()) }; + let ca = { self.take_unchecked(idx) }; builder.append(&ca) } builder.finish().into_series() @@ -183,7 +183,7 @@ impl AggList for Utf8Chunked { let mut builder = ListUtf8ChunkedBuilder::new(self.name(), groups.len(), self.len()); for idx in groups.all().iter() { - let ca = { self.take_unchecked(idx.into()) }; + let ca = { self.take_unchecked(idx) }; builder.append(&ca) } builder.finish().into_series() @@ -208,7 +208,7 @@ impl AggList for BinaryChunked { let mut builder = ListBinaryChunkedBuilder::new(self.name(), groups.len(), self.len()); for idx in groups.all().iter() { - let ca = { self.take_unchecked(idx.into()) }; + let ca = { self.take_unchecked(idx) }; builder.append(&ca) } builder.finish().into_series() @@ -292,7 +292,7 @@ impl AggList for ListChunked { // SAFETY: // group tuples are in bounds { - let mut s = ca.take_unchecked(idx.into()); + let mut s = ca.take_unchecked(idx); let arr = s.chunks.pop().unwrap_unchecked_release(); list_values.push_unchecked(arr); @@ -362,7 +362,7 @@ impl AggList for ArrayChunked { // SAFETY: group tuples are in bounds { - let mut s = ca.take_unchecked(idx.into()); + let mut s = ca.take_unchecked(idx); let arr = s.chunks.pop().unwrap_unchecked_release(); list_values.push_unchecked(arr); } @@ -419,7 +419,7 @@ impl AggList for ObjectChunked { GroupsIndicator::Idx((_first, idx)) => { // SAFETY: // group tuples always in bounds - let group_vals = self.take_unchecked(idx.into()); + let group_vals = self.take_unchecked(idx); (group_vals, idx.len() as IdxSize) }, @@ -481,7 +481,7 @@ impl AggList for StructChunked { Some(self.dtype().clone()), ); for idx in groups.all().iter() { - let taken = s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)); + let taken = s.take_slice_unchecked(idx); builder.append_series(&taken).unwrap(); } builder.finish().into_series() diff --git a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs index 0e353fe12be2..abee55a44fd0 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/dispatch.rs @@ -24,8 +24,7 @@ impl Series { } else if !self.has_validity() { Some(idx.len() as IdxSize) } else { - let take = - unsafe { self.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) }; + let take = unsafe { self.take_slice_unchecked(idx) }; Some((take.len() - take.null_count()) as IdxSize) } }), @@ -49,31 +48,28 @@ impl Series { pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Series { let mut out = match groups { GroupsProxy::Idx(groups) => { - let mut iter = groups.iter().map(|(first, idx)| { - if idx.is_empty() { - None - } else { - Some(first as usize) - } - }); - // Safety: - // groups are always in bounds - self.take_opt_iter_unchecked(&mut iter) - }, - GroupsProxy::Slice { groups, .. } => { - let mut iter = - groups.iter().map( - |&[first, len]| { - if len == 0 { + let indices = groups + .iter() + .map( + |(first, idx)| { + if idx.is_empty() { None } else { - Some(first as usize) + Some(first) } }, - ); - // Safety: - // groups are always in bounds - self.take_opt_iter_unchecked(&mut iter) + ) + .collect_ca(""); + // SAFETY: groups are always in bounds. + self.take_unchecked(&indices) + }, + GroupsProxy::Slice { groups, .. } => { + let indices = groups + .iter() + .map(|&[first, len]| if len == 0 { None } else { Some(first) }) + .collect_ca(""); + // SAFETY: groups are always in bounds. + self.take_unchecked(&indices) }, }; if groups.is_sorted_flag() { @@ -90,7 +86,7 @@ impl Series { if idx.is_empty() { None } else { - let take = self.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)); + let take = self.take_slice_unchecked(idx); take.n_unique().ok().map(|v| v as IdxSize) } }), @@ -186,24 +182,31 @@ impl Series { pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Series { let out = match groups { GroupsProxy::Idx(groups) => { - let mut iter = groups.all().iter().map(|idx| { - if idx.is_empty() { - None - } else { - Some(idx[idx.len() - 1] as usize) - } - }); - self.take_opt_iter_unchecked(&mut iter) + let indices = groups + .all() + .iter() + .map(|idx| { + if idx.is_empty() { + None + } else { + Some(idx[idx.len() - 1]) + } + }) + .collect_ca(""); + self.take_unchecked(&indices) }, GroupsProxy::Slice { groups, .. } => { - let mut iter = groups.iter().map(|&[first, len]| { - if len == 0 { - None - } else { - Some((first + len - 1) as usize) - } - }); - self.take_opt_iter_unchecked(&mut iter) + let indices = groups + .iter() + .map(|&[first, len]| { + if len == 0 { + None + } else { + Some(first + len - 1) + } + }) + .collect_ca(""); + self.take_unchecked(&indices) }, }; self.restore_logical(out) diff --git a/crates/polars-core/src/frame/group_by/aggregations/mod.rs b/crates/polars-core/src/frame/group_by/aggregations/mod.rs index 0a0b1d00808b..51b1f24d3e9d 100644 --- a/crates/polars-core/src/frame/group_by/aggregations/mod.rs +++ b/crates/polars-core/src/frame/group_by/aggregations/mod.rs @@ -355,7 +355,7 @@ where if idx.is_empty() { return None; } - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; // checked with invalid quantile check take._quantile(quantile, interpol).unwrap_unchecked() }) @@ -429,7 +429,7 @@ where if idx.is_empty() { return None; } - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; take._median() }) }, @@ -984,7 +984,7 @@ where }) }, _ => { - let take = { self.take_unchecked(idx.into()) }; + let take = { self.take_unchecked(idx) }; take.mean() }, } @@ -1111,5 +1111,3 @@ where agg_median_generic::<_, Float64Type>(self, groups) } } - -impl ChunkedArray where ChunkedArray: ChunkTake + IntoSeries {} diff --git a/crates/polars-core/src/frame/group_by/hashing.rs b/crates/polars-core/src/frame/group_by/hashing.rs index 94c93a2d3c96..7b2d2e5efae9 100644 --- a/crates/polars-core/src/frame/group_by/hashing.rs +++ b/crates/polars-core/src/frame/group_by/hashing.rs @@ -11,17 +11,13 @@ use crate::datatypes::PlHashMap; use crate::frame::group_by::{GroupsIdx, IdxItem}; use crate::hashing::{ df_rows_to_hashes_threaded_vertical, series_to_hashes, this_partition, AsU64, IdBuildHasher, - IdxHash, + IdxHash, *, }; use crate::prelude::compare_inner::PartialEqInner; use crate::prelude::*; use crate::utils::{flatten, split_df, CustomIterTools}; use crate::POOL; -// We must strike a balance between cache coherence and resizing costs. -// Overallocation seems a lot more expensive than resizing so we start reasonable small. -pub(crate) const HASHMAP_INIT_SIZE: usize = 512; - fn get_init_size() -> usize { // we check if this is executed from the main thread // we don't want to pre-allocate this much if executed @@ -319,75 +315,6 @@ where finish_group_order(out, sorted) } -/// Utility function used as comparison function in the hashmap. -/// The rationale is that equality is an AND operation and therefore its probability of success -/// declines rapidly with the number of keys. Instead of first copying an entire row from both -/// sides and then do the comparison, we do the comparison value by value catching early failures -/// eagerly. -/// -/// # Safety -/// Doesn't check any bounds -#[inline] -pub(crate) unsafe fn compare_df_rows(keys: &DataFrame, idx_a: usize, idx_b: usize) -> bool { - for s in keys.get_columns() { - if !s.equal_element(idx_a, idx_b, s) { - return false; - } - } - true -} - -/// Populate a multiple key hashmap with row indexes. -/// Instead of the keys (which could be very large), the row indexes are stored. -/// To check if a row is equal the original DataFrame is also passed as ref. -/// When a hash collision occurs the indexes are ptrs to the rows and the rows are compared -/// on equality. -pub(crate) fn populate_multiple_key_hashmap( - hash_tbl: &mut HashMap, - // row index - idx: IdxSize, - // hash - original_h: u64, - // keys of the hash table (will not be inserted, the indexes will be used) - // the keys are needed for the equality check - keys: &DataFrame, - // value to insert - vacant_fn: G, - // function that gets a mutable ref to the occupied value in the hash table - mut occupied_fn: F, -) where - G: Fn() -> V, - F: FnMut(&mut V), - H: BuildHasher, -{ - let entry = hash_tbl - .raw_entry_mut() - // uses the idx to probe rows in the original DataFrame with keys - // to check equality to find an entry - // this does not invalidate the hashmap as this equality function is not used - // during rehashing/resize (then the keys are already known to be unique). - // Only during insertion and probing an equality function is needed - .from_hash(original_h, |idx_hash| { - // first check the hash values - // before we incur a cache miss - idx_hash.hash == original_h && { - let key_idx = idx_hash.idx; - // Safety: - // indices in a group_by operation are always in bounds. - unsafe { compare_df_rows(keys, key_idx as usize, idx as usize) } - } - }); - match entry { - RawEntryMut::Vacant(entry) => { - entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn()); - }, - RawEntryMut::Occupied(mut entry) => { - let (_k, v) = entry.get_key_value_mut(); - occupied_fn(v); - }, - } -} - #[inline] pub(crate) unsafe fn compare_keys<'a>( keys_cmp: &'a [Box], diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index 02452cd0a9ce..08ca7f5e9f47 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -265,10 +265,8 @@ impl<'df> GroupBy<'df> { .map(|s| { match groups { GroupsProxy::Idx(groups) => { - let mut iter = groups.first().iter().map(|first| *first as usize); - // Safety: - // groups are always in bounds - let mut out = unsafe { s.take_iter_unchecked(&mut iter) }; + // SAFETY: groups are always in bounds. + let mut out = unsafe { s.take_slice_unchecked(groups.first()) }; if groups.sorted { out.set_sorted_flag(s.is_sorted_flag()); }; @@ -276,7 +274,7 @@ impl<'df> GroupBy<'df> { }, GroupsProxy::Slice { groups, rolling } => { if *rolling && !groups.is_empty() { - // groups can be sliced + // Groups can be sliced. let offset = groups[0][0]; let [upper_offset, upper_len] = groups[groups.len() - 1]; return s.slice( @@ -285,11 +283,10 @@ impl<'df> GroupBy<'df> { ); } - let mut iter = groups.iter().map(|&[first, _len]| first as usize); - // Safety: - // groups are always in bounds - let mut out = unsafe { s.take_iter_unchecked(&mut iter) }; - // sliced groups are always in order of discovery + let indices = groups.iter().map(|&[first, _len]| first).collect_ca(""); + // SAFETY: groups are always in bounds. + let mut out = unsafe { s.take_unchecked(&indices) }; + // Sliced groups are always in order of discovery. out.set_sorted_flag(s.is_sorted_flag()); out }, @@ -838,7 +835,7 @@ impl<'df> GroupBy<'df> { unsafe fn take_df(df: &DataFrame, g: GroupsIndicator) -> DataFrame { match g { - GroupsIndicator::Idx(idx) => df.take_iter_unchecked(idx.1.iter().map(|i| *i as usize)), + GroupsIndicator::Idx(idx) => df.take_slice_unchecked(idx.1), GroupsIndicator::Slice([first, len]) => df.slice(first as i64, len as usize), } } diff --git a/crates/polars-core/src/frame/hash_join/args.rs b/crates/polars-core/src/frame/hash_join/args.rs index 4b632263f315..f95b294d6787 100644 --- a/crates/polars-core/src/frame/hash_join/args.rs +++ b/crates/polars-core/src/frame/hash_join/args.rs @@ -15,9 +15,6 @@ pub type ChunkJoinOptIds = Vec>; #[cfg(not(feature = "chunked_ids"))] pub type ChunkJoinIds = Vec; -/// [ChunkIdx, DfIdx] -pub type ChunkId = [IdxSize; 2]; - #[derive(Clone, PartialEq, Eq, Debug)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct JoinArgs { diff --git a/crates/polars-core/src/frame/hash_join/mod.rs b/crates/polars-core/src/frame/hash_join/mod.rs index 8facfcbb323f..3e47d4f9e2c1 100644 --- a/crates/polars-core/src/frame/hash_join/mod.rs +++ b/crates/polars-core/src/frame/hash_join/mod.rs @@ -36,7 +36,6 @@ pub(crate) use zip_outer::*; pub use self::multiple_keys::private_left_join_multiple_keys; use crate::datatypes::PlHashMap; -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; pub use crate::frame::hash_join::multiple_keys::{ _inner_join_multiple_keys, _left_join_multiple_keys, _outer_join_multiple_keys, }; @@ -46,7 +45,7 @@ pub use crate::frame::hash_join::multiple_keys::{ }; use crate::hashing::{ create_hash_and_keys_threaded_vectorized, prepare_hashed_relation_threaded, this_partition, - AsU64, BytesHash, + AsU64, BytesHash, HASHMAP_INIT_SIZE, }; use crate::prelude::*; use crate::utils::{_set_partition_size, slice_slice, split_ca}; @@ -213,11 +212,7 @@ impl DataFrame { if let Some((offset, len)) = args.slice { right_idx = slice_slice(right_idx, offset, len); } - unsafe { - other.take_opt_iter_unchecked( - right_idx.iter().map(|opt_i| opt_i.map(|i| i as usize)), - ) - } + unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } }; let (df_left, df_right) = POOL.join(materialize_left, materialize_right); @@ -257,11 +252,7 @@ impl DataFrame { if let Some((offset, len)) = slice { right_idx = slice_slice(right_idx, offset, len); } - unsafe { - other.take_opt_iter_unchecked( - right_idx.iter().map(|opt_i| opt_i.map(|i| i as usize)), - ) - } + unsafe { other.take_unchecked(&right_idx.iter().copied().collect_ca("")) } }, ChunkJoinOptIds::Right(right_idx) => { let mut right_idx = &*right_idx; @@ -359,17 +350,21 @@ impl DataFrame { // Take the left and right dataframes by join tuples let (mut df_left, df_right) = POOL.join( || unsafe { - self.drop(s_left.name()).unwrap().take_opt_iter_unchecked( - opt_join_tuples + self.drop(s_left.name()).unwrap().take_unchecked( + &opt_join_tuples .iter() - .map(|(left, _right)| left.map(|i| i as usize)), + .copied() + .map(|(left, _right)| left) + .collect_ca("outer-join-left-indices"), ) }, || unsafe { - other.drop(s_right.name()).unwrap().take_opt_iter_unchecked( - opt_join_tuples + other.drop(s_right.name()).unwrap().take_unchecked( + &opt_join_tuples .iter() - .map(|(_left, right)| right.map(|i| i as usize)), + .copied() + .map(|(_left, right)| right) + .collect_ca("outer-join-right-indices"), ) }, ); diff --git a/crates/polars-core/src/frame/hash_join/multiple_keys.rs b/crates/polars-core/src/frame/hash_join/multiple_keys.rs index a1eb835f3607..59502b08ccb2 100644 --- a/crates/polars-core/src/frame/hash_join/multiple_keys.rs +++ b/crates/polars-core/src/frame/hash_join/multiple_keys.rs @@ -3,11 +3,13 @@ use hashbrown::HashMap; use rayon::prelude::*; use super::*; -use crate::frame::group_by::hashing::{populate_multiple_key_hashmap, HASHMAP_INIT_SIZE}; use crate::frame::hash_join::{ get_hash_tbl_threaded_join_mut_partitioned, get_hash_tbl_threaded_join_partitioned, }; -use crate::hashing::{df_rows_to_hashes_threaded_vertical, this_partition, IdBuildHasher, IdxHash}; +use crate::hashing::{ + df_rows_to_hashes_threaded_vertical, populate_multiple_key_hashmap, this_partition, + IdBuildHasher, IdxHash, HASHMAP_INIT_SIZE, +}; use crate::prelude::*; use crate::utils::series::_to_physical_and_bit_repr; use crate::utils::{_set_partition_size, split_df}; diff --git a/crates/polars-core/src/frame/hash_join/sort_merge.rs b/crates/polars-core/src/frame/hash_join/sort_merge.rs index 6db0b0a4e086..1d00732e74a3 100644 --- a/crates/polars-core/src/frame/hash_join/sort_merge.rs +++ b/crates/polars-core/src/frame/hash_join/sort_merge.rs @@ -226,7 +226,7 @@ pub fn _sort_or_hash_inner( multithreaded: true, maintain_order: false, }); - let s_right = unsafe { s_right.take_unchecked(&sort_idx).unwrap() }; + let s_right = unsafe { s_right.take_unchecked(&sort_idx) }; let ids = par_sorted_merge_inner_no_nulls(s_left, &s_right); let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); @@ -253,7 +253,7 @@ pub fn _sort_or_hash_inner( multithreaded: true, maintain_order: false, }); - let s_left = unsafe { s_left.take_unchecked(&sort_idx).unwrap() }; + let s_left = unsafe { s_left.take_unchecked(&sort_idx) }; let ids = par_sorted_merge_inner_no_nulls(&s_left, s_right); let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); @@ -322,7 +322,7 @@ pub(super) fn sort_or_hash_left( multithreaded: true, maintain_order: false, }); - let s_right = unsafe { s_right.take_unchecked(&sort_idx).unwrap() }; + let s_right = unsafe { s_right.take_unchecked(&sort_idx) }; let ids = par_sorted_merge_left(s_left, &s_right); let reverse_idx_map = create_reverse_map_from_arg_sort(sort_idx); diff --git a/crates/polars-core/src/frame/hash_join/zip_outer.rs b/crates/polars-core/src/frame/hash_join/zip_outer.rs index 205aefabfa40..927ffcfa726f 100644 --- a/crates/polars-core/src/frame/hash_join/zip_outer.rs +++ b/crates/polars-core/src/frame/hash_join/zip_outer.rs @@ -12,7 +12,7 @@ pub trait ZipOuterJoinColumn { impl ZipOuterJoinColumn for ChunkedArray where - T: PolarsIntegerType, + T: PolarsDataType, ChunkedArray: IntoSeries, { unsafe fn zip_outer_join_column( @@ -22,102 +22,36 @@ where ) -> Series { let right_ca = self.unpack_series_matching_type(right_column).unwrap(); - let left_rand_access = self.take_rand(); - let right_rand_access = right_ca.take_rand(); - - opt_join_tuples - .iter() - .map(|(opt_left_idx, opt_right_idx)| { - if let Some(left_idx) = opt_left_idx { - unsafe { left_rand_access.get_unchecked(*left_idx as usize) } - } else { - unsafe { - let right_idx = opt_right_idx.unwrap_unchecked(); - right_rand_access.get_unchecked(right_idx as usize) + if self.null_count() == 0 && right_ca.null_count() == 0 { + opt_join_tuples + .iter() + .map(|(opt_left_idx, opt_right_idx)| { + if let Some(left_idx) = opt_left_idx { + unsafe { self.value_unchecked(*left_idx as usize) } + } else { + unsafe { + let right_idx = opt_right_idx.unwrap_unchecked(); + right_ca.value_unchecked(right_idx as usize) + } } - } - }) - .collect_trusted::>() - .into_series() - } -} - -macro_rules! impl_zip_outer_join { - ($chunkedtype:ident) => { - impl ZipOuterJoinColumn for $chunkedtype { - unsafe fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - let right_ca = self.unpack_series_matching_type(right_column).unwrap(); - - let left_rand_access = self.take_rand(); - let right_rand_access = right_ca.take_rand(); - - opt_join_tuples - .iter() - .map(|(opt_left_idx, opt_right_idx)| { - if let Some(left_idx) = opt_left_idx { - unsafe { left_rand_access.get_unchecked(*left_idx as usize) } - } else { - unsafe { - let right_idx = opt_right_idx.unwrap_unchecked(); - right_rand_access.get_unchecked(right_idx as usize) - } + }) + .collect_ca_like(self) + .into_series() + } else { + opt_join_tuples + .iter() + .map(|(opt_left_idx, opt_right_idx)| { + if let Some(left_idx) = opt_left_idx { + unsafe { self.get_unchecked(*left_idx as usize) } + } else { + unsafe { + let right_idx = opt_right_idx.unwrap_unchecked(); + right_ca.get_unchecked(right_idx as usize) } - }) - .collect::<$chunkedtype>() - .into_series() - } - } - }; -} -impl_zip_outer_join!(BooleanChunked); -impl_zip_outer_join!(BinaryChunked); - -impl ZipOuterJoinColumn for Utf8Chunked { - unsafe fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - unsafe { - let out = self.as_binary().zip_outer_join_column( - &right_column.cast_unchecked(&DataType::Binary).unwrap(), - opt_join_tuples, - ); - out.cast_unchecked(&DataType::Utf8).unwrap_unchecked() + } + }) + .collect_ca_like(self) + .into_series() } } } - -impl ZipOuterJoinColumn for Float32Chunked { - unsafe fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - self.apply_as_ints(|s| { - s.zip_outer_join_column( - &right_column.bit_repr_small().into_series(), - opt_join_tuples, - ) - }) - } -} - -impl ZipOuterJoinColumn for Float64Chunked { - unsafe fn zip_outer_join_column( - &self, - right_column: &Series, - opt_join_tuples: &[(Option, Option)], - ) -> Series { - self.apply_as_ints(|s| { - s.zip_outer_join_column( - &right_column.bit_repr_large().into_series(), - opt_join_tuples, - ) - }) - } -} diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index b995366137d2..03bcf7b14cc7 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -7,6 +7,7 @@ use ahash::AHashSet; use polars_arrow::prelude::QuantileInterpolOptions; use rayon::prelude::*; +#[cfg(feature = "algorithm_group_by")] use crate::chunked_array::ops::unique::is_unique_helper; use crate::prelude::*; #[cfg(feature = "describe")] @@ -22,7 +23,9 @@ mod chunks; pub(crate) mod cross_join; pub mod explode; mod from; +#[cfg(feature = "algorithm_group_by")] pub mod group_by; +#[cfg(feature = "algorithm_join")] pub mod hash_join; #[cfg(feature = "rows")] pub mod row; @@ -34,6 +37,7 @@ pub use chunks::*; use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::GroupsIndicator; #[cfg(feature = "row_hash")] use crate::hashing::df_rows_to_hashes_threaded_vertical; @@ -1670,103 +1674,6 @@ impl DataFrame { Ok(DataFrame::new_no_checks(new_col)) } - /// Take [`DataFrame`] value by indexes from an iterator. - /// - /// # Example - /// - /// ``` - /// # use polars_core::prelude::*; - /// fn example(df: &DataFrame) -> PolarsResult { - /// let iterator = (0..9).into_iter(); - /// df.take_iter(iterator) - /// } - /// ``` - pub fn take_iter(&self, iter: I) -> PolarsResult - where - I: Iterator + Clone + Sync + TrustedLen, - { - let new_col = self.try_apply_columns_par(&|s| { - let mut i = iter.clone(); - s.take_iter(&mut i) - })?; - - Ok(DataFrame::new_no_checks(new_col)) - } - - /// Take [`DataFrame`] values by indexes from an iterator. - /// - /// # Safety - /// - /// This doesn't do any bound checking but checks null validity. - #[must_use] - pub unsafe fn take_iter_unchecked(&self, mut iter: I) -> Self - where - I: Iterator + Clone + Sync + TrustedLen, - { - let n_chunks = self.n_chunks(); - let has_utf8 = self - .columns - .iter() - .any(|s| matches!(s.dtype(), DataType::Utf8)); - - if (n_chunks == 1 && self.width() > 1) || has_utf8 { - let idx_ca: NoNull = iter.map(|idx| idx as IdxSize).collect(); - let idx_ca = idx_ca.into_inner(); - return self.take_unchecked(&idx_ca); - } - - let new_col = if self.width() == 1 { - self.columns - .iter() - .map(|s| s.take_iter_unchecked(&mut iter)) - .collect::>() - } else { - self.apply_columns_par(&|s| { - let mut i = iter.clone(); - s.take_iter_unchecked(&mut i) - }) - }; - DataFrame::new_no_checks(new_col) - } - - /// Take [`DataFrame`] values by indexes from an iterator that may contain None values. - /// - /// # Safety - /// - /// This doesn't do any bound checking. Out of bounds may access uninitialized memory. - /// Null validity is checked - #[must_use] - pub unsafe fn take_opt_iter_unchecked(&self, mut iter: I) -> Self - where - I: Iterator> + Clone + Sync + TrustedLen, - { - let n_chunks = self.n_chunks(); - - let has_utf8 = self - .columns - .iter() - .any(|s| matches!(s.dtype(), DataType::Utf8)); - - if (n_chunks == 1 && self.width() > 1) || has_utf8 { - let idx_ca: IdxCa = iter.map(|opt| opt.map(|v| v as IdxSize)).collect(); - return self.take_unchecked(&idx_ca); - } - - let new_col = if self.width() == 1 { - self.columns - .iter() - .map(|s| s.take_opt_iter_unchecked(&mut iter)) - .collect::>() - } else { - self.apply_columns_par(&|s| { - let mut i = iter.clone(); - s.take_opt_iter_unchecked(&mut i) - }) - }; - - DataFrame::new_no_checks(new_col) - } - /// Take [`DataFrame`] rows by index values. /// /// # Example @@ -1779,22 +1686,19 @@ impl DataFrame { /// } /// ``` pub fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; let new_col = POOL.install(|| { self.try_apply_columns_par(&|s| match s.dtype() { - DataType::Utf8 => s.take_threaded(&indices, true), - _ => s.take(&indices), + DataType::Utf8 => s.take_threaded(indices, true), + _ => s.take(indices), }) })?; Ok(DataFrame::new_no_checks(new_col)) } - pub(crate) unsafe fn take_unchecked(&self, idx: &IdxCa) -> Self { + /// # Safety + /// The indices must be in-bounds. + pub unsafe fn take_unchecked(&self, idx: &IdxCa) -> Self { self.take_unchecked_impl(idx, true) } @@ -1802,14 +1706,32 @@ impl DataFrame { let cols = if allow_threads { POOL.install(|| { self.apply_columns_par(&|s| match s.dtype() { - DataType::Utf8 => s.take_unchecked_threaded(idx, true).unwrap(), - _ => s.take_unchecked(idx).unwrap(), + DataType::Utf8 => s.take_unchecked_threaded(idx, true), + _ => s.take_unchecked(idx), + }) + }) + } else { + self.columns.iter().map(|s| s.take_unchecked(idx)).collect() + }; + DataFrame::new_no_checks(cols) + } + + pub(crate) unsafe fn take_slice_unchecked(&self, idx: &[IdxSize]) -> Self { + self.take_slice_unchecked_impl(idx, true) + } + + unsafe fn take_slice_unchecked_impl(&self, idx: &[IdxSize], allow_threads: bool) -> Self { + let cols = if allow_threads { + POOL.install(|| { + self.apply_columns_par(&|s| match s.dtype() { + DataType::Utf8 => s.take_slice_unchecked_threaded(idx, true), + _ => s.take_slice_unchecked(idx), }) }) } else { self.columns .iter() - .map(|s| s.take_unchecked(idx).unwrap()) + .map(|s| s.take_slice_unchecked(idx)) .collect() }; DataFrame::new_no_checks(cols) @@ -3065,6 +2987,7 @@ impl DataFrame { /// | 3 | 3 | "c" | /// +-----+-----+-----+ /// ``` + #[cfg(feature = "algorithm_group_by")] pub fn unique_stable( &self, subset: Option<&[String]>, @@ -3075,6 +2998,7 @@ impl DataFrame { } /// Unstable distinct. See [`DataFrame::unique_stable`]. + #[cfg(feature = "algorithm_group_by")] pub fn unique( &self, subset: Option<&[String]>, @@ -3084,6 +3008,7 @@ impl DataFrame { self.unique_impl(false, subset, keep, slice) } + #[cfg(feature = "algorithm_group_by")] pub fn unique_impl( &self, maintain_order: bool, @@ -3169,6 +3094,7 @@ impl DataFrame { /// assert!(ca.all()); /// # Ok::<(), PolarsError>(()) /// ``` + #[cfg(feature = "algorithm_group_by")] pub fn is_unique(&self) -> PolarsResult { let gb = self.group_by(self.get_column_names())?; let groups = gb.take_groups(); @@ -3193,6 +3119,7 @@ impl DataFrame { /// assert!(!ca.all()); /// # Ok::<(), PolarsError>(()) /// ``` + #[cfg(feature = "algorithm_group_by")] pub fn is_duplicated(&self) -> PolarsResult { let gb = self.group_by(self.get_column_names())?; let groups = gb.take_groups(); @@ -3327,7 +3254,7 @@ impl DataFrame { self.take_unchecked_impl(&ca, allow_threads) } - #[cfg(feature = "partition_by")] + #[cfg(all(feature = "partition_by", feature = "algorithm_group_by"))] #[doc(hidden)] pub fn _partition_by_impl( &self, diff --git a/crates/polars-core/src/hashing/mod.rs b/crates/polars-core/src/hashing/mod.rs index 110ead6db68e..c94bc8cc1b8c 100644 --- a/crates/polars-core/src/hashing/mod.rs +++ b/crates/polars-core/src/hashing/mod.rs @@ -7,6 +7,8 @@ use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; use ahash::RandomState; pub use fx::*; +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; pub use identity::*; pub(crate) use partition::*; pub use vector_hasher::*; @@ -18,3 +20,76 @@ use crate::prelude::*; pub fn _boost_hash_combine(l: u64, r: u64) -> u64 { l ^ r.wrapping_add(0x9e3779b9u64.wrapping_add(l << 6).wrapping_add(r >> 2)) } + +// We must strike a balance between cache +// Overallocation seems a lot more expensive than resizing so we start reasonable small. +pub(crate) const HASHMAP_INIT_SIZE: usize = 512; + +/// Utility function used as comparison function in the hashmap. +/// The rationale is that equality is an AND operation and therefore its probability of success +/// declines rapidly with the number of keys. Instead of first copying an entire row from both +/// sides and then do the comparison, we do the comparison value by value catching early failures +/// eagerly. +/// +/// # Safety +/// Doesn't check any bounds +#[inline] +pub(crate) unsafe fn compare_df_rows(keys: &DataFrame, idx_a: usize, idx_b: usize) -> bool { + for s in keys.get_columns() { + if !s.equal_element(idx_a, idx_b, s) { + return false; + } + } + true +} + +/// Populate a multiple key hashmap with row indexes. +/// Instead of the keys (which could be very large), the row indexes are stored. +/// To check if a row is equal the original DataFrame is also passed as ref. +/// When a hash collision occurs the indexes are ptrs to the rows and the rows are compared +/// on equality. +pub(crate) fn populate_multiple_key_hashmap( + hash_tbl: &mut HashMap, + // row index + idx: IdxSize, + // hash + original_h: u64, + // keys of the hash table (will not be inserted, the indexes will be used) + // the keys are needed for the equality check + keys: &DataFrame, + // value to insert + vacant_fn: G, + // function that gets a mutable ref to the occupied value in the hash table + mut occupied_fn: F, +) where + G: Fn() -> V, + F: FnMut(&mut V), + H: BuildHasher, +{ + let entry = hash_tbl + .raw_entry_mut() + // uses the idx to probe rows in the original DataFrame with keys + // to check equality to find an entry + // this does not invalidate the hashmap as this equality function is not used + // during rehashing/resize (then the keys are already known to be unique). + // Only during insertion and probing an equality function is needed + .from_hash(original_h, |idx_hash| { + // first check the hash values + // before we incur a cache miss + idx_hash.hash == original_h && { + let key_idx = idx_hash.idx; + // Safety: + // indices in a group_by operation are always in bounds. + unsafe { compare_df_rows(keys, key_idx as usize, idx as usize) } + } + }); + match entry { + RawEntryMut::Vacant(entry) => { + entry.insert_hashed_nocheck(original_h, IdxHash::new(idx, original_h), vacant_fn()); + }, + RawEntryMut::Occupied(mut entry) => { + let (_k, v) = entry.get_key_value_mut(); + occupied_fn(v); + }, + } +} diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs index 17ead65b8daa..a734db66e38b 100644 --- a/crates/polars-core/src/lib.rs +++ b/crates/polars-core/src/lib.rs @@ -10,11 +10,8 @@ extern crate core; #[macro_use] pub mod utils; pub mod chunked_array; -pub mod cloud; pub mod config; pub mod datatypes; -#[cfg(feature = "docs")] -pub mod doc; pub mod error; pub mod export; pub mod fmt; diff --git a/crates/polars-core/src/prelude.rs b/crates/polars-core/src/prelude.rs index e80e37899f85..dd603402e6ab 100644 --- a/crates/polars-core/src/prelude.rs +++ b/crates/polars-core/src/prelude.rs @@ -14,6 +14,7 @@ pub use crate::chunked_array::builder::{ ListBooleanChunkedBuilder, ListBuilderTrait, ListPrimitiveChunkedBuilder, ListUtf8ChunkedBuilder, NewChunkedArray, PrimitiveChunkedBuilder, Utf8ChunkedBuilder, }; +pub use crate::chunked_array::collect::{ChunkedCollectInferIterExt, ChunkedCollectIterExt}; pub use crate::chunked_array::iterator::PolarsIterator; #[cfg(feature = "dtype-categorical")] pub use crate::chunked_array::logical::categorical::*; @@ -31,16 +32,18 @@ pub use crate::chunked_array::ops::*; pub use crate::chunked_array::temporal::conversion::*; pub(crate) use crate::chunked_array::ChunkIdIter; pub use crate::chunked_array::ChunkedArray; -pub use crate::datatypes::*; +pub use crate::datatypes::{ArrayCollectIterExt, *}; pub use crate::error::{ polars_bail, polars_ensure, polars_err, polars_warn, PolarsError, PolarsResult, }; #[cfg(feature = "asof_join")] pub use crate::frame::asof_join::*; pub use crate::frame::explode::MeltArgs; +#[cfg(feature = "algorithm_group_by")] pub(crate) use crate::frame::group_by::aggregations::*; +#[cfg(feature = "algorithm_group_by")] pub use crate::frame::group_by::{GroupsIdx, GroupsProxy, GroupsSlice, IntoGroupsProxy}; -pub(crate) use crate::frame::hash_join::*; +#[cfg(feature = "algorithm_join")] pub use crate::frame::hash_join::{JoinArgs, JoinType}; pub use crate::frame::{DataFrame, UniqueKeepStrategy}; pub use crate::hashing::{FxHash, VecHash}; @@ -53,4 +56,4 @@ pub use crate::series::{IntoSeries, Series, SeriesTrait}; pub use crate::testing::*; pub(crate) use crate::utils::CustomIterTools; pub use crate::utils::IntoVec; -pub use crate::{cloud, datatypes, df}; +pub use crate::{datatypes, df}; diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index 09455842eab4..29625de7e11a 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -129,7 +129,7 @@ impl Schema { ) -> PolarsResult { polars_ensure!( index <= self.len(), - ComputeError: + OutOfBounds: "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", index, self.len() @@ -167,7 +167,7 @@ impl Schema { ) -> PolarsResult> { polars_ensure!( index <= self.len(), - ComputeError: + OutOfBounds: "index {} is out of bounds for schema with length {} (the max index allowed is self.len())", index, self.len() diff --git a/crates/polars-core/src/series/from.rs b/crates/polars-core/src/series/from.rs index 3c4876d75e80..05b00a5a9992 100644 --- a/crates/polars-core/src/series/from.rs +++ b/crates/polars-core/src/series/from.rs @@ -4,7 +4,8 @@ use arrow::compute::cast::utf8_to_large_utf8; #[cfg(any( feature = "dtype-date", feature = "dtype-datetime", - feature = "dtype-time" + feature = "dtype-time", + feature = "dtype-duration" ))] use arrow::temporal_conversions::*; use polars_arrow::compute::cast::cast; diff --git a/crates/polars-core/src/series/implementations/array.rs b/crates/polars-core/src/series/implementations/array.rs index a08feca9fe13..698003b873ea 100644 --- a/crates/polars-core/src/series/implementations/array.rs +++ b/crates/polars-core/src/series/implementations/array.rs @@ -4,7 +4,10 @@ use std::borrow::Cow; use super::{private, IntoSeries, SeriesTrait}; use crate::chunked_array::comparison::*; use crate::chunked_array::ops::explode::ExplodeByOffsets; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::{AsSinglePtr, Settings}; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -43,10 +46,12 @@ impl private::PrivateSeries for SeriesWrap { ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -104,38 +109,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(self.0.take((&*indices).into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(self.0.take_unchecked((&*idx).into()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index a1030bc18b5f..09e4623327af 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -9,7 +9,9 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -59,10 +61,12 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -85,6 +89,7 @@ impl private::PrivateSeries for SeriesWrap { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -148,47 +153,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(self.0.take((&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - - let mut out = self.0.take_unchecked((&*idx).into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_series()) - } - - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() - } - - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -232,14 +209,17 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 5a4150b5c7f7..8dcf007385b4 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -10,7 +10,9 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::{AsSinglePtr, ChunkIdIter}; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -60,27 +62,33 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.0.agg_sum(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { self.0 .cast(&DataType::Float64) .unwrap() .agg_std(groups, _ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { self.0 .cast(&DataType::Float64) @@ -88,6 +96,7 @@ impl private::PrivateSeries for SeriesWrap { .agg_var(groups, _ddof) } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -95,6 +104,7 @@ impl private::PrivateSeries for SeriesWrap { ) -> Series { ZipOuterJoinColumn::zip_outer_join_column(&self.0, right_column, opt_join_tuples) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -179,38 +189,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(self.0.take((&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(self.0.take_unchecked((&*idx).into()).into_series()) - } - - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() - } - - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -254,14 +245,17 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index a09a802ea47a..836c339e986f 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -8,7 +8,9 @@ use crate::chunked_array::comparison::*; use crate::chunked_array::ops::compare_inner::{IntoPartialOrdInner, PartialOrdInner}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -105,6 +107,7 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect let list = self.0.logical().agg_list(groups); @@ -113,6 +116,7 @@ impl private::PrivateSeries for SeriesWrap { list.into_series() } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -137,6 +141,7 @@ impl private::PrivateSeries for SeriesWrap { CategoricalChunked::from_cats_and_rev_map_unchecked(cats, new_rev_map).into_series() } } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { #[cfg(feature = "performant")] { @@ -214,45 +219,23 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - self.try_with_state(false, |cats| cats.take((&*indices).into())) + self.try_with_state(false, |cats| cats.take(indices)) .map(|ca| ca.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - let cats = self.0.logical().take(iter.into())?; - Ok(self.finish_with_state(false, cats).into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - let cats = self.0.logical().take_unchecked(iter.into()); - self.finish_with_state(false, cats).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(self - .with_state(false, |cats| cats.take_unchecked((&*idx).into())) - .into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.with_state(false, |cats| cats.take_unchecked(indices)) + .into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - let cats = self.0.logical().take_unchecked(iter.into()); - self.finish_with_state(false, cats).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + self.try_with_state(false, |cats| cats.take(indices)) + .map(|ca| ca.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - let cats = self.0.logical().take(iter.into())?; - Ok(self.finish_with_state(false, cats).into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.with_state(false, |cats| cats.take_unchecked(indices)) + .into_series() } fn len(&self) -> usize { @@ -297,14 +280,17 @@ impl SeriesTrait for SeriesWrap { self.0.logical().has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0.unique().map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.logical().arg_unique() } diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index e2af95050038..0b4ea77fc194 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -17,7 +17,9 @@ use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::ops::ToBitRepr; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::*; use crate::prelude::*; @@ -90,14 +92,17 @@ macro_rules! impl_dyn_series { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups).$into_logical().into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups).$into_logical().into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 @@ -106,6 +111,7 @@ macro_rules! impl_dyn_series { .unwrap() } +#[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -153,6 +159,7 @@ macro_rules! impl_dyn_series { fn remainder(&self, rhs: &Series) -> PolarsResult { polars_bail!(opq = rem, self.0.dtype(), rhs.dtype()); } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } @@ -238,43 +245,19 @@ macro_rules! impl_dyn_series { } fn take(&self, indices: &IdxCa) -> PolarsResult { - self.0.deref().take(indices.into()) - .map(|ca| ca.$into_logical().into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - self.0.deref().take(iter.into()) - .map(|ca| ca.$into_logical().into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.deref().take_unchecked(iter.into()) - .$into_logical() - .into_series() + Ok(self.0.take(indices)?.$into_logical().into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = self.0.deref().take_unchecked(idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.$into_logical().into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).$into_logical().into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.deref().take_unchecked(iter.into()) - .$into_logical() - .into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.$into_logical().into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - self.0.deref().take(iter.into()) - .map(|ca| ca.$into_logical().into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).$into_logical().into_series() } fn len(&self) -> usize { @@ -294,6 +277,7 @@ macro_rules! impl_dyn_series { fn cast(&self, data_type: &DataType) -> PolarsResult { match (self.dtype(), data_type) { + #[cfg(feature="dtype-date")] (DataType::Date, DataType::Utf8) => Ok(self .0 .clone() @@ -302,6 +286,7 @@ macro_rules! impl_dyn_series { .unwrap() .to_string("%Y-%m-%d") .into_series()), + #[cfg(feature="dtype-time")] (DataType::Time, DataType::Utf8) => Ok(self .0 .clone() @@ -352,14 +337,17 @@ macro_rules! impl_dyn_series { self.0.has_validity() } +#[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0.unique().map(|ca| ca.$into_logical().into_series()) } +#[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } +#[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.arg_unique() } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index e600a19f307c..3740612f2ea4 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -6,7 +6,9 @@ use ahash::RandomState; use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::*; use crate::prelude::*; @@ -84,6 +86,7 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0 .agg_min(groups) @@ -91,12 +94,14 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0 .agg_max(groups) .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 @@ -105,6 +110,7 @@ impl private::PrivateSeries for SeriesWrap { .unwrap() } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -160,6 +166,7 @@ impl private::PrivateSeries for SeriesWrap { fn remainder(&self, rhs: &Series) -> PolarsResult { polars_bail!(opq = rem, self.dtype(), rhs.dtype()); } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } @@ -243,53 +250,31 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - self.0.deref().take(indices.into()).map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - self.0.deref().take(iter.into()).map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) + let ca = self.0.take(indices)?; + Ok(ca + .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) + .into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - let ca = self.0.deref().take_unchecked(iter.into()); + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + let ca = self.0.take_unchecked(indices); ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = self.0.deref().take_unchecked(idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + let ca = self.0.take(indices)?; + Ok(ca .into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series()) } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - let ca = self.0.deref().take_unchecked(iter.into()); + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + let ca = self.0.take_unchecked(indices); ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) .into_series() } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - self.0.deref().take(iter.into()).map(|ca| { - ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) - .into_series() - }) - } - fn len(&self) -> usize { self.0.len() } @@ -351,6 +336,7 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0.unique().map(|ca| { ca.into_datetime(self.0.time_unit(), self.0.time_zone().clone()) @@ -358,10 +344,12 @@ impl SeriesTrait for SeriesWrap { }) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.arg_unique() } diff --git a/crates/polars-core/src/series/implementations/decimal.rs b/crates/polars-core/src/series/implementations/decimal.rs index 370fa7c3edda..0b1d6954886b 100644 --- a/crates/polars-core/src/series/implementations/decimal.rs +++ b/crates/polars-core/src/series/implementations/decimal.rs @@ -67,18 +67,22 @@ impl private::PrivateSeries for SeriesWrap { .into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.agg_helper(|ca| ca.agg_sum(groups)) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.agg_helper(|ca| ca.agg_min(groups)) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.agg_helper(|ca| ca.agg_max(groups)) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } @@ -158,48 +162,36 @@ impl SeriesTrait for SeriesWrap { self.apply_physical(|ca| ca.take_opt_chunked_unchecked(by)) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - self.0.deref().take(iter.into()).map(|ca| { - ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series() - }) + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_decimal_unchecked(self.0.precision(), self.0.scale()) + .into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - let ca = self.0.deref().take_unchecked(iter.into()); - ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0 + .take_unchecked(indices) + .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = self.0.deref().take_unchecked(idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self + .0 + .take(indices)? .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series()) } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { self.0 - .deref() - .take_unchecked(iter.into()) + .take_unchecked(indices) .into_decimal_unchecked(self.0.precision(), self.0.scale()) .into_series() } - fn take(&self, indices: &IdxCa) -> PolarsResult { - self.0.deref().take(indices.into()).map(|ca| { - ca.into_decimal_unchecked(self.0.precision(), self.0.scale()) - .into_series() - }) - } - fn len(&self) -> usize { self.0.len() } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index a62fe654ec27..58d2d6558e14 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -7,7 +7,9 @@ use super::{private, IntoSeries, SeriesTrait, SeriesWrap, *}; use crate::chunked_array::comparison::*; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::*; use crate::prelude::*; @@ -89,6 +91,7 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0 .agg_min(groups) @@ -96,6 +99,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0 .agg_max(groups) @@ -103,6 +107,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.0 .agg_sum(groups) @@ -110,6 +115,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0 .agg_std(groups, ddof) @@ -120,6 +126,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0 .agg_var(groups, ddof) @@ -130,6 +137,7 @@ impl private::PrivateSeries for SeriesWrap { .into_series() } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { // we cannot cast and dispatch as the inner type of the list would be incorrect self.0 @@ -138,6 +146,7 @@ impl private::PrivateSeries for SeriesWrap { .unwrap() } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -197,6 +206,7 @@ impl private::PrivateSeries for SeriesWrap { .into_duration(self.0.time_unit()) .into_series()) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { self.0.group_tuples(multithreaded, sorted) } @@ -277,41 +287,33 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let ca = self.0.deref().take(indices.into())?; - Ok(ca.into_duration(self.0.time_unit()).into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - let ca = self.0.deref().take(iter.into())?; - Ok(ca.into_duration(self.0.time_unit()).into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - let ca = self.0.deref().take_unchecked(iter.into()); - ca.into_duration(self.0.time_unit()).into_series() + Ok(self + .0 + .take(indices)? + .into_duration(self.0.time_unit()) + .into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let mut out = self.0.deref().take_unchecked(idx.into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_duration(self.0.time_unit()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0 + .take_unchecked(indices) + .into_duration(self.0.time_unit()) + .into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - let ca = self.0.deref().take_unchecked(iter.into()); - ca.into_duration(self.0.time_unit()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self + .0 + .take(indices)? + .into_duration(self.0.time_unit()) + .into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - let ca = self.0.deref().take(iter.into())?; - Ok(ca.into_duration(self.0.time_unit()).into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0 + .take_unchecked(indices) + .into_duration(self.0.time_unit()) + .into_series() } fn len(&self) -> usize { @@ -364,16 +366,19 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { self.0 .unique() .map(|ca| ca.into_duration(self.0.time_unit()).into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { self.0.n_unique() } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { self.0.arg_unique() } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 120d4664f51b..4d0ac12d5b9c 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -11,7 +11,9 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; #[cfg(feature = "checked_arithmetic")] @@ -89,30 +91,37 @@ macro_rules! impl_dyn_series { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { self.0.agg_sum(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.agg_std(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.agg_var(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -135,6 +144,7 @@ macro_rules! impl_dyn_series { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -214,47 +224,19 @@ macro_rules! impl_dyn_series { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(self.0.take((&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - - let mut out = self.0.take_unchecked((&*idx).into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_series()) - } - - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() - } - - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -298,14 +280,17 @@ macro_rules! impl_dyn_series { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } diff --git a/crates/polars-core/src/series/implementations/list.rs b/crates/polars-core/src/series/implementations/list.rs index 0fe84c30b69e..0ae8b9ced37b 100644 --- a/crates/polars-core/src/series/implementations/list.rs +++ b/crates/polars-core/src/series/implementations/list.rs @@ -9,6 +9,7 @@ use crate::chunked_array::comparison::*; use crate::chunked_array::ops::compare_inner::{IntoPartialEqInner, PartialEqInner}; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::{AsSinglePtr, Settings}; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -45,10 +46,12 @@ impl private::PrivateSeries for SeriesWrap { ChunkZip::zip_with(&self.0, mask, other.as_ref().as_ref()).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -125,38 +128,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(self.0.take((&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(self.0.take_unchecked((&*idx).into()).into_series()) - } - - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() - } - - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -193,6 +177,7 @@ impl SeriesTrait for SeriesWrap { } #[cfg(feature = "group_by_list")] + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { if !self.inner_dtype().is_numeric() { polars_bail!(opq = unique, self.dtype()); @@ -209,6 +194,7 @@ impl SeriesTrait for SeriesWrap { } #[cfg(feature = "group_by_list")] + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { if !self.inner_dtype().is_numeric() { polars_bail!(opq = n_unique, self.dtype()); @@ -226,6 +212,7 @@ impl SeriesTrait for SeriesWrap { } #[cfg(feature = "group_by_list")] + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { if !self.inner_dtype().is_numeric() { polars_bail!(opq = arg_unique, self.dtype()); diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index c05a72ef59b9..36fac927f7dd 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -40,8 +40,12 @@ use crate::chunked_array::ops::compare_inner::{ IntoPartialEqInner, IntoPartialOrdInner, PartialEqInner, PartialOrdInner, }; use crate::chunked_array::ops::explode::ExplodeByOffsets; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; #[cfg(feature = "checked_arithmetic")] @@ -152,14 +156,17 @@ macro_rules! impl_dyn_series { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { use DataType::*; match self.dtype() { @@ -168,18 +175,22 @@ macro_rules! impl_dyn_series { } } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0.agg_std(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, ddof: u8) -> Series { self.0.agg_var(groups, ddof) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -202,6 +213,7 @@ macro_rules! impl_dyn_series { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -311,44 +323,19 @@ macro_rules! impl_dyn_series { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(self.0.take((&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - let mut out = self.0.take_unchecked((&*idx).into()); - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - Ok(out.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -392,14 +379,17 @@ macro_rules! impl_dyn_series { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } diff --git a/crates/polars-core/src/series/implementations/null.rs b/crates/polars-core/src/series/implementations/null.rs index 7efded5440f0..aeeeced2aab2 100644 --- a/crates/polars-core/src/series/implementations/null.rs +++ b/crates/polars-core/src/series/implementations/null.rs @@ -107,24 +107,20 @@ impl SeriesTrait for NullChunked { NullChunked::new(self.name.clone(), by.len()).into_series() } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(NullChunked::new(self.name.clone(), iter.size_hint().0).into_series()) - } - - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - NullChunked::new(self.name.clone(), iter.size_hint().0).into_series() + fn take(&self, indices: &IdxCa) -> PolarsResult { + Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - Ok(NullChunked::new(self.name.clone(), idx.len()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + NullChunked::new(self.name.clone(), indices.len()).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - NullChunked::new(self.name.clone(), iter.size_hint().0).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) } - fn take(&self, indices: &IdxCa) -> PolarsResult { - Ok(NullChunked::new(self.name.clone(), indices.len()).into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + NullChunked::new(self.name.clone(), indices.len()).into_series() } fn len(&self) -> usize { diff --git a/crates/polars-core/src/series/implementations/object.rs b/crates/polars-core/src/series/implementations/object.rs index 29667d7c2965..c58f37159d90 100644 --- a/crates/polars-core/src/series/implementations/object.rs +++ b/crates/polars-core/src/series/implementations/object.rs @@ -5,7 +5,10 @@ use ahash::RandomState; use crate::chunked_array::object::PolarsObjectSafe; use crate::chunked_array::ops::compare_inner::{IntoPartialEqInner, PartialEqInner}; +#[cfg(feature = "chunked_ids")] +use crate::chunked_array::ops::take::TakeChunked; use crate::chunked_array::Settings; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::{GroupsProxy, IntoGroupsProxy}; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -64,6 +67,7 @@ where Ok(()) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -132,33 +136,19 @@ where } fn take(&self, indices: &IdxCa) -> PolarsResult { - Ok(self.0.take(indices.into())?.into_series()) - } - - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() - } - - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - Ok(self.0.take_unchecked((&*idx).into()).into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, _iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - todo!() + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index da183572d7b7..16d758fd5257 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -58,10 +58,12 @@ impl private::PrivateSeries for SeriesWrap { Ok(StructChunked::new_unchecked(self.0.name(), &fields).into_series()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { let df = DataFrame::new_no_checks(vec![]); let gb = df @@ -176,16 +178,6 @@ impl SeriesTrait for SeriesWrap { .map(|ca| ca.into_series()) } - /// Take by index from an iterator. This operation clones the data. - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - self.0 - .try_apply_fields(|s| { - let mut iter = iter.boxed_clone(); - s.take_iter(&mut *iter) - }) - .map(|ca| ca.into_series()) - } - #[cfg(feature = "chunked_ids")] unsafe fn _take_chunked_unchecked(&self, by: &[ChunkId], sorted: IsSorted) -> Series { self.0 @@ -200,53 +192,30 @@ impl SeriesTrait for SeriesWrap { .into_series() } - /// Take by index from an iterator. This operation clones the data. - /// - /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0 - .apply_fields(|s| { - let mut iter = iter.boxed_clone(); - s.take_iter_unchecked(&mut *iter) - }) - .into_series() - } - - /// Take by index if ChunkedArray contains a single chunk. - /// - /// # Safety - /// This doesn't check any bounds. - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { + fn take(&self, indices: &IdxCa) -> PolarsResult { self.0 - .try_apply_fields(|s| s.take_unchecked(idx)) + .try_apply_fields(|s| s.take(indices)) .map(|ca| ca.into_series()) } - /// Take by index from an iterator. This operation clones the data. - /// - /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { self.0 - .apply_fields(|s| { - let mut iter = iter.boxed_clone(); - s.take_opt_iter_unchecked(&mut *iter) - }) + .apply_fields(|s| s.take_unchecked(indices)) .into_series() } - /// Take by index. This operation is clone. - fn take(&self, indices: &IdxCa) -> PolarsResult { + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { self.0 - .try_apply_fields(|s| s.take(indices)) + .try_apply_fields(|s| s.take_slice(indices)) .map(|ca| ca.into_series()) } + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0 + .apply_fields(|s| s.take_slice_unchecked(indices)) + .into_series() + } + /// Get length of series. fn len(&self) -> usize { self.0.len() @@ -283,6 +252,7 @@ impl SeriesTrait for SeriesWrap { } /// Get unique values in the Series. + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { // this can called in aggregation, so this fast path can be worth a lot if self.len() < 2 { @@ -296,6 +266,7 @@ impl SeriesTrait for SeriesWrap { } /// Get unique values in the Series. + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { // this can called in aggregation, so this fast path can be worth a lot match self.len() { @@ -311,6 +282,7 @@ impl SeriesTrait for SeriesWrap { } /// Get first indexes of unique values. + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { // this can called in aggregation, so this fast path can be worth a lot if self.len() == 1 { diff --git a/crates/polars-core/src/series/implementations/utf8.rs b/crates/polars-core/src/series/implementations/utf8.rs index 15e664f26f62..693ffb306c04 100644 --- a/crates/polars-core/src/series/implementations/utf8.rs +++ b/crates/polars-core/src/series/implementations/utf8.rs @@ -9,7 +9,9 @@ use crate::chunked_array::ops::compare_inner::{ }; use crate::chunked_array::ops::explode::ExplodeByOffsets; use crate::chunked_array::AsSinglePtr; +#[cfg(feature = "algorithm_group_by")] use crate::frame::group_by::*; +#[cfg(feature = "algorithm_join")] use crate::frame::hash_join::ZipOuterJoinColumn; use crate::prelude::*; use crate::series::implementations::SeriesWrap; @@ -60,18 +62,22 @@ impl private::PrivateSeries for SeriesWrap { Ok(()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { self.0.agg_list(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { self.0.agg_min(groups) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { self.0.agg_max(groups) } + #[cfg(feature = "algorithm_join")] unsafe fn zip_outer_join_column( &self, right_column: &Series, @@ -94,6 +100,7 @@ impl private::PrivateSeries for SeriesWrap { fn remainder(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::remainder(&self.0, rhs) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } @@ -163,47 +170,19 @@ impl SeriesTrait for SeriesWrap { } fn take(&self, indices: &IdxCa) -> PolarsResult { - let indices = if indices.chunks.len() > 1 { - Cow::Owned(indices.rechunk()) - } else { - Cow::Borrowed(indices) - }; - Ok(self.0.take((&*indices).into())?.into_series()) + Ok(self.0.take(indices)?.into_series()) } - fn take_iter(&self, iter: &mut dyn TakeIterator) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_unchecked(&self, indices: &IdxCa) -> Series { + self.0.take_unchecked(indices).into_series() } - unsafe fn take_iter_unchecked(&self, iter: &mut dyn TakeIterator) -> Series { - self.0.take_unchecked(iter.into()).into_series() + fn take_slice(&self, indices: &[IdxSize]) -> PolarsResult { + Ok(self.0.take(indices)?.into_series()) } - unsafe fn take_unchecked(&self, idx: &IdxCa) -> PolarsResult { - let idx = if idx.chunks.len() > 1 { - Cow::Owned(idx.rechunk()) - } else { - Cow::Borrowed(idx) - }; - - let mut out = self.0.take_unchecked((&*idx).into()); - - if self.0.is_sorted_ascending_flag() - && (idx.is_sorted_ascending_flag() || idx.is_sorted_descending_flag()) - { - out.set_sorted_flag(idx.is_sorted_flag()) - } - - Ok(out.into_series()) - } - - unsafe fn take_opt_iter_unchecked(&self, iter: &mut dyn TakeIteratorNulls) -> Series { - self.0.take_unchecked(iter.into()).into_series() - } - - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - Ok(self.0.take(iter.into())?.into_series()) + unsafe fn take_slice_unchecked(&self, indices: &[IdxSize]) -> Series { + self.0.take_unchecked(indices).into_series() } fn len(&self) -> usize { @@ -247,14 +226,17 @@ impl SeriesTrait for SeriesWrap { self.0.has_validity() } + #[cfg(feature = "algorithm_group_by")] fn unique(&self) -> PolarsResult { ChunkUnique::unique(&self.0).map(|ca| ca.into_series()) } + #[cfg(feature = "algorithm_group_by")] fn n_unique(&self) -> PolarsResult { ChunkUnique::n_unique(&self.0) } + #[cfg(feature = "algorithm_group_by")] fn arg_unique(&self) -> PolarsResult { ChunkUnique::arg_unique(&self.0) } diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 4e91bb8b8d22..9ae01a942c51 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -480,24 +480,31 @@ impl Series { /// /// # Safety /// This doesn't check any bounds. Null validity is checked. - pub unsafe fn take_unchecked_from_slice(&self, idx: &[IdxSize]) -> PolarsResult { - let idx = IdxCa::mmap_slice("", idx); - self.take_unchecked(&idx) + pub unsafe fn take_unchecked_from_slice(&self, idx: &[IdxSize]) -> Series { + self.take_slice_unchecked(idx) } /// Take by index if ChunkedArray contains a single chunk. /// /// # Safety /// This doesn't check any bounds. Null validity is checked. - pub unsafe fn take_unchecked_threaded( - &self, - idx: &IdxCa, - rechunk: bool, - ) -> PolarsResult { + pub unsafe fn take_unchecked_threaded(&self, idx: &IdxCa, rechunk: bool) -> Series { self.threaded_op(rechunk, idx.len(), &|offset, len| { let idx = idx.slice(offset as i64, len); - self.take_unchecked(&idx) + Ok(self.take_unchecked(&idx)) }) + .unwrap() + } + + /// Take by index if ChunkedArray contains a single chunk. + /// + /// # Safety + /// This doesn't check any bounds. Null validity is checked. + pub unsafe fn take_slice_unchecked_threaded(&self, idx: &[IdxSize], rechunk: bool) -> Series { + self.threaded_op(rechunk, idx.len(), &|offset, len| { + Ok(self.take_slice_unchecked(&idx[offset..offset + len])) + }) + .unwrap() } /// # Safety @@ -542,6 +549,13 @@ impl Series { }) } + /// Traverse and collect every nth element in a new array. + pub fn take_every(&self, n: usize) -> Series { + let idx = (0..self.len() as IdxSize).step_by(n).collect_ca(""); + // SAFETY: we stay in-bounds. + unsafe { self.take_unchecked(&idx) } + } + /// Filter by boolean mask. This operation clones data. pub fn filter_threaded(&self, filter: &BooleanChunked, rechunk: bool) -> PolarsResult { // This would fail if there is a broadcasting filter, because we cannot @@ -901,7 +915,7 @@ impl Series { pub fn unique_stable(&self) -> PolarsResult { let idx = self.arg_unique()?; // SAFETY: Indices are in bounds. - unsafe { self.take_unchecked(&idx) } + unsafe { Ok(self.take_unchecked(&idx)) } } pub fn idx(&self) -> PolarsResult<&IdxCa> { diff --git a/crates/polars-core/src/series/ops/unique.rs b/crates/polars-core/src/series/ops/unique.rs index cfae77d687e7..397fc85d4e8c 100644 --- a/crates/polars-core/src/series/ops/unique.rs +++ b/crates/polars-core/src/series/ops/unique.rs @@ -2,7 +2,7 @@ use std::hash::Hash; #[cfg(feature = "unique_counts")] -use crate::frame::group_by::hashing::HASHMAP_INIT_SIZE; +use crate::hashing::HASHMAP_INIT_SIZE; use crate::prelude::*; #[cfg(feature = "unique_counts")] use crate::utils::NoNull; diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index 49a5cbcc5378..8d1ee384bdae 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -126,23 +126,29 @@ pub(crate) mod private { ) -> PolarsResult<()> { polars_bail!(opq = vec_hash_combine, self._dtype()); } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_min(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_max(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } /// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is /// first cast to `Int64` to prevent overflow issues. + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_std(&self, groups: &GroupsProxy, _ddof: u8) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_var(&self, groups: &GroupsProxy, _ddof: u8) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } + #[cfg(feature = "algorithm_group_by")] unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series { Series::full_null(self._field().name(), groups.len(), self._dtype()) } @@ -170,6 +176,7 @@ pub(crate) mod private { fn remainder(&self, _rhs: &Series) -> PolarsResult { invalid_operation_panic!(rem, self) } + #[cfg(feature = "algorithm_group_by")] fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult { invalid_operation_panic!(group_tuples, self) } @@ -268,40 +275,23 @@ pub trait SeriesTrait: #[cfg(feature = "chunked_ids")] unsafe fn _take_opt_chunked_unchecked(&self, by: &[Option]) -> Series; - /// Take by index from an iterator. This operation clones the data. - fn take_iter(&self, _iter: &mut dyn TakeIterator) -> PolarsResult; - - /// Take by index from an iterator. This operation clones the data. - /// - /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_iter_unchecked(&self, _iter: &mut dyn TakeIterator) -> Series; + /// Take by index. This operation is clone. + fn take(&self, _indices: &IdxCa) -> PolarsResult; - /// Take by index if ChunkedArray contains a single chunk. + /// Take by index. /// /// # Safety /// This doesn't check any bounds. - unsafe fn take_unchecked(&self, _idx: &IdxCa) -> PolarsResult; + unsafe fn take_unchecked(&self, _idx: &IdxCa) -> Series; + + /// Take by index. This operation is clone. + fn take_slice(&self, _indices: &[IdxSize]) -> PolarsResult; - /// Take by index from an iterator. This operation clones the data. + /// Take by index. /// /// # Safety - /// - /// - This doesn't check any bounds. - /// - Iterator must be TrustedLen - unsafe fn take_opt_iter_unchecked(&self, _iter: &mut dyn TakeIteratorNulls) -> Series; - - /// Take by index from an iterator. This operation clones the data. - /// todo! remove? - #[cfg(feature = "take_opt_iter")] - fn take_opt_iter(&self, _iter: &mut dyn TakeIteratorNulls) -> PolarsResult { - invalid_operation_panic!(take_opt_iter, self) - } - - /// Take by index. This operation is clone. - fn take(&self, _indices: &IdxCa) -> PolarsResult; + /// This doesn't check any bounds. + unsafe fn take_slice_unchecked(&self, _idx: &[IdxSize]) -> Series; /// Get length of series. fn len(&self) -> usize; diff --git a/crates/polars-core/src/utils/mod.rs b/crates/polars-core/src/utils/mod.rs index e521aa274d3b..0eb20558cfca 100644 --- a/crates/polars-core/src/utils/mod.rs +++ b/crates/polars-core/src/utils/mod.rs @@ -673,7 +673,6 @@ where } #[allow(clippy::type_complexity)] -#[cfg(feature = "zip_with")] pub fn align_chunks_ternary<'a, A, B, C>( a: &'a ChunkedArray, b: &'a ChunkedArray, @@ -809,22 +808,6 @@ pub(crate) fn index_to_chunked_index< (current_chunk_idx, index_remainder) } -#[cfg(feature = "dtype-struct")] -pub(crate) fn index_to_chunked_index2(chunks: &[ArrayRef], index: usize) -> (usize, usize) { - let mut index_remainder = index; - let mut current_chunk_idx = 0; - - for chunk in chunks { - if chunk.len() > index_remainder { - break; - } else { - index_remainder -= chunk.len(); - current_chunk_idx += 1; - } - } - (current_chunk_idx, index_remainder) -} - #[cfg(feature = "chunked_ids")] pub(crate) fn create_chunked_index_mapping(chunks: &[ArrayRef], len: usize) -> Vec { let mut vals = Vec::with_capacity(len); diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index c1d8491813e2..db190d63370b 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -55,6 +55,8 @@ pub enum PolarsError { Io(#[from] io::Error), #[error("no data: {0}")] NoData(ErrString), + #[error("{0}")] + OutOfBounds(ErrString), #[error("field not found: {0}")] SchemaFieldNotFound(ErrString), #[error("data types don't match: {0}")] @@ -105,6 +107,7 @@ impl PolarsError { InvalidOperation(msg) => InvalidOperation(func(msg).into()), Io(err) => ComputeError(func(&format!("IO: {err}")).into()), NoData(msg) => NoData(func(msg).into()), + OutOfBounds(msg) => OutOfBounds(func(msg).into()), SchemaFieldNotFound(msg) => SchemaFieldNotFound(func(msg).into()), SchemaMismatch(msg) => SchemaMismatch(func(msg).into()), ShapeMismatch(msg) => ShapeMismatch(func(msg).into()), @@ -205,7 +208,7 @@ on startup."#.trim_start()) polars_err!(Duplicate: "column with name '{}' has more than one occurrences", $name) }; (oob = $idx:expr, $len:expr) => { - polars_err!(ComputeError: "index {} is out of bounds for sequence of size {}", $idx, $len) + polars_err!(OutOfBounds: "index {} is out of bounds for sequence of length {}", $idx, $len) }; (agg_len = $agg_len:expr, $groups_len:expr) => { polars_err!( diff --git a/crates/polars-ffi/Cargo.toml b/crates/polars-ffi/Cargo.toml new file mode 100644 index 000000000000..40e7376fce70 --- /dev/null +++ b/crates/polars-ffi/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "polars-ffi" +version.workspace = true +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +description = "FFI utils for the Polars project." + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +polars-core = { workspace = true } diff --git a/crates/polars-ffi/src/lib.rs b/crates/polars-ffi/src/lib.rs new file mode 100644 index 000000000000..699d5e7a7fd5 --- /dev/null +++ b/crates/polars-ffi/src/lib.rs @@ -0,0 +1,128 @@ +use std::mem::ManuallyDrop; + +use arrow::ffi; +use arrow::ffi::{ArrowArray, ArrowSchema}; +use polars_core::error::PolarsResult; +use polars_core::prelude::{ArrayRef, ArrowField, Series}; + +// A utility that helps releasing/owning memory. +#[allow(dead_code)] +struct PrivateData { + schema: Box, + arrays: Box<[*mut ArrowArray]>, +} + +/// An FFI exported `Series`. +#[repr(C)] +pub struct SeriesExport { + field: *mut ArrowSchema, + // A double ptr, so we can easily release the buffer + // without dropping the arrays. + arrays: *mut *mut ArrowArray, + len: usize, + release: Option, + private_data: *mut std::os::raw::c_void, +} + +impl Drop for SeriesExport { + fn drop(&mut self) { + if let Some(release) = self.release { + unsafe { release(self) } + } + } +} + +// callback used to drop [SeriesExport] when it is exported. +unsafe extern "C" fn c_release_series_export(e: *mut SeriesExport) { + if e.is_null() { + return; + } + let e = &mut *e; + let private = Box::from_raw(e.private_data as *mut PrivateData); + for ptr in private.arrays.iter() { + // drop the box, not the array + let _ = Box::from_raw(*ptr as *mut ManuallyDrop); + } + + e.release = None; +} + +pub fn export_series(s: &Series) -> SeriesExport { + let field = ArrowField::new(s.name(), s.dtype().to_arrow(), true); + let schema = Box::new(ffi::export_field_to_c(&field)); + let mut arrays = s + .chunks() + .iter() + .map(|arr| Box::into_raw(Box::new(ffi::export_array_to_c(arr.clone())))) + .collect::>(); + let len = arrays.len(); + let ptr = arrays.as_mut_ptr(); + SeriesExport { + field: schema.as_ref() as *const ArrowSchema as *mut ArrowSchema, + arrays: ptr, + len, + release: Some(c_release_series_export), + private_data: Box::into_raw(Box::new(PrivateData { arrays, schema })) + as *mut std::os::raw::c_void, + } +} + +/// # Safety +/// `SeriesExport` must be valid +pub unsafe fn import_series(e: SeriesExport) -> PolarsResult { + let field = ffi::import_field_from_c(&(*e.field))?; + + let pointers = std::slice::from_raw_parts_mut(e.arrays, e.len); + let chunks = pointers + .iter() + .map(|ptr| { + let arr = std::ptr::read(*ptr); + import_array(arr, &(*e.field)) + }) + .collect::>>()?; + + Ok(Series::from_chunks_and_dtype_unchecked( + &field.name, + chunks, + &(&field.data_type).into(), + )) +} + +/// # Safety +/// `SeriesExport` must be valid +pub unsafe fn import_series_buffer(e: *mut SeriesExport, len: usize) -> PolarsResult> { + let mut out = Vec::with_capacity(len); + for i in 0..len { + let e = std::ptr::read(e.add(i)); + out.push(import_series(e)?) + } + Ok(out) +} + +/// # Safety +/// `ArrowArray` and `ArrowSchema` must be valid +unsafe fn import_array( + array: ffi::ArrowArray, + schema: &ffi::ArrowSchema, +) -> PolarsResult { + let field = ffi::import_field_from_c(schema)?; + let out = ffi::import_array_from_c(array, field.data_type)?; + Ok(out) +} + +#[cfg(test)] +mod test { + use polars_core::prelude::*; + + use super::*; + + #[test] + fn test_ffi() { + let s = Series::new("a", [1, 2]); + let e = export_series(&s); + + unsafe { + assert_eq!(import_series(e).unwrap(), s); + }; + } +} diff --git a/crates/polars-io/Cargo.toml b/crates/polars-io/Cargo.toml index 3a6d3908b4da..1524641b0ed7 100644 --- a/crates/polars-io/Cargo.toml +++ b/crates/polars-io/Cargo.toml @@ -9,12 +9,12 @@ repository = { workspace = true } description = "IO related logic for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow" } -polars-core = { version = "0.32.0", path = "../polars-core", features = [], default-features = false } -polars-error = { version = "0.32.0", path = "../polars-error", default-features = false } -polars-json = { version = "0.32.0", optional = true, path = "../polars-json" } -polars-time = { version = "0.32.0", path = "../polars-time", features = [], default-features = false, optional = true } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true } +polars-core = { workspace = true } +polars-error = { workspace = true, default-features = false } +polars-json = { workspace = true, optional = true } +polars-time = { workspace = true, features = [], optional = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } @@ -25,19 +25,21 @@ chrono-tz = { workspace = true, optional = true } fast-float = { version = "0.2", optional = true } flate2 = { version = "1", optional = true, default-features = false } futures = { workspace = true, optional = true } +itoa = { workspace = true, optional = true } lexical = { version = "6", optional = true, default-features = false, features = ["std", "parse-integers"] } -lexical-core = { version = "0.8", optional = true } +lexical-core = { workspace = true, optional = true } memchr = { workspace = true } -memmap = { package = "memmap2", version = "0.7", optional = true } +memmap = { package = "memmap2", version = "0.7" } num-traits = { workspace = true } object_store = { workspace = true, optional = true } once_cell = { workspace = true } rayon = { workspace = true } regex = { workspace = true } +ryu = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"], optional = true } serde_json = { version = "1", default-features = false, features = ["alloc", "raw_value"], optional = true } simd-json = { workspace = true, optional = true } -simdutf8 = { version = "0.1", optional = true } +simdutf8 = { workspace = true, optional = true } tokio = { version = "1.26", features = ["net"], optional = true } tokio-util = { version = "0.7.8", features = ["io", "io-util"], optional = true } url = { workspace = true, optional = true } @@ -52,22 +54,21 @@ tempdir = "0.3.7" default = ["decompress"] # support for arrows json parsing json = [ - "arrow/io_json_write", "polars-json", "simd-json", - "memmap", "lexical", "lexical-core", "serde_json", "dtype-struct", + "csv", ] # support for arrows ipc file parsing -ipc = ["arrow/io_ipc", "arrow/io_ipc_compression", "memmap"] +ipc = ["arrow/io_ipc", "arrow/io_ipc_compression"] # support for arrows streaming ipc file parsing ipc_streaming = ["arrow/io_ipc", "arrow/io_ipc_compression"] # support for arrow avro parsing avro = ["arrow/io_avro", "arrow/io_avro_compression"] -csv = ["memmap", "lexical", "polars-core/rows", "lexical-core", "fast-float", "simdutf8"] +csv = ["lexical", "polars-core/rows", "itoa", "ryu", "fast-float", "simdutf8"] decompress = ["flate2/rust_backend"] decompress-fast = ["flate2/zlib-ng"] dtype-categorical = ["polars-core/dtype-categorical"] @@ -88,12 +89,12 @@ dtype-struct = ["polars-core/dtype-struct"] dtype-decimal = ["polars-core/dtype-decimal"] fmt = ["polars-core/fmt"] lazy = [] -parquet = ["polars-core/parquet", "arrow/io_parquet", "arrow/io_parquet_compression", "memmap"] +parquet = ["polars-core/parquet", "arrow/io_parquet", "arrow/io_parquet_compression"] async = ["async-trait", "futures", "tokio", "tokio-util", "arrow/io_ipc_write_async", "polars-error/regex"] -cloud = ["object_store", "async", "polars-core/async", "polars-error/object_store", "url"] -aws = ["object_store/aws", "cloud", "polars-core/aws"] -azure = ["object_store/azure", "cloud", "polars-core/azure"] -gcp = ["object_store/gcp", "cloud", "polars-core/gcp"] +cloud = ["object_store", "async", "polars-error/object_store", "url"] +aws = ["object_store/aws", "cloud"] +azure = ["object_store/azure", "cloud"] +gcp = ["object_store/gcp", "cloud"] partition = ["polars-core/partition_by"] temporal = ["dtype-datetime", "dtype-date", "dtype-time"] simd = [] diff --git a/crates/polars-io/src/cloud/adaptors.rs b/crates/polars-io/src/cloud/adaptors.rs index d39c8d8b6226..5850f33006b9 100644 --- a/crates/polars-io/src/cloud/adaptors.rs +++ b/crates/polars-io/src/cloud/adaptors.rs @@ -13,10 +13,11 @@ use futures::lock::Mutex; use futures::{AsyncRead, AsyncSeek, Future, TryFutureExt}; use object_store::path::Path; use object_store::{MultipartId, ObjectStore}; -use polars_core::cloud::CloudOptions; use polars_error::{PolarsError, PolarsResult}; use tokio::io::{AsyncWrite, AsyncWriteExt}; +use super::*; + type OptionalFuture = Arc>>>>>; /// Adaptor to translate from AsyncSeek and AsyncRead to the object_store get_range API. diff --git a/crates/polars-io/src/cloud/glob.rs b/crates/polars-io/src/cloud/glob.rs index f1798be9fa3d..191ed7e60e63 100644 --- a/crates/polars-io/src/cloud/glob.rs +++ b/crates/polars-io/src/cloud/glob.rs @@ -2,12 +2,13 @@ use futures::future::ready; use futures::{StreamExt, TryStreamExt}; use object_store::path::Path; use polars_arrow::error::polars_bail; -use polars_core::cloud::CloudOptions; use polars_core::error::to_compute_err; use polars_core::prelude::{polars_ensure, polars_err, PolarsError, PolarsResult}; use regex::Regex; use url::Url; +use super::*; + const DELIMITER: char = '/'; /// Split the url in @@ -95,7 +96,9 @@ impl CloudLocation { let key = parsed.path(); let bucket = parsed .host() - .ok_or(polars_err!(ComputeError: "cannot parse bucket (host) from url: {}", url))? + .ok_or_else( + || polars_err!(ComputeError: "cannot parse bucket (host) from url: {}", url), + )? .to_string(); (bucket, key) }; diff --git a/crates/polars-io/src/cloud/mod.rs b/crates/polars-io/src/cloud/mod.rs index 075d6a29f9d3..297ed860a579 100644 --- a/crates/polars-io/src/cloud/mod.rs +++ b/crates/polars-io/src/cloud/mod.rs @@ -1,20 +1,33 @@ //! Interface with cloud storage through the object_store crate. +#[cfg(feature = "cloud")] +use std::borrow::Cow; +#[cfg(feature = "cloud")] use std::str::FromStr; +#[cfg(feature = "cloud")] use object_store::local::LocalFileSystem; +#[cfg(feature = "cloud")] use object_store::ObjectStore; -use polars_core::cloud::{CloudOptions, CloudType}; +#[cfg(feature = "cloud")] use polars_core::prelude::{polars_bail, PolarsError, PolarsResult}; +#[cfg(feature = "cloud")] mod adaptors; +#[cfg(feature = "cloud")] mod glob; +pub mod options; +#[cfg(feature = "cloud")] pub use adaptors::*; +#[cfg(feature = "cloud")] pub use glob::*; +pub use options::*; +#[cfg(feature = "cloud")] type BuildResult = PolarsResult<(CloudLocation, Box)>; #[allow(dead_code)] +#[cfg(feature = "cloud")] fn err_missing_feature(feature: &str, scheme: &str) -> BuildResult { polars_bail!( ComputeError: @@ -28,7 +41,9 @@ fn err_missing_configuration(feature: &str, scheme: &str) -> BuildResult { "configuration '{}' must be provided in order to use '{}' cloud urls", feature, scheme, ); } + /// Build an [`ObjectStore`] based on the URL and passed in url. Return the cloud location and an implementation of the object store. +#[cfg(feature = "cloud")] pub fn build(url: &str, _options: Option<&CloudOptions>) -> BuildResult { let cloud_location = CloudLocation::new(url)?; let store = match CloudType::from_str(url)? { @@ -38,12 +53,12 @@ pub fn build(url: &str, _options: Option<&CloudOptions>) -> BuildResult { }, CloudType::Aws => { #[cfg(feature = "aws")] - match _options { - Some(options) => { - let store = options.build_aws(&cloud_location.bucket)?; - Ok::<_, PolarsError>(Box::new(store) as Box) - }, - _ => return err_missing_configuration("aws", &cloud_location.scheme), + { + let options = _options + .map(Cow::Borrowed) + .unwrap_or_else(|| Cow::Owned(Default::default())); + let store = options.build_aws(&cloud_location.bucket)?; + Ok::<_, PolarsError>(Box::new(store) as Box) } #[cfg(not(feature = "aws"))] return err_missing_feature("aws", &cloud_location.scheme); diff --git a/crates/polars-core/src/cloud.rs b/crates/polars-io/src/cloud/options.rs similarity index 88% rename from crates/polars-core/src/cloud.rs rename to crates/polars-io/src/cloud/options.rs index 5797b2ff748f..5b7de5088c38 100644 --- a/crates/polars-core/src/cloud.rs +++ b/crates/polars-io/src/cloud/options.rs @@ -14,14 +14,13 @@ use object_store::gcp::GoogleCloudStorageBuilder; pub use object_store::gcp::GoogleConfigKey; #[cfg(feature = "async")] use object_store::ObjectStore; -use polars_error::{polars_bail, polars_err}; -#[cfg(feature = "serde-lazy")] +use polars_core::error::{PolarsError, PolarsResult}; +use polars_error::*; +#[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; #[cfg(feature = "async")] use url::Url; -use crate::error::{PolarsError, PolarsResult}; - /// The type of the config keys must satisfy the following requirements: /// 1. must be easily collected into a HashMap, the type required by the object_crate API. /// 2. be Serializable, required when the serde-lazy feature is defined. @@ -32,7 +31,7 @@ use crate::error::{PolarsError, PolarsResult}; type Configs = Vec<(T, String)>; #[derive(Clone, Debug, Default, PartialEq)] -#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] /// Options to connect to various cloud providers. pub struct CloudOptions { #[cfg(feature = "aws")] @@ -110,15 +109,26 @@ impl CloudOptions { /// Build the [`ObjectStore`] implementation for AWS. #[cfg(feature = "aws")] pub fn build_aws(&self, bucket_name: &str) -> PolarsResult { - let options = self - .aws - .as_ref() - .ok_or_else(|| polars_err!(ComputeError: "`aws` configuration missing"))?; + let options = self.aws.as_ref(); + + let builder = match options { + Some(options) => { + let mut builder = AmazonS3Builder::new(); + for (key, value) in options.iter() { + builder = builder.with_config(*key, value); + } + builder + }, + None => { + let builder = AmazonS3Builder::from_env(); + polars_ensure!( + builder.get_config_value(&AmazonS3ConfigKey::AccessKeyId).is_some() && + builder.get_config_value(&AmazonS3ConfigKey::SecretAccessKey).is_some(), + ComputeError: "`aws` configuration and env vars missing"); + builder + }, + }; - let mut builder = AmazonS3Builder::new(); - for (key, value) in options.iter() { - builder = builder.with_config(*key, value); - } builder .with_bucket_name(bucket_name) .build() diff --git a/crates/polars-io/src/csv/mod.rs b/crates/polars-io/src/csv/mod.rs index 483012a42299..2d4de6ac21c5 100644 --- a/crates/polars-io/src/csv/mod.rs +++ b/crates/polars-io/src/csv/mod.rs @@ -66,8 +66,7 @@ pub use write::{BatchedWriter, CsvWriter, QuoteStyle}; pub use write_impl::SerializeOptions; use crate::csv::read_impl::CoreReader; -use crate::csv::utils::get_reader_bytes; use crate::mmap::MmapBytesReader; use crate::predicates::PhysicalIoExpr; -use crate::utils::resolve_homedir; +use crate::utils::{get_reader_bytes, resolve_homedir}; use crate::{RowCount, SerReader, SerWriter}; diff --git a/crates/polars-io/src/csv/read_impl/mod.rs b/crates/polars-io/src/csv/read_impl/mod.rs index 2b13585f80bf..c544694f973c 100644 --- a/crates/polars-io/src/csv/read_impl/mod.rs +++ b/crates/polars-io/src/csv/read_impl/mod.rs @@ -515,7 +515,7 @@ impl<'a> CoreReader<'a> { for i in projection { let (_, dtype) = self.schema.get_at_index(*i).ok_or_else(|| { polars_err!( - ComputeError: + OutOfBounds: "projection index {} is out of bounds for CSV schema with {} columns", i, self.schema.len(), ) diff --git a/crates/polars-io/src/csv/utils.rs b/crates/polars-io/src/csv/utils.rs index e9c89d1a2ab7..c0a0c5b09624 100644 --- a/crates/polars-io/src/csv/utils.rs +++ b/crates/polars-io/src/csv/utils.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +#[cfg(any(feature = "decompress", feature = "decompress-fast"))] use std::io::Read; use std::mem::MaybeUninit; @@ -16,7 +17,7 @@ use crate::csv::parser::next_line_position_naive; use crate::csv::parser::{next_line_position, skip_bom, skip_line_ending, SplitLines}; use crate::csv::splitfields::SplitFields; use crate::csv::CsvEncoding; -use crate::mmap::{MmapBytesReader, ReaderBytes}; +use crate::mmap::*; use crate::prelude::NullValues; pub(crate) fn get_file_chunks( @@ -57,37 +58,6 @@ pub(crate) fn get_file_chunks( offsets } -pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( - reader: &'a mut R, -) -> PolarsResult> { - // we have a file so we can mmap - if let Some(file) = reader.to_file() { - let mmap = unsafe { memmap::Mmap::map(file)? }; - - // somehow bck thinks borrows alias - // this is sound as file was already bound to 'a - use std::fs::File; - let file = unsafe { std::mem::transmute::<&File, &'a File>(file) }; - Ok(ReaderBytes::Mapped(mmap, file)) - } else { - // we can get the bytes for free - if reader.to_bytes().is_some() { - // duplicate .to_bytes() is necessary to satisfy the borrow checker - Ok(ReaderBytes::Borrowed((*reader).to_bytes().unwrap())) - } else { - // we have to read to an owned buffer to get the bytes. - let mut bytes = Vec::with_capacity(1024 * 128); - reader.read_to_end(&mut bytes)?; - if !bytes.is_empty() - && (bytes[bytes.len() - 1] != b'\n' || bytes[bytes.len() - 1] != b'\r') - { - bytes.push(b'\n') - } - Ok(ReaderBytes::Owned(bytes)) - } - } -} - static FLOAT_RE: Lazy = Lazy::new(|| { Regex::new(r"^\s*[-+]?((\d*\.\d+)([eE][-+]?\d+)?|inf|NaN|(\d+)[eE][-+]?\d+|\d+\.)$").unwrap() }); diff --git a/crates/polars-io/src/csv/write.rs b/crates/polars-io/src/csv/write.rs index f9c31a827805..f0db058c3855 100644 --- a/crates/polars-io/src/csv/write.rs +++ b/crates/polars-io/src/csv/write.rs @@ -15,6 +15,8 @@ pub enum QuoteStyle { Necessary, /// This puts quotes around all fields that are non-numeric. Namely, when writing a field that does not parse as a valid float or integer, then quotes will be used even if they aren’t strictly necessary. NonNumeric, + /// Never quote any fields, even if it would produce invalid CSV data. + Never, } /// Write a DataFrame to csv. diff --git a/crates/polars-io/src/csv/write_impl.rs b/crates/polars-io/src/csv/write_impl.rs index 4b02a8f39e0a..3869aa5137cf 100644 --- a/crates/polars-io/src/csv/write_impl.rs +++ b/crates/polars-io/src/csv/write_impl.rs @@ -8,7 +8,6 @@ use std::io::Write; use arrow::temporal_conversions; #[cfg(feature = "timezones")] use chrono::TimeZone; -use lexical_core::{FormattedSize, ToLexical}; use memchr::{memchr, memchr2}; use polars_arrow::time_zone::Tz; use polars_core::prelude::*; @@ -22,11 +21,12 @@ use serde::{Deserialize, Serialize}; use super::write::QuoteStyle; fn fmt_and_escape_str(f: &mut Vec, v: &str, options: &SerializeOptions) -> std::io::Result<()> { - if v.is_empty() { + if options.quote_style == QuoteStyle::Never { + write!(f, "{v}") + } else if v.is_empty() { write!(f, "\"\"") } else { let needs_escaping = memchr(options.quote, v.as_bytes()).is_some(); - if needs_escaping { let replaced = unsafe { // Replace from single quote " to double quote "". @@ -40,6 +40,7 @@ fn fmt_and_escape_str(f: &mut Vec, v: &str, options: &SerializeOptions) -> s let surround_with_quotes = match options.quote_style { QuoteStyle::Always | QuoteStyle::NonNumeric => true, QuoteStyle::Necessary => memchr2(options.delimiter, b'\n', v.as_bytes()).is_some(), + QuoteStyle::Never => false, }; let quote = options.quote as char; @@ -51,15 +52,16 @@ fn fmt_and_escape_str(f: &mut Vec, v: &str, options: &SerializeOptions) -> s } } -fn fast_float_write(f: &mut Vec, n: N, write_size: usize) -> std::io::Result<()> { - let len = f.len(); - f.reserve(write_size); - unsafe { - let buffer = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size); - let written_n = n.to_lexical(buffer).len(); - f.set_len(len + written_n); - } - Ok(()) +fn fast_float_write(f: &mut Vec, val: I) { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()) +} + +fn write_integer(f: &mut Vec, val: I) { + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()) } unsafe fn write_anyvalue( @@ -94,20 +96,50 @@ unsafe fn write_anyvalue( match value { AnyValue::Null => write!(f, "{}", &options.null), - AnyValue::Int8(v) => write!(f, "{v}"), - AnyValue::Int16(v) => write!(f, "{v}"), - AnyValue::Int32(v) => write!(f, "{v}"), - AnyValue::Int64(v) => write!(f, "{v}"), - AnyValue::UInt8(v) => write!(f, "{v}"), - AnyValue::UInt16(v) => write!(f, "{v}"), - AnyValue::UInt32(v) => write!(f, "{v}"), - AnyValue::UInt64(v) => write!(f, "{v}"), + AnyValue::Int8(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::Int16(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::Int32(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::Int64(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt8(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt16(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt32(v) => { + write_integer(f, v); + Ok(()) + }, + AnyValue::UInt64(v) => { + write_integer(f, v); + Ok(()) + }, AnyValue::Float32(v) => match &options.float_precision { - None => fast_float_write(f, v, f32::FORMATTED_SIZE_DECIMAL), + None => { + fast_float_write(f, v); + Ok(()) + }, Some(precision) => write!(f, "{v:.precision$}"), }, AnyValue::Float64(v) => match &options.float_precision { - None => fast_float_write(f, v, f64::FORMATTED_SIZE_DECIMAL), + None => { + fast_float_write(f, v); + Ok(()) + }, Some(precision) => write!(f, "{v:.precision$}"), }, _ => { diff --git a/crates/polars-io/src/ipc/mmap.rs b/crates/polars-io/src/ipc/mmap.rs index ee8e7cc9560e..d8320e0c48fa 100644 --- a/crates/polars-io/src/ipc/mmap.rs +++ b/crates/polars-io/src/ipc/mmap.rs @@ -67,6 +67,7 @@ impl ArrowReader for MMapChunkIter<'_> { } } +#[cfg(feature = "ipc")] impl IpcReader { pub(super) fn finish_memmapped( &mut self, diff --git a/crates/polars-io/src/ipc/mod.rs b/crates/polars-io/src/ipc/mod.rs index 7b74b8486935..1366aa84324f 100644 --- a/crates/polars-io/src/ipc/mod.rs +++ b/crates/polars-io/src/ipc/mod.rs @@ -6,7 +6,7 @@ mod ipc_file; #[cfg(feature = "ipc_streaming")] mod ipc_stream; mod mmap; -#[cfg(feature = "ipc")] +#[cfg(any(feature = "ipc", feature = "ipc_streaming"))] mod write; #[cfg(all(feature = "async", feature = "ipc"))] mod write_async; diff --git a/crates/polars-io/src/json/mod.rs b/crates/polars-io/src/json/mod.rs index c5c7d8fe503f..f64709a52a19 100644 --- a/crates/polars-io/src/json/mod.rs +++ b/crates/polars-io/src/json/mod.rs @@ -67,7 +67,6 @@ use std::ops::Deref; use arrow::array::StructArray; pub use arrow::error::Result as ArrowResult; -pub use arrow::io::json; use polars_arrow::conversion::chunk_to_struct; use polars_arrow::utils::CustomIterTools; use polars_core::error::to_compute_err; @@ -141,13 +140,14 @@ where match self.json_format { JsonFormat::JsonLines => { - let serializer = arrow_ndjson::write::Serializer::new(batches, vec![]); - let writer = arrow_ndjson::write::FileWriter::new(&mut self.buffer, serializer); + let serializer = polars_json::ndjson::write::Serializer::new(batches, vec![]); + let writer = + polars_json::ndjson::write::FileWriter::new(&mut self.buffer, serializer); writer.collect::>()?; }, JsonFormat::Json => { - let serializer = json::write::Serializer::new(batches, vec![]); - json::write::write(&mut self.buffer, serializer)?; + let serializer = polars_json::json::write::Serializer::new(batches, vec![]); + polars_json::json::write::write(&mut self.buffer, serializer)?; }, } diff --git a/crates/polars-io/src/lib.rs b/crates/polars-io/src/lib.rs index fb9860e7bcc5..deac059882c7 100644 --- a/crates/polars-io/src/lib.rs +++ b/crates/polars-io/src/lib.rs @@ -4,7 +4,6 @@ #[cfg(feature = "avro")] pub mod avro; -#[cfg(feature = "cloud")] pub mod cloud; #[cfg(any(feature = "csv", feature = "json"))] pub mod csv; @@ -19,12 +18,6 @@ pub mod ndjson; #[cfg(feature = "cloud")] pub use crate::cloud::glob as async_glob; -#[cfg(any( - feature = "csv", - feature = "parquet", - feature = "ipc", - feature = "json" -))] pub mod mmap; mod options; #[cfg(feature = "parquet")] @@ -33,7 +26,7 @@ pub mod predicates; pub mod prelude; #[cfg(all(test, feature = "csv"))] mod tests; -pub(crate) mod utils; +pub mod utils; #[cfg(feature = "partition")] pub mod partition; diff --git a/crates/polars-io/src/ndjson/core.rs b/crates/polars-io/src/ndjson/core.rs index 726ea8d78bd3..dda5108d2f9c 100644 --- a/crates/polars-io/src/ndjson/core.rs +++ b/crates/polars-io/src/ndjson/core.rs @@ -3,14 +3,12 @@ use std::io::Cursor; use std::path::PathBuf; pub use arrow::array::StructArray; -pub use arrow::io::ndjson as arrow_ndjson; use num_traits::pow::Pow; use polars_core::prelude::*; use polars_core::utils::accumulate_dataframes_vertical; use polars_core::POOL; use rayon::prelude::*; -use crate::csv::utils::*; use crate::mmap::{MmapBytesReader, ReaderBytes}; use crate::ndjson::buffer::*; use crate::prelude::*; diff --git a/crates/polars-io/src/parquet/async_impl.rs b/crates/polars-io/src/parquet/async_impl.rs index 961776f8dfb0..0b45f3a8821d 100644 --- a/crates/polars-io/src/parquet/async_impl.rs +++ b/crates/polars-io/src/parquet/async_impl.rs @@ -11,7 +11,6 @@ use futures::lock::Mutex; use futures::{stream, StreamExt, TryFutureExt, TryStreamExt}; use object_store::path::Path as ObjectPath; use object_store::ObjectStore; -use polars_core::cloud::CloudOptions; use polars_core::config::verbose; use polars_core::datatypes::PlHashMap; use polars_core::error::{to_compute_err, PolarsResult}; @@ -22,6 +21,7 @@ use super::cloud::{build, CloudLocation, CloudReader}; use super::mmap; use super::mmap::ColumnStore; use super::read_impl::FetchRowGroups; +use crate::cloud::CloudOptions; pub struct ParquetObjectStore { store: Arc>>, diff --git a/crates/polars-io/src/parquet/predicates.rs b/crates/polars-io/src/parquet/predicates.rs index 9454ac431f3b..1dfc6b231b5a 100644 --- a/crates/polars-io/src/parquet/predicates.rs +++ b/crates/polars-io/src/parquet/predicates.rs @@ -26,9 +26,14 @@ impl ColumnStats { _ => { // the array holds the null count for every row group // so we sum them to get them of the whole file. - Series::try_from(("", self.0.null_count.clone())) - .unwrap() - .sum() + let s = Series::try_from(("", self.0.null_count.clone())).unwrap(); + + // if all null, there are no statistics. + if s.null_count() != s.len() { + s.sum() + } else { + None + } }, } } diff --git a/crates/polars-io/src/parquet/read.rs b/crates/polars-io/src/parquet/read.rs index c21a2892c8ba..496d61a6a118 100644 --- a/crates/polars-io/src/parquet/read.rs +++ b/crates/polars-io/src/parquet/read.rs @@ -3,13 +3,15 @@ use std::sync::Arc; use arrow::io::parquet::read; use arrow::io::parquet::write::FileMetaData; -#[cfg(feature = "cloud")] -use polars_core::cloud::CloudOptions; use polars_core::prelude::*; +#[cfg(feature = "cloud")] +use polars_core::utils::concat_df; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use super::read_impl::FetchRowGroupsFromMmapReader; +#[cfg(feature = "cloud")] +use crate::cloud::CloudOptions; use crate::mmap::MmapBytesReader; #[cfg(feature = "cloud")] use crate::parquet::async_impl::FetchRowGroupsFromObjectStore; @@ -17,6 +19,8 @@ use crate::parquet::async_impl::FetchRowGroupsFromObjectStore; use crate::parquet::async_impl::ParquetObjectStore; use crate::parquet::read_impl::read_parquet; pub use crate::parquet::read_impl::BatchedParquetReader; +#[cfg(feature = "cloud")] +use crate::predicates::apply_predicate; use crate::predicates::PhysicalIoExpr; use crate::prelude::*; use crate::RowCount; @@ -221,11 +225,10 @@ impl SerReader for ParquetReader { #[cfg(feature = "cloud")] pub struct ParquetAsyncReader { reader: ParquetObjectStore, - rechunk: bool, n_rows: Option, + rechunk: bool, projection: Option>, row_count: Option, - low_memory: bool, use_statistics: bool, } @@ -241,7 +244,6 @@ impl ParquetAsyncReader { n_rows: None, projection: None, row_count: None, - low_memory: false, use_statistics: true, }) } @@ -280,11 +282,6 @@ impl ParquetAsyncReader { self } - pub fn set_low_memory(mut self, low_memory: bool) -> Self { - self.low_memory = low_memory; - self - } - pub fn with_projection(mut self, projection: Option>) -> Self { self.projection = projection; self @@ -300,6 +297,7 @@ impl ParquetAsyncReader { #[tokio::main(flavor = "current_thread")] pub async fn batched(mut self, chunk_size: usize) -> PolarsResult { let metadata = self.reader.get_metadata().await?.to_owned(); + // row group fetched deals with projection let row_group_fetcher = Box::new(FetchRowGroupsFromObjectStore::new( self.reader, &metadata, @@ -315,4 +313,26 @@ impl ParquetAsyncReader { self.use_statistics, ) } + + pub fn finish(self, predicate: Option>) -> PolarsResult { + let rechunk = self.rechunk; + + // batched reader deals with slice pushdown + let reader = self.batched(usize::MAX)?; + let chunks = reader + .iter(16) // todo! tune this parameter + .map(|df_result| { + df_result.and_then(|mut df| { + apply_predicate(&mut df, predicate.as_deref(), true)?; + Ok(df) + }) + }) + .collect::>>()?; + let mut df = concat_df(&chunks)?; + + if rechunk { + df.as_single_chunk_par(); + } + Ok(df) + } } diff --git a/crates/polars-io/src/parquet/read_impl.rs b/crates/polars-io/src/parquet/read_impl.rs index e9cf8d71d47d..aadbf06e6f84 100644 --- a/crates/polars-io/src/parquet/read_impl.rs +++ b/crates/polars-io/src/parquet/read_impl.rs @@ -18,8 +18,7 @@ use crate::parquet::mmap::mmap_columns; use crate::parquet::predicates::read_this_row_group; use crate::parquet::{mmap, ParallelStrategy}; use crate::predicates::{apply_predicate, arrow_schema_to_empty_df, PhysicalIoExpr}; -use crate::prelude::utils::get_reader_bytes; -use crate::utils::apply_projection; +use crate::utils::{apply_projection, get_reader_bytes}; use crate::RowCount; fn column_idx_to_series( @@ -511,9 +510,9 @@ impl BatchedParquetReader { } /// Turn the batched reader into an iterator. - pub fn iter(self, batch_size: usize) -> BatchedParquetIter { + pub fn iter(self, batches_per_iter: usize) -> BatchedParquetIter { BatchedParquetIter { - batch_size, + batches_per_iter, inner: self, current_batch: vec![].into_iter(), } @@ -521,7 +520,7 @@ impl BatchedParquetReader { } pub struct BatchedParquetIter { - batch_size: usize, + batches_per_iter: usize, inner: BatchedParquetReader, current_batch: std::vec::IntoIter, } @@ -532,7 +531,7 @@ impl Iterator for BatchedParquetIter { fn next(&mut self) -> Option { match self.current_batch.next() { Some(df) => Some(Ok(df)), - None => match self.inner.next_batches(self.batch_size) { + None => match self.inner.next_batches(self.batches_per_iter) { Err(e) => Some(Err(e)), Ok(opt_batch) => { let batch = opt_batch?; diff --git a/crates/polars-io/src/prelude.rs b/crates/polars-io/src/prelude.rs index f62a5e357187..2d1362c6970f 100644 --- a/crates/polars-io/src/prelude.rs +++ b/crates/polars-io/src/prelude.rs @@ -12,7 +12,8 @@ pub use crate::ndjson::core::*; #[cfg(feature = "parquet")] pub use crate::parquet::*; pub use crate::utils::*; -pub use crate::{SerReader, SerWriter}; +pub use crate::{cloud, SerReader, SerWriter}; + #[cfg(test)] pub(crate) fn create_df() -> DataFrame { let s0 = Series::new("days", [0, 1, 2, 3, 4].as_ref()); diff --git a/crates/polars-io/src/utils.rs b/crates/polars-io/src/utils.rs index 64ce8f30a79a..287420a1cf71 100644 --- a/crates/polars-io/src/utils.rs +++ b/crates/polars-io/src/utils.rs @@ -1,8 +1,10 @@ +use std::io::Read; use std::path::{Path, PathBuf}; use polars_core::frame::DataFrame; use polars_core::prelude::*; +use crate::mmap::{MmapBytesReader, ReaderBytes}; #[cfg(any( feature = "ipc", feature = "ipc_streaming", @@ -11,6 +13,37 @@ use polars_core::prelude::*; ))] use crate::ArrowSchema; +pub fn get_reader_bytes<'a, R: Read + MmapBytesReader + ?Sized>( + reader: &'a mut R, +) -> PolarsResult> { + // we have a file so we can mmap + if let Some(file) = reader.to_file() { + let mmap = unsafe { memmap::Mmap::map(file)? }; + + // somehow bck thinks borrows alias + // this is sound as file was already bound to 'a + use std::fs::File; + let file = unsafe { std::mem::transmute::<&File, &'a File>(file) }; + Ok(ReaderBytes::Mapped(mmap, file)) + } else { + // we can get the bytes for free + if reader.to_bytes().is_some() { + // duplicate .to_bytes() is necessary to satisfy the borrow checker + Ok(ReaderBytes::Borrowed((*reader).to_bytes().unwrap())) + } else { + // we have to read to an owned buffer to get the bytes. + let mut bytes = Vec::with_capacity(1024 * 128); + reader.read_to_end(&mut bytes)?; + if !bytes.is_empty() + && (bytes[bytes.len() - 1] != b'\n' || bytes[bytes.len() - 1] != b'\r') + { + bytes.push(b'\n') + } + Ok(ReaderBytes::Owned(bytes)) + } + } +} + // used by python polars pub fn resolve_homedir(path: &Path) -> PathBuf { // replace "~" with home directory diff --git a/crates/polars-json/Cargo.toml b/crates/polars-json/Cargo.toml index 8a9d0d53fb9a..7c0e69bd7fab 100644 --- a/crates/polars-json/Cargo.toml +++ b/crates/polars-json/Cargo.toml @@ -9,14 +9,18 @@ repository = { workspace = true } description = "JSON related logic for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", default-features = false } -polars-error = { version = "0.32.0", path = "../polars-error" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true } +polars-error = { workspace = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } +chrono = { workpace = true } fallible-streaming-iterator = { version = "0.1" } hashbrown = { workspace = true } indexmap = { workspace = true } +itoa = { workspace = true } num-traits = { workspace = true } +ryu = { workspace = true } simd-json = { workspace = true } +streaming-iterator = { workspace = true } diff --git a/crates/polars-json/src/json/deserialize.rs b/crates/polars-json/src/json/deserialize.rs index 3baa2ea0b61f..a27a14184b8e 100644 --- a/crates/polars-json/src/json/deserialize.rs +++ b/crates/polars-json/src/json/deserialize.rs @@ -88,6 +88,17 @@ fn deserialize_list<'a, A: Borrow>>( .try_push_usize(value.len()) .expect("List offset is too large :/"); }, + BorrowedValue::Static(StaticNode::Null) => { + validity.push(false); + offsets.extend_constant(1) + }, + value @ (BorrowedValue::Static(_) | BorrowedValue::String(_)) => { + inner.push(value); + validity.push(true); + offsets + .try_push_usize(1) + .expect("List offset is too large :/"); + }, _ => { validity.push(false); offsets.extend_constant(1); diff --git a/crates/polars-json/src/json/mod.rs b/crates/polars-json/src/json/mod.rs index 1ab9c2dd15ce..d39d7513c431 100644 --- a/crates/polars-json/src/json/mod.rs +++ b/crates/polars-json/src/json/mod.rs @@ -5,3 +5,4 @@ pub use deserialize::deserialize; pub use infer_schema::{infer, infer_records_schema}; use polars_error::*; use polars_utils::aliases::*; +pub mod write; diff --git a/crates/polars-json/src/json/write/mod.rs b/crates/polars-json/src/json/write/mod.rs new file mode 100644 index 000000000000..343bae73e520 --- /dev/null +++ b/crates/polars-json/src/json/write/mod.rs @@ -0,0 +1,157 @@ +//! APIs to write to JSON +mod serialize; +mod utf8; + +use std::io::Write; + +use arrow::array::Array; +use arrow::chunk::Chunk; +use arrow::datatypes::Schema; +use arrow::error::Error; +use arrow::io::iterator::StreamingIterator; +pub use fallible_streaming_iterator::*; +pub(crate) use serialize::new_serializer; +use serialize::serialize; + +/// [`FallibleStreamingIterator`] that serializes an [`Array`] to bytes of valid JSON +/// # Implementation +/// Advancing this iterator CPU-bounded +#[derive(Debug, Clone)] +pub struct Serializer +where + A: AsRef, + I: Iterator>, +{ + arrays: I, + buffer: Vec, +} + +impl Serializer +where + A: AsRef, + I: Iterator>, +{ + /// Creates a new [`Serializer`]. + pub fn new(arrays: I, buffer: Vec) -> Self { + Self { arrays, buffer } + } +} + +impl FallibleStreamingIterator for Serializer +where + A: AsRef, + I: Iterator>, +{ + type Item = [u8]; + + type Error = Error; + + fn advance(&mut self) -> Result<(), Error> { + self.buffer.clear(); + self.arrays + .next() + .map(|maybe_array| maybe_array.map(|array| serialize(array.as_ref(), &mut self.buffer))) + .transpose()?; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// [`FallibleStreamingIterator`] that serializes a [`Chunk`] into bytes of JSON +/// in a (pandas-compatible) record-oriented format. +/// +/// # Implementation +/// Advancing this iterator is CPU-bounded. +pub struct RecordSerializer<'a> { + schema: Schema, + index: usize, + end: usize, + iterators: Vec + Send + Sync + 'a>>, + buffer: Vec, +} + +impl<'a> RecordSerializer<'a> { + /// Creates a new [`RecordSerializer`]. + pub fn new(schema: Schema, chunk: &'a Chunk, buffer: Vec) -> Self + where + A: AsRef, + { + let end = chunk.len(); + let iterators = chunk + .arrays() + .iter() + .map(|arr| new_serializer(arr.as_ref(), 0, usize::MAX)) + .collect(); + + Self { + schema, + index: 0, + end, + iterators, + buffer, + } + } +} + +impl<'a> FallibleStreamingIterator for RecordSerializer<'a> { + type Item = [u8]; + + type Error = Error; + + fn advance(&mut self) -> Result<(), Error> { + self.buffer.clear(); + if self.index == self.end { + return Ok(()); + } + + let mut is_first_row = true; + write!(&mut self.buffer, "{{")?; + for (f, ref mut it) in self.schema.fields.iter().zip(self.iterators.iter_mut()) { + if !is_first_row { + write!(&mut self.buffer, ",")?; + } + write!(&mut self.buffer, "\"{}\":", f.name)?; + + self.buffer.extend_from_slice(it.next().unwrap()); + is_first_row = false; + } + write!(&mut self.buffer, "}}")?; + + self.index += 1; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// Writes valid JSON from an iterator of (assumed JSON-encoded) bytes to `writer` +pub fn write(writer: &mut W, mut blocks: I) -> Result<(), Error> +where + W: std::io::Write, + I: FallibleStreamingIterator, +{ + writer.write_all(&[b'['])?; + let mut is_first_row = true; + while let Some(block) = blocks.next()? { + if !is_first_row { + writer.write_all(&[b','])?; + } + is_first_row = false; + writer.write_all(block)?; + } + writer.write_all(&[b']'])?; + Ok(()) +} diff --git a/crates/polars-json/src/json/write/serialize.rs b/crates/polars-json/src/json/write/serialize.rs new file mode 100644 index 000000000000..7622d006e761 --- /dev/null +++ b/crates/polars-json/src/json/write/serialize.rs @@ -0,0 +1,522 @@ +use std::io::Write; + +use arrow::array::*; +use arrow::bitmap::utils::ZipValidity; +use arrow::datatypes::{DataType, IntegerType, TimeUnit}; +use arrow::io::iterator::BufStreamingIterator; +use arrow::offset::Offset; +#[cfg(feature = "chrono-tz")] +use arrow::temporal_conversions::parse_offset_tz; +use arrow::temporal_conversions::{ + date32_to_date, date64_to_date, duration_ms_to_duration, duration_ns_to_duration, + duration_s_to_duration, duration_us_to_duration, parse_offset, timestamp_ms_to_datetime, + timestamp_ns_to_datetime, timestamp_s_to_datetime, timestamp_to_datetime, + timestamp_us_to_datetime, +}; +use arrow::types::NativeType; +use chrono::{Duration, NaiveDate, NaiveDateTime}; +use streaming_iterator::StreamingIterator; + +use super::utf8; + +fn write_integer(buf: &mut Vec, val: I) { + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(val); + buf.extend_from_slice(value.as_bytes()) +} + +fn write_float(f: &mut Vec, val: I) { + let mut buffer = ryu::Buffer::new(); + let value = buffer.format(val); + f.extend_from_slice(value.as_bytes()) +} + +fn materialize_serializer<'a, I, F, T>( + f: F, + iterator: I, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: 'a, + I: Iterator + Send + Sync + 'a, + F: FnMut(T, &mut Vec) + Send + Sync + 'a, +{ + if offset > 0 || take < usize::MAX { + Box::new(BufStreamingIterator::new( + iterator.skip(offset).take(take), + f, + vec![], + )) + } else { + Box::new(BufStreamingIterator::new(iterator, f, vec![])) + } +} + +fn boolean_serializer<'a>( + array: &'a BooleanArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option, buf: &mut Vec| match x { + Some(true) => buf.extend_from_slice(b"true"), + Some(false) => buf.extend_from_slice(b"false"), + None => buf.extend_from_slice(b"null"), + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn null_serializer( + len: usize, + offset: usize, + take: usize, +) -> Box + Send + Sync> { + let f = |_x: (), buf: &mut Vec| buf.extend_from_slice(b"null"); + materialize_serializer(f, std::iter::repeat(()).take(len), offset, take) +} + +fn primitive_serializer<'a, T: NativeType + itoa::Integer>( + array: &'a PrimitiveArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + write_integer(buf, *x) + } else { + buf.extend(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn float_serializer<'a, T>( + array: &'a PrimitiveArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: num_traits::Float + NativeType + ryu::Float, +{ + let f = |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + if T::is_nan(*x) || T::is_infinite(*x) { + buf.extend(b"null") + } else { + write_float(buf, *x) + } + } else { + buf.extend(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn dictionary_utf8_serializer<'a, K: DictionaryKey, O: Offset>( + array: &'a DictionaryArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let iter = array.iter_typed::>().unwrap().skip(offset); + let f = |x: Option<&str>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, x).unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, iter, offset, take) +} + +fn utf8_serializer<'a, O: Offset>( + array: &'a Utf8Array, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let f = |x: Option<&str>, buf: &mut Vec| { + if let Some(x) = x { + utf8::write_str(buf, x).unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn struct_serializer<'a>( + array: &'a StructArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + // {"a": [1, 2, 3], "b": [a, b, c], "c": {"a": [1, 2, 3]}} + // [ + // {"a": 1, "b": a, "c": {"a": 1}}, + // {"a": 2, "b": b, "c": {"a": 2}}, + // {"a": 3, "b": c, "c": {"a": 3}}, + // ] + // + let mut serializers = array + .values() + .iter() + .map(|x| x.as_ref()) + .map(|arr| new_serializer(arr, offset, take)) + .collect::>(); + let names = array.fields().iter().map(|f| f.name.as_str()); + + Box::new(BufStreamingIterator::new( + ZipValidity::new_with_validity(0..array.len(), array.validity()), + move |maybe, buf| { + if maybe.is_some() { + let names = names.clone(); + let mut record: Vec<(&str, &[u8])> = Default::default(); + serializers + .iter_mut() + .zip(names) + // `unwrap` is infalible because `array.len()` equals `len` on `Chunk` + .for_each(|(iter, name)| { + let item = iter.next().unwrap(); + record.push((name, item)); + }); + serialize_item(buf, &record, true); + } else { + serializers.iter_mut().for_each(|iter| { + let _ = iter.next(); + }); + buf.extend(b"null"); + } + }, + vec![], + )) +} + +fn list_serializer<'a, O: Offset>( + array: &'a ListArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + // [[1, 2], [3]] + // [ + // [1, 2], + // [3] + // ] + // + let offsets = array.offsets().as_slice(); + let start = offsets[0].to_usize(); + let end = offsets.last().unwrap().to_usize(); + let mut serializer = new_serializer(array.values().as_ref(), start, end - start); + + let f = move |offset: Option<&[O]>, buf: &mut Vec| { + if let Some(offset) = offset { + let length = (offset[1] - offset[0]).to_usize(); + buf.push(b'['); + let mut is_first_row = true; + for _ in 0..length { + if !is_first_row { + buf.push(b','); + } + is_first_row = false; + buf.extend(serializer.next().unwrap()); + } + buf.push(b']'); + } else { + buf.extend(b"null"); + } + }; + + let iter = + ZipValidity::new_with_validity(array.offsets().buffer().windows(2), array.validity()); + materialize_serializer(f, iter, offset, take) +} + +fn fixed_size_list_serializer<'a>( + array: &'a FixedSizeListArray, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + let mut serializer = new_serializer(array.values().as_ref(), offset, take); + + Box::new(BufStreamingIterator::new( + ZipValidity::new(0..array.len(), array.validity().map(|x| x.iter())), + move |ix, buf| { + if ix.is_some() { + let length = array.size(); + buf.push(b'['); + let mut is_first_row = true; + for _ in 0..length { + if !is_first_row { + buf.push(b','); + } + is_first_row = false; + buf.extend(serializer.next().unwrap()); + } + buf.push(b']'); + } else { + buf.extend(b"null"); + } + }, + vec![], + )) +} + +fn date_serializer<'a, T, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: NativeType, + F: Fn(T) -> NaiveDate + 'static + Send + Sync, +{ + let f = move |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + let nd = convert(*x); + write!(buf, "\"{nd}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn duration_serializer<'a, T, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + T: NativeType, + F: Fn(T) -> Duration + 'static + Send + Sync, +{ + let f = move |x: Option<&T>, buf: &mut Vec| { + if let Some(x) = x { + let duration = convert(*x); + write!(buf, "\"{duration}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) +} + +fn timestamp_serializer<'a, F>( + array: &'a PrimitiveArray, + convert: F, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> +where + F: Fn(i64) -> NaiveDateTime + 'static + Send + Sync, +{ + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let ndt = convert(*x); + write!(buf, "\"{ndt}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + materialize_serializer(f, array.iter(), offset, take) +} + +fn timestamp_tz_serializer<'a>( + array: &'a PrimitiveArray, + time_unit: TimeUnit, + tz: &str, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + match parse_offset(tz) { + Ok(parsed_tz) => { + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let dt_str = timestamp_to_datetime(*x, time_unit, &parsed_tz).to_rfc3339(); + write!(buf, "\"{dt_str}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) + }, + #[cfg(feature = "chrono-tz")] + _ => match parse_offset_tz(tz) { + Ok(parsed_tz) => { + let f = move |x: Option<&i64>, buf: &mut Vec| { + if let Some(x) = x { + let dt_str = timestamp_to_datetime(*x, time_unit, &parsed_tz).to_rfc3339(); + write!(buf, "\"{dt_str}\"").unwrap(); + } else { + buf.extend_from_slice(b"null") + } + }; + + materialize_serializer(f, array.iter(), offset, take) + }, + _ => { + panic!("Timezone {} is invalid or not supported", tz); + }, + }, + #[cfg(not(feature = "chrono-tz"))] + _ => { + panic!("Invalid Offset format (must be [-]00:00) or chrono-tz feature not active"); + }, + } +} + +pub(crate) fn new_serializer<'a>( + array: &'a dyn Array, + offset: usize, + take: usize, +) -> Box + 'a + Send + Sync> { + match array.data_type().to_logical_type() { + DataType::Boolean => { + boolean_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int8 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int16 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int32 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Int64 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt8 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt16 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt32 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::UInt64 => { + primitive_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Float32 => { + float_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Float64 => { + float_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Utf8 => { + utf8_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::LargeUtf8 => { + utf8_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::Struct(_) => { + struct_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::FixedSizeList(_, _) => { + fixed_size_list_serializer(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::List(_) => { + list_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + DataType::LargeList(_) => { + list_serializer::(array.as_any().downcast_ref().unwrap(), offset, take) + }, + other @ DataType::Dictionary(k, v, _) => match (k, &**v) { + (IntegerType::UInt32, DataType::LargeUtf8) => { + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + dictionary_utf8_serializer::(array, offset, take) + }, + _ => { + todo!("Writing {:?} to JSON", other) + }, + }, + DataType::Date32 => date_serializer( + array.as_any().downcast_ref().unwrap(), + date32_to_date, + offset, + take, + ), + DataType::Date64 => date_serializer( + array.as_any().downcast_ref().unwrap(), + date64_to_date, + offset, + take, + ), + DataType::Timestamp(tu, None) => { + let convert = match tu { + TimeUnit::Nanosecond => timestamp_ns_to_datetime, + TimeUnit::Microsecond => timestamp_us_to_datetime, + TimeUnit::Millisecond => timestamp_ms_to_datetime, + TimeUnit::Second => timestamp_s_to_datetime, + }; + timestamp_serializer( + array.as_any().downcast_ref().unwrap(), + convert, + offset, + take, + ) + }, + DataType::Timestamp(time_unit, Some(tz)) => timestamp_tz_serializer( + array.as_any().downcast_ref().unwrap(), + *time_unit, + tz, + offset, + take, + ), + DataType::Duration(tu) => { + let convert = match tu { + TimeUnit::Nanosecond => duration_ns_to_duration, + TimeUnit::Microsecond => duration_us_to_duration, + TimeUnit::Millisecond => duration_ms_to_duration, + TimeUnit::Second => duration_s_to_duration, + }; + duration_serializer( + array.as_any().downcast_ref().unwrap(), + convert, + offset, + take, + ) + }, + DataType::Null => null_serializer(array.len(), offset, take), + other => todo!("Writing {:?} to JSON", other), + } +} + +fn serialize_item(buffer: &mut Vec, record: &[(&str, &[u8])], is_first_row: bool) { + if !is_first_row { + buffer.push(b','); + } + buffer.push(b'{'); + let mut first_item = true; + for (key, value) in record { + if !first_item { + buffer.push(b','); + } + first_item = false; + utf8::write_str(buffer, key).unwrap(); + buffer.push(b':'); + buffer.extend(*value); + } + buffer.push(b'}'); +} + +/// Serializes `array` to a valid JSON to `buffer` +/// # Implementation +/// This operation is CPU-bounded +pub(crate) fn serialize(array: &dyn Array, buffer: &mut Vec) { + let mut serializer = new_serializer(array, 0, usize::MAX); + + (0..array.len()).for_each(|i| { + if i != 0 { + buffer.push(b','); + } + buffer.extend_from_slice(serializer.next().unwrap()); + }); +} diff --git a/crates/polars-json/src/json/write/utf8.rs b/crates/polars-json/src/json/write/utf8.rs new file mode 100644 index 000000000000..941d73379c3d --- /dev/null +++ b/crates/polars-json/src/json/write/utf8.rs @@ -0,0 +1,138 @@ +// Adapted from https://github.com/serde-rs/json/blob/f901012df66811354cb1d490ad59480d8fdf77b5/src/ser.rs +use std::io; + +pub fn write_str(writer: &mut W, value: &str) -> io::Result<()> +where + W: io::Write, +{ + writer.write_all(b"\"")?; + let bytes = value.as_bytes(); + + let mut start = 0; + + for (i, &byte) in bytes.iter().enumerate() { + let escape = ESCAPE[byte as usize]; + if escape == 0 { + continue; + } + + if start < i { + writer.write_all(&bytes[start..i])?; + } + + let char_escape = CharEscape::from_escape_table(escape, byte); + write_char_escape(writer, char_escape)?; + + start = i + 1; + } + + if start != bytes.len() { + writer.write_all(&bytes[start..])?; + } + writer.write_all(b"\"") +} + +const BB: u8 = b'b'; // \x08 +const TT: u8 = b't'; // \x09 +const NN: u8 = b'n'; // \x0A +const FF: u8 = b'f'; // \x0C +const RR: u8 = b'r'; // \x0D +const QU: u8 = b'"'; // \x22 +const BS: u8 = b'\\'; // \x5C +const UU: u8 = b'u'; // \x00...\x1F except the ones above +const __: u8 = 0; + +// Lookup table of escape sequences. A value of b'x' at index i means that byte +// i is escaped as "\x" in JSON. A value of 0 means that byte i is not escaped. +static ESCAPE: [u8; 256] = [ + // 1 2 3 4 5 6 7 8 9 A B C D E F + UU, UU, UU, UU, UU, UU, UU, UU, BB, TT, NN, UU, FF, RR, UU, UU, // 0 + UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, // 1 + __, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4 + __, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9 + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E + __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F +]; + +/// Represents a character escape code in a type-safe manner. +pub enum CharEscape { + /// An escaped quote `"` + Quote, + /// An escaped reverse solidus `\` + ReverseSolidus, + // An escaped solidus `/` + //Solidus, + /// An escaped backspace character (usually escaped as `\b`) + Backspace, + /// An escaped form feed character (usually escaped as `\f`) + FormFeed, + /// An escaped line feed character (usually escaped as `\n`) + LineFeed, + /// An escaped carriage return character (usually escaped as `\r`) + CarriageReturn, + /// An escaped tab character (usually escaped as `\t`) + Tab, + /// An escaped ASCII plane control character (usually escaped as + /// `\u00XX` where `XX` are two hex characters) + AsciiControl(u8), +} + +impl CharEscape { + #[inline] + fn from_escape_table(escape: u8, byte: u8) -> CharEscape { + match escape { + self::BB => CharEscape::Backspace, + self::TT => CharEscape::Tab, + self::NN => CharEscape::LineFeed, + self::FF => CharEscape::FormFeed, + self::RR => CharEscape::CarriageReturn, + self::QU => CharEscape::Quote, + self::BS => CharEscape::ReverseSolidus, + self::UU => CharEscape::AsciiControl(byte), + _ => unreachable!(), + } + } +} + +#[inline] +fn write_char_escape(writer: &mut W, char_escape: CharEscape) -> io::Result<()> +where + W: io::Write, +{ + use self::CharEscape::*; + + let s = match char_escape { + Quote => b"\\\"", + ReverseSolidus => b"\\\\", + //Solidus => b"\\/", + Backspace => b"\\b", + FormFeed => b"\\f", + LineFeed => b"\\n", + CarriageReturn => b"\\r", + Tab => b"\\t", + AsciiControl(byte) => { + static HEX_DIGITS: [u8; 16] = *b"0123456789abcdef"; + let bytes = &[ + b'\\', + b'u', + b'0', + b'0', + HEX_DIGITS[(byte >> 4) as usize], + HEX_DIGITS[(byte & 0xF) as usize], + ]; + return writer.write_all(bytes); + }, + }; + + writer.write_all(s) +} diff --git a/crates/polars-json/src/ndjson/mod.rs b/crates/polars-json/src/ndjson/mod.rs index 429b1096b1ae..2076715e711f 100644 --- a/crates/polars-json/src/ndjson/mod.rs +++ b/crates/polars-json/src/ndjson/mod.rs @@ -3,5 +3,6 @@ use polars_arrow::prelude::*; use polars_error::*; pub mod deserialize; mod file; +pub mod write; pub use file::{infer, infer_iter}; diff --git a/crates/polars-json/src/ndjson/write.rs b/crates/polars-json/src/ndjson/write.rs new file mode 100644 index 000000000000..5cbda120711f --- /dev/null +++ b/crates/polars-json/src/ndjson/write.rs @@ -0,0 +1,118 @@ +//! APIs to serialize and write to [NDJSON](http://ndjson.org/). +use std::io::Write; + +use arrow::array::Array; +use arrow::error::Error; +pub use fallible_streaming_iterator::FallibleStreamingIterator; + +use super::super::json::write::new_serializer; + +fn serialize(array: &dyn Array, buffer: &mut Vec) { + let mut serializer = new_serializer(array, 0, usize::MAX); + (0..array.len()).for_each(|_| { + buffer.extend_from_slice(serializer.next().unwrap()); + buffer.push(b'\n'); + }); +} + +/// [`FallibleStreamingIterator`] that serializes an [`Array`] to bytes of valid NDJSON +/// where every line is an element of the array. +/// # Implementation +/// Advancing this iterator CPU-bounded +#[derive(Debug, Clone)] +pub struct Serializer +where + A: AsRef, + I: Iterator>, +{ + arrays: I, + buffer: Vec, +} + +impl Serializer +where + A: AsRef, + I: Iterator>, +{ + /// Creates a new [`Serializer`]. + pub fn new(arrays: I, buffer: Vec) -> Self { + Self { arrays, buffer } + } +} + +impl FallibleStreamingIterator for Serializer +where + A: AsRef, + I: Iterator>, +{ + type Item = [u8]; + + type Error = Error; + + fn advance(&mut self) -> Result<(), Error> { + self.buffer.clear(); + self.arrays + .next() + .map(|maybe_array| maybe_array.map(|array| serialize(array.as_ref(), &mut self.buffer))) + .transpose()?; + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if !self.buffer.is_empty() { + Some(&self.buffer) + } else { + None + } + } +} + +/// An iterator adapter that receives an implementer of [`Write`] and +/// an implementer of [`FallibleStreamingIterator`] (such as [`Serializer`]) +/// and writes a valid NDJSON +/// # Implementation +/// Advancing this iterator mixes CPU-bounded (serializing arrays) tasks and IO-bounded (write to the writer). +pub struct FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + writer: W, + iterator: I, +} + +impl FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + /// Creates a new [`FileWriter`]. + pub fn new(writer: W, iterator: I) -> Self { + Self { writer, iterator } + } + + /// Returns the inner content of this iterator + /// + /// There are two use-cases for this function: + /// * to continue writing to its writer + /// * to re-use an internal buffer of its iterator + pub fn into_inner(self) -> (W, I) { + (self.writer, self.iterator) + } +} + +impl Iterator for FileWriter +where + W: Write, + I: FallibleStreamingIterator, +{ + type Item = Result<(), Error>; + + fn next(&mut self) -> Option { + let item = self.iterator.next().transpose()?; + Some(item.and_then(|x| { + self.writer.write_all(x)?; + Ok(()) + })) + } +} diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index c78c49901acd..b14604a7338d 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -9,15 +9,15 @@ repository = { workspace = true } description = "Lazy query engine for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow" } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["lazy", "zip_with", "random"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", features = ["lazy", "csv"], default-features = false } -polars-json = { version = "0.32.0", path = "../polars-json", optional = true } -polars-ops = { version = "0.32.0", path = "../polars-ops", default-features = false } -polars-pipe = { version = "0.32.0", path = "../polars-pipe", optional = true } -polars-plan = { version = "0.32.0", path = "../polars-plan" } -polars-time = { version = "0.32.0", path = "../polars-time", optional = true } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } +polars-io = { workspace = true, features = ["lazy"] } +polars-json = { workspace = true, optional = true } +polars-ops = { workspace = true } +polars-pipe = { workspace = true, optional = true } +polars-plan = { workspace = true } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } ahash = { workspace = true } bitflags = { workspace = true } @@ -34,38 +34,35 @@ serde_json = { workspace = true } version_check = { workspace = true } [features] -nightly = ["polars-core/nightly", "polars-pipe/nightly", "polars-plan/nightly"] -compile = ["polars-plan/compile"] -streaming = ["chunked_ids", "polars-pipe/compile", "polars-plan/streaming"] -default = ["compile"] -parquet = ["polars-core/parquet", "polars-io/parquet", "polars-plan/parquet", "polars-pipe/parquet"] +nightly = ["polars-core/nightly", "polars-pipe?/nightly", "polars-plan/nightly"] +streaming = ["chunked_ids", "polars-pipe", "polars-plan/streaming"] +parquet = ["polars-core/parquet", "polars-io/parquet", "polars-plan/parquet", "polars-pipe?/parquet"] async = [ "polars-plan/async", "polars-io/cloud", - "polars-pipe/async", - "streaming", + "polars-pipe?/async", ] -cloud = ["async", "polars-pipe/cloud"] +cloud = ["async", "polars-pipe?/cloud", "polars-plan/cloud"] cloud_write = ["cloud"] -ipc = ["polars-io/ipc", "polars-plan/ipc", "polars-pipe/ipc"] +ipc = ["polars-io/ipc", "polars-plan/ipc", "polars-pipe?/ipc"] json = ["polars-io/json", "polars-plan/json", "polars-json"] -csv = ["polars-io/csv", "polars-plan/csv", "polars-pipe/csv"] +csv = ["polars-io/csv", "polars-plan/csv", "polars-pipe?/csv"] temporal = ["dtype-datetime", "dtype-date", "dtype-time", "dtype-duration", "polars-plan/temporal"] # debugging purposes fmt = ["polars-core/fmt", "polars-plan/fmt"] strings = ["polars-plan/strings"] future = [] -dtype-u8 = ["polars-plan/dtype-u8", "polars-pipe/dtype-u8"] -dtype-u16 = ["polars-plan/dtype-u16", "polars-pipe/dtype-u16"] -dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe/dtype-i8"] -dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe/dtype-i16"] -dtype-decimal = ["polars-plan/dtype-decimal", "polars-pipe/dtype-decimal"] +dtype-u8 = ["polars-plan/dtype-u8", "polars-pipe?/dtype-u8"] +dtype-u16 = ["polars-plan/dtype-u16", "polars-pipe?/dtype-u16"] +dtype-i8 = ["polars-plan/dtype-i8", "polars-pipe?/dtype-i8"] +dtype-i16 = ["polars-plan/dtype-i16", "polars-pipe?/dtype-i16"] +dtype-decimal = ["polars-plan/dtype-decimal", "polars-pipe?/dtype-decimal"] dtype-date = ["polars-plan/dtype-date", "polars-time/dtype-date", "temporal"] dtype-datetime = ["polars-plan/dtype-datetime", "polars-time/dtype-datetime", "temporal"] dtype-duration = ["polars-plan/dtype-duration", "polars-time/dtype-duration", "temporal"] dtype-time = ["polars-core/dtype-time", "temporal"] -dtype-array = ["polars-plan/dtype-array", "polars-pipe/dtype-array", "polars-ops/dtype-array"] -dtype-categorical = ["polars-plan/dtype-categorical", "polars-pipe/dtype-categorical"] +dtype-array = ["polars-plan/dtype-array", "polars-pipe?/dtype-array", "polars-ops/dtype-array"] +dtype-categorical = ["polars-plan/dtype-categorical", "polars-pipe?/dtype-categorical"] dtype-struct = ["polars-plan/dtype-struct"] object = ["polars-plan/object"] date_offset = ["polars-plan/date_offset"] @@ -83,10 +80,10 @@ approx_unique = ["polars-plan/approx_unique"] is_in = ["polars-plan/is_in", "polars-ops/is_in"] repeat_by = ["polars-plan/repeat_by"] round_series = ["polars-plan/round_series", "polars-ops/round_series"] -is_first = ["polars-plan/is_first"] -is_last = ["polars-plan/is_last"] +is_first_distinct = ["polars-plan/is_first_distinct"] +is_last_distinct = ["polars-plan/is_last_distinct"] is_unique = ["polars-plan/is_unique"] -cross_join = ["polars-plan/cross_join", "polars-pipe/cross_join", "polars-ops/cross_join"] +cross_join = ["polars-plan/cross_join", "polars-pipe?/cross_join", "polars-ops/cross_join"] asof_join = ["polars-plan/asof_join", "polars-time"] concat_str = ["polars-plan/concat_str"] range = ["polars-plan/range"] @@ -132,7 +129,7 @@ serde = [ "polars-plan/serde", "polars-arrow/serde", "polars-core/serde-lazy", - "polars-time/serde", + "polars-time?/serde", "polars-io/serde", "polars-ops/serde", ] diff --git a/crates/polars-lazy/src/frame/csv.rs b/crates/polars-lazy/src/frame/csv.rs index be497c336388..0b76c1233654 100644 --- a/crates/polars-lazy/src/frame/csv.rs +++ b/crates/polars-lazy/src/frame/csv.rs @@ -1,8 +1,9 @@ use std::path::{Path, PathBuf}; use polars_core::prelude::*; -use polars_io::csv::utils::{get_reader_bytes, infer_file_schema}; +use polars_io::csv::utils::infer_file_schema; use polars_io::csv::{CsvEncoding, NullValues}; +use polars_io::utils::get_reader_bytes; use polars_io::RowCount; use crate::frame::LazyFileListReader; diff --git a/crates/polars-lazy/src/frame/file_list_reader.rs b/crates/polars-lazy/src/frame/file_list_reader.rs index 8824406f2599..73511120017b 100644 --- a/crates/polars-lazy/src/frame/file_list_reader.rs +++ b/crates/polars-lazy/src/frame/file_list_reader.rs @@ -1,8 +1,8 @@ use std::path::{Path, PathBuf}; -use polars_core::cloud::CloudOptions; use polars_core::error::to_compute_err; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::{is_cloud_url, RowCount}; use crate::prelude::*; diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 6f367cae9a9b..141bea865bb5 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -689,7 +689,7 @@ impl LazyFrame { pub fn sink_parquet_cloud( mut self, uri: String, - cloud_options: Option, + cloud_options: Option, parquet_options: ParquetWriteOptions, ) -> PolarsResult<()> { self.opt_state.streaming = true; diff --git a/crates/polars-lazy/src/frame/parquet.rs b/crates/polars-lazy/src/frame/parquet.rs index c71ed3b7821d..0a1b7c1b51b0 100644 --- a/crates/polars-lazy/src/frame/parquet.rs +++ b/crates/polars-lazy/src/frame/parquet.rs @@ -1,7 +1,7 @@ use std::path::{Path, PathBuf}; -use polars_core::cloud::CloudOptions; use polars_core::prelude::*; +use polars_io::cloud::CloudOptions; use polars_io::parquet::ParallelStrategy; use polars_io::RowCount; diff --git a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs index b6d890cbac0a..20e3c789afaa 100644 --- a/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs +++ b/crates/polars-lazy/src/physical_plan/executors/group_by_rolling.rs @@ -66,14 +66,13 @@ impl GroupByRollingExec { // can be empty, but we still want to know the first value // of that group for key in keys.iter_mut() { - *key = key.take_unchecked_from_slice(first).unwrap(); + *key = key.take_unchecked_from_slice(first); } }, GroupsProxy::Slice { groups, .. } => { for key in keys.iter_mut() { - let iter = &mut groups.iter().map(|[first, _len]| *first as usize) - as &mut dyn TakeIterator; - *key = key.take_iter_unchecked(iter); + let indices = groups.iter().map(|[first, _len]| *first).collect_ca(""); + *key = key.take_unchecked(&indices); } }, } diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs index 79c7b43dccc6..e61c1b7057c5 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/ipc.rs @@ -20,7 +20,7 @@ impl IpcExec { self.file_options.n_rows, self.file_options.row_count.is_some(), ); - IpcReader::new(file) + IpcReader::new(file.unwrap()) .with_n_rows(n_rows) .with_row_count(std::mem::take(&mut self.file_options.row_count)) .set_rechunk(self.file_options.rechunk) diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs b/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs index 07788ff10a74..cd509e73097e 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/mod.rs @@ -41,8 +41,8 @@ fn prepare_scan_args( schema: &mut SchemaRef, n_rows: Option, has_row_count: bool, -) -> (std::fs::File, Projection, StopNRows, Predicate) { - let file = std::fs::File::open(path).unwrap(); +) -> (Option, Projection, StopNRows, Predicate) { + let file = std::fs::File::open(path).ok(); let with_columns = mem::take(with_columns); let schema = mem::take(schema); diff --git a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs index 987b92fbdcff..4584c49ec035 100644 --- a/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs +++ b/crates/polars-lazy/src/physical_plan/executors/scan/parquet.rs @@ -1,15 +1,16 @@ use std::path::PathBuf; -use polars_core::cloud::CloudOptions; +use polars_io::cloud::CloudOptions; +use polars_io::is_cloud_url; use super::*; -#[allow(dead_code)] pub struct ParquetExec { path: PathBuf, schema: SchemaRef, predicate: Option>, options: ParquetOptions, + #[allow(dead_code)] cloud_options: Option, file_options: FileScanOptions, } @@ -43,14 +44,35 @@ impl ParquetExec { self.file_options.row_count.is_some(), ); - ParquetReader::new(file) - .with_n_rows(n_rows) - .read_parallel(self.options.parallel) - .with_row_count(mem::take(&mut self.file_options.row_count)) - .set_rechunk(self.file_options.rechunk) - .set_low_memory(self.options.low_memory) - .use_statistics(self.options.use_statistics) - ._finish_with_scan_ops(predicate, projection.as_ref().map(|v| v.as_ref())) + if let Some(file) = file { + ParquetReader::new(file) + .with_n_rows(n_rows) + .read_parallel(self.options.parallel) + .with_row_count(mem::take(&mut self.file_options.row_count)) + .set_rechunk(self.file_options.rechunk) + .set_low_memory(self.options.low_memory) + .use_statistics(self.options.use_statistics) + ._finish_with_scan_ops(predicate, projection.as_ref().map(|v| v.as_ref())) + } else if is_cloud_url(self.path.as_path()) { + #[cfg(feature = "cloud")] + { + let reader = ParquetAsyncReader::from_uri( + &self.path.to_string_lossy(), + self.cloud_options.as_ref(), + )? + .with_n_rows(n_rows) + .with_row_count(mem::take(&mut self.file_options.row_count)) + .use_statistics(self.options.use_statistics); + + reader.finish(predicate) + } + #[cfg(not(feature = "cloud"))] + { + panic!("activate cloud feature") + } + } else { + polars_bail!(ComputeError: "could not read {}", self.path.display()) + } } } diff --git a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs index 4b67da4b5d5a..4689afa2dd2b 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/aggregation.rs @@ -427,7 +427,7 @@ impl PartitionedAggregation for AggregationExpr { let ca = unsafe { // Safety // The indexes of the group_by operation are never out of bounds - ca.take_unchecked(idx.into()) + ca.take_unchecked(idx) }; process_group(ca)?; } diff --git a/crates/polars-lazy/src/physical_plan/expressions/sort.rs b/crates/polars-lazy/src/physical_plan/expressions/sort.rs index 473c43e5befc..4f95d3e555d3 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sort.rs @@ -84,13 +84,8 @@ impl PhysicalExpr for SortExpr { groups .par_iter() .map(|(first, idx)| { - // Safety: - // Group tuples are always in bounds - let group = unsafe { - series.take_iter_unchecked( - &mut idx.iter().map(|i| *i as usize), - ) - }; + // SAFETY: group tuples are always in bounds. + let group = unsafe { series.take_slice_unchecked(idx) }; let sorted_idx = group.arg_sort(sort_options); let new_idx = map_sorted_indices_to_group_idx(&sorted_idx, idx); diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index 2aa1b01a367b..e7503e4d6995 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -92,7 +92,7 @@ impl PhysicalExpr for SortByExpr { ); // SAFETY: sorted index are within bounds. - unsafe { series.take_unchecked(&sorted_idx) } + unsafe { Ok(series.take_unchecked(&sorted_idx)) } } #[allow(clippy::ptr_arg)] @@ -135,7 +135,7 @@ impl PhysicalExpr for SortByExpr { multithreaded: false, ..Default::default() }); - Some(unsafe { s.take_unchecked(&idx).unwrap() }) + Some(unsafe { s.take_unchecked(&idx) }) } }, _ => None, @@ -168,11 +168,7 @@ impl PhysicalExpr for SortByExpr { let new_idx = match indicator { GroupsIndicator::Idx((_, idx)) => { // SAFETY: group tuples are always in bounds. - let group = unsafe { - sort_by_s.take_iter_unchecked( - &mut idx.iter().map(|i| *i as usize), - ) - }; + let group = unsafe { sort_by_s.take_slice_unchecked(idx) }; let sorted_idx = group.arg_sort(SortOptions { descending: descending[0], @@ -244,11 +240,7 @@ impl PhysicalExpr for SortByExpr { // SAFETY: group tuples are always in bounds. let groups = sort_by_s .iter() - .map(|s| unsafe { - s.take_iter_unchecked( - &mut idx.iter().map(|i| *i as usize), - ) - }) + .map(|s| unsafe { s.take_slice_unchecked(idx) }) .collect::>(); let options = SortMultipleOptions { diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index fb26b0e6fb5c..7f4b117ab01d 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -120,7 +120,7 @@ impl WindowExpr { // Safety: // groups should always be in bounds. - unsafe { flattened.take_unchecked(&idx) } + unsafe { Ok(flattened.take_unchecked(&idx)) } } #[allow(clippy::too_many_arguments)] @@ -635,9 +635,7 @@ fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Series) -> Se match join_opt_ids { Either::Left(ids) => unsafe { - out_column.take_opt_iter_unchecked( - &mut ids.iter().map(|&opt_i| opt_i.map(|i| i as usize)), - ) + out_column.take_unchecked(&ids.iter().copied().collect_ca("")) }, Either::Right(ids) => unsafe { out_column._take_opt_chunked_unchecked(ids) }, } @@ -645,9 +643,7 @@ fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Series) -> Se #[cfg(not(feature = "chunked_ids"))] unsafe { - out_column.take_opt_iter_unchecked( - &mut join_opt_ids.iter().map(|&opt_i| opt_i.map(|i| i as usize)), - ) + out_column.take_unchecked(&join_opt_ids.iter().copied().collect_ca("")) } } diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index d0e620056ea1..d69ddfc460f6 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -183,7 +183,7 @@ fn test_power_in_agg_list1() -> PolarsResult<()> { .collect()?; let agg = out.column("foo")?.list()?; - let first = agg.get(0).unwrap(); + let first = agg.get_as_series(0).unwrap(); let vals = first.f64()?; assert_eq!(Vec::from(vals), &[Some(1.0), Some(4.0), Some(25.0)]); diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index fe9035c0544a..568000e19ebf 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -299,7 +299,7 @@ fn test_lazy_query_5() { .unwrap() .list() .unwrap() - .get(0) + .get_as_series(0) .unwrap(); assert_eq!(s.len(), 2); let s = out @@ -307,7 +307,7 @@ fn test_lazy_query_5() { .unwrap() .list() .unwrap() - .get(0) + .get_as_series(0) .unwrap(); assert_eq!(s.len(), 2); } @@ -670,7 +670,7 @@ fn test_lazy_partition_agg() { .collect() .unwrap(); let cat_agg_list = out.select_at_idx(1).unwrap(); - let fruit_series = cat_agg_list.list().unwrap().get(0).unwrap(); + let fruit_series = cat_agg_list.list().unwrap().get_as_series(0).unwrap(); let fruit_list = fruit_series.i64().unwrap(); assert_eq!( Vec::from(fruit_list), @@ -1141,11 +1141,11 @@ fn test_fill_forward() -> PolarsResult<()> { .collect()?; let agg = out.column("b")?.list()?; - let a: Series = agg.get(0).unwrap(); + let a: Series = agg.get_as_series(0).unwrap(); assert!(a.series_equal(&Series::new("b", &[1, 1]))); - let a: Series = agg.get(2).unwrap(); + let a: Series = agg.get_as_series(2).unwrap(); assert!(a.series_equal(&Series::new("b", &[1, 1]))); - let a: Series = agg.get(1).unwrap(); + let a: Series = agg.get_as_series(1).unwrap(); assert_eq!(a.null_count(), 1); Ok(()) } @@ -1310,7 +1310,7 @@ fn test_filter_after_shift_in_groups() -> PolarsResult<()> { assert_eq!( out.column("filtered")? .list()? - .get(0) + .get_as_series(0) .unwrap() .i32()? .get(0) @@ -1320,14 +1320,21 @@ fn test_filter_after_shift_in_groups() -> PolarsResult<()> { assert_eq!( out.column("filtered")? .list()? - .get(1) + .get_as_series(1) .unwrap() .i32()? .get(0) .unwrap(), 5 ); - assert_eq!(out.column("filtered")?.list()?.get(2).unwrap().len(), 0); + assert_eq!( + out.column("filtered")? + .list()? + .get_as_series(2) + .unwrap() + .len(), + 0 + ); Ok(()) } @@ -1564,7 +1571,7 @@ fn test_group_by_rank() -> PolarsResult<()> { .collect()?; let out = out.column("B")?; - let out = out.list()?.get(1).unwrap(); + let out = out.list()?.get_as_series(1).unwrap(); let out = out.idx()?; assert_eq!(Vec::from(out), &[Some(1)]); diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 6e122d0aea91..272cca8e65c7 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -9,18 +9,18 @@ repository = { workspace = true } description = "More operations on Polars data structures" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", default-features = false } -polars-core = { version = "0.32.0", path = "../polars-core", features = [], default-features = false } -polars-json = { version = "0.32.0", optional = true, path = "../polars-json", default-features = false } -polars-utils = { version = "0.32.0", path = "../polars-utils", default-features = false } +polars-arrow = { workspace = true, default-features = false } +polars-core = { workspace = true, features = ["algorithm_group_by", "algorithm_join"], default-features = false } +polars-json = { workspace = true, optional = true } +polars-utils = { workspace = true, default-features = false } argminmax = { version = "0.6.1", default-features = false, features = ["float"] } arrow = { workspace = true } -base64 = { version = "0.21", optional = true } +base64 = { workspace = true, optional = true } chrono = { workspace = true, optional = true } chrono-tz = { workspace = true, optional = true } either = { workspace = true } -hex = { version = "0.4", optional = true } +hex = { workspace = true, optional = true } indexmap = { workspace = true } jsonpath_lib = { version = "0.3", optional = true, git = "https://github.com/ritchie46/jsonpath", branch = "improve_compiled" } memchr = { workspace = true } @@ -52,8 +52,8 @@ propagate_nans = [] performant = ["polars-core/performant", "fused"] big_idx = ["polars-core/bigidx"] round_series = [] -is_first = [] -is_last = [] +is_first_distinct = [] +is_last_distinct = [] is_unique = [] approx_unique = [] fused = [] diff --git a/crates/polars-ops/src/chunked_array/binary/namespace.rs b/crates/polars-ops/src/chunked_array/binary/namespace.rs index 59c444b4ac1d..400f6023a94b 100644 --- a/crates/polars-ops/src/chunked_array/binary/namespace.rs +++ b/crates/polars-ops/src/chunked_array/binary/namespace.rs @@ -6,26 +6,27 @@ use base64::engine::general_purpose; #[cfg(feature = "binary_encoding")] use base64::Engine as _; use memchr::memmem::find; +use polars_core::prelude::arity::binary_elementwise_values; use super::*; pub trait BinaryNameSpaceImpl: AsBinary { /// Check if binary contains given literal - fn contains(&self, lit: &[u8]) -> PolarsResult { + fn contains(&self, lit: &[u8]) -> BooleanChunked { let ca = self.as_binary(); let f = |s: &[u8]| find(s, lit).is_some(); - let mut out: BooleanChunked = if !ca.has_validity() { - ca.into_no_null_iter().map(f).collect() - } else { - ca.into_iter().map(|opt_s| opt_s.map(f)).collect() - }; - out.rename(ca.name()); - Ok(out) + ca.apply_values_generic(f) } - /// Check if strings contain a given literal - fn contains_literal(&self, lit: &[u8]) -> PolarsResult { - self.contains(lit) + fn contains_chunked(&self, lit: &BinaryChunked) -> BooleanChunked { + let ca = self.as_binary(); + match lit.len() { + 1 => match lit.get(0) { + Some(lit) => ca.contains(lit), + None => BooleanChunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise_values(ca, lit, |src, lit| find(src, lit).is_some()), + } } /// Check if strings ends with a substring @@ -46,6 +47,28 @@ pub trait BinaryNameSpaceImpl: AsBinary { out } + fn starts_with_chunked(&self, prefix: &BinaryChunked) -> BooleanChunked { + let ca = self.as_binary(); + match prefix.len() { + 1 => match prefix.get(0) { + Some(s) => self.starts_with(s), + None => BooleanChunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise_values(ca, prefix, |s, sub| s.starts_with(sub)), + } + } + + fn ends_with_chunked(&self, suffix: &BinaryChunked) -> BooleanChunked { + let ca = self.as_binary(); + match suffix.len() { + 1 => match suffix.get(0) { + Some(s) => self.ends_with(s), + None => BooleanChunked::full_null(ca.name(), ca.len()), + }, + _ => binary_elementwise_values(ca, suffix, |s, sub| s.ends_with(sub)), + } + } + #[cfg(feature = "binary_encoding")] fn hex_decode(&self, strict: bool) -> PolarsResult { let ca = self.as_binary(); diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index bcd545a1ba06..44ddcad5d25f 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -76,22 +76,62 @@ fn cast_rhs( pub trait ListNameSpaceImpl: AsList { /// In case the inner dtype [`DataType::Utf8`], the individual items will be joined into a /// single string separated by `separator`. - fn lst_join(&self, separator: &str) -> PolarsResult { + fn lst_join(&self, separator: &Utf8Chunked) -> PolarsResult { let ca = self.as_list(); match ca.inner_dtype() { - DataType::Utf8 => { - // used to amortize heap allocs - let mut buf = String::with_capacity(128); + DataType::Utf8 => match separator.len() { + 1 => match separator.get(0) { + Some(separator) => self.join_literal(separator), + _ => Ok(Utf8Chunked::full_null(ca.name(), ca.len())), + }, + _ => self.join_many(separator), + }, + dt => polars_bail!(op = "`lst.join`", got = dt, expected = "Utf8"), + } + } - let mut builder = Utf8ChunkedBuilder::new( - ca.name(), - ca.len(), - ca.get_values_size() + separator.len() * ca.len(), - ); + fn join_literal(&self, separator: &str) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = Utf8ChunkedBuilder::new( + ca.name(), + ca.len(), + ca.get_values_size() + separator.len() * ca.len(), + ); + + ca.for_each_amortized(|opt_s| { + let opt_val = opt_s.map(|s| { + // make sure that we don't write values of previous iteration + buf.clear(); + let ca = s.as_ref().utf8().unwrap(); + let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null")); + + for val in iter { + buf.write_str(val).unwrap(); + buf.write_str(separator).unwrap(); + } + // last value should not have a separator, so slice that off + // saturating sub because there might have been nothing written. + &buf[..buf.len().saturating_sub(separator.len())] + }); + builder.append_option(opt_val) + }); + Ok(builder.finish()) + } - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - ca.amortized_iter().for_each(|opt_s| { + fn join_many(&self, separator: &Utf8Chunked) -> PolarsResult { + let ca = self.as_list(); + // used to amortize heap allocs + let mut buf = String::with_capacity(128); + let mut builder = + Utf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size() + ca.len()); + // SAFETY: unstable series never lives longer than the iterator. + unsafe { + ca.amortized_iter() + .zip(separator) + .for_each(|(opt_s, opt_sep)| match opt_sep { + Some(separator) => { let opt_val = opt_s.map(|s| { // make sure that we don't write values of previous iteration buf.clear(); @@ -107,12 +147,11 @@ pub trait ListNameSpaceImpl: AsList { &buf[..buf.len().saturating_sub(separator.len())] }); builder.append_option(opt_val) - }) - }; - Ok(builder.finish()) - }, - dt => polars_bail!(op = "`lst.join`", got = dt, expected = "Utf8"), + }, + _ => builder.append_null(), + }) } + Ok(builder.finish()) } fn lst_max(&self) -> Series { @@ -379,7 +418,7 @@ pub trait ListNameSpaceImpl: AsList { .iter() .flat_map(|s| { let lst = s.list().unwrap(); - lst.get(0) + lst.get_as_series(0) }) .collect::>(); // there was a None, so all values will be None diff --git a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs index ebf7f0d1545d..24e0b36a2d24 100644 --- a/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs +++ b/crates/polars-ops/src/chunked_array/nan_propagating_aggregate.rs @@ -118,7 +118,7 @@ where idx.len() as IdxSize, ), _ => { - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; ca_nan_agg(&take, nan_max) }, } @@ -190,7 +190,7 @@ where idx.len() as IdxSize, ), _ => { - let take = { ca.take_unchecked(idx.into()) }; + let take = { ca.take_unchecked(idx) }; ca_nan_agg(&take, nan_min) }, } diff --git a/crates/polars-ops/src/chunked_array/strings/mod.rs b/crates/polars-ops/src/chunked_array/strings/mod.rs index 3caaec8a9dba..caed488b40a9 100644 --- a/crates/polars-ops/src/chunked_array/strings/mod.rs +++ b/crates/polars-ops/src/chunked_array/strings/mod.rs @@ -10,6 +10,8 @@ mod justify; mod namespace; #[cfg(feature = "strings")] mod replace; +#[cfg(feature = "strings")] +mod substring; #[cfg(feature = "extract_jsonpath")] pub use json_path::*; diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 0cc3b8721c71..fae3a127143c 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -2,13 +2,11 @@ use base64::engine::general_purpose; #[cfg(feature = "string_encoding")] use base64::Engine as _; -use polars_arrow::export::arrow::compute::substring::substring; -use polars_arrow::export::arrow::{self}; use polars_arrow::kernels::string::*; #[cfg(feature = "string_from_radix")] use polars_core::export::num::Num; use polars_core::export::regex::Regex; -use polars_core::prelude::arity::try_binary_elementwise; +use polars_core::prelude::arity::*; use polars_utils::cache::FastFixedCache; use regex::escape; @@ -88,6 +86,56 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Ok(out) } + fn contains_chunked( + &self, + pat: &Utf8Chunked, + literal: bool, + strict: bool, + ) -> PolarsResult { + let ca = self.as_utf8(); + match pat.len() { + 1 => match pat.get(0) { + Some(pat) => { + if literal { + ca.contains_literal(pat) + } else { + ca.contains(pat, strict) + } + }, + None => Ok(BooleanChunked::full_null(ca.name(), ca.len())), + }, + _ => { + if literal { + Ok(binary_elementwise_values(ca, pat, |src, pat| { + src.contains(pat) + })) + } else if strict { + // A sqrt(n) regex cache is not too small, not too large. + let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); + try_binary_elementwise(ca, pat, |opt_src, opt_pat| match (opt_src, opt_pat) { + (Some(src), Some(pat)) => { + let reg = reg_cache.try_get_or_insert_with(pat, |p| Regex::new(p))?; + Ok(Some(reg.is_match(src))) + }, + _ => Ok(None), + }) + } else { + // A sqrt(n) regex cache is not too small, not too large. + let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); + Ok(binary_elementwise(ca, pat, |opt_src, opt_pat| { + match (opt_src, opt_pat) { + (Some(src), Some(pat)) => { + let reg = reg_cache.try_get_or_insert_with(pat, |p| Regex::new(p)); + reg.ok().map(|re| re.is_match(src)) + }, + _ => None, + } + })) + } + }, + } + } + /// Get the length of the string values as number of chars. fn str_n_chars(&self) -> UInt32Chunked { let ca = self.as_utf8(); @@ -135,18 +183,11 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { let res_reg = Regex::new(pat); let opt_reg = if strict { Some(res_reg?) } else { res_reg.ok() }; - let mut out: BooleanChunked = match (opt_reg, ca.has_validity()) { - (Some(reg), false) => ca - .into_no_null_iter() - .map(|s: &str| reg.is_match(s)) - .collect(), - (Some(reg), true) => ca - .into_iter() - .map(|opt_s| opt_s.map(|s: &str| reg.is_match(s))) - .collect(), - (None, _) => ca.into_iter().map(|_| None).collect(), + let out: BooleanChunked = if let Some(reg) = opt_reg { + ca.apply_values_generic(|s| reg.is_match(s)) + } else { + BooleanChunked::full_null(ca.name(), ca.len()) }; - out.rename(ca.name()); Ok(out) } @@ -158,24 +199,6 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { self.contains(regex::escape(lit).as_str(), true) } - /// Check if strings ends with a substring - fn ends_with(&self, sub: &str) -> BooleanChunked { - let ca = self.as_utf8(); - let f = |s: &str| s.ends_with(sub); - let mut out: BooleanChunked = ca.into_iter().map(|opt_s| opt_s.map(f)).collect(); - out.rename(ca.name()); - out - } - - /// Check if strings starts with a substring - fn starts_with(&self, sub: &str) -> BooleanChunked { - let ca = self.as_utf8(); - let f = |s: &str| s.starts_with(sub); - let mut out: BooleanChunked = ca.into_iter().map(|opt_s| opt_s.map(f)).collect(); - out.rename(ca.name()); - out - } - /// Replace the leftmost regex-matched (sub)string with another string fn replace<'a>(&'a self, pat: &str, val: &str) -> PolarsResult { let reg = Regex::new(pat)?; @@ -311,6 +334,66 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Ok(builder.finish()) } + fn split(&self, by: &str) -> ListChunked { + let ca = self.as_utf8(); + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + ca.for_each(|opt_v| match opt_v { + Some(val) => { + let iter = val.split(by); + builder.append_values_iter(iter) + }, + _ => builder.append_null(), + }); + builder.finish() + } + + fn split_many(&self, by: &Utf8Chunked) -> ListChunked { + let ca = self.as_utf8(); + + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + let iter = s.split(by); + builder.append_values_iter(iter); + }, + _ => builder.append_null(), + }); + + builder.finish() + } + + fn split_inclusive(&self, by: &str) -> ListChunked { + let ca = self.as_utf8(); + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + ca.for_each(|opt_v| match opt_v { + Some(val) => { + let iter = val.split_inclusive(by); + builder.append_values_iter(iter) + }, + _ => builder.append_null(), + }); + builder.finish() + } + + fn split_inclusive_many(&self, by: &Utf8Chunked) -> ListChunked { + let ca = self.as_utf8(); + + let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); + + binary_elementwise_for_each(ca, by, |opt_s, opt_by| match (opt_s, opt_by) { + (Some(s), Some(by)) => { + let iter = s.split_inclusive(by); + builder.append_values_iter(iter); + }, + _ => builder.append_null(), + }); + + builder.finish() + } + /// Extract each successive non-overlapping regex match in an individual string as an array. fn extract_all_many(&self, pat: &Utf8Chunked) -> PolarsResult { let ca = self.as_utf8(); @@ -323,15 +406,13 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { // A sqrt(n) regex cache is not too small, not too large. let mut reg_cache = FastFixedCache::new((ca.len() as f64).sqrt() as usize); let mut builder = ListUtf8ChunkedBuilder::new(ca.name(), ca.len(), ca.get_values_size()); - for (opt_s, opt_pat) in ca.into_iter().zip(pat) { - match (opt_s, opt_pat) { - (_, None) | (None, _) => builder.append_null(), - (Some(s), Some(pat)) => { - let reg = reg_cache.get_or_insert_with(pat, |p| Regex::new(p).unwrap()); - builder.append_values_iter(reg.find_iter(s).map(|m| m.as_str())); - }, - } - } + binary_elementwise_for_each(ca, pat, |opt_s, opt_pat| match (opt_s, opt_pat) { + (_, None) | (None, _) => builder.append_null(), + (Some(s), Some(pat)) => { + let reg = reg_cache.get_or_insert_with(pat, |p| Regex::new(p).unwrap()); + builder.append_values_iter(reg.find_iter(s).map(|m| m.as_str())); + }, + }); Ok(builder.finish()) } @@ -351,12 +432,7 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { Regex::new(pat)? }; - let mut out: UInt32Chunked = ca - .into_iter() - .map(|opt_s| opt_s.map(|s| reg.find_iter(s).count() as u32)) - .collect(); - out.rename(ca.name()); - Ok(out) + Ok(ca.apply_generic(|opt_s| opt_s.map(|s| reg.find_iter(s).count() as u32))) } /// Count all successive non-overlapping regex matches. @@ -424,14 +500,12 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { /// /// Determines a substring starting from `start` and with optional length `length` of each of the elements in `array`. /// `start` can be negative, in which case the start counts from the end of the string. - fn str_slice(&self, start: i64, length: Option) -> PolarsResult { + fn str_slice(&self, start: i64, length: Option) -> Utf8Chunked { let ca = self.as_utf8(); - let chunks = ca + let iter = ca .downcast_iter() - .map(|c| substring(c, start, &length)) - .collect::>()?; - // SAFETY: these are all the same type. - unsafe { Ok(Utf8Chunked::from_chunks(ca.name(), chunks)) } + .map(|c| substring::utf8_substring(c, start, &length)); + Utf8Chunked::from_chunk_iter_like(ca, iter) } } diff --git a/crates/polars-ops/src/chunked_array/strings/substring.rs b/crates/polars-ops/src/chunked_array/strings/substring.rs new file mode 100644 index 000000000000..e485e25dd216 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/strings/substring.rs @@ -0,0 +1,51 @@ +use arrow::array::Utf8Array; + +/// Returns a Utf8Array with a substring starting from `start` and with optional length `length` of each of the elements in `array`. +/// `start` can be negative, in which case the start counts from the end of the string. +pub(super) fn utf8_substring( + array: &Utf8Array, + start: i64, + length: &Option, +) -> Utf8Array { + let length = length.map(|v| v as usize); + + let iter = array.values_iter().map(|str_val| { + // compute where we should start slicing this entry. + let start = if start >= 0 { + start as usize + } else { + let start = (0i64 - start) as usize; + str_val + .char_indices() + .rev() + .nth(start) + .map(|(idx, _)| idx + 1) + .unwrap_or(0) + }; + + let mut iter_chars = str_val.char_indices(); + if let Some((start_idx, _)) = iter_chars.nth(start) { + // length of the str + let len_end = str_val.len() - start_idx; + + // length to slice + let length = length.unwrap_or(len_end); + + if length == 0 { + return ""; + } + // compute + let end_idx = iter_chars + .nth(length.saturating_sub(1)) + .map(|(idx, _)| idx) + .unwrap_or(str_val.len()); + + &str_val[start_idx..end_idx] + } else { + "" + } + }); + + let new = Utf8Array::::from_trusted_len_values_iter(iter); + new.with_validity(array.validity().cloned()) +} diff --git a/crates/polars-ops/src/frame/join/mod.rs b/crates/polars-ops/src/frame/join/mod.rs index f7f035d4a0c8..50c6b68d28cd 100644 --- a/crates/polars-ops/src/frame/join/mod.rs +++ b/crates/polars-ops/src/frame/join/mod.rs @@ -295,17 +295,19 @@ pub trait DataFrameJoinOps: IntoDf { // Take the left and right dataframes by join tuples let (df_left, df_right) = POOL.join( || unsafe { - remove_selected(left_df, &selected_left).take_opt_iter_unchecked( - opt_join_tuples + remove_selected(left_df, &selected_left).take_unchecked( + &opt_join_tuples .iter() - .map(|(left, _right)| left.map(|i| i as usize)), + .map(|(left, _right)| *left) + .collect_ca(""), ) }, || unsafe { - remove_selected(other, &selected_right).take_opt_iter_unchecked( - opt_join_tuples + remove_selected(other, &selected_right).take_unchecked( + &opt_join_tuples .iter() - .map(|(_left, right)| right.map(|i| i as usize)), + .map(|(_left, right)| *right) + .collect_ca(""), ) }, ); diff --git a/crates/polars-ops/src/series/ops/index.rs b/crates/polars-ops/src/series/ops/index.rs index 8ede2ac374f6..a73b5378dd51 100644 --- a/crates/polars-ops/src/series/ops/index.rs +++ b/crates/polars-ops/src/series/ops/index.rs @@ -22,7 +22,7 @@ where Ok(IdxSize::try_from(v).unwrap()) } else { IdxSize::from_i64(len + v.to_i64().unwrap()).ok_or_else(|| { - PolarsError::ComputeError( + PolarsError::OutOfBounds( format!( "index {} is out of bounds for series of len {}", v, target_len diff --git a/crates/polars-ops/src/series/ops/is_first.rs b/crates/polars-ops/src/series/ops/is_first_distinct.rs similarity index 82% rename from crates/polars-ops/src/series/ops/is_first.rs rename to crates/polars-ops/src/series/ops/is_first_distinct.rs index 9ddc73ba1bf1..9542394c00ef 100644 --- a/crates/polars-ops/src/series/ops/is_first.rs +++ b/crates/polars-ops/src/series/ops/is_first_distinct.rs @@ -6,7 +6,7 @@ use polars_arrow::bit_util::*; use polars_arrow::utils::CustomIterTools; use polars_core::prelude::*; use polars_core::with_match_physical_integer_polars_type; -fn is_first_numeric(ca: &ChunkedArray) -> BooleanChunked +fn is_first_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, T::Native: Hash + Eq, @@ -21,7 +21,7 @@ where BooleanChunked::from_chunk_iter(ca.name(), chunks) } -fn is_first_bin(ca: &BinaryChunked) -> BooleanChunked { +fn is_first_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { let mut unique = PlHashSet::new(); let chunks = ca.downcast_iter().map(|arr| -> BooleanArray { arr.into_iter() @@ -32,7 +32,7 @@ fn is_first_bin(ca: &BinaryChunked) -> BooleanChunked { BooleanChunked::from_chunk_iter(ca.name(), chunks) } -fn is_first_boolean(ca: &BooleanChunked) -> BooleanChunked { +fn is_first_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { let mut out = MutableBitmap::with_capacity(ca.len()); out.extend_constant(ca.len(), false); @@ -71,7 +71,7 @@ fn is_first_boolean(ca: &BooleanChunked) -> BooleanChunked { } #[cfg(feature = "dtype-struct")] -fn is_first_struct(s: &Series) -> PolarsResult { +fn is_first_distinct_struct(s: &Series) -> PolarsResult { let groups = s.group_tuples(true, false)?; let first = groups.take_group_firsts(); let mut out = MutableBitmap::with_capacity(s.len()); @@ -87,7 +87,7 @@ fn is_first_struct(s: &Series) -> PolarsResult { } #[cfg(feature = "group_by_list")] -fn is_first_list(ca: &ListChunked) -> PolarsResult { +fn is_first_distinct_list(ca: &ListChunked) -> PolarsResult { let groups = ca.group_tuples(true, false)?; let first = groups.take_group_firsts(); let mut out = MutableBitmap::with_capacity(ca.len()); @@ -102,7 +102,7 @@ fn is_first_list(ca: &ListChunked) -> PolarsResult { Ok(BooleanChunked::with_chunk(ca.name(), arr)) } -pub fn is_first(s: &Series) -> PolarsResult { +pub fn is_first_distinct(s: &Series) -> PolarsResult { // fast path. if s.len() == 0 { return Ok(BooleanChunked::full_null(s.name(), 0)); @@ -116,38 +116,38 @@ pub fn is_first(s: &Series) -> PolarsResult { let out = match s.dtype() { Boolean => { let ca = s.bool().unwrap(); - is_first_boolean(ca) + is_first_distinct_boolean(ca) }, Binary => { let ca = s.binary().unwrap(); - is_first_bin(ca) + is_first_distinct_bin(ca) }, Utf8 => { let s = s.cast(&Binary).unwrap(); - return is_first(&s); + return is_first_distinct(&s); }, Float32 => { let ca = s.bit_repr_small(); - is_first_numeric(&ca) + is_first_distinct_numeric(&ca) }, Float64 => { let ca = s.bit_repr_large(); - is_first_numeric(&ca) + is_first_distinct_numeric(&ca) }, dt if dt.is_numeric() => { with_match_physical_integer_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - is_first_numeric(ca) + is_first_distinct_numeric(ca) }) }, #[cfg(feature = "dtype-struct")] - Struct(_) => return is_first_struct(&s), + Struct(_) => return is_first_distinct_struct(&s), #[cfg(feature = "group_by_list")] List(inner) if inner.is_numeric() => { let ca = s.list().unwrap(); - return is_first_list(ca); + return is_first_distinct_list(ca); }, - dt => polars_bail!(opq = is_first, dt), + dt => polars_bail!(opq = is_first_distinct, dt), }; Ok(out) } diff --git a/crates/polars-ops/src/series/ops/is_last.rs b/crates/polars-ops/src/series/ops/is_last_distinct.rs similarity index 85% rename from crates/polars-ops/src/series/ops/is_last.rs rename to crates/polars-ops/src/series/ops/is_last_distinct.rs index 02168d73b6c8..9cb00799dacf 100644 --- a/crates/polars-ops/src/series/ops/is_last.rs +++ b/crates/polars-ops/src/series/ops/is_last_distinct.rs @@ -7,7 +7,7 @@ use polars_core::prelude::*; use polars_core::utils::NoNull; use polars_core::with_match_physical_integer_polars_type; -pub fn is_last(s: &Series) -> PolarsResult { +pub fn is_last_distinct(s: &Series) -> PolarsResult { // fast path. if s.len() == 0 { return Ok(BooleanChunked::full_null(s.name(), 0)); @@ -21,43 +21,43 @@ pub fn is_last(s: &Series) -> PolarsResult { let out = match s.dtype() { Boolean => { let ca = s.bool().unwrap(); - is_last_boolean(ca) + is_last_distinct_boolean(ca) }, Binary => { let ca = s.binary().unwrap(); - is_last_bin(ca) + is_last_distinct_bin(ca) }, Utf8 => { let s = s.cast(&Binary).unwrap(); - return is_last(&s); + return is_last_distinct(&s); }, Float32 => { let ca = s.bit_repr_small(); - is_last_numeric(&ca) + is_last_distinct_numeric(&ca) }, Float64 => { let ca = s.bit_repr_large(); - is_last_numeric(&ca) + is_last_distinct_numeric(&ca) }, dt if dt.is_numeric() => { with_match_physical_integer_polars_type!(s.dtype(), |$T| { let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref(); - is_last_numeric(ca) + is_last_distinct_numeric(ca) }) }, #[cfg(feature = "dtype-struct")] - Struct(_) => return is_last_struct(&s), + Struct(_) => return is_last_distinct_struct(&s), #[cfg(feature = "group_by_list")] List(inner) if inner.is_numeric() => { let ca = s.list().unwrap(); - return is_last_list(ca); + return is_last_distinct_list(ca); }, - dt => polars_bail!(opq = is_last, dt), + dt => polars_bail!(opq = is_last_distinct, dt), }; Ok(out) } -fn is_last_boolean(ca: &BooleanChunked) -> BooleanChunked { +fn is_last_distinct_boolean(ca: &BooleanChunked) -> BooleanChunked { let mut out = MutableBitmap::with_capacity(ca.len()); out.extend_constant(ca.len(), false); @@ -114,7 +114,7 @@ fn is_last_boolean(ca: &BooleanChunked) -> BooleanChunked { BooleanChunked::with_chunk(ca.name(), arr) } -fn is_last_bin(ca: &BinaryChunked) -> BooleanChunked { +fn is_last_distinct_bin(ca: &BinaryChunked) -> BooleanChunked { let ca = ca.rechunk(); let arr = ca.downcast_iter().next().unwrap(); let mut unique = PlHashSet::new(); @@ -128,7 +128,7 @@ fn is_last_bin(ca: &BinaryChunked) -> BooleanChunked { new_ca } -fn is_last_numeric(ca: &ChunkedArray) -> BooleanChunked +fn is_last_distinct_numeric(ca: &ChunkedArray) -> BooleanChunked where T: PolarsNumericType, T::Native: Hash + Eq, @@ -147,7 +147,7 @@ where } #[cfg(feature = "dtype-struct")] -fn is_last_struct(s: &Series) -> PolarsResult { +fn is_last_distinct_struct(s: &Series) -> PolarsResult { let groups = s.group_tuples(true, false)?; let last = groups.take_group_lasts(); let mut out = MutableBitmap::with_capacity(s.len()); @@ -163,7 +163,7 @@ fn is_last_struct(s: &Series) -> PolarsResult { } #[cfg(feature = "group_by_list")] -fn is_last_list(ca: &ListChunked) -> PolarsResult { +fn is_last_distinct_list(ca: &ListChunked) -> PolarsResult { let groups = ca.group_tuples(true, false)?; let last = groups.take_group_lasts(); let mut out = MutableBitmap::with_capacity(ca.len()); diff --git a/crates/polars-ops/src/series/ops/mod.rs b/crates/polars-ops/src/series/ops/mod.rs index 38845eafd154..0d6ba4ce4b55 100644 --- a/crates/polars-ops/src/series/ops/mod.rs +++ b/crates/polars-ops/src/series/ops/mod.rs @@ -10,12 +10,12 @@ mod floor_divide; mod fused; #[cfg(feature = "convert_index")] mod index; -#[cfg(feature = "is_first")] -mod is_first; +#[cfg(feature = "is_first_distinct")] +mod is_first_distinct; #[cfg(feature = "is_in")] mod is_in; -#[cfg(feature = "is_last")] -mod is_last; +#[cfg(feature = "is_last_distinct")] +mod is_last_distinct; #[cfg(feature = "is_unique")] mod is_unique; #[cfg(feature = "log")] @@ -42,12 +42,12 @@ pub use floor_divide::*; pub use fused::*; #[cfg(feature = "convert_index")] pub use index::*; -#[cfg(feature = "is_first")] -pub use is_first::*; +#[cfg(feature = "is_first_distinct")] +pub use is_first_distinct::*; #[cfg(feature = "is_in")] pub use is_in::*; -#[cfg(feature = "is_last")] -pub use is_last::*; +#[cfg(feature = "is_last_distinct")] +pub use is_last_distinct::*; #[cfg(feature = "is_unique")] pub use is_unique::*; #[cfg(feature = "log")] diff --git a/crates/polars-pipe/Cargo.toml b/crates/polars-pipe/Cargo.toml index 96df46c82f50..d9e9db3fcb0e 100644 --- a/crates/polars-pipe/Cargo.toml +++ b/crates/polars-pipe/Cargo.toml @@ -9,16 +9,16 @@ repository = { workspace = true } description = "Lazy query engine for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", default-features = false } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["lazy", "zip_with", "random"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", default-features = false } -polars-ops = { version = "0.32.0", path = "../polars-ops", features = ["search_sorted"] } -polars-plan = { version = "0.32.0", path = "../polars-plan", default-features = false, features = ["compile"] } -polars-row = { version = "0.32.0", path = "../polars-row" } -polars-utils = { version = "0.32.0", path = "../polars-utils", features = ["sysinfo"] } +polars-arrow = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random", "rows", "chunked_ids"] } +polars-io = { workspace = true, features = ["ipc"] } +polars-ops = { workspace = true, features = ["search_sorted"] } +polars-plan = { workspace = true } +polars-row = { workspace = true } +polars-utils = { workspace = true, features = ["sysinfo"] } -crossbeam-channel = { version = "0.5", optional = true } -crossbeam-queue = { version = "0.3", optional = true } +crossbeam-channel = { version = "0.5" } +crossbeam-queue = { version = "0.3" } enum_dispatch = { version = "0.3" } hashbrown = { workspace = true } num-traits = { workspace = true } @@ -29,7 +29,6 @@ smartstring = { workspace = true } version_check = { workspace = true } [features] -compile = ["crossbeam-channel", "crossbeam-queue", "polars-io/ipc"] csv = ["polars-plan/csv", "polars-io/csv"] cloud = ["async", "polars-io/cloud", "polars-plan/cloud"] parquet = ["polars-plan/parquet", "polars-io/parquet"] @@ -45,4 +44,4 @@ dtype-decimal = ["polars-core/dtype-decimal"] dtype-array = ["polars-core/dtype-array"] dtype-categorical = ["polars-core/dtype-categorical"] trigger_ooc = [] -test = ["compile", "polars-core/chunked_ids"] +test = ["polars-core/chunked_ids"] diff --git a/crates/polars-pipe/src/executors/sinks/file_sink.rs b/crates/polars-pipe/src/executors/sinks/file_sink.rs index cd030e575a27..402d23b9f078 100644 --- a/crates/polars-pipe/src/executors/sinks/file_sink.rs +++ b/crates/polars-pipe/src/executors/sinks/file_sink.rs @@ -10,7 +10,7 @@ use polars_io::csv::CsvWriter; use polars_io::parquet::ParquetWriter; #[cfg(feature = "ipc")] use polars_io::prelude::IpcWriter; -#[cfg(feature = "ipc")] +#[cfg(any(feature = "ipc", feature = "csv"))] use polars_io::SerWriter; use polars_plan::prelude::*; @@ -106,7 +106,7 @@ impl ParquetCloudSink { #[allow(clippy::new_ret_no_self)] pub fn new( uri: &str, - cloud_options: Option<&polars_core::cloud::CloudOptions>, + cloud_options: Option<&polars_io::cloud::CloudOptions>, parquet_options: ParquetWriteOptions, schema: &Schema, ) -> PolarsResult { @@ -259,7 +259,7 @@ impl CsvCloudSink { } } -#[cfg(any(feature = "parquet", feature = "ipc"))] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] fn init_writer_thread( receiver: Receiver>, mut writer: Box, diff --git a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs index e848f2ff5367..b37b88ec4904 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/generic_build.rs @@ -4,9 +4,9 @@ use std::sync::Arc; use hashbrown::hash_map::RawEntryMut; use polars_arrow::export::arrow::array::BinaryArray; +use polars_core::datatypes::ChunkId; use polars_core::error::PolarsResult; use polars_core::export::ahash::RandomState; -use polars_core::frame::hash_join::ChunkId; use polars_core::prelude::*; use polars_core::utils::{_set_partition_size, accumulate_dataframes_vertical_unchecked}; use polars_utils::hash_to_partition; diff --git a/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs b/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs index 7433dfa848cc..4c71815a3eae 100644 --- a/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs +++ b/crates/polars-pipe/src/executors/sinks/joins/inner_left.rs @@ -2,9 +2,10 @@ use std::borrow::Cow; use std::sync::Arc; use polars_arrow::export::arrow::array::BinaryArray; +use polars_core::datatypes::ChunkId; use polars_core::error::PolarsResult; use polars_core::export::ahash::RandomState; -use polars_core::frame::hash_join::{ChunkId, _finish_join}; +use polars_core::frame::hash_join::_finish_join; use polars_core::prelude::*; use polars_core::series::IsSorted; use polars_row::RowsEncoded; diff --git a/crates/polars-pipe/src/executors/sinks/mod.rs b/crates/polars-pipe/src/executors/sinks/mod.rs index 328ab178a9e6..8c9b46366da7 100644 --- a/crates/polars-pipe/src/executors/sinks/mod.rs +++ b/crates/polars-pipe/src/executors/sinks/mod.rs @@ -1,4 +1,4 @@ -#[cfg(any(feature = "parquet", feature = "ipc"))] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] mod file_sink; pub(crate) mod group_by; mod io; @@ -10,7 +10,7 @@ mod slice; mod sort; mod utils; -#[cfg(any(feature = "parquet", feature = "ipc"))] +#[cfg(any(feature = "parquet", feature = "ipc", feature = "csv"))] pub(crate) use file_sink::*; pub(crate) use joins::*; pub(crate) use ordered::*; diff --git a/crates/polars-pipe/src/executors/sources/parquet.rs b/crates/polars-pipe/src/executors/sources/parquet.rs index 4cb11f6b1688..dca282f18a62 100644 --- a/crates/polars-pipe/src/executors/sources/parquet.rs +++ b/crates/polars-pipe/src/executors/sources/parquet.rs @@ -1,9 +1,9 @@ use std::path::PathBuf; -use polars_core::cloud::CloudOptions; use polars_core::error::PolarsResult; use polars_core::schema::*; use polars_core::POOL; +use polars_io::cloud::CloudOptions; use polars_io::parquet::{BatchedParquetReader, ParquetReader}; #[cfg(feature = "async")] use polars_io::prelude::ParquetAsyncReader; diff --git a/crates/polars-pipe/src/lib.rs b/crates/polars-pipe/src/lib.rs index 4a63657adb8a..b2724e9a8981 100644 --- a/crates/polars-pipe/src/lib.rs +++ b/crates/polars-pipe/src/lib.rs @@ -1,13 +1,8 @@ extern crate core; -#[cfg(feature = "compile")] mod executors; -#[cfg(feature = "compile")] pub mod expressions; -#[cfg(feature = "compile")] pub mod operators; -#[cfg(feature = "compile")] pub mod pipeline; -#[cfg(feature = "compile")] pub use operators::SExecutionContext; diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 6212bf68e39f..5c3851f4b611 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -12,12 +12,14 @@ description = "Lazy query engine for the Polars DataFrame library" doctest = false [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow" } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["lazy", "zip_with", "random"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", features = ["lazy", "csv"], default-features = false } -polars-ops = { version = "0.32.0", path = "../polars-ops", default-features = false } -polars-time = { version = "0.32.0", path = "../polars-time", optional = true } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +libloading = { version = "0.8.0", optional = true } +polars-arrow = { workspace = true } +polars-core = { workspace = true, features = ["lazy", "zip_with", "random"], default-features = false } +polars-ffi = { workspace = true, optional = true } +polars-io = { workspace = true, features = ["lazy"], default-features = false } +polars-ops = { workspace = true, default-features = false } +polars-time = { workspace = true, optional = true } +polars-utils = { workspace = true } ahash = { workspace = true } arrow = { workspace = true } @@ -40,9 +42,6 @@ version_check = { workspace = true } # debugging utility debugging = [] python = ["dep:pyo3", "ciborium"] -# make sure we don't compile unneeded things even though -# this dependency gets activated -compile = [] serde = [ "dep:serde", "polars-core/serde-lazy", @@ -50,7 +49,6 @@ serde = [ "polars-io/serde", "polars-ops/serde", ] -default = ["compile"] streaming = [] parquet = ["polars-core/parquet", "polars-io/parquet"] async = ["polars-io/async"] @@ -92,8 +90,8 @@ approx_unique = ["polars-ops/approx_unique"] is_in = ["polars-ops/is_in"] repeat_by = ["polars-core/repeat_by"] round_series = ["polars-core/round_series"] -is_first = ["polars-core/is_first", "polars-ops/is_first"] -is_last = ["polars-core/is_last", "polars-ops/is_last"] +is_first_distinct = ["polars-core/is_first_distinct", "polars-ops/is_first_distinct"] +is_last_distinct = ["polars-core/is_last_distinct", "polars-ops/is_last_distinct"] is_unique = ["polars-ops/is_unique"] cross_join = ["polars-core/cross_join"] asof_join = ["polars-core/asof_join", "polars-time", "polars-ops/asof_join"] @@ -140,8 +138,9 @@ list_any_all = ["polars-ops/list_any_all"] cutqcut = ["polars-ops/cutqcut"] rle = ["polars-ops/rle"] extract_groups = ["regex", "dtype-struct", "polars-ops/extract_groups"] +ffi_plugin = ["libloading", "polars-ffi"] -bigidx = ["polars-arrow/bigidx", "polars-core/bigidx", "polars-utils/bigidx"] +bigidx = ["polars-core/bigidx"] panic_on_schema = [] diff --git a/crates/polars-plan/src/dsl/binary.rs b/crates/polars-plan/src/dsl/binary.rs index 5a979d87d723..9a1d6f2fd81c 100644 --- a/crates/polars-plan/src/dsl/binary.rs +++ b/crates/polars-plan/src/dsl/binary.rs @@ -5,21 +5,29 @@ pub struct BinaryNameSpace(pub(crate) Expr); impl BinaryNameSpace { /// Check if a binary value contains a literal binary. - pub fn contains_literal>(self, pat: S) -> Expr { - let pat = pat.as_ref().into(); - self.0 - .map_private(BinaryFunction::Contains { pat, literal: true }.into()) + pub fn contains_literal(self, pat: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::BinaryExpr(BinaryFunction::Contains), + &[pat], + true, + ) } /// Check if a binary value ends with the given sequence. - pub fn ends_with>(self, sub: S) -> Expr { - let sub = sub.as_ref().into(); - self.0.map_private(BinaryFunction::EndsWith(sub).into()) + pub fn ends_with(self, sub: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::BinaryExpr(BinaryFunction::EndsWith), + &[sub], + true, + ) } /// Check if a binary value starts with the given sequence. - pub fn starts_with>(self, sub: S) -> Expr { - let sub = sub.as_ref().into(); - self.0.map_private(BinaryFunction::StartsWith(sub).into()) + pub fn starts_with(self, sub: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::BinaryExpr(BinaryFunction::StartsWith), + &[sub], + true, + ) } } diff --git a/crates/polars-plan/src/dsl/dt.rs b/crates/polars-plan/src/dsl/dt.rs index 19877fd18e03..25c37906014c 100644 --- a/crates/polars-plan/src/dsl/dt.rs +++ b/crates/polars-plan/src/dsl/dt.rs @@ -215,10 +215,10 @@ impl DateLikeNameSpace { .map_private(FunctionExpr::TemporalExpr(TemporalFunction::TimeStamp(tu))) } - pub fn truncate(self, options: TruncateOptions, ambiguous: Expr) -> Expr { + pub fn truncate(self, every: Expr, offset: String, ambiguous: Expr) -> Expr { self.0.map_many_private( - FunctionExpr::TemporalExpr(TemporalFunction::Truncate(options)), - &[ambiguous], + FunctionExpr::TemporalExpr(TemporalFunction::Truncate(offset)), + &[every, ambiguous], false, ) } diff --git a/crates/polars-plan/src/dsl/function_expr/binary.rs b/crates/polars-plan/src/dsl/function_expr/binary.rs index 8ca4c7eaa256..0aa8688dde13 100644 --- a/crates/polars-plan/src/dsl/function_expr/binary.rs +++ b/crates/polars-plan/src/dsl/function_expr/binary.rs @@ -6,9 +6,9 @@ use super::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, PartialEq, Debug, Eq, Hash)] pub enum BinaryFunction { - Contains { pat: Vec, literal: bool }, - StartsWith(Vec), - EndsWith(Vec), + Contains, + StartsWith, + EndsWith, } impl Display for BinaryFunction { @@ -16,29 +16,36 @@ impl Display for BinaryFunction { use BinaryFunction::*; let s = match self { Contains { .. } => "contains", - StartsWith(_) => "starts_with", - EndsWith(_) => "ends_with", + StartsWith => "starts_with", + EndsWith => "ends_with", }; write!(f, "bin.{s}") } } -pub(super) fn contains(s: &Series, pat: &[u8], literal: bool) -> PolarsResult { - let ca = s.binary()?; - if literal { - ca.contains_literal(pat).map(|ca| ca.into_series()) - } else { - ca.contains(pat).map(|ca| ca.into_series()) - } +pub(super) fn contains(s: &[Series]) -> PolarsResult { + let ca = s[0].binary()?; + let lit = s[1].binary()?; + Ok(ca.contains_chunked(lit).with_name(ca.name()).into_series()) } -pub(super) fn ends_with(s: &Series, sub: &[u8]) -> PolarsResult { - let ca = s.binary()?; - Ok(ca.ends_with(sub).into_series()) +pub(super) fn ends_with(s: &[Series]) -> PolarsResult { + let ca = s[0].binary()?; + let suffix = s[1].binary()?; + + Ok(ca + .ends_with_chunked(suffix) + .with_name(ca.name()) + .into_series()) } -pub(super) fn starts_with(s: &Series, sub: &[u8]) -> PolarsResult { - let ca = s.binary()?; - Ok(ca.starts_with(sub).into_series()) +pub(super) fn starts_with(s: &[Series]) -> PolarsResult { + let ca = s[0].binary()?; + let prefix = s[1].binary()?; + + Ok(ca + .starts_with_chunked(prefix) + .with_name(ca.name()) + .into_series()) } impl From for FunctionExpr { diff --git a/crates/polars-plan/src/dsl/function_expr/boolean.rs b/crates/polars-plan/src/dsl/function_expr/boolean.rs index 1a11781b4828..2f225596fb5e 100644 --- a/crates/polars-plan/src/dsl/function_expr/boolean.rs +++ b/crates/polars-plan/src/dsl/function_expr/boolean.rs @@ -21,10 +21,10 @@ pub enum BooleanFunction { IsInfinite, IsNan, IsNotNan, - #[cfg(feature = "is_first")] - IsFirst, - #[cfg(feature = "is_last")] - IsLast, + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct, + #[cfg(feature = "is_last_distinct")] + IsLastDistinct, #[cfg(feature = "is_unique")] IsUnique, #[cfg(feature = "is_unique")] @@ -59,10 +59,10 @@ impl Display for BooleanFunction { IsInfinite => "is_infinite", IsNan => "is_nan", IsNotNan => "is_not_nan", - #[cfg(feature = "is_first")] - IsFirst => "is_first", - #[cfg(feature = "is_last")] - IsLast => "is_last", + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct => "is_first_distinct", + #[cfg(feature = "is_last_distinct")] + IsLastDistinct => "is_last_distinct", #[cfg(feature = "is_unique")] IsUnique => "is_unique", #[cfg(feature = "is_unique")] @@ -89,10 +89,10 @@ impl From for SpecialEq> { IsInfinite => map!(is_infinite), IsNan => map!(is_nan), IsNotNan => map!(is_not_nan), - #[cfg(feature = "is_first")] - IsFirst => map!(is_first), - #[cfg(feature = "is_last")] - IsLast => map!(is_last), + #[cfg(feature = "is_first_distinct")] + IsFirstDistinct => map!(is_first_distinct), + #[cfg(feature = "is_last_distinct")] + IsLastDistinct => map!(is_last_distinct), #[cfg(feature = "is_unique")] IsUnique => map!(is_unique), #[cfg(feature = "is_unique")] @@ -154,14 +154,14 @@ pub(super) fn is_not_nan(s: &Series) -> PolarsResult { s.is_not_nan().map(|ca| ca.into_series()) } -#[cfg(feature = "is_first")] -fn is_first(s: &Series) -> PolarsResult { - polars_ops::prelude::is_first(s).map(|ca| ca.into_series()) +#[cfg(feature = "is_first_distinct")] +fn is_first_distinct(s: &Series) -> PolarsResult { + polars_ops::prelude::is_first_distinct(s).map(|ca| ca.into_series()) } -#[cfg(feature = "is_last")] -fn is_last(s: &Series) -> PolarsResult { - polars_ops::prelude::is_last(s).map(|ca| ca.into_series()) +#[cfg(feature = "is_last_distinct")] +fn is_last_distinct(s: &Series) -> PolarsResult { + polars_ops::prelude::is_last_distinct(s).map(|ca| ca.into_series()) } #[cfg(feature = "is_unique")] diff --git a/crates/polars-plan/src/dsl/function_expr/coerce.rs b/crates/polars-plan/src/dsl/function_expr/coerce.rs new file mode 100644 index 000000000000..00c180d0ba4a --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/coerce.rs @@ -0,0 +1,6 @@ +use polars_core::prelude::*; + +#[cfg(feature = "dtype-struct")] +pub fn as_struct(s: &[Series]) -> PolarsResult { + Ok(StructChunked::new(s[0].name(), s)?.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/correlation.rs b/crates/polars-plan/src/dsl/function_expr/correlation.rs index a49df74b95c9..c56114487501 100644 --- a/crates/polars-plan/src/dsl/function_expr/correlation.rs +++ b/crates/polars-plan/src/dsl/function_expr/correlation.rs @@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize}; use super::*; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -#[derive(Copy, Clone, PartialEq, Debug)] +#[derive(Copy, Clone, PartialEq, Debug, Hash)] pub enum CorrelationMethod { Pearson, #[cfg(all(feature = "rank", feature = "propagate_nans"))] diff --git a/crates/polars-plan/src/dsl/function_expr/datetime.rs b/crates/polars-plan/src/dsl/function_expr/datetime.rs index 26c3d883e298..e8b44d402ac2 100644 --- a/crates/polars-plan/src/dsl/function_expr/datetime.rs +++ b/crates/polars-plan/src/dsl/function_expr/datetime.rs @@ -31,7 +31,7 @@ pub enum TemporalFunction { Microsecond, Nanosecond, TimeStamp(TimeUnit), - Truncate(TruncateOptions), + Truncate(String), #[cfg(feature = "date_offset")] MonthStart, #[cfg(feature = "date_offset")] @@ -201,27 +201,26 @@ pub(super) fn timestamp(s: &Series, tu: TimeUnit) -> PolarsResult { s.timestamp(tu).map(|ca| ca.into_series()) } -pub(super) fn truncate(s: &[Series], options: &TruncateOptions) -> PolarsResult { +pub(super) fn truncate(s: &[Series], offset: &str) -> PolarsResult { let time_series = &s[0]; - let ambiguous = &s[1].utf8().unwrap(); + let every = s[1].utf8()?; + let ambiguous = s[2].utf8()?; + let mut out = match time_series.dtype() { DataType::Datetime(_, tz) => match tz { #[cfg(feature = "timezones")] Some(tz) => time_series - .datetime() - .unwrap() - .truncate(options, tz.parse::().ok().as_ref(), ambiguous)? + .datetime()? + .truncate(tz.parse::().ok().as_ref(), every, offset, ambiguous)? .into_series(), _ => time_series - .datetime() - .unwrap() - .truncate(options, None, ambiguous)? + .datetime()? + .truncate(None, every, offset, ambiguous)? .into_series(), }, DataType::Date => time_series - .date() - .unwrap() - .truncate(options, None, ambiguous)? + .date()? + .truncate(None, every, offset, ambiguous)? .into_series(), dt => polars_bail!(opq = round, got = dt, expected = "date/datetime"), }; diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 856273fadc76..72151f4853da 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -22,6 +22,7 @@ pub enum ListFunction { Any, #[cfg(feature = "list_any_all")] All, + Join, } impl Display for ListFunction { @@ -45,6 +46,7 @@ impl Display for ListFunction { Any => "any", #[cfg(feature = "list_any_all")] All => "all", + Join => "join", }; write!(f, "{name}") } @@ -219,7 +221,7 @@ pub(super) fn get(s: &mut [Series]) -> PolarsResult> { }) .collect::(); let s = Series::try_from((ca.name(), arr.values().clone())).unwrap(); - unsafe { s.take_unchecked(&take_by) }.map(Some) + unsafe { Ok(Some(s.take_unchecked(&take_by))) } }, len => polars_bail!( ComputeError: @@ -279,3 +281,9 @@ pub(super) fn lst_any(s: &Series) -> PolarsResult { pub(super) fn lst_all(s: &Series) -> PolarsResult { s.list()?.lst_all() } + +pub(super) fn join(s: &[Series]) -> PolarsResult { + let ca = s[0].list()?; + let separator = s[1].utf8()?; + Ok(ca.lst_join(separator)?.into_series()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index b4a7d9574d57..12b7341c35bd 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -11,6 +11,7 @@ mod bounds; mod cat; #[cfg(feature = "round_series")] mod clip; +mod coerce; mod concat; mod correlation; mod cum; @@ -24,6 +25,8 @@ mod list; #[cfg(feature = "log")] mod log; mod nan; +#[cfg(feature = "ffi_plugin")] +mod plugin; mod pow; #[cfg(feature = "random")] mod random; @@ -35,7 +38,7 @@ mod rolling; mod round; #[cfg(feature = "row_hash")] mod row_hash; -mod schema; +pub(super) mod schema; #[cfg(feature = "search_sorted")] mod search_sorted; mod shift_and_fill; @@ -140,6 +143,8 @@ pub enum FunctionExpr { ArrayExpr(ArrayFunction), #[cfg(feature = "dtype-struct")] StructExpr(StructFunction), + #[cfg(feature = "dtype-struct")] + AsStruct, #[cfg(feature = "top_k")] TopK { k: usize, @@ -230,18 +235,32 @@ pub enum FunctionExpr { seed: Option, }, SetSortedFlag(IsSorted), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { + lib: Arc, + symbol: Arc, + }, } impl Hash for FunctionExpr { fn hash(&self, state: &mut H) { std::mem::discriminant(self).hash(state); match self { + FunctionExpr::Pow(f) => f.hash(state), + #[cfg(feature = "search_sorted")] + FunctionExpr::SearchSorted(f) => f.hash(state), FunctionExpr::BinaryExpr(f) => f.hash(state), FunctionExpr::Boolean(f) => f.hash(state), #[cfg(feature = "strings")] FunctionExpr::StringExpr(f) => f.hash(state), + FunctionExpr::ListExpr(f) => f.hash(state), + #[cfg(feature = "dtype-array")] + FunctionExpr::ArrayExpr(f) => f.hash(state), + #[cfg(feature = "dtype-struct")] + FunctionExpr::StructExpr(f) => f.hash(state), #[cfg(feature = "random")] FunctionExpr::Random { method, .. } => method.hash(state), + FunctionExpr::Correlation { method, .. } => method.hash(state), #[cfg(feature = "range")] FunctionExpr::Range(f) => f.hash(state), #[cfg(feature = "temporal")] @@ -250,10 +269,17 @@ impl Hash for FunctionExpr { FunctionExpr::Trigonometry(f) => f.hash(state), #[cfg(feature = "fused")] FunctionExpr::Fused(f) => f.hash(state), + #[cfg(feature = "diff")] + FunctionExpr::Diff(_, null_behavior) => null_behavior.hash(state), #[cfg(feature = "interpolate")] FunctionExpr::Interpolate(f) => f.hash(state), #[cfg(feature = "dtype-categorical")] FunctionExpr::Categorical(f) => f.hash(state), + #[cfg(feature = "ffi_plugin")] + FunctionExpr::FfiPlugin { lib, symbol } => { + lib.hash(state); + symbol.hash(state); + }, _ => {}, } } @@ -304,6 +330,8 @@ impl Display for FunctionExpr { ListExpr(func) => return write!(f, "{func}"), #[cfg(feature = "dtype-struct")] StructExpr(func) => return write!(f, "{func}"), + #[cfg(feature = "dtype-struct")] + AsStruct => "as_struct", #[cfg(feature = "top_k")] TopK { .. } => "top_k", Shift(_) => "shift", @@ -365,6 +393,8 @@ impl Display for FunctionExpr { #[cfg(feature = "random")] Random { method, .. } => method.into(), SetSortedFlag(_) => "set_sorted", + #[cfg(feature = "ffi_plugin")] + FfiPlugin { lib, symbol, .. } => return write!(f, "{lib}:{symbol}"), }; write!(f, "{s}") } @@ -535,6 +565,7 @@ impl From for SpecialEq> { Any => map!(list::lst_any), #[cfg(feature = "list_any_all")] All => map!(list::lst_all), + Join => map_as_slice!(list::join), } }, #[cfg(feature = "dtype-array")] @@ -555,6 +586,10 @@ impl From for SpecialEq> { FieldByName(name) => map!(struct_::get_by_name, name.clone()), } }, + #[cfg(feature = "dtype-struct")] + AsStruct => { + map_as_slice!(coerce::as_struct) + }, #[cfg(feature = "top_k")] TopK { k, descending } => { map!(top_k, k, descending) @@ -636,6 +671,10 @@ impl From for SpecialEq> { #[cfg(feature = "random")] Random { method, seed } => map!(random::random, method, seed), SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { lib, symbol, .. } => unsafe { + map_as_slice!(plugin::call_plugin, lib.as_ref(), symbol.as_ref()) + }, } } } @@ -680,6 +719,12 @@ impl From for SpecialEq> { Strptime(dtype, options) => { map_as_slice!(strings::strptime, dtype.clone(), &options) }, + Split => { + map_as_slice!(strings::split) + }, + SplitInclusive => { + map_as_slice!(strings::split_inclusive) + }, #[cfg(feature = "concat_str")] ConcatVertical(delimiter) => map!(strings::concat, &delimiter), #[cfg(feature = "concat_str")] @@ -714,14 +759,14 @@ impl From for SpecialEq> { fn from(func: BinaryFunction) -> Self { use BinaryFunction::*; match func { - Contains { pat, literal } => { - map!(binary::contains, &pat, literal) + Contains => { + map_as_slice!(binary::contains) }, - EndsWith(sub) => { - map!(binary::ends_with, &sub) + EndsWith => { + map_as_slice!(binary::ends_with) }, - StartsWith(sub) => { - map!(binary::starts_with, &sub) + StartsWith => { + map_as_slice!(binary::starts_with) }, } } @@ -751,8 +796,8 @@ impl From for SpecialEq> { Microsecond => map!(datetime::microsecond), Nanosecond => map!(datetime::nanosecond), TimeStamp(tu) => map!(datetime::timestamp, tu), - Truncate(truncate_options) => { - map_as_slice!(datetime::truncate, &truncate_options) + Truncate(offset) => { + map_as_slice!(datetime::truncate, &offset) }, #[cfg(feature = "date_offset")] MonthStart => map!(datetime::month_start), diff --git a/crates/polars-plan/src/dsl/function_expr/plugin.rs b/crates/polars-plan/src/dsl/function_expr/plugin.rs new file mode 100644 index 000000000000..6c8113a54aac --- /dev/null +++ b/crates/polars-plan/src/dsl/function_expr/plugin.rs @@ -0,0 +1,75 @@ +use std::sync::RwLock; + +use arrow::ffi::{import_field_from_c, ArrowSchema}; +use libloading::Library; +use once_cell::sync::Lazy; +use polars_ffi::*; + +use super::*; + +static LOADED: Lazy>> = Lazy::new(Default::default); + +fn get_lib(lib: &str) -> PolarsResult<&'static Library> { + let lib_map = LOADED.read().unwrap(); + if let Some(library) = lib_map.get(lib) { + // lifetime is static as we never remove libraries. + Ok(unsafe { std::mem::transmute::<&Library, &'static Library>(library) }) + } else { + drop(lib_map); + let library = unsafe { + Library::new(lib).map_err(|e| { + PolarsError::ComputeError(format!("error loading dynamic library: {e}").into()) + })? + }; + + let mut lib_map = LOADED.write().unwrap(); + lib_map.insert(lib.to_string(), library); + drop(lib_map); + + get_lib(lib) + } +} + +pub(super) unsafe fn call_plugin(s: &[Series], lib: &str, symbol: &str) -> PolarsResult { + let lib = get_lib(lib)?; + + let symbol: libloading::Symbol< + unsafe extern "C" fn(*const SeriesExport, usize) -> SeriesExport, + > = lib.get(symbol.as_bytes()).unwrap(); + + let n_args = s.len(); + + let input = s.iter().map(export_series).collect::>(); + let slice_ptr = input.as_ptr(); + let out = symbol(slice_ptr, n_args); + + for e in input { + std::mem::forget(e); + } + + import_series(out) +} + +pub(super) unsafe fn plugin_field( + fields: &[Field], + lib: &str, + symbol: &str, +) -> PolarsResult { + let lib = get_lib(lib)?; + + let symbol: libloading::Symbol ArrowSchema> = + lib.get(symbol.as_bytes()).unwrap(); + + // we deallocate the fields buffer + let fields = fields + .iter() + .map(|field| arrow::ffi::export_field_to_c(&field.to_arrow())) + .collect::>() + .into_boxed_slice(); + let n_args = fields.len(); + let slice_ptr = fields.as_ptr(); + let out = symbol(slice_ptr, n_args); + + let arrow_field = import_field_from_c(&out)?; + Ok(Field::from(&arrow_field)) +} diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 71a4781847a7..cfdb181895f1 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -27,9 +27,7 @@ impl FunctionExpr { BinaryExpr(s) => { use BinaryFunction::*; match s { - Contains { .. } | EndsWith(_) | StartsWith(_) => { - mapper.with_dtype(DataType::Boolean) - }, + Contains { .. } | EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean), } }, #[cfg(feature = "temporal")] @@ -47,7 +45,7 @@ impl FunctionExpr { DataType::Datetime(tu, _) => DataType::Datetime(tu, None), dtype => polars_bail!(ComputeError: "expected Datetime, got {}", dtype), }, - Truncate(..) => mapper.with_same_dtype().unwrap().dtype, + Truncate(_) => mapper.with_same_dtype().unwrap().dtype, #[cfg(feature = "date_offset")] MonthStart => mapper.with_same_dtype().unwrap().dtype, #[cfg(feature = "date_offset")] @@ -115,6 +113,7 @@ impl FunctionExpr { Any => mapper.with_dtype(DataType::Boolean), #[cfg(feature = "list_any_all")] All => mapper.with_dtype(DataType::Boolean), + Join => mapper.with_dtype(DataType::Utf8), } }, #[cfg(feature = "dtype-array")] @@ -133,6 +132,11 @@ impl FunctionExpr { } }, #[cfg(feature = "dtype-struct")] + AsStruct => Ok(Field::new( + fields[0].name(), + DataType::Struct(fields.to_vec()), + )), + #[cfg(feature = "dtype-struct")] StructExpr(s) => { use polars_core::utils::slice_offsets; use StructFunction::*; @@ -242,33 +246,41 @@ impl FunctionExpr { #[cfg(feature = "random")] Random { .. } => mapper.with_same_dtype(), SetSortedFlag(_) => mapper.with_same_dtype(), + #[cfg(feature = "ffi_plugin")] + FfiPlugin { lib, symbol } => unsafe { + plugin::plugin_field(fields, lib, &format!("__polars_field_{}", symbol.as_ref())) + }, } } } -pub(super) struct FieldsMapper<'a> { +pub struct FieldsMapper<'a> { fields: &'a [Field], } impl<'a> FieldsMapper<'a> { + pub fn new(fields: &'a [Field]) -> Self { + Self { fields } + } + /// Field with the same dtype. - pub(super) fn with_same_dtype(&self) -> PolarsResult { + pub fn with_same_dtype(&self) -> PolarsResult { self.map_dtype(|dtype| dtype.clone()) } /// Set a dtype. - pub(super) fn with_dtype(&self, dtype: DataType) -> PolarsResult { + pub fn with_dtype(&self, dtype: DataType) -> PolarsResult { Ok(Field::new(self.fields[0].name(), dtype)) } /// Map a single dtype. - pub(super) fn map_dtype(&self, func: impl Fn(&DataType) -> DataType) -> PolarsResult { + pub fn map_dtype(&self, func: impl Fn(&DataType) -> DataType) -> PolarsResult { let dtype = func(self.fields[0].data_type()); Ok(Field::new(self.fields[0].name(), dtype)) } /// Map to a float supertype. - pub(super) fn map_to_float_dtype(&self) -> PolarsResult { + pub fn map_to_float_dtype(&self) -> PolarsResult { self.map_dtype(|dtype| match dtype { DataType::Float32 => DataType::Float32, _ => DataType::Float64, @@ -276,13 +288,13 @@ impl<'a> FieldsMapper<'a> { } /// Map to a physical type. - pub(super) fn to_physical_type(&self) -> PolarsResult { + pub fn to_physical_type(&self) -> PolarsResult { self.map_dtype(|dtype| dtype.to_physical()) } /// Map a single dtype with a potentially failing mapper function. #[cfg(any(feature = "timezones", feature = "dtype-array"))] - pub(super) fn try_map_dtype( + pub fn try_map_dtype( &self, func: impl Fn(&DataType) -> PolarsResult, ) -> PolarsResult { @@ -291,7 +303,7 @@ impl<'a> FieldsMapper<'a> { } /// Map all dtypes with a potentially failing mapper function. - pub(super) fn try_map_dtypes( + pub fn try_map_dtypes( &self, func: impl Fn(&[&DataType]) -> PolarsResult, ) -> PolarsResult { @@ -307,7 +319,7 @@ impl<'a> FieldsMapper<'a> { } /// Map the dtype to the "supertype" of all fields. - pub(super) fn map_to_supertype(&self) -> PolarsResult { + pub fn map_to_supertype(&self) -> PolarsResult { let mut first = self.fields[0].clone(); let mut st = first.data_type().clone(); for field in &self.fields[1..] { @@ -318,7 +330,7 @@ impl<'a> FieldsMapper<'a> { } /// Map the dtype to the dtype of the list elements. - pub(super) fn map_to_list_inner_dtype(&self) -> PolarsResult { + pub fn map_to_list_inner_dtype(&self) -> PolarsResult { let mut first = self.fields[0].clone(); let dt = first .data_type() @@ -330,7 +342,7 @@ impl<'a> FieldsMapper<'a> { } /// Map the dtypes to the "supertype" of a list of lists. - pub(super) fn map_to_list_supertype(&self) -> PolarsResult { + pub fn map_to_list_supertype(&self) -> PolarsResult { self.try_map_dtypes(|dts| { let mut super_type_inner = None; @@ -356,7 +368,7 @@ impl<'a> FieldsMapper<'a> { /// Set the timezone of a datetime dtype. #[cfg(feature = "timezones")] - pub(super) fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult { + pub fn map_datetime_dtype_timezone(&self, tz: Option<&TimeZone>) -> PolarsResult { self.try_map_dtype(|dt| { if let DataType::Datetime(tu, _) = dt { Ok(DataType::Datetime(*tu, tz.cloned())) @@ -366,7 +378,7 @@ impl<'a> FieldsMapper<'a> { }) } - fn nested_sum_type(&self) -> PolarsResult { + pub fn nested_sum_type(&self) -> PolarsResult { let mut first = self.fields[0].clone(); use DataType::*; let dt = first.data_type().inner_dtype().cloned().unwrap_or(Unknown); @@ -380,7 +392,7 @@ impl<'a> FieldsMapper<'a> { } #[cfg(feature = "extract_jsonpath")] - pub(super) fn with_opt_dtype(&self, dtype: Option) -> PolarsResult { + pub fn with_opt_dtype(&self, dtype: Option) -> PolarsResult { let dtype = dtype.unwrap_or(DataType::Unknown); self.with_dtype(dtype) } diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 2218641c052b..d9f72c0b1ff6 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -75,6 +75,8 @@ pub enum StringFunction { StripSuffix(String), #[cfg(feature = "temporal")] Strptime(DataType, StrptimeOptions), + Split, + SplitInclusive, #[cfg(feature = "dtype-decimal")] ToDecimal(usize), #[cfg(feature = "nightly")] @@ -89,7 +91,7 @@ impl StringFunction { use StringFunction::*; match self { #[cfg(feature = "concat_str")] - ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_same_dtype(), + ConcatVertical(_) | ConcatHorizontal(_) => mapper.with_dtype(DataType::Utf8), #[cfg(feature = "regex")] Contains { .. } => mapper.with_dtype(DataType::Boolean), CountMatches(_) => mapper.with_dtype(DataType::UInt32), @@ -109,6 +111,7 @@ impl StringFunction { Replace { .. } => mapper.with_same_dtype(), #[cfg(feature = "temporal")] Strptime(dtype, _) => mapper.with_dtype(dtype.clone()), + Split | SplitInclusive => mapper.with_dtype(DataType::List(Box::new(DataType::Utf8))), #[cfg(feature = "nightly")] Titlecase => mapper.with_same_dtype(), #[cfg(feature = "dtype-decimal")] @@ -165,6 +168,8 @@ impl Display for StringFunction { StringFunction::StripSuffix(_) => "strip_suffix", #[cfg(feature = "temporal")] StringFunction::Strptime(_, _) => "strptime", + StringFunction::Split => "split", + StringFunction::SplitInclusive => "split_inclusive", #[cfg(feature = "nightly")] StringFunction::Titlecase => "titlecase", #[cfg(feature = "dtype-decimal")] @@ -205,101 +210,24 @@ pub(super) fn lengths(s: &Series) -> PolarsResult { #[cfg(feature = "regex")] pub(super) fn contains(s: &[Series], literal: bool, strict: bool) -> PolarsResult { - // TODO! move to polars-ops let ca = s[0].utf8()?; let pat = s[1].utf8()?; - - let mut out: BooleanChunked = match pat.len() { - 1 => match pat.get(0) { - Some(pat) => { - if literal { - ca.contains_literal(pat)? - } else { - ca.contains(pat, strict)? - } - }, - None => BooleanChunked::full(ca.name(), false, ca.len()), - }, - _ => { - if literal { - ca.into_iter() - .zip(pat) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(pat)) => src.contains(pat), - _ => false, - }) - .collect_trusted() - } else if strict { - ca.into_iter() - .zip(pat) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(pat)) => { - let re = Regex::new(pat)?; - Ok(re.is_match(src)) - }, - _ => Ok(false), - }) - .collect::>()? - } else { - ca.into_iter() - .zip(pat) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(pat)) => Regex::new(pat).ok().map(|re| re.is_match(src)), - _ => Some(false), - }) - .collect_trusted() - } - }, - }; - - out.rename(ca.name()); - Ok(out.into_series()) + ca.contains_chunked(pat, literal, strict) + .map(|ok| ok.into_series()) } pub(super) fn ends_with(s: &[Series]) -> PolarsResult { - let ca = s[0].utf8()?; - let sub = s[1].utf8()?; + let ca = &s[0].utf8()?.as_binary(); + let suffix = &s[1].utf8()?.as_binary(); - let mut out: BooleanChunked = match sub.len() { - 1 => match sub.get(0) { - Some(s) => ca.ends_with(s), - None => BooleanChunked::full(ca.name(), false, ca.len()), - }, - _ => ca - .into_iter() - .zip(sub) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(val)) => src.ends_with(val), - _ => false, - }) - .collect_trusted(), - }; - - out.rename(ca.name()); - Ok(out.into_series()) + Ok(ca.ends_with_chunked(suffix).into_series()) } pub(super) fn starts_with(s: &[Series]) -> PolarsResult { - let ca = s[0].utf8()?; - let sub = s[1].utf8()?; - - let mut out: BooleanChunked = match sub.len() { - 1 => match sub.get(0) { - Some(s) => ca.starts_with(s), - None => BooleanChunked::full(ca.name(), false, ca.len()), - }, - _ => ca - .into_iter() - .zip(sub) - .map(|(opt_src, opt_val)| match (opt_src, opt_val) { - (Some(src), Some(val)) => src.starts_with(val), - _ => false, - }) - .collect_trusted(), - }; + let ca = &s[0].utf8()?.as_binary(); + let prefix = &s[1].utf8()?.as_binary(); - out.rename(ca.name()); - Ok(out.into_series()) + Ok(ca.starts_with_chunked(prefix).into_series()) } /// Extract a regex pattern from the a string value. @@ -467,6 +395,44 @@ pub(super) fn strptime( } } +pub(super) fn split(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let by = s[1].utf8()?; + + if by.len() == 1 { + if let Some(by) = by.get(0) { + Ok(ca.split(by).into_series()) + } else { + Ok(Series::full_null( + ca.name(), + ca.len(), + &DataType::List(Box::new(DataType::Utf8)), + )) + } + } else { + Ok(ca.split_many(by).into_series()) + } +} + +pub(super) fn split_inclusive(s: &[Series]) -> PolarsResult { + let ca = s[0].utf8()?; + let by = s[1].utf8()?; + + if by.len() == 1 { + if let Some(by) = by.get(0) { + Ok(ca.split_inclusive(by).into_series()) + } else { + Ok(Series::full_null( + ca.name(), + ca.len(), + &DataType::List(Box::new(DataType::Utf8)), + )) + } + } else { + Ok(ca.split_inclusive_many(by).into_series()) + } +} + fn handle_temporal_parsing_error( ca: &Utf8Chunked, out: &Series, @@ -778,7 +744,7 @@ pub(super) fn from_radix(s: &Series, radix: u32, strict: bool) -> PolarsResult) -> PolarsResult { let ca = s.utf8()?; - ca.str_slice(start, length).map(|ca| ca.into_series()) + Ok(ca.str_slice(start, length).into_series()) } pub(super) fn explode(s: &Series) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/function_expr/struct_.rs b/crates/polars-plan/src/dsl/function_expr/struct_.rs index 7d9522d133a9..c32b08df6822 100644 --- a/crates/polars-plan/src/dsl/function_expr/struct_.rs +++ b/crates/polars-plan/src/dsl/function_expr/struct_.rs @@ -13,8 +13,8 @@ impl Display for StructFunction { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use self::*; match self { - StructFunction::FieldByIndex(_) => write!(f, "struct.field_by_name"), - StructFunction::FieldByName(_) => write!(f, "struct.field_by_index"), + StructFunction::FieldByIndex(index) => write!(f, "struct.field_by_index({index})"), + StructFunction::FieldByName(name) => write!(f, "struct.field_by_name({name})"), } } } diff --git a/crates/polars-plan/src/dsl/function_expr/temporal.rs b/crates/polars-plan/src/dsl/function_expr/temporal.rs index 6611593049fe..2de19d5b5ab7 100644 --- a/crates/polars-plan/src/dsl/function_expr/temporal.rs +++ b/crates/polars-plan/src/dsl/function_expr/temporal.rs @@ -90,7 +90,7 @@ pub(super) fn datetime( .map(|ndt| match time_unit { TimeUnit::Milliseconds => ndt.timestamp_millis(), TimeUnit::Microseconds => ndt.timestamp_micros(), - TimeUnit::Nanoseconds => ndt.timestamp_nanos(), + TimeUnit::Nanoseconds => ndt.timestamp_nanos_opt().unwrap(), }) } else { None @@ -126,8 +126,14 @@ fn apply_offsets_to_datetime( offset_fn: fn(&Duration, i64, Option<&Tz>) -> PolarsResult, time_zone: Option<&Tz>, ) -> PolarsResult { - match offsets.len() { - 1 => match offsets.get(0) { + match (datetime.len(), offsets.len()) { + (1, _) => match datetime.0.get(0) { + Some(dt) => offsets.try_apply_values_generic(|offset| { + offset_fn(&Duration::parse(offset), dt, time_zone) + }), + _ => Ok(Int64Chunked::full_null(datetime.0.name(), offsets.len())), + }, + (_, 1) => match offsets.get(0) { Some(offset) => datetime .0 .try_apply(|v| offset_fn(&Duration::parse(offset), v, time_zone)), diff --git a/crates/polars-plan/src/dsl/functions/coerce.rs b/crates/polars-plan/src/dsl/functions/coerce.rs index e28a1697eefd..e009e9e61918 100644 --- a/crates/polars-plan/src/dsl/functions/coerce.rs +++ b/crates/polars-plan/src/dsl/functions/coerce.rs @@ -3,16 +3,15 @@ use super::*; /// Take several expressions and collect them into a [`StructChunked`]. #[cfg(feature = "dtype-struct")] -pub fn as_struct(exprs: &[Expr]) -> Expr { - map_multiple( - |s| StructChunked::new(s[0].name(), s).map(|ca| Some(ca.into_series())), - exprs, - GetOutput::map_fields(|fld| Field::new(fld[0].name(), DataType::Struct(fld.to_vec()))), - ) - .with_function_options(|mut options| { - options.input_wildcard_expansion = true; - options.fmt_str = "as_struct"; - options.pass_name_to_apply = true; - options - }) +pub fn as_struct(exprs: Vec) -> Expr { + Expr::Function { + input: exprs, + function: FunctionExpr::AsStruct, + options: FunctionOptions { + input_wildcard_expansion: true, + pass_name_to_apply: true, + collect_groups: ApplyOptions::ApplyFlat, + ..Default::default() + }, + } } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 427d99d5ca6c..4670f90e473a 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -162,18 +162,12 @@ impl ListNameSpace { /// Join all string items in a sublist and place a separator between them. /// # Error /// This errors if inner type of list `!= DataType::Utf8`. - pub fn join(self, separator: &str) -> Expr { - let separator = separator.to_string(); - self.0 - .map( - move |s| { - s.list()? - .lst_join(&separator) - .map(|ca| Some(ca.into_series())) - }, - GetOutput::from_type(DataType::Utf8), - ) - .with_fmt("list.join") + pub fn join(self, separator: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Join), + &[separator], + false, + ) } /// Return the index of the minimal value of every sublist diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 74b4c2c73c4e..866a7b4dbb7b 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1,5 +1,7 @@ #![allow(ambiguous_glob_reexports)] //! Domain specific language for the Lazy API. +#[cfg(feature = "rolling_window")] +use polars_core::utils::ensure_sorted_arg; #[cfg(feature = "dtype-categorical")] pub mod cat; #[cfg(feature = "dtype-categorical")] @@ -15,7 +17,6 @@ mod expr; mod expr_dyn_fn; mod from; pub(crate) mod function_expr; -#[cfg(feature = "compile")] pub mod functions; mod list; #[cfg(feature = "meta")] @@ -39,6 +40,7 @@ pub use arity::*; #[cfg(feature = "dtype-array")] pub use array::*; pub use expr::*; +pub use function_expr::schema::FieldsMapper; pub use function_expr::*; pub use functions::*; pub use list::*; @@ -1096,18 +1098,18 @@ impl Expr { self.repeat_by_impl(by.into()) } - #[cfg(feature = "is_first")] + #[cfg(feature = "is_first_distinct")] #[allow(clippy::wrong_self_convention)] /// Get a mask of the first unique value. - pub fn is_first(self) -> Expr { - self.apply_private(BooleanFunction::IsFirst.into()) + pub fn is_first_distinct(self) -> Expr { + self.apply_private(BooleanFunction::IsFirstDistinct.into()) } - #[cfg(feature = "is_last")] + #[cfg(feature = "is_last_distinct")] #[allow(clippy::wrong_self_convention)] /// Get a mask of the last unique value. - pub fn is_last(self) -> Expr { - self.apply_private(BooleanFunction::IsLast.into()) + pub fn is_last_distinct(self) -> Expr { + self.apply_private(BooleanFunction::IsLastDistinct.into()) } fn dot_impl(self, other: Expr) -> Expr { @@ -1239,6 +1241,7 @@ impl Expr { }, _ => (by.clone(), &None), }; + ensure_sorted_arg(&by, expr_name)?; let by = by.datetime().unwrap(); let by_values = by.cont_slice().map_err(|_| { polars_err!( diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index d4bec70dcbf0..e7d0a7322061 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -214,53 +214,15 @@ impl StringNameSpace { } /// Split the string by a substring. The resulting dtype is `List`. - pub fn split(self, by: &str) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size()); - ca.into_iter().for_each(|opt_s| match opt_s { - None => builder.append_null(), - Some(s) => { - let iter = s.split(&by); - builder.append_values_iter(iter); - }, - }); - Ok(Some(builder.finish().into_series())) - }; + pub fn split(self, by: Expr) -> Expr { self.0 - .map( - function, - GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))), - ) - .with_fmt("str.split") + .map_many_private(StringFunction::Split.into(), &[by], false) } /// Split the string by a substring and keep the substring. The resulting dtype is `List`. - pub fn split_inclusive(self, by: &str) -> Expr { - let by = by.to_string(); - - let function = move |s: Series| { - let ca = s.utf8()?; - - let mut builder = ListUtf8ChunkedBuilder::new(s.name(), s.len(), ca.get_values_size()); - ca.into_iter().for_each(|opt_s| match opt_s { - None => builder.append_null(), - Some(s) => { - let iter = s.split_inclusive(&by); - builder.append_values_iter(iter); - }, - }); - Ok(Some(builder.finish().into_series())) - }; + pub fn split_inclusive(self, by: Expr) -> Expr { self.0 - .map( - function, - GetOutput::from_type(DataType::List(Box::new(DataType::Utf8))), - ) - .with_fmt("str.split_inclusive") + .map_many_private(StringFunction::SplitInclusive.into(), &[by], false) } #[cfg(feature = "dtype-struct")] diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index 5847b1b5399a..cba20688c312 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -1,10 +1,10 @@ #[cfg(feature = "csv")] use std::io::{Read, Seek}; -#[cfg(feature = "parquet")] -use polars_core::cloud::CloudOptions; use polars_core::frame::explode::MeltArgs; use polars_core::prelude::*; +#[cfg(feature = "parquet")] +use polars_io::cloud::CloudOptions; #[cfg(feature = "ipc")] use polars_io::ipc::IpcReader; #[cfg(all(feature = "parquet", feature = "async"))] @@ -20,9 +20,10 @@ use polars_io::parquet::ParquetReader; use polars_io::RowCount; #[cfg(feature = "csv")] use polars_io::{ - csv::utils::{get_reader_bytes, infer_file_schema, is_compressed}, + csv::utils::{infer_file_schema, is_compressed}, csv::CsvEncoding, csv::NullValues, + utils::get_reader_bytes, }; use super::builder_functions::*; @@ -523,9 +524,11 @@ impl LogicalPlanBuilder { pub fn filter(self, predicate: Expr) -> Self { let predicate = if has_expr(&predicate, |e| match e { Expr::Column(name) => is_regex_projection(name), - Expr::Wildcard | Expr::RenameAlias { .. } | Expr::Columns(_) | Expr::DtypeColumn(_) => { - true - }, + Expr::Wildcard + | Expr::RenameAlias { .. } + | Expr::Columns(_) + | Expr::DtypeColumn(_) + | Expr::Nth(_) => true, _ => false, }) { let schema = try_delayed!(self.0.schema(), &self.0, into); diff --git a/crates/polars-plan/src/logical_plan/lit.rs b/crates/polars-plan/src/logical_plan/lit.rs index 7f8dc6da6446..6a7fb87ad6f6 100644 --- a/crates/polars-plan/src/logical_plan/lit.rs +++ b/crates/polars-plan/src/logical_plan/lit.rs @@ -252,7 +252,7 @@ impl Literal for NaiveDateTime { fn lit(self) -> Expr { if in_nanoseconds_window(&self) { Expr::Literal(LiteralValue::DateTime( - self.timestamp_nanos(), + self.timestamp_nanos_opt().unwrap(), TimeUnit::Nanoseconds, None, )) diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index 3865a268eda6..81814647d931 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -2,9 +2,9 @@ use std::fmt::Debug; use std::path::PathBuf; use std::sync::{Arc, Mutex}; -#[cfg(any(feature = "cloud", feature = "parquet"))] -use polars_core::cloud::CloudOptions; use polars_core::prelude::*; +#[cfg(any(feature = "cloud", feature = "parquet"))] +use polars_io::cloud::CloudOptions; use crate::logical_plan::LogicalPlan::DataFrameScan; use crate::prelude::*; diff --git a/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs b/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs new file mode 100644 index 000000000000..17b051411cde --- /dev/null +++ b/crates/polars-plan/src/logical_plan/optimizer/collect_members.rs @@ -0,0 +1,28 @@ +use super::*; + +pub(super) struct MemberCollector { + pub(crate) has_joins_or_unions: bool, + pub(crate) has_cache: bool, + pub(crate) has_ext_context: bool, +} + +impl MemberCollector { + pub(super) fn new() -> Self { + Self { + has_joins_or_unions: false, + has_cache: false, + has_ext_context: false, + } + } + pub fn collect(&mut self, root: Node, lp_arena: &Arena) { + use ALogicalPlan::*; + for (_, alp) in lp_arena.iter(root) { + match alp { + Join { .. } | Union { .. } => self.has_joins_or_unions = true, + Cache { .. } => self.has_cache = true, + ExtContext { .. } => self.has_ext_context = true, + _ => {}, + } + } + } +} diff --git a/crates/polars-plan/src/logical_plan/optimizer/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/mod.rs index caa1a19a593f..2f462239445b 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/mod.rs @@ -9,6 +9,7 @@ mod cse; mod delay_rechunk; mod drop_nulls; +mod collect_members; #[cfg(feature = "cse")] mod cse_expr; mod fast_projection; @@ -43,6 +44,7 @@ pub use crate::frame::{AllowedOptimizations, OptState}; use crate::logical_plan::optimizer::cse_expr::CommonSubExprOptimizer; #[cfg(feature = "cse")] use crate::logical_plan::visitor::*; +use crate::prelude::optimizer::collect_members::MemberCollector; pub trait Optimize { fn optimize(&self, logical_plan: LogicalPlan) -> PolarsResult; @@ -75,8 +77,11 @@ pub fn optimize( let eager = opt_state.eager; #[cfg(feature = "cse")] let comm_subplan_elim = opt_state.comm_subplan_elim && !eager; + #[cfg(feature = "cse")] let comm_subexpr_elim = opt_state.comm_subexpr_elim; + #[cfg(not(feature = "cse"))] + let comm_subexpr_elim = false; #[allow(unused_variables)] let agg_scan_projection = opt_state.file_caching && !streaming && !eager; @@ -91,16 +96,23 @@ pub fn optimize( let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena)?; + // Collect members for optimizations that need it. + let mut members = MemberCollector::new(); + if !eager && (comm_subexpr_elim || projection_pushdown) { + members.collect(lp_top, lp_arena) + } + #[cfg(feature = "cse")] - let cse_changed = if comm_subplan_elim { + let cse_plan_changed = if comm_subplan_elim { let (lp, changed) = cse::elim_cmn_subplans(lp_top, lp_arena, expr_arena); lp_top = lp; + members.has_cache |= changed; changed } else { false }; #[cfg(not(feature = "cse"))] - let cse_changed = false; + let cse_plan_changed = false; // we do simplification if simplify_expr { @@ -116,8 +128,8 @@ pub fn optimize( let alp = projection_pushdown_opt.optimize(alp, lp_arena, expr_arena)?; lp_arena.replace(lp_top, alp); - if projection_pushdown_opt.has_joins_or_unions && projection_pushdown_opt.has_cache { - cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_changed); + if members.has_joins_or_unions && members.has_cache { + cache_states::set_cache_states(lp_top, lp_arena, expr_arena, scratch, cse_plan_changed); } } @@ -160,7 +172,7 @@ pub fn optimize( // and predicate pushdown are done. At that moment // the file fingerprints are finished. #[cfg(any(feature = "cse", feature = "parquet", feature = "ipc", feature = "csv"))] - if agg_scan_projection || cse_changed { + if agg_scan_projection || cse_plan_changed { // we do this so that expressions are simplified created by the pushdown optimizations // we must clean up the predicates, because the agg_scan_projection // uses them in the hashtable to determine duplicates. @@ -181,7 +193,7 @@ pub fn optimize( file_cacher.assign_unions(lp_top, lp_arena, expr_arena, scratch); #[cfg(feature = "cse")] - if cse_changed { + if cse_plan_changed { // this must run after cse cse::decrement_file_counters_by_cache_hits(lp_top, lp_arena, expr_arena, 0, scratch); } @@ -196,7 +208,7 @@ pub fn optimize( // This one should run (nearly) last as this modifies the projections #[cfg(feature = "cse")] - if comm_subexpr_elim { + if comm_subexpr_elim && !members.has_ext_context { let mut optimizer = CommonSubExprOptimizer::new(expr_arena); lp_top = ALogicalPlanNode::with_context(lp_top, lp_arena, |alp_node| { alp_node.rewrite(&mut optimizer) diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs index aec5f1780554..30db76221022 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/join.rs @@ -28,9 +28,9 @@ fn should_block_join_specific(ae: &AExpr, how: &JoinType) -> LeftRight { | FunctionExpr::Boolean(BooleanFunction::IsDuplicated), .. } => LeftRight(true, true), - #[cfg(feature = "is_first")] + #[cfg(feature = "is_first_distinct")] Function { - function: FunctionExpr::Boolean(BooleanFunction::IsFirst), + function: FunctionExpr::Boolean(BooleanFunction::IsFirstDistinct), .. } => LeftRight(true, true), // any operation that checks for equality or ordering can be wrong because diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs index 9bd5290edeb9..d5c004ce8bad 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs @@ -42,7 +42,6 @@ pub(super) fn process_asof_join( lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - proj_pd.has_joins_or_unions = true; // n = 0 if no projections, so we don't allocate unneeded let n = acc_projections.len() * 2; let mut pushdown_left = Vec::with_capacity(n); @@ -222,7 +221,6 @@ pub(super) fn process_join( ); } - proj_pd.has_joins_or_unions = true; // n = 0 if no projections, so we don't allocate unneeded let n = acc_projections.len() * 2; let mut pushdown_left = Vec::with_capacity(n); diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index a412cb5a6dc5..2e208daf5530 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -149,17 +149,11 @@ fn update_scan_schema( Ok(new_schema) } -pub struct ProjectionPushDown { - pub(crate) has_joins_or_unions: bool, - pub(crate) has_cache: bool, -} +pub struct ProjectionPushDown {} impl ProjectionPushDown { pub(super) fn new() -> Self { - Self { - has_joins_or_unions: false, - has_cache: false, - } + Self {} } /// Projection will be done at this node, but we continue optimization @@ -675,18 +669,15 @@ impl ProjectionPushDown { lp_arena, expr_arena, ), - lp @ Union { .. } => { - self.has_joins_or_unions = true; - process_generic( - self, - lp, - acc_projections, - projected_names, - projections_seen, - lp_arena, - expr_arena, - ) - }, + lp @ Union { .. } => process_generic( + self, + lp, + acc_projections, + projected_names, + projections_seen, + lp_arena, + expr_arena, + ), // These nodes only have inputs and exprs, so we can use same logic. lp @ Slice { .. } | lp @ Sink { .. } => process_generic( self, @@ -698,7 +689,6 @@ impl ProjectionPushDown { expr_arena, ), Cache { .. } => { - self.has_cache = true; // projections above this cache will be accumulated and pushed down // later // the redundant projection will be cleaned in the fast projection optimization diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs index 7e0cee38462e..16b2f5bb073b 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/semi_anti_join.rs @@ -15,7 +15,6 @@ pub(super) fn process_semi_anti_join( lp_arena: &mut Arena, expr_arena: &mut Arena, ) -> PolarsResult { - proj_pd.has_joins_or_unions = true; // n = 0 if no projections, so we don't allocate unneeded let n = acc_projections.len() * 2; let mut pushdown_left = Vec::with_capacity(n); diff --git a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs index ba92141239fe..ac51f5b8fabe 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/type_coercion/mod.rs @@ -519,28 +519,56 @@ fn early_escape(type_self: &DataType, type_other: &DataType) -> Option<()> { } } -// TODO: Fix this test and re-enable it (currently does not compile) -// #[cfg(test)] -// #[cfg(feature = "dtype-categorical")] -// mod test { -// use polars_core::prelude::*; - -// use super::*; -// use crate::prelude::*; - -// #[test] -// fn test_categorical_utf8() { -// let mut rules: Vec> = vec![Box::new(TypeCoercionRule {})]; -// let schema = Schema::from_iter([Field::new("fruits", DataType::Categorical(None))]); - -// let expr = col("fruits").eq(lit("somestr")); -// let out = optimize_expr(expr.clone(), schema.clone(), &mut rules); -// // we test that the fruits column is not casted to utf8 for the comparison -// assert_eq!(out, expr); - -// let expr = col("fruits") + (lit("somestr")); -// let out = optimize_expr(expr, schema, &mut rules); -// let expected = col("fruits").cast(DataType::Utf8) + lit("somestr"); -// assert_eq!(out, expected); -// } -// } +#[cfg(test)] +#[cfg(feature = "dtype-categorical")] +mod test { + use polars_core::prelude::*; + + use super::*; + + #[test] + fn test_categorical_utf8() { + let mut expr_arena = Arena::new(); + let mut lp_arena = Arena::new(); + let optimizer = StackOptimizer {}; + let rules: &mut [Box] = &mut [Box::new(TypeCoercionRule {})]; + + let df = DataFrame::new(Vec::from([Series::new_empty( + "fruits", + &DataType::Categorical(None), + )])) + .unwrap(); + + let expr_in = vec![col("fruits").eq(lit("somestr"))]; + let lp = LogicalPlanBuilder::from_existing_df(df.clone()) + .project(expr_in.clone(), Default::default()) + .build(); + + let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is not casted to utf8 for the comparison + if let LogicalPlan::Projection { expr, .. } = lp { + assert_eq!(expr, expr_in); + }; + + let expr_in = vec![col("fruits") + (lit("somestr"))]; + let lp = LogicalPlanBuilder::from_existing_df(df) + .project(expr_in, Default::default()) + .build(); + let mut lp_top = to_alp(lp, &mut expr_arena, &mut lp_arena).unwrap(); + lp_top = optimizer + .optimize_loop(rules, &mut expr_arena, &mut lp_arena, lp_top) + .unwrap(); + let lp = node_to_lp(lp_top, &expr_arena, &mut lp_arena); + + // we test that the fruits column is casted to utf8 for the addition + let expected = vec![col("fruits").cast(DataType::Utf8) + lit("somestr")]; + if let LogicalPlan::Projection { expr, .. } = lp { + assert_eq!(expr, expected); + }; + } +} diff --git a/crates/polars-plan/src/logical_plan/options.rs b/crates/polars-plan/src/logical_plan/options.rs index 96489db7bd85..1c4a3e8aa358 100644 --- a/crates/polars-plan/src/logical_plan/options.rs +++ b/crates/polars-plan/src/logical_plan/options.rs @@ -313,7 +313,7 @@ pub enum SinkType { Cloud { uri: Arc, file_type: FileType, - cloud_options: Option, + cloud_options: Option, }, } diff --git a/crates/polars-plan/src/logical_plan/pyarrow.rs b/crates/polars-plan/src/logical_plan/pyarrow.rs index ca2b23f7b934..a007887be351 100644 --- a/crates/polars-plan/src/logical_plan/pyarrow.rs +++ b/crates/polars-plan/src/logical_plan/pyarrow.rs @@ -1,6 +1,7 @@ use std::fmt::Write; use polars_core::datatypes::AnyValue; +use polars_core::prelude::{TimeUnit, TimeZone}; use crate::prelude::*; @@ -11,6 +12,15 @@ pub(super) struct Args { allow_literal_series: bool, } +fn to_py_datetime(v: i64, tu: &TimeUnit, tz: Option<&TimeZone>) -> String { + // note: `_to_python_datetime` and the `Datetime` + // dtype have to be in-scope on the python side + match tz { + None => format!("_to_python_datetime({},'{}')", v, tu.to_ascii()), + Some(tz) => format!("_to_python_datetime({},'{}',{})", v, tu.to_ascii(), tz), + } +} + // convert to a pyarrow expression that can be evaluated with pythons eval pub(super) fn predicate_to_pa( predicate: Node, @@ -39,15 +49,18 @@ pub(super) fn predicate_to_pa( if let AnyValue::Boolean(v) = av { let s = if v { "True" } else { "False" }; write!(list_repr, "{},", s).unwrap(); + } else if let AnyValue::Datetime(v, tu, tz) = av { + let dtm = to_py_datetime(v, &tu, tz.as_ref()); + write!(list_repr, "{dtm},").unwrap(); + } else if let AnyValue::Date(v) = av { + write!(list_repr, "_to_python_date({v}),").unwrap(); } else { write!(list_repr, "{av},").unwrap(); } } - // pop last comma list_repr.pop(); list_repr.push(']'); - Some(list_repr) } }, @@ -68,26 +81,10 @@ pub(super) fn predicate_to_pa( AnyValue::Date(v) => { // the function `_to_python_date` and the `Date` // dtype have to be in scope on the python side - Some(format!("_to_python_date(value={v})")) + Some(format!("_to_python_date({v})")) }, #[cfg(feature = "dtype-datetime")] - AnyValue::Datetime(v, tu, tz) => { - // the function `_to_python_datetime` and the `Datetime` - // dtype have to be in scope on the python side - match tz { - None => Some(format!( - "_to_python_datetime(value={}, tu='{}')", - v, - tu.to_ascii() - )), - Some(tz) => Some(format!( - "_to_python_datetime(value={}, tu='{}', tz={})", - v, - tu.to_ascii(), - tz - )), - } - }, + AnyValue::Datetime(v, tu, tz) => Some(to_py_datetime(v, &tu, tz.as_ref())), // Activate once pyarrow supports them // #[cfg(feature = "dtype-time")] // AnyValue::Time(v) => { diff --git a/crates/polars-row/Cargo.toml b/crates/polars-row/Cargo.toml index 7a61189a7d6b..cd764898b2d6 100644 --- a/crates/polars-row/Cargo.toml +++ b/crates/polars-row/Cargo.toml @@ -9,7 +9,7 @@ repository = { workspace = true } description = "Row encodings for the Polars DataFrame library" [dependencies] -polars-error = { version = "0.32.0", path = "../polars-error" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-error = { workspace = true } +polars-utils = { workspace = true } arrow = { workspace = true } diff --git a/crates/polars-sql/Cargo.toml b/crates/polars-sql/Cargo.toml index cbd079b8c6c2..30f65a7565d7 100644 --- a/crates/polars-sql/Cargo.toml +++ b/crates/polars-sql/Cargo.toml @@ -9,10 +9,10 @@ repository = { workspace = true } description = "SQL transpiler for Polars. Converts SQL to Polars logical plans" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", features = ["like"] } -polars-core = { version = "0.32.0", path = "../polars-core", features = [] } -polars-lazy = { version = "0.32.0", path = "../polars-lazy", features = ["compile", "strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg"] } -polars-plan = { version = "0.32.0", path = "../polars-plan", features = ["compile"] } +polars-arrow = { workspace = true } +polars-core = { workspace = true } +polars-lazy = { workspace = true, features = ["strings", "cross_join", "trigonometry", "abs", "round_series", "log", "regex", "is_in", "meta", "cum_agg"] } +polars-plan = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index aeca8376aff2..d4b73a371b97 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -682,14 +682,7 @@ impl SqlFunctionVisitor<'_> { ArrayReverse => self.visit_unary(|e| e.list().reverse()), ArraySum => self.visit_unary(|e| e.list().sum()), ArrayToString => self.try_visit_binary(|e, s| { - let sep = match s { - Expr::Literal(LiteralValue::Utf8(ref sep)) => sep, - _ => { - polars_bail!(InvalidOperation: "Invalid 'separator' for ArrayToString: {}", function.args[1]); - } - }; - - Ok(e.list().join(sep)) + Ok(e.list().join(s)) }), ArrayUnique => self.visit_unary(|e| e.list().unique()), Explode => self.visit_unary(|e| e.explode()), diff --git a/crates/polars-sql/tests/functions_io.rs b/crates/polars-sql/tests/functions_io.rs index 823744b389e8..59740a6a792f 100644 --- a/crates/polars-sql/tests/functions_io.rs +++ b/crates/polars-sql/tests/functions_io.rs @@ -1,5 +1,8 @@ +#[cfg(any(feature = "csv", feature = "ipc"))] use polars_core::prelude::*; +#[cfg(any(feature = "csv", feature = "ipc"))] use polars_lazy::prelude::*; +#[cfg(any(feature = "csv", feature = "ipc"))] use polars_sql::*; #[test] diff --git a/crates/polars-sql/tests/functions_string.rs b/crates/polars-sql/tests/functions_string.rs index eecdc497d271..b91a375157a7 100644 --- a/crates/polars-sql/tests/functions_string.rs +++ b/crates/polars-sql/tests/functions_string.rs @@ -116,7 +116,7 @@ fn array_to_string() { .lazy() .group_by([col("b")]) .agg([col("a")]) - .select(&[col("b"), col("a").list().join(", ").alias("as")]) + .select(&[col("b"), col("a").list().join(lit(", ")).alias("as")]) .sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true) .collect() .unwrap(); diff --git a/crates/polars-sql/tests/iss_7436.rs b/crates/polars-sql/tests/iss_7436.rs index 34895e3a657f..65b3f1c854ec 100644 --- a/crates/polars-sql/tests/iss_7436.rs +++ b/crates/polars-sql/tests/iss_7436.rs @@ -1,9 +1,9 @@ -use polars_lazy::prelude::*; -use polars_sql::*; - #[test] #[cfg(feature = "csv")] fn iss_7436() { + use polars_lazy::prelude::*; + use polars_sql::*; + let mut context = SQLContext::new(); let sql = r#" CREATE TABLE foods AS diff --git a/crates/polars-sql/tests/iss_7437.rs b/crates/polars-sql/tests/iss_7437.rs index 29229ba5c4c6..9b150ac06244 100644 --- a/crates/polars-sql/tests/iss_7437.rs +++ b/crates/polars-sql/tests/iss_7437.rs @@ -1,5 +1,8 @@ +#[cfg(feature = "csv")] use polars_core::prelude::*; +#[cfg(feature = "csv")] use polars_lazy::prelude::*; +#[cfg(feature = "csv")] use polars_sql::*; #[test] diff --git a/crates/polars-sql/tests/iss_8395.rs b/crates/polars-sql/tests/iss_8395.rs index a54f360a456b..b48c30718771 100644 --- a/crates/polars-sql/tests/iss_8395.rs +++ b/crates/polars-sql/tests/iss_8395.rs @@ -1,4 +1,6 @@ +#[cfg(feature = "csv")] use polars_core::prelude::*; +#[cfg(feature = "csv")] use polars_sql::*; #[test] diff --git a/crates/polars-time/Cargo.toml b/crates/polars-time/Cargo.toml index a06a270c8dd9..88a39d3659f4 100644 --- a/crates/polars-time/Cargo.toml +++ b/crates/polars-time/Cargo.toml @@ -9,10 +9,10 @@ repository = { workspace = true } description = "Time related code for the Polars DataFrame library" [dependencies] -polars-arrow = { version = "0.32.0", path = "../polars-arrow", features = ["compute", "temporal"] } -polars-core = { version = "0.32.0", path = "../polars-core", default-features = false, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } -polars-ops = { version = "0.32.0", path = "../polars-ops" } -polars-utils = { version = "0.32.0", path = "../polars-utils" } +polars-arrow = { workspace = true, features = ["compute", "temporal"] } +polars-core = { workspace = true, default-features = false, features = ["dtype-datetime", "dtype-duration", "dtype-time", "dtype-date"] } +polars-ops = { workspace = true } +polars-utils = { workspace = true } arrow = { workspace = true } atoi = { workspace = true } diff --git a/crates/polars-time/src/date_range.rs b/crates/polars-time/src/date_range.rs index 91ee5f9f1b7b..7b81ecbc27dc 100644 --- a/crates/polars-time/src/date_range.rs +++ b/crates/polars-time/src/date_range.rs @@ -24,7 +24,10 @@ pub fn date_range( tz: Option, ) -> PolarsResult { let (start, end) = match tu { - TimeUnit::Nanoseconds => (start.timestamp_nanos(), end.timestamp_nanos()), + TimeUnit::Nanoseconds => ( + start.timestamp_nanos_opt().unwrap(), + end.timestamp_nanos_opt().unwrap(), + ), TimeUnit::Microseconds => (start.timestamp_micros(), end.timestamp_micros()), TimeUnit::Milliseconds => (start.timestamp_millis(), end.timestamp_millis()), }; diff --git a/crates/polars-time/src/group_by/dynamic.rs b/crates/polars-time/src/group_by/dynamic.rs index fd06f47954b7..519b7dd1edee 100644 --- a/crates/polars-time/src/group_by/dynamic.rs +++ b/crates/polars-time/src/group_by/dynamic.rs @@ -338,7 +338,7 @@ impl Wrap<&DataFrame> { let ir = groups .par_iter() .map(|base_g| { - let dt = unsafe { dt.take_unchecked(base_g.1.into()) }; + let dt = unsafe { dt.take_unchecked(base_g.1) }; let vals = dt.downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); if options.check_sorted @@ -417,7 +417,7 @@ impl Wrap<&DataFrame> { let groupsidx = groups .par_iter() .map(|base_g| { - let dt = unsafe { dt.take_unchecked(base_g.1.into()) }; + let dt = unsafe { dt.take_unchecked(base_g.1) }; let vals = dt.downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); if options.check_sorted @@ -560,7 +560,7 @@ impl Wrap<&DataFrame> { let idx = groups .par_iter() .map(|base_g| { - let dt = unsafe { dt_local.take_unchecked(base_g.1.into()) }; + let dt = unsafe { dt_local.take_unchecked(base_g.1) }; let vals = dt.downcast_iter().next().unwrap(); let ts = vals.values().as_slice(); if options.check_sorted diff --git a/crates/polars-time/src/month_start.rs b/crates/polars-time/src/month_start.rs index f317d3852112..ff03f3317f51 100644 --- a/crates/polars-time/src/month_start.rs +++ b/crates/polars-time/src/month_start.rs @@ -30,16 +30,18 @@ pub(crate) fn roll_backward( ts.second(), ts.timestamp_subsec_nanos(), ) - .ok_or(polars_err!( - ComputeError: - format!( - "Could not construct time {}:{}:{}.{}", - ts.hour(), - ts.minute(), - ts.second(), - ts.timestamp_subsec_nanos() - ) - ))?; + .ok_or_else(|| { + polars_err!( + ComputeError: + format!( + "Could not construct time {}:{}:{}.{}", + ts.hour(), + ts.minute(), + ts.second(), + ts.timestamp_subsec_nanos() + ) + ) + })?; let ndt = NaiveDateTime::new(date, time); let t = match tz { #[cfg(feature = "timezones")] diff --git a/crates/polars-time/src/truncate.rs b/crates/polars-time/src/truncate.rs index bc233c51a662..c682e27a9661 100644 --- a/crates/polars-time/src/truncate.rs +++ b/crates/polars-time/src/truncate.rs @@ -1,26 +1,17 @@ #[cfg(feature = "dtype-date")] use polars_arrow::export::arrow::temporal_conversions::{MILLISECONDS, SECONDS_IN_DAY}; use polars_arrow::time_zone::Tz; -use polars_core::chunked_array::ops::arity::try_binary_elementwise_values; +use polars_core::chunked_array::ops::arity::{try_binary_elementwise, try_ternary_elementwise}; use polars_core::prelude::*; -#[cfg(feature = "serde")] -use serde::{Deserialize, Serialize}; use crate::prelude::*; -#[derive(Clone, PartialEq, Debug, Eq, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct TruncateOptions { - /// Period length - pub every: String, - /// Offset of the window - pub offset: String, -} pub trait PolarsTruncate { fn truncate( &self, - options: &TruncateOptions, tz: Option<&Tz>, + every: &Utf8Chunked, + offset: &str, ambiguous: &Utf8Chunked, ) -> PolarsResult where @@ -31,13 +22,12 @@ pub trait PolarsTruncate { impl PolarsTruncate for DatetimeChunked { fn truncate( &self, - options: &TruncateOptions, tz: Option<&Tz>, + every: &Utf8Chunked, + offset: &str, ambiguous: &Utf8Chunked, ) -> PolarsResult { - let every = Duration::parse(&options.every); - let offset = Duration::parse(&options.offset); - let w = Window::new(every, every, offset); + let offset = Duration::parse(offset); let func = match self.time_unit() { TimeUnit::Nanoseconds => Window::truncate_ns, @@ -45,18 +35,65 @@ impl PolarsTruncate for DatetimeChunked { TimeUnit::Milliseconds => Window::truncate_ms, }; - let out = match ambiguous.len() { - 1 => match ambiguous.get(0) { - Some(ambiguous) => self - .0 - .try_apply(|timestamp| func(&w, timestamp, tz, ambiguous)), - _ => Ok(self.0.apply(|_| None)), + let out = match (every.len(), ambiguous.len()) { + (1, 1) => match (every.get(0), ambiguous.get(0)) { + (Some(every), Some(ambiguous)) => { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + self.0 + .try_apply(|timestamp| func(&w, timestamp, tz, ambiguous)) + }, + _ => Ok(Int64Chunked::full_null(self.name(), self.len())), + }, + (1, _) => { + if let Some(every) = every.get(0) { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + try_binary_elementwise(self, ambiguous, |opt_timestamp, opt_ambiguous| { + match (opt_timestamp, opt_ambiguous) { + (Some(timestamp), Some(ambiguous)) => { + func(&w, timestamp, tz, ambiguous).map(Some) + }, + _ => Ok(None), + } + }) + } else { + Ok(Int64Chunked::full_null(self.name(), self.len())) + } }, - _ => { - try_binary_elementwise_values(self, ambiguous, |timestamp: i64, ambiguous: &str| { - func(&w, timestamp, tz, ambiguous) - }) + (_, 1) => { + if let Some(ambiguous) = ambiguous.get(0) { + try_binary_elementwise(self, every, |opt_timestamp, opt_every| { + match (opt_timestamp, opt_every) { + (Some(timestamp), Some(every)) => { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + func(&w, timestamp, tz, ambiguous).map(Some) + }, + _ => Ok(None), + } + }) + } else { + Ok(Int64Chunked::full_null(self.name(), self.len())) + } }, + _ => try_ternary_elementwise( + self, + every, + ambiguous, + |opt_timestamp, opt_every, opt_ambiguous| match ( + opt_timestamp, + opt_every, + opt_ambiguous, + ) { + (Some(timestamp), Some(every), Some(ambiguous)) => { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + func(&w, timestamp, tz, ambiguous).map(Some) + }, + _ => Ok(None), + }, + ), }; Ok(out?.into_datetime(self.time_unit(), self.time_zone().clone())) } @@ -66,18 +103,42 @@ impl PolarsTruncate for DatetimeChunked { impl PolarsTruncate for DateChunked { fn truncate( &self, - options: &TruncateOptions, _tz: Option<&Tz>, + every: &Utf8Chunked, + offset: &str, _ambiguous: &Utf8Chunked, ) -> PolarsResult { - let every = Duration::parse(&options.every); - let offset = Duration::parse(&options.offset); - let w = Window::new(every, every, offset); - Ok(self - .try_apply(|t| { - const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; - Ok((w.truncate_ms(MSECS_IN_DAY * t as i64, None, "raise")? / MSECS_IN_DAY) as i32) - })? - .into_date()) + let offset = Duration::parse(offset); + let out = + match every.len() { + 1 => { + if let Some(every) = every.get(0) { + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + self.try_apply(|t| { + const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; + Ok((w.truncate_ms(MSECS_IN_DAY * t as i64, None, "raise")? + / MSECS_IN_DAY) as i32) + }) + } else { + Ok(Int32Chunked::full_null(self.name(), self.len())) + } + }, + _ => try_binary_elementwise(&self.0, every, |opt_t, opt_every| { + match (opt_t, opt_every) { + (Some(t), Some(every)) => { + const MSECS_IN_DAY: i64 = MILLISECONDS * SECONDS_IN_DAY; + let every = Duration::parse(every); + let w = Window::new(every, every, offset); + Ok(Some( + (w.truncate_ms(MSECS_IN_DAY * t as i64, None, "raise")? + / MSECS_IN_DAY) as i32, + )) + }, + _ => Ok(None), + } + }), + }; + Ok(out?.into_date()) } } diff --git a/crates/polars-time/src/utils.rs b/crates/polars-time/src/utils.rs index 21edd285941f..a5c781ee66c7 100644 --- a/crates/polars-time/src/utils.rs +++ b/crates/polars-time/src/utils.rs @@ -53,7 +53,8 @@ pub(crate) fn localize_timestamp(timestamp: i64, tu: TimeUnit, tz: Tz) -> Polars TimeUnit::Nanoseconds => { Ok( localize_datetime(timestamp_ns_to_datetime(timestamp), &tz, "raise")? - .timestamp_nanos(), + .timestamp_nanos_opt() + .unwrap(), ) }, TimeUnit::Microseconds => { @@ -74,9 +75,9 @@ pub(crate) fn localize_timestamp(timestamp: i64, tu: TimeUnit, tz: Tz) -> Polars #[cfg(feature = "timezones")] pub(crate) fn unlocalize_timestamp(timestamp: i64, tu: TimeUnit, tz: Tz) -> i64 { match tu { - TimeUnit::Nanoseconds => { - unlocalize_datetime(timestamp_ns_to_datetime(timestamp), &tz).timestamp_nanos() - }, + TimeUnit::Nanoseconds => unlocalize_datetime(timestamp_ns_to_datetime(timestamp), &tz) + .timestamp_nanos_opt() + .unwrap(), TimeUnit::Microseconds => { unlocalize_datetime(timestamp_us_to_datetime(timestamp), &tz).timestamp_micros() }, diff --git a/crates/polars-time/src/windows/test.rs b/crates/polars-time/src/windows/test.rs index 8c5afb240e14..95cf8e44b161 100644 --- a/crates/polars-time/src/windows/test.rs +++ b/crates/polars-time/src/windows/test.rs @@ -17,8 +17,8 @@ fn test_date_range() { .and_hms_opt(0, 0, 0) .unwrap(); let dates = datetime_range_i64( - start.timestamp_nanos(), - end.timestamp_nanos(), + start.timestamp_nanos_opt().unwrap(), + end.timestamp_nanos_opt().unwrap(), Duration::parse("1mo"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -32,7 +32,12 @@ fn test_date_range() { NaiveDate::from_ymd_opt(2022, 4, 1).unwrap(), ] .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); assert_eq!(dates, expected); } @@ -48,8 +53,8 @@ fn test_feb_date_range() { .and_hms_opt(0, 0, 0) .unwrap(); let dates = datetime_range_i64( - start.timestamp_nanos(), - end.timestamp_nanos(), + start.timestamp_nanos_opt().unwrap(), + end.timestamp_nanos_opt().unwrap(), Duration::parse("1mo"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -61,7 +66,12 @@ fn test_feb_date_range() { NaiveDate::from_ymd_opt(2022, 3, 1).unwrap(), ] .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); assert_eq!(dates, expected); } @@ -89,7 +99,12 @@ fn test_groups_large_interval() { ]; let ts = dates .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); let dur = Duration::parse("2d"); @@ -141,7 +156,8 @@ fn test_offset() { .unwrap() .and_hms_opt(0, 0, 0) .unwrap() - .timestamp_nanos(); + .timestamp_nanos_opt() + .unwrap(); let w = Window::new( Duration::parse("5m"), Duration::parse("5m"), @@ -153,7 +169,8 @@ fn test_offset() { .unwrap() .and_hms_opt(23, 58, 0) .unwrap() - .timestamp_nanos(); + .timestamp_nanos_opt() + .unwrap(); assert_eq!(b.start, start); } @@ -169,8 +186,8 @@ fn test_boundaries() { .unwrap(); let ts = datetime_range_i64( - start.timestamp_nanos(), - stop.timestamp_nanos(), + start.timestamp_nanos_opt().unwrap(), + stop.timestamp_nanos_opt().unwrap(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -189,7 +206,7 @@ fn test_boundaries() { // earliest bound is first datapoint: 2021-12-16 00:00:00 let b = w.get_earliest_bounds_ns(ts[0], None).unwrap(); - assert_eq!(b.start, start.timestamp_nanos()); + assert_eq!(b.start, start.timestamp_nanos_opt().unwrap()); // test closed: "both" (includes both ends of the interval) let (groups, lower, higher) = group_by_windows( @@ -226,9 +243,9 @@ fn test_boundaries() { assert_eq!( g, &[ - t0.timestamp_nanos(), - t1.timestamp_nanos(), - t2.timestamp_nanos() + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap(), + t2.timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -241,7 +258,10 @@ fn test_boundaries() { .unwrap(); assert_eq!( &[lower[0], higher[0]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); // 2nd group @@ -267,9 +287,9 @@ fn test_boundaries() { assert_eq!( g, &[ - t0.timestamp_nanos(), - t1.timestamp_nanos(), - t2.timestamp_nanos() + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap(), + t2.timestamp_nanos_opt().unwrap() ] ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) @@ -282,7 +302,10 @@ fn test_boundaries() { .unwrap(); assert_eq!( &[lower[1], higher[1]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); assert_eq!(groups[2], [4, 3]); @@ -345,8 +368,8 @@ fn test_boundaries_2() { .unwrap(); let ts = datetime_range_i64( - start.timestamp_nanos(), - stop.timestamp_nanos(), + start.timestamp_nanos_opt().unwrap(), + stop.timestamp_nanos_opt().unwrap(), Duration::parse("30m"), ClosedWindow::Both, TimeUnit::Nanoseconds, @@ -366,7 +389,10 @@ fn test_boundaries_2() { // earliest bound is first datapoint: 2021-12-16 00:00:00 + 30m offset: 2021-12-16 00:30:00 let b = w.get_earliest_bounds_ns(ts[0], None).unwrap(); - assert_eq!(b.start, start.timestamp_nanos() + offset.duration_ns()); + assert_eq!( + b.start, + start.timestamp_nanos_opt().unwrap() + offset.duration_ns() + ); let (groups, lower, higher) = group_by_windows( w, @@ -396,7 +422,13 @@ fn test_boundaries_2() { .unwrap() .and_hms_opt(1, 0, 0) .unwrap(); - assert_eq!(g, &[t0.timestamp_nanos(), t1.timestamp_nanos()]); + assert_eq!( + g, + &[ + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap() + ] + ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(0, 30, 0) @@ -407,7 +439,10 @@ fn test_boundaries_2() { .unwrap(); assert_eq!( &[lower[0], higher[0]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); // 2nd group @@ -426,7 +461,13 @@ fn test_boundaries_2() { .unwrap() .and_hms_opt(3, 0, 0) .unwrap(); - assert_eq!(g, &[t0.timestamp_nanos(), t1.timestamp_nanos()]); + assert_eq!( + g, + &[ + t0.timestamp_nanos_opt().unwrap(), + t1.timestamp_nanos_opt().unwrap() + ] + ); let b_start = NaiveDate::from_ymd_opt(2021, 12, 16) .unwrap() .and_hms_opt(2, 30, 0) @@ -437,7 +478,10 @@ fn test_boundaries_2() { .unwrap(); assert_eq!( &[lower[1], higher[1]], - &[b_start.timestamp_nanos(), b_end.timestamp_nanos()] + &[ + b_start.timestamp_nanos_opt().unwrap(), + b_end.timestamp_nanos_opt().unwrap() + ] ); } @@ -823,7 +867,12 @@ fn test_group_by_windows_offsets_3776() { ]; let ts = dates .iter() - .map(|d| d.and_hms_opt(0, 0, 0).unwrap().timestamp_nanos()) + .map(|d| { + d.and_hms_opt(0, 0, 0) + .unwrap() + .timestamp_nanos_opt() + .unwrap() + }) .collect::>(); let window = Window::new( diff --git a/crates/polars-utils/Cargo.toml b/crates/polars-utils/Cargo.toml index 1c04b8894c8c..c017f731739c 100644 --- a/crates/polars-utils/Cargo.toml +++ b/crates/polars-utils/Cargo.toml @@ -9,7 +9,7 @@ repository = { workspace = true } description = "Private utils for the Polars DataFrame library" [dependencies] -polars-error = { version = "0.32.0", path = "../polars-error" } +polars-error = { workspace = true } ahash = { workspace = true } bytemuck = { workspace = true } diff --git a/crates/polars-utils/src/cache.rs b/crates/polars-utils/src/cache.rs index 21a022e6f604..26c6a932bfb5 100644 --- a/crates/polars-utils/src/cache.rs +++ b/crates/polars-utils/src/cache.rs @@ -111,6 +111,25 @@ impl FastFixedCache { } } + pub fn try_get_or_insert_with(&mut self, key: &Q, f: F) -> Result<&mut V, E> + where + K: Borrow, + Q: Hash + Eq + ToOwned + ?Sized, + F: FnOnce(&K) -> Result, + { + unsafe { + let h = self.hash(key); + if let Some(slot_idx) = self.raw_get(self.hash(&key), key) { + let slot = self.slots.get_unchecked_mut(slot_idx); + return Ok(slot.value.assume_init_mut()); + } + + let key = key.to_owned(); + let val = f(&key)?; + Ok(self.raw_insert(h, key, val)) + } + } + unsafe fn raw_get(&self, h: HashResult, key: &Q) -> Option where K: Borrow, diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 5a0ebefe093a..e4071bd68134 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -11,13 +11,13 @@ repository = { workspace = true } description = "DataFrame library based on Apache Arrow" [dependencies] -polars-algo = { version = "0.32.0", path = "../polars-algo", optional = true } -polars-core = { version = "0.32.0", path = "../polars-core", features = ["docs"], default-features = false } -polars-io = { version = "0.32.0", path = "../polars-io", features = [], default-features = false, optional = true } -polars-lazy = { version = "0.32.0", path = "../polars-lazy", features = [], default-features = false, optional = true } -polars-ops = { version = "0.32.0", path = "../polars-ops" } -polars-sql = { version = "0.32.0", path = "../polars-sql", default-features = false, optional = true } -polars-time = { version = "0.32.0", path = "../polars-time", default-features = false, optional = true } +polars-algo = { workspace = true, optional = true } +polars-core = { workspace = true } +polars-io = { workspace = true, optional = true } +polars-lazy = { workspace = true, default-features = false, optional = true } +polars-ops = { workspace = true } +polars-sql = { workspace = true, optional = true } +polars-time = { workspace = true, optional = true } [dev-dependencies] ahash = { workspace = true } @@ -36,10 +36,10 @@ sql = ["polars-sql"] rows = ["polars-core/rows"] simd = ["polars-core/simd", "polars-io/simd", "polars-ops/simd"] avx512 = ["polars-core/avx512"] -nightly = ["polars-core/nightly", "polars-ops/nightly", "simd", "polars-lazy/nightly"] +nightly = ["polars-core/nightly", "polars-ops/nightly", "simd", "polars-lazy?/nightly"] docs = ["polars-core/docs"] -temporal = ["polars-core/temporal", "polars-lazy/temporal", "polars-io/temporal", "polars-time"] -random = ["polars-core/random", "polars-lazy/random"] +temporal = ["polars-core/temporal", "polars-lazy?/temporal", "polars-io/temporal", "polars-time"] +random = ["polars-core/random", "polars-lazy?/random"] default = [ "docs", "zip_with", @@ -51,38 +51,44 @@ default = [ ndarray = ["polars-core/ndarray"] # serde support for dataframes and series serde = ["polars-core/serde"] -serde-lazy = ["polars-core/serde-lazy", "polars-lazy/serde", "polars-time/serde", "polars-io/serde", "polars-ops/serde"] -parquet = ["polars-io", "polars-core/parquet", "polars-lazy/parquet", "polars-io/parquet", "polars-sql/parquet"] -async = ["polars-lazy/async"] -cloud = ["polars-lazy/cloud", "polars-io/cloud"] -cloud_write = ["cloud", "polars-lazy/cloud_write"] +serde-lazy = [ + "polars-core/serde-lazy", + "polars-lazy?/serde", + "polars-time?/serde", + "polars-io/serde", + "polars-ops/serde", +] +parquet = ["polars-io", "polars-core/parquet", "polars-lazy?/parquet", "polars-io/parquet", "polars-sql?/parquet"] +async = ["polars-lazy?/async"] +cloud = ["polars-lazy?/cloud", "polars-io/cloud"] +cloud_write = ["cloud", "polars-lazy?/cloud_write"] aws = ["async", "cloud", "polars-io/aws"] azure = ["async", "cloud", "polars-io/azure"] gcp = ["async", "cloud", "polars-io/gcp"] -lazy = ["polars-core/lazy", "polars-lazy", "polars-lazy/compile"] +lazy = ["polars-core/lazy", "polars-lazy"] # commented out until UB is fixed # parallel = ["polars-core/parallel"] # extra utilities for Utf8Chunked -strings = ["polars-core/strings", "polars-lazy/strings", "polars-ops/strings"] +strings = ["polars-core/strings", "polars-lazy?/strings", "polars-ops/strings"] # support for ObjectChunked (downcastable Series of any type) -object = ["polars-core/object", "polars-lazy/object", "polars-io/object"] +object = ["polars-core/object", "polars-lazy?/object", "polars-io/object"] # support for arrows json parsing -json = ["polars-io", "polars-io/json", "polars-lazy/json", "polars-sql/json", "dtype-struct"] +json = ["polars-io", "polars-io/json", "polars-lazy?/json", "polars-sql?/json", "dtype-struct"] # support for arrows ipc file parsing -ipc = ["polars-io", "polars-io/ipc", "polars-lazy/ipc", "polars-sql/ipc"] +ipc = ["polars-io", "polars-io/ipc", "polars-lazy?/ipc", "polars-sql?/ipc"] # support for arrows streaming ipc file parsing -ipc_streaming = ["polars-io", "polars-io/ipc_streaming", "polars-lazy/ipc"] +ipc_streaming = ["polars-io", "polars-io/ipc_streaming", "polars-lazy?/ipc"] # support for apache avro file parsing avro = ["polars-io", "polars-io/avro"] # support for arrows csv file parsing -csv = ["polars-io", "polars-io/csv", "polars-lazy/csv", "polars-sql/csv"] +csv = ["polars-io", "polars-io/csv", "polars-lazy?/csv", "polars-sql?/csv"] # slower builds performant = [ @@ -105,88 +111,88 @@ fmt_no_tty = ["polars-core/fmt_no_tty"] sort_multiple = ["polars-core/sort_multiple"] # extra operations -approx_unique = ["polars-lazy/approx_unique", "polars-ops/approx_unique"] -is_in = ["polars-lazy/is_in"] +approx_unique = ["polars-lazy?/approx_unique", "polars-ops/approx_unique"] +is_in = ["polars-lazy?/is_in"] zip_with = ["polars-core/zip_with"] -round_series = ["polars-core/round_series", "polars-lazy/round_series", "polars-ops/round_series"] +round_series = ["polars-core/round_series", "polars-lazy?/round_series", "polars-ops/round_series"] checked_arithmetic = ["polars-core/checked_arithmetic"] -repeat_by = ["polars-core/repeat_by", "polars-lazy/repeat_by"] -is_first = ["polars-lazy/is_first", "polars-ops/is_first"] -is_last = ["polars-lazy/is_last", "polars-ops/is_last"] -is_unique = ["polars-lazy/is_unique", "polars-ops/is_unique"] -asof_join = ["polars-core/asof_join", "polars-lazy/asof_join", "polars-ops/asof_join"] -cross_join = ["polars-core/cross_join", "polars-lazy/cross_join", "polars-ops/cross_join"] +repeat_by = ["polars-core/repeat_by", "polars-lazy?/repeat_by"] +is_first_distinct = ["polars-lazy?/is_first_distinct", "polars-ops/is_first_distinct"] +is_last_distinct = ["polars-lazy?/is_last_distinct", "polars-ops/is_last_distinct"] +is_unique = ["polars-lazy?/is_unique", "polars-ops/is_unique"] +asof_join = ["polars-core/asof_join", "polars-lazy?/asof_join", "polars-ops/asof_join"] +cross_join = ["polars-core/cross_join", "polars-lazy?/cross_join", "polars-ops/cross_join"] dot_product = ["polars-core/dot_product"] -concat_str = ["polars-core/concat_str", "polars-lazy/concat_str"] -row_hash = ["polars-core/row_hash", "polars-lazy/row_hash"] +concat_str = ["polars-core/concat_str", "polars-lazy?/concat_str"] +row_hash = ["polars-core/row_hash", "polars-lazy?/row_hash"] reinterpret = ["polars-core/reinterpret"] decompress = ["polars-io/decompress"] decompress-fast = ["polars-io/decompress-fast"] -mode = ["polars-core/mode", "polars-lazy/mode"] +mode = ["polars-core/mode", "polars-lazy?/mode"] take_opt_iter = ["polars-core/take_opt_iter"] extract_jsonpath = [ "polars-core/strings", "polars-ops/extract_jsonpath", "polars-ops/strings", - "polars-lazy/extract_jsonpath", + "polars-lazy?/extract_jsonpath", ] string_encoding = ["polars-ops/string_encoding", "polars-core/strings"] binary_encoding = ["polars-ops/binary_encoding"] group_by_list = ["polars-core/group_by_list", "polars-ops/group_by_list"] -lazy_regex = ["polars-lazy/regex"] +lazy_regex = ["polars-lazy?/regex"] cum_agg = ["polars-core/cum_agg", "polars-core/cum_agg"] -rolling_window = ["polars-core/rolling_window", "polars-lazy/rolling_window", "polars-time/rolling_window"] -interpolate = ["polars-ops/interpolate", "polars-lazy/interpolate"] -rank = ["polars-core/rank", "polars-lazy/rank"] -diff = ["polars-core/diff", "polars-lazy/diff", "polars-ops/diff"] -pct_change = ["polars-core/pct_change", "polars-lazy/pct_change"] -moment = ["polars-core/moment", "polars-lazy/moment", "polars-ops/moment"] -range = ["polars-lazy/range"] -true_div = ["polars-lazy/true_div"] -diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy/diagonal_concat"] +rolling_window = ["polars-core/rolling_window", "polars-lazy?/rolling_window", "polars-time/rolling_window"] +interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"] +rank = ["polars-core/rank", "polars-lazy?/rank"] +diff = ["polars-core/diff", "polars-lazy?/diff", "polars-ops/diff"] +pct_change = ["polars-core/pct_change", "polars-lazy?/pct_change"] +moment = ["polars-core/moment", "polars-lazy?/moment", "polars-ops/moment"] +range = ["polars-lazy?/range"] +true_div = ["polars-lazy?/true_div"] +diagonal_concat = ["polars-core/diagonal_concat", "polars-lazy?/diagonal_concat"] horizontal_concat = ["polars-core/horizontal_concat"] -abs = ["polars-core/abs", "polars-lazy/abs"] -dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy/dynamic_group_by"] -ewma = ["polars-core/ewma", "polars-lazy/ewma"] -dot_diagram = ["polars-lazy/dot_diagram"] +abs = ["polars-core/abs", "polars-lazy?/abs"] +dynamic_group_by = ["polars-core/dynamic_group_by", "polars-lazy?/dynamic_group_by"] +ewma = ["polars-core/ewma", "polars-lazy?/ewma"] +dot_diagram = ["polars-lazy?/dot_diagram"] dataframe_arithmetic = ["polars-core/dataframe_arithmetic"] product = ["polars-core/product"] -unique_counts = ["polars-core/unique_counts", "polars-lazy/unique_counts"] -log = ["polars-ops/log", "polars-lazy/log"] +unique_counts = ["polars-core/unique_counts", "polars-lazy?/unique_counts"] +log = ["polars-ops/log", "polars-lazy?/log"] partition_by = ["polars-core/partition_by"] -semi_anti_join = ["polars-core/semi_anti_join", "polars-lazy/semi_anti_join", "polars-ops/semi_anti_join"] -list_eval = ["polars-lazy/list_eval"] -cumulative_eval = ["polars-lazy/cumulative_eval"] -chunked_ids = ["polars-core/chunked_ids", "polars-lazy/chunked_ids", "polars-core/chunked_ids"] +semi_anti_join = ["polars-core/semi_anti_join", "polars-lazy?/semi_anti_join", "polars-ops/semi_anti_join"] +list_eval = ["polars-lazy?/list_eval"] +cumulative_eval = ["polars-lazy?/cumulative_eval"] +chunked_ids = ["polars-lazy?/chunked_ids", "polars-core/chunked_ids"] to_dummies = ["polars-ops/to_dummies"] -bigidx = ["polars-core/bigidx", "polars-lazy/bigidx", "polars-ops/big_idx"] -list_to_struct = ["polars-ops/list_to_struct", "polars-lazy/list_to_struct"] -list_count = ["polars-ops/list_count", "polars-lazy/list_count"] -list_take = ["polars-ops/list_take", "polars-lazy/list_take"] +bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx"] +list_to_struct = ["polars-ops/list_to_struct", "polars-lazy?/list_to_struct"] +list_count = ["polars-ops/list_count", "polars-lazy?/list_count"] +list_take = ["polars-ops/list_take", "polars-lazy?/list_take"] describe = ["polars-core/describe"] -timezones = ["polars-core/timezones", "polars-lazy/timezones", "polars-io/timezones"] -string_justify = ["polars-lazy/string_justify", "polars-ops/string_justify"] -string_from_radix = ["polars-lazy/string_from_radix", "polars-ops/string_from_radix"] -arg_where = ["polars-lazy/arg_where"] -search_sorted = ["polars-lazy/search_sorted"] -merge_sorted = ["polars-lazy/merge_sorted"] -meta = ["polars-lazy/meta"] -date_offset = ["polars-lazy/date_offset"] -trigonometry = ["polars-lazy/trigonometry"] -sign = ["polars-lazy/sign"] -pivot = ["polars-lazy/pivot"] -top_k = ["polars-lazy/top_k"] +timezones = ["polars-core/timezones", "polars-lazy?/timezones", "polars-io/timezones"] +string_justify = ["polars-lazy?/string_justify", "polars-ops/string_justify"] +string_from_radix = ["polars-lazy?/string_from_radix", "polars-ops/string_from_radix"] +arg_where = ["polars-lazy?/arg_where"] +search_sorted = ["polars-lazy?/search_sorted"] +merge_sorted = ["polars-lazy?/merge_sorted"] +meta = ["polars-lazy?/meta"] +date_offset = ["polars-lazy?/date_offset"] +trigonometry = ["polars-lazy?/trigonometry"] +sign = ["polars-lazy?/sign"] +pivot = ["polars-lazy?/pivot"] +top_k = ["polars-lazy?/top_k"] algo = ["polars-algo"] -cse = ["polars-lazy/cse"] -propagate_nans = ["polars-lazy/propagate_nans"] -coalesce = ["polars-lazy/coalesce"] -streaming = ["polars-lazy/streaming"] -fused = ["polars-ops/fused", "polars-lazy/fused"] -list_sets = ["polars-lazy/list_sets"] -list_any_all = ["polars-lazy/list_any_all"] -cutqcut = ["polars-lazy/cutqcut"] -rle = ["polars-lazy/rle"] -extract_groups = ["polars-lazy/extract_groups"] +cse = ["polars-lazy?/cse"] +propagate_nans = ["polars-lazy?/propagate_nans"] +coalesce = ["polars-lazy?/coalesce"] +streaming = ["polars-lazy?/streaming"] +fused = ["polars-ops/fused", "polars-lazy?/fused"] +list_sets = ["polars-lazy?/list_sets"] +list_any_all = ["polars-lazy?/list_any_all"] +cutqcut = ["polars-lazy?/cutqcut"] +rle = ["polars-lazy?/rle"] +extract_groups = ["polars-lazy?/extract_groups"] test = [ "lazy", @@ -231,51 +237,51 @@ dtype-slim = [ # opt-in datatypes for Series dtype-date = [ "polars-core/dtype-date", - "polars-lazy/dtype-date", + "polars-lazy?/dtype-date", "polars-io/dtype-date", - "polars-time/dtype-date", + "polars-time?/dtype-date", "polars-core/dtype-date", "polars-ops/dtype-date", ] dtype-datetime = [ "polars-core/dtype-datetime", - "polars-lazy/dtype-datetime", + "polars-lazy?/dtype-datetime", "polars-io/dtype-datetime", - "polars-time/dtype-datetime", + "polars-time?/dtype-datetime", "polars-ops/dtype-datetime", ] dtype-duration = [ "polars-core/dtype-duration", - "polars-lazy/dtype-duration", - "polars-time/dtype-duration", + "polars-lazy?/dtype-duration", + "polars-time?/dtype-duration", "polars-core/dtype-duration", "polars-ops/dtype-duration", ] -dtype-time = ["polars-core/dtype-time", "polars-io/dtype-time", "polars-time/dtype-time", "polars-ops/dtype-time"] +dtype-time = ["polars-core/dtype-time", "polars-io/dtype-time", "polars-time?/dtype-time", "polars-ops/dtype-time"] dtype-array = [ "polars-core/dtype-array", - "polars-lazy/dtype-array", + "polars-lazy?/dtype-array", "polars-ops/dtype-array", ] -dtype-i8 = ["polars-core/dtype-i8", "polars-lazy/dtype-i8", "polars-ops/dtype-i8"] -dtype-i16 = ["polars-core/dtype-i16", "polars-lazy/dtype-i16", "polars-ops/dtype-i16"] +dtype-i8 = ["polars-core/dtype-i8", "polars-lazy?/dtype-i8", "polars-ops/dtype-i8"] +dtype-i16 = ["polars-core/dtype-i16", "polars-lazy?/dtype-i16", "polars-ops/dtype-i16"] dtype-decimal = [ "polars-core/dtype-decimal", - "polars-lazy/dtype-decimal", + "polars-lazy?/dtype-decimal", "polars-ops/dtype-decimal", "polars-io/dtype-decimal", ] -dtype-u8 = ["polars-core/dtype-u8", "polars-lazy/dtype-u8", "polars-ops/dtype-u8"] -dtype-u16 = ["polars-core/dtype-u16", "polars-lazy/dtype-u16", "polars-ops/dtype-u16"] +dtype-u8 = ["polars-core/dtype-u8", "polars-lazy?/dtype-u8", "polars-ops/dtype-u8"] +dtype-u16 = ["polars-core/dtype-u16", "polars-lazy?/dtype-u16", "polars-ops/dtype-u16"] dtype-categorical = [ "polars-core/dtype-categorical", "polars-io/dtype-categorical", - "polars-lazy/dtype-categorical", + "polars-lazy?/dtype-categorical", "polars-ops/dtype-categorical", ] dtype-struct = [ "polars-core/dtype-struct", - "polars-lazy/dtype-struct", + "polars-lazy?/dtype-struct", "polars-ops/dtype-struct", "polars-io/dtype-struct", ] @@ -300,8 +306,8 @@ docs-selection = [ "checked_arithmetic", "ndarray", "repeat_by", - "is_first", - "is_last", + "is_first_distinct", + "is_last_distinct", "asof_join", "cross_join", "concat_str", diff --git a/crates/polars/src/docs/lazy.rs b/crates/polars/src/docs/lazy.rs index fb0c7dfd2d9e..44b536914ce1 100644 --- a/crates/polars/src/docs/lazy.rs +++ b/crates/polars/src/docs/lazy.rs @@ -106,7 +106,7 @@ //! //! ## Groupby //! -//! This example is from the polars [user guide](https://pola-rs.github.io/polars-book/user-guide/concepts/contexts/#group_by-aggregation). +//! This example is from the polars [user guide](https://pola-rs.github.io/polars/user-guide/concepts/contexts/#group_by-aggregation). //! //! ``` //! use polars::prelude::*; diff --git a/crates/polars/src/lib.rs b/crates/polars/src/lib.rs index bd6affefb10e..11abf37644c7 100644 --- a/crates/polars/src/lib.rs +++ b/crates/polars/src/lib.rs @@ -147,7 +147,7 @@ //! (Note that within an expression there may be more parallelization going on). //! //! Understanding polars expressions is most important when starting with the polars library. Read more -//! about them in the [User Guide](https://pola-rs.github.io/polars-book/user-guide/concepts/expressions). +//! about them in the [User Guide](https://pola-rs.github.io/polars/user-guide/concepts/expressions). //! Though the examples given there are in python. The expressions API is almost identical and the //! the read should certainly be valuable to rust users as well. //! @@ -238,8 +238,8 @@ //! - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip). //! - `round_series` - round underlying float types of [`Series`]. //! - `repeat_by` - [Repeat element in an Array N times, where N is given by another array. -//! - `is_first` - Check if element is first unique value. -//! - `is_last` - Check if element is last unique value. +//! - `is_first_distinct` - Check if element is first unique value. +//! - `is_last_distinct` - Check if element is last unique value. //! - `checked_arithmetic` - checked arithmetic/ returning [`None`] on invalid operations. //! - `dot_product` - Dot/inner product on [`Series`] and [`Expr`]. //! - `concat_str` - Concat string data in linear time. @@ -397,7 +397,7 @@ //! * `POLARS_NO_CHUNKED_JOIN` -> force rechunk before joins. //! //! ## User Guide -//! If you want to read more, [check the User Guide](https://pola-rs.github.io/polars-book/). +//! If you want to read more, [check the User Guide](https://pola-rs.github.io/polars/). #![cfg_attr(docsrs, feature(doc_auto_cfg))] #![allow(ambiguous_glob_reexports)] pub mod docs; @@ -408,8 +408,8 @@ pub mod prelude; pub mod sql; pub use polars_core::{ - apply_method_all_arrow_series, chunked_array, datatypes, df, doc, error, frame, functions, - series, testing, + apply_method_all_arrow_series, chunked_array, datatypes, df, error, frame, functions, series, + testing, }; #[cfg(feature = "dtype-categorical")] pub use polars_core::{enable_string_cache, using_string_cache}; diff --git a/crates/polars/tests/it/core/ops/take.rs b/crates/polars/tests/it/core/ops/take.rs index 14acb9338970..a958997954c1 100644 --- a/crates/polars/tests/it/core/ops/take.rs +++ b/crates/polars/tests/it/core/ops/take.rs @@ -2,13 +2,13 @@ use super::*; #[test] fn test_list_take_nulls_and_empty() { - unsafe { - let a: &[i32] = &[]; - let a = Series::new("", a); - let b = Series::new("", &[None, Some(a.clone())]); - let mut iter = [Some(0), Some(1usize), None].iter().copied(); - let out = b.take_opt_iter_unchecked(&mut iter); - let expected = Series::new("", &[None, Some(a), None]); - assert!(out.series_equal_missing(&expected)) - } + let a: &[i32] = &[]; + let a = Series::new("", a); + let b = Series::new("", &[None, Some(a.clone())]); + let indices = [Some(0 as IdxSize), Some(1), None] + .into_iter() + .collect_ca(""); + let out = b.take(&indices).unwrap(); + let expected = Series::new("", &[None, Some(a), None]); + assert!(out.series_equal_missing(&expected)) } diff --git a/crates/polars/tests/it/lazy/explodes.rs b/crates/polars/tests/it/lazy/explodes.rs index 540af19a1525..01cc6ff69db7 100644 --- a/crates/polars/tests/it/lazy/explodes.rs +++ b/crates/polars/tests/it/lazy/explodes.rs @@ -9,7 +9,7 @@ fn test_explode_row_numbers() -> PolarsResult<()> { "text" => ["one two three four", "uno dos tres cuatro"] ]? .lazy() - .select([col("text").str().split(" ").alias("tokens")]) + .select([col("text").str().split(lit(" ")).alias("tokens")]) .with_row_count("row_nr", None) .explode([col("tokens")]) .select([col("row_nr"), col("tokens")]) diff --git a/crates/polars/tests/it/lazy/queries.rs b/crates/polars/tests/it/lazy/queries.rs index d0af51efaab3..67b27707a871 100644 --- a/crates/polars/tests/it/lazy/queries.rs +++ b/crates/polars/tests/it/lazy/queries.rs @@ -225,7 +225,7 @@ fn test_apply_multiple_columns() -> PolarsResult<()> { .collect()?; let out = out.column("A")?; - let out = out.list()?.get(1).unwrap(); + let out = out.list()?.get_as_series(1).unwrap(); let out = out.i32()?; assert_eq!(Vec::from(out), &[Some(16)]); diff --git a/docs/_build/API_REFERENCE_LINKS.yml b/docs/_build/API_REFERENCE_LINKS.yml new file mode 100644 index 000000000000..4e028d99a8b2 --- /dev/null +++ b/docs/_build/API_REFERENCE_LINKS.yml @@ -0,0 +1,264 @@ +python: + DataFrame: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/index.html + Categorical: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.Categorical.html + Series: https://pola-rs.github.io/polars/py-polars/html/reference/series/index.html + select: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.select.html + filter: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.filter.html + with_columns: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.with_columns.html + group_by: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by.html + join: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.join.html + hstack: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.hstack.html + read_csv: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_csv.html + write_csv: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_csv.html + read_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_json.html + write_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_json.html + read_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html + write_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_parquet.html + min: https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.min.html + max: https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.max.html + value_counts: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.value_counts.html + unnest: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.unnest.html + field: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.struct.field.html + struct: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.struct.html + rename_fields: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.struct.rename_fields.html + is_duplicated: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.is_duplicated.html + replace: https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.str.replace.html + sample: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.sample.html + day: https://pola-rs.github.io/polars/py-polars/html/reference/series/api/polars.Series.dt.day.html + head: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.head.html + tail: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.tail.html + describe: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.describe.html + col: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.col.html + sort: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.sort.html + scan_csv: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_csv.html + collect: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.collect.html + fold: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.fold.html + concat_str: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.concat_str.html + str.split: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.split.html + Expr.List: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/list.html + element: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.element.html + all: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.all.html + exclude: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.exclude.html + alias: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.alias.html + prefix: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.prefix.html + suffix: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.suffix.html + map_alias: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.map_alias.html + n_unique: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.n_unique.html + approx_n_unique: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.approx_n_unique.html + when: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.when.html + concat_list: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.concat_list.html + list.eval: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.list.eval.html + null_count: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.null_count.html + is_null: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.is_null.html + fill_null: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.fill_null.html + interpolate: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.interpolate.html + fill_nan: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.fill_nan.html + operators: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/operators.html + map: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.map.html + apply: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.apply.html + over: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.over.html + implode: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.implode.html + dt_to_string: + link: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.to_string.html + name: dt.to_string + selectors: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html + cs_numeric: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.numeric + name: cs.numeric + cs_by_name: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.by_name + name: cs.by_name + cs_first: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.first + name: cs.first + cs_temporal: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.temporal + name: cs.temporal + cs_contains: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.contains + name: cs.contains + cs_matches: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.matches + name: cs.matches + is_selector: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.is_selector + name: is_selector + selector_column_names: + link: https://pola-rs.github.io/polars/py-polars/html/reference/selectors.html#polars.selectors.selector_column_names + name: selector_column_names + DataFrame.explode: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.explode.html + read_database_connectorx: + name: read_database + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_database.html + feature_flags: ['connectorx'] + read_database: + name: read_database + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_database.html + write_database: + name: write_database + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_database.html + read_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_parquet.html + write_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_parquet.html + scan_parquet: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_parquet.html + read_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_json.html + read_ndjson: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.read_ndjson.html + write_ndjson: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_ndjson.html + write_json: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.DataFrame.write_json.html + scan_ndjson: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.scan_ndjson.html + from_arrow: + name: from_arrow + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.from_arrow.html + feature_flags: ['fsspec','pyarrow'] + show_graph: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.show_graph.html + lazy: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.lazy.html + explain: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.explain.html + fetch: https://pola-rs.github.io/polars/py-polars/html/reference/lazyframe/api/polars.LazyFrame.fetch.html + SQLContext: https://pola-rs.github.io/polars/py-polars/html/reference/sql + SQLregister: + name: register + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.register.html#polars.SQLContext.register + SQLregister_many: + name: register_many + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.register_many.html + SQLquery: + name: query + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.query.html + SQLexecute: + name: execute + link: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.SQLContext.execute.html + join_asof: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.join_asof.html + concat: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.concat.html + pivot: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.pivot.html + melt: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.melt.html + is_between: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.is_between.html + strftime: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.strftime.html + strptime: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.strptime.html + year: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.year.html + convert_time_zone: + name: convert_time_zone + link: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.convert_time_zone.html + feature_flags: ['timezone'] + replace_time_zone: + name: replace_time_zone + link: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.replace_time_zone.html + feature_flags: ['timezone'] + date_range: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.date_range.html + upsample: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.upsample.html + group_by_dynamic: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.group_by_dynamic.html + explode: https://pola-rs.github.io/polars/py-polars/html/reference/dataframe/api/polars.DataFrame.explode.html + cast: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.cast.html + np.log: + name: log + link: https://numpy.org/doc/stable/reference/generated/numpy.log.html + feature_flags: ['numpy'] + lengths: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.lengths.html + n_chars: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.n_chars.html + str.contains: + name: str.contains + link: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.contains.html + starts_with: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.starts_with.html + ends_with: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.ends_with.html + extract: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.extract.html + extract_all: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.extract_all.html + replace: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.replace.html + replace_all: https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.str.replace_all.html + Array: https://pola-rs.github.io/polars/py-polars/html/reference/api/polars.Array.html + arr: https://pola-rs.github.io/polars/py-polars/html/reference/series/array.html + +rust: + DataFrame: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html + Series: https://pola-rs.github.io/polars/docs/rust/dev/polars/series/struct.Series.html + Categorical: + name: Categorical + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/enum.DataType.html#variant.Categorical + feature_flags: ['dtype-categorical'] + select: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.select + filter: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.filter + with_columns: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.with_columns + group_by: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyFrame.html#method.group_by + join: https://pola-rs.github.io/polars/docs/rust/dev/polars_core/frame/hash_join/index.html + hstack: https://pola-rs.github.io/polars/docs/rust/dev/polars_core/frame/struct.DataFrame.html#method.hstack + SQLContext: https://pola-rs.github.io/polars/py-polars/html/reference/sql.html + read_csv: + name: CsvReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/csv/struct.CsvReader.html + feature_flags: ['csv'] + scan_csv: + name: LazyCsvReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/struct.LazyCsvReader.html + feature_flags: ['csv'] + write_csv: + name: CsvWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/csv/struct.CsvWriter.html + feature_flags: ['csv'] + read_json: + name: JsonReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/json/struct.JsonReader.html + feature_flags: ['json'] + read_ndjson: + name: JsonLineReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/ndjson_core/ndjson/struct.JsonLineReader.html + feature_flags: ['json'] + write_json: + name: JsonWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/json/struct.JsonWriter.html + feature_flags: ['json'] + write_ndjson: + name: JsonWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/json/struct.JsonWriter.html + feature_flags: ['json'] + scan_ndjson: + name: LazyJsonLineReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/frame/struct.LazyJsonLineReader.html + feature_flags: ['json'] + read_parquet: + name: ParquetReader + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/parquet/struct.ParquetReader.html + feature_flags: ['parquet'] + write_parquet: + name: ParquetWriter + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_io/parquet/struct.ParquetWriter.html + feature_flags: ['parquet'] + scan_parquet: + name: scan_parquet + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/struct.LazyFrame.html#method.scan_parquet + feature_flags: ['parquet'] + min: https://pola-rs.github.io/polars/docs/rust/dev/polars/series/struct.Series.html#method.min + max: https://pola-rs.github.io/polars/docs/rust/dev/polars/series/struct.Series.html#method.max + struct: + name: Struct + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/datatypes/enum.DataType.html#variant.Struct + feature_flags: ['dtype-struct'] + implode: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.implode + sample: + name: sample_n + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.sample_n + head: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.head + tail: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.tail + describe: + name: describe + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/frame/struct.DataFrame.html#method.describe + feature_flags: ['describe'] + collect: + name: collect + link: https://pola-rs.github.io/polars/docs/rust/dev/polars/prelude/struct.LazyFrame.html#method.collect + feature_flags: ['streaming'] + col: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.col.html + sort: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.sort + arr.eval: + name: arr + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.arr + feature_flags: ['list_eval','rank'] + fold: + name: fold_exprs + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.fold_exprs.html + concat_str: + name: concat_str + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.concat_str.html + feature_flags: ['concat_str'] + concat_list: + name: concat_lst + link: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/fn.concat_lst.html + map: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.map + apply: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.apply + over: https://pola-rs.github.io/polars/docs/rust/dev/polars_lazy/dsl/enum.Expr.html#method.over diff --git a/docs/_build/assets/logo.png b/docs/_build/assets/logo.png new file mode 100644 index 000000000000..9b5486edce3b Binary files /dev/null and b/docs/_build/assets/logo.png differ diff --git a/docs/_build/css/extra.css b/docs/_build/css/extra.css new file mode 100644 index 000000000000..420db3966780 --- /dev/null +++ b/docs/_build/css/extra.css @@ -0,0 +1,64 @@ +:root { + --md-primary-fg-color: #0B7189 ; + --md-primary-fg-color--light: #C2CCD6; + --md-primary-fg-color--dark: #103547; + --md-text-font: 'Proxima Nova', sans-serif; +} + + +span .md-typeset .emojione, .md-typeset .gemoji, .md-typeset .twemoji { + vertical-align: text-bottom; +} + +@font-face { + font-family: 'Proxima Nova', sans-serif; + src: 'https://fonts.cdnfonts.com/css/proxima-nova-2' +} + +:root { + --md-code-font: "Source Code Pro" !important; +} + +.contributor_icon { + height:40px; + width:40px; + border-radius: 20px; + margin: 0 5px; +} + +.feature-flag{ + background-color: rgba(255, 245, 214,.5); + border: none; + padding: 0px 5px; + text-align: center; + text-decoration: none; + display: inline-block; + margin: 4px 2px; + cursor: pointer; + font-size: .85em; +} + +[data-md-color-scheme=slate] .feature-flag{ + background-color:var(--md-code-bg-color); +} +.md-typeset ol li, .md-typeset ul li{ + margin-bottom: 0em !important; +} + +:root { + --md-admonition-icon--rust: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 512 512'%3E%3C!--! Font Awesome Free 6.4.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2023 Fonticons, Inc.--%3E%3Cpath d='m508.52 249.75-21.82-13.51c-.17-2-.34-3.93-.55-5.88l18.72-17.5a7.35 7.35 0 0 0-2.44-12.25l-24-9c-.54-1.88-1.08-3.78-1.67-5.64l15-20.83a7.35 7.35 0 0 0-4.79-11.54l-25.42-4.15c-.9-1.73-1.79-3.45-2.73-5.15l10.68-23.42a7.35 7.35 0 0 0-6.95-10.39l-25.82.91q-1.79-2.22-3.61-4.4L439 81.84a7.36 7.36 0 0 0-8.84-8.84L405 78.93q-2.17-1.83-4.4-3.61l.91-25.82a7.35 7.35 0 0 0-10.39-7L367.7 53.23c-1.7-.94-3.43-1.84-5.15-2.73l-4.15-25.42a7.35 7.35 0 0 0-11.54-4.79L326 35.26c-1.86-.59-3.75-1.13-5.64-1.67l-9-24a7.35 7.35 0 0 0-12.25-2.44l-17.5 18.72c-1.95-.21-3.91-.38-5.88-.55L262.25 3.48a7.35 7.35 0 0 0-12.5 0L236.24 25.3c-2 .17-3.93.34-5.88.55l-17.5-18.72a7.35 7.35 0 0 0-12.25 2.44l-9 24c-1.89.55-3.79 1.08-5.66 1.68l-20.82-15a7.35 7.35 0 0 0-11.54 4.79l-4.15 25.41c-1.73.9-3.45 1.79-5.16 2.73l-23.4-10.63a7.35 7.35 0 0 0-10.39 7l.92 25.81c-1.49 1.19-3 2.39-4.42 3.61L81.84 73A7.36 7.36 0 0 0 73 81.84L78.93 107c-1.23 1.45-2.43 2.93-3.62 4.41l-25.81-.91a7.42 7.42 0 0 0-6.37 3.26 7.35 7.35 0 0 0-.57 7.13l10.66 23.41c-.94 1.7-1.83 3.43-2.73 5.16l-25.41 4.14a7.35 7.35 0 0 0-4.79 11.54l15 20.82c-.59 1.87-1.13 3.77-1.68 5.66l-24 9a7.35 7.35 0 0 0-2.44 12.25l18.72 17.5c-.21 1.95-.38 3.91-.55 5.88l-21.86 13.5a7.35 7.35 0 0 0 0 12.5l21.82 13.51c.17 2 .34 3.92.55 5.87l-18.72 17.5a7.35 7.35 0 0 0 2.44 12.25l24 9c.55 1.89 1.08 3.78 1.68 5.65l-15 20.83a7.35 7.35 0 0 0 4.79 11.54l25.42 4.15c.9 1.72 1.79 3.45 2.73 5.14l-10.63 23.43a7.35 7.35 0 0 0 .57 7.13 7.13 7.13 0 0 0 6.37 3.26l25.83-.91q1.77 2.22 3.6 4.4L73 430.16a7.36 7.36 0 0 0 8.84 8.84l25.16-5.93q2.18 1.83 4.41 3.61l-.92 25.82a7.35 7.35 0 0 0 10.39 6.95l23.43-10.68c1.69.94 3.42 1.83 5.14 2.73l4.15 25.42a7.34 7.34 0 0 0 11.54 4.78l20.83-15c1.86.6 3.76 1.13 5.65 1.68l9 24a7.36 7.36 0 0 0 12.25 2.44l17.5-18.72c1.95.21 3.92.38 5.88.55l13.51 21.82a7.35 7.35 0 0 0 12.5 0l13.51-21.82c2-.17 3.93-.34 5.88-.56l17.5 18.73a7.36 7.36 0 0 0 12.25-2.44l9-24c1.89-.55 3.78-1.08 5.65-1.68l20.82 15a7.34 7.34 0 0 0 11.54-4.78l4.15-25.42c1.72-.9 3.45-1.79 5.15-2.73l23.42 10.68a7.35 7.35 0 0 0 10.39-6.95l-.91-25.82q2.22-1.79 4.4-3.61l25.15 5.93a7.36 7.36 0 0 0 8.84-8.84L433.07 405q1.83-2.17 3.61-4.4l25.82.91a7.23 7.23 0 0 0 6.37-3.26 7.35 7.35 0 0 0 .58-7.13l-10.68-23.42c.94-1.7 1.83-3.43 2.73-5.15l25.42-4.15a7.35 7.35 0 0 0 4.79-11.54l-15-20.83c.59-1.87 1.13-3.76 1.67-5.65l24-9a7.35 7.35 0 0 0 2.44-12.25l-18.72-17.5c.21-1.95.38-3.91.55-5.87l21.82-13.51a7.35 7.35 0 0 0 0-12.5Zm-151 129.08A13.91 13.91 0 0 0 341 389.51l-7.64 35.67a187.51 187.51 0 0 1-156.36-.74l-7.64-35.66a13.87 13.87 0 0 0-16.46-10.68l-31.51 6.76a187.38 187.38 0 0 1-16.26-19.21H258.3c1.72 0 2.89-.29 2.89-1.91v-54.19c0-1.57-1.17-1.91-2.89-1.91h-44.83l.05-34.35H262c4.41 0 23.66 1.28 29.79 25.87 1.91 7.55 6.17 32.14 9.06 40 2.89 8.82 14.6 26.46 27.1 26.46H407a187.3 187.3 0 0 1-17.34 20.09Zm25.77 34.49A15.24 15.24 0 1 1 368 398.08h.44a15.23 15.23 0 0 1 14.8 15.24Zm-225.62-.68a15.24 15.24 0 1 1-15.25-15.25h.45a15.25 15.25 0 0 1 14.75 15.25Zm-88.1-178.49 32.83-14.6a13.88 13.88 0 0 0 7.06-18.33L102.69 186h26.56v119.73h-53.6a187.65 187.65 0 0 1-6.08-71.58Zm-11.26-36.06a15.24 15.24 0 0 1 15.23-15.25H74a15.24 15.24 0 1 1-15.67 15.24Zm155.16 24.49.05-35.32h63.26c3.28 0 23.07 3.77 23.07 18.62 0 12.29-15.19 16.7-27.68 16.7ZM399 306.71c-9.8 1.13-20.63-4.12-22-10.09-5.78-32.49-15.39-39.4-30.57-51.4 18.86-11.95 38.46-29.64 38.46-53.26 0-25.52-17.49-41.59-29.4-49.48-16.76-11-35.28-13.23-40.27-13.23h-198.9a187.49 187.49 0 0 1 104.89-59.19l23.47 24.6a13.82 13.82 0 0 0 19.6.44l26.26-25a187.51 187.51 0 0 1 128.37 91.43l-18 40.57a14 14 0 0 0 7.09 18.33l34.59 15.33a187.12 187.12 0 0 1 .4 32.54h-19.28c-1.91 0-2.69 1.27-2.69 3.13v8.82C421 301 409.31 305.58 399 306.71ZM240 60.21A15.24 15.24 0 0 1 255.21 45h.45A15.24 15.24 0 1 1 240 60.21ZM436.84 214a15.24 15.24 0 1 1 0-30.48h.44a15.24 15.24 0 0 1-.44 30.48Z'/%3E%3C/svg%3E"); + } + .md-typeset .admonition.rust, + .md-typeset details.rust { + border-color: rgb(205, 121, 44); + } + .md-typeset .rust > .admonition-title, + .md-typeset .rust > summary { + background-color: rgb(205, 121, 44,.1); + } + .md-typeset .rust > .admonition-title::before, + .md-typeset .rust > summary::before { + background-color:rgb(205, 121, 44); + -webkit-mask-image: var(--md-admonition-icon--rust); + mask-image: var(--md-admonition-icon--rust); + } \ No newline at end of file diff --git a/docs/_build/overrides/404.html b/docs/_build/overrides/404.html new file mode 100644 index 000000000000..ee9b8faa2aba --- /dev/null +++ b/docs/_build/overrides/404.html @@ -0,0 +1,222 @@ +{% extends "main.html" %} +{% block content %} +

+ +{% endblock %} diff --git a/docs/_build/scripts/macro.py b/docs/_build/scripts/macro.py new file mode 100644 index 000000000000..d93d5170adec --- /dev/null +++ b/docs/_build/scripts/macro.py @@ -0,0 +1,156 @@ +from collections import OrderedDict +import os +from typing import List, Optional, Set +import yaml +import logging + + +# Supported Languages and their metadata +LANGUAGES = OrderedDict( + python={ + "extension": ".py", + "display_name": "Python", + "icon_name": "python", + "code_name": "python", + }, + rust={ + "extension": ".rs", + "display_name": "Rust", + "icon_name": "rust", + "code_name": "rust", + }, +) + +# Load all links to reference docs +with open("docs/_build/API_REFERENCE_LINKS.yml", "r") as f: + API_REFERENCE_LINKS = yaml.load(f, Loader=yaml.CLoader) + + +def create_feature_flag_link(feature_name: str) -> str: + """Create a feature flag warning telling the user to activate a certain feature before running the code + + Args: + feature_name (str): name of the feature + + Returns: + str: Markdown formatted string with a link and the feature flag message + """ + return f'[:material-flag-plus: Available on feature {feature_name}](/polars/user-guide/installation/#feature-flags "To use this functionality enable the feature flag {feature_name}"){{.feature-flag}}' + + +def create_feature_flag_links(language: str, api_functions: List[str]) -> List[str]: + """Generate markdown feature flags for the code tas based on the api_functions. + It checks for the key feature_flag in the configuration yaml for the function and if it exists print out markdown + + Args: + language (str): programming languages + api_functions (List[str]): Api functions that are called + + Returns: + List[str]: Per unique feature flag a markdown formatted string for the feature flag + """ + api_functions_info = [ + info + for f in api_functions + if (info := API_REFERENCE_LINKS.get(language).get(f)) + ] + feature_flags: Set[str] = { + flag + for info in api_functions_info + if type(info) == dict and info.get("feature_flags") + for flag in info.get("feature_flags") + } + + return [create_feature_flag_link(flag) for flag in feature_flags] + + +def create_api_function_link(language: str, function_key: str) -> Optional[str]: + """Create an API link in markdown with an icon of the YAML file + + Args: + language (str): programming language + function_key (str): Key to the specific function + + Returns: + str: If the function is found than the link else None + """ + info = API_REFERENCE_LINKS.get(language, {}).get(function_key) + + if info is None: + logging.warning(f"Could not find {function_key} for language {language}") + return None + else: + # Either be a direct link + if type(info) == str: + return f"[:material-api: `{function_key}`]({info})" + else: + function_name = info["name"] + link = info["link"] + return f"[:material-api: `{function_name}`]({link})" + + +def code_tab( + base_path: str, + section: Optional[str], + language_info: dict, + api_functions: List[str], +) -> str: + """Generate a single tab for the code block corresponding to a specific language. + It gets the code at base_path and possible section and pretty prints markdown for it + + Args: + base_path (str): path where the code is located + section (str, optional): section in the code that should be displayed + language_info (dict): Language specific information (icon name, display name, ...) + api_functions (List[str]): List of api functions which should be linked + + Returns: + str: A markdown formatted string represented a single tab + """ + language = language_info["code_name"] + + # Create feature flags + feature_flags_links = create_feature_flag_links(language, api_functions) + + # Create API Links if they are defined in the YAML + api_functions = [ + link for f in api_functions if (link := create_api_function_link(language, f)) + ] + language_headers = " ·".join(api_functions + feature_flags_links) + + # Create path for Snippets extension + snippets_file_name = f"{base_path}:{section}" if section else f"{base_path}" + + # See Content Tabs for details https://squidfunk.github.io/mkdocs-material/reference/content-tabs/ + return f"""=== \":fontawesome-brands-{language_info['icon_name']}: {language_info['display_name']}\" + {language_headers} + ```{language} + --8<-- \"{snippets_file_name}\" + ``` + """ + + +def define_env(env): + @env.macro + def code_block( + path: str, section: str = None, api_functions: List[str] = None + ) -> str: + """Dynamically generate a code block for the code located under {language}/path + + Args: + path (str): base_path for each language + section (str, optional): Optional segment within the code file. Defaults to None. + api_functions (List[str], optional): API functions that should be linked. Defaults to None. + Returns: + str: Markdown tabbed code block with possible links to api functions and feature flags + """ + result = [] + + for language, info in LANGUAGES.items(): + base_path = f"{language}/{path}{info['extension']}" + full_path = "docs/src/" + base_path + # Check if file exists for the language + if os.path.exists(full_path): + result.append(code_tab(base_path, section, info, api_functions)) + + return "\n".join(result) diff --git a/docs/_build/scripts/people.py b/docs/_build/scripts/people.py new file mode 100644 index 000000000000..81ba1982f132 --- /dev/null +++ b/docs/_build/scripts/people.py @@ -0,0 +1,38 @@ +import itertools +from github import Github + +g = Github(None) + +ICON_TEMPLATE = "[![{login}]({avatar_url}){{.contributor_icon}}]({html_url})" + + +def get_people_md(): + repo = g.get_repo("pola-rs/polars") + contributors = repo.get_contributors() + with open("./docs/people.md", "w") as f: + for c in itertools.islice(contributors, 50): + # We love dependabot, but he doesn't need a spot on our website + if c.login == "dependabot[bot]": + continue + + f.write( + ICON_TEMPLATE.format( + login=c.login, + avatar_url=c.avatar_url, + html_url=c.html_url, + ) + + "\n" + ) + + +def on_startup(command, dirty): + """Mkdocs hook to autogenerate docs/people.md on startup""" + try: + get_people_md() + except Exception as e: + msg = f"WARNING:{__file__}: Could not generate docs/people.md. Got error: {str(e)}" + print(msg) + + +if __name__ == "__main__": + get_people_md() diff --git a/docs/_build/snippets/under_construction.md b/docs/_build/snippets/under_construction.md new file mode 100644 index 000000000000..c4bb56a735af --- /dev/null +++ b/docs/_build/snippets/under_construction.md @@ -0,0 +1,4 @@ +!!! warning ":construction: Under Construction :construction: " + + This section is still under development. Want to help out? Consider contributing and making a [pull request](https://github.com/pola-rs/polars) to our repository. + Please read our [Contribution Guidelines](https://github.com/pola-rs/polars/blob/main/CONTRIBUTING.md) on how to proceed. diff --git a/docs/data/apple_stock.csv b/docs/data/apple_stock.csv new file mode 100644 index 000000000000..6c3f9752d587 --- /dev/null +++ b/docs/data/apple_stock.csv @@ -0,0 +1,101 @@ +Date,Close +1981-02-23,24.62 +1981-05-06,27.38 +1981-05-18,28.0 +1981-09-25,14.25 +1982-07-08,11.0 +1983-01-03,28.5 +1983-04-06,40.0 +1983-10-03,23.13 +1984-07-27,27.13 +1984-08-17,27.5 +1984-08-24,28.12 +1985-05-07,20.0 +1985-09-03,14.75 +1985-12-06,19.75 +1986-03-12,24.75 +1986-04-09,27.13 +1986-04-17,29.0 +1986-09-17,34.25 +1986-11-26,40.5 +1987-02-25,69.13 +1987-04-15,71.0 +1988-02-23,42.75 +1988-03-07,46.88 +1988-03-23,42.5 +1988-12-12,38.5 +1988-12-19,40.75 +1989-04-17,39.25 +1989-11-13,46.5 +1990-11-23,36.38 +1991-03-22,63.25 +1991-05-17,47.0 +1991-06-03,49.25 +1991-06-18,42.12 +1992-06-25,45.62 +1992-10-12,44.0 +1993-07-06,37.75 +1993-09-15,24.5 +1993-09-30,23.37 +1993-11-09,30.12 +1994-01-24,35.0 +1994-03-15,37.62 +1994-06-27,26.25 +1994-07-08,27.06 +1994-12-21,38.38 +1995-07-06,47.0 +1995-10-16,36.13 +1995-11-17,40.13 +1995-12-12,38.0 +1996-01-31,27.63 +1996-02-05,29.25 +1996-07-15,17.19 +1996-09-20,22.87 +1996-12-23,23.25 +1997-03-17,16.5 +1997-05-09,17.06 +1997-08-06,26.31 +1997-09-30,21.69 +1998-02-09,19.19 +1998-03-12,27.0 +1998-05-07,30.19 +1998-05-12,30.12 +1999-07-09,55.63 +1999-12-08,110.06 +2000-01-14,100.44 +2000-06-27,51.75 +2000-07-05,51.62 +2000-07-19,52.69 +2000-08-07,47.94 +2000-08-28,58.06 +2000-09-26,51.44 +2001-03-02,19.25 +2001-12-10,22.54 +2002-01-25,23.25 +2002-03-07,24.38 +2002-08-16,15.81 +2002-10-03,14.3 +2003-11-18,20.41 +2004-02-26,23.04 +2004-03-08,26.0 +2004-09-22,36.92 +2005-06-24,37.76 +2005-12-07,73.95 +2005-12-22,74.02 +2006-06-22,59.58 +2006-11-28,91.81 +2007-08-13,127.79 +2007-12-04,179.81 +2007-12-31,198.08 +2008-05-09,183.45 +2008-06-27,170.09 +2009-08-03,166.43 +2010-04-01,235.97 +2010-12-10,320.56 +2011-04-28,346.75 +2011-12-02,389.7 +2012-05-16,546.08 +2012-12-04,575.85 +2013-07-05,417.42 +2013-11-07,512.49 +2014-02-25,522.06 \ No newline at end of file diff --git a/docs/data/iris.csv b/docs/data/iris.csv new file mode 100644 index 000000000000..d6b466b31892 --- /dev/null +++ b/docs/data/iris.csv @@ -0,0 +1,151 @@ +sepal_length,sepal_width,petal_length,petal_width,species +5.1,3.5,1.4,.2,Setosa +4.9,3,1.4,.2,Setosa +4.7,3.2,1.3,.2,Setosa +4.6,3.1,1.5,.2,Setosa +5,3.6,1.4,.2,Setosa +5.4,3.9,1.7,.4,Setosa +4.6,3.4,1.4,.3,Setosa +5,3.4,1.5,.2,Setosa +4.4,2.9,1.4,.2,Setosa +4.9,3.1,1.5,.1,Setosa +5.4,3.7,1.5,.2,Setosa +4.8,3.4,1.6,.2,Setosa +4.8,3,1.4,.1,Setosa +4.3,3,1.1,.1,Setosa +5.8,4,1.2,.2,Setosa +5.7,4.4,1.5,.4,Setosa +5.4,3.9,1.3,.4,Setosa +5.1,3.5,1.4,.3,Setosa +5.7,3.8,1.7,.3,Setosa +5.1,3.8,1.5,.3,Setosa +5.4,3.4,1.7,.2,Setosa +5.1,3.7,1.5,.4,Setosa +4.6,3.6,1,.2,Setosa +5.1,3.3,1.7,.5,Setosa +4.8,3.4,1.9,.2,Setosa +5,3,1.6,.2,Setosa +5,3.4,1.6,.4,Setosa +5.2,3.5,1.5,.2,Setosa +5.2,3.4,1.4,.2,Setosa +4.7,3.2,1.6,.2,Setosa +4.8,3.1,1.6,.2,Setosa +5.4,3.4,1.5,.4,Setosa +5.2,4.1,1.5,.1,Setosa +5.5,4.2,1.4,.2,Setosa +4.9,3.1,1.5,.2,Setosa +5,3.2,1.2,.2,Setosa +5.5,3.5,1.3,.2,Setosa +4.9,3.6,1.4,.1,Setosa +4.4,3,1.3,.2,Setosa +5.1,3.4,1.5,.2,Setosa +5,3.5,1.3,.3,Setosa +4.5,2.3,1.3,.3,Setosa +4.4,3.2,1.3,.2,Setosa +5,3.5,1.6,.6,Setosa +5.1,3.8,1.9,.4,Setosa +4.8,3,1.4,.3,Setosa +5.1,3.8,1.6,.2,Setosa +4.6,3.2,1.4,.2,Setosa +5.3,3.7,1.5,.2,Setosa +5,3.3,1.4,.2,Setosa +7,3.2,4.7,1.4,Versicolor +6.4,3.2,4.5,1.5,Versicolor +6.9,3.1,4.9,1.5,Versicolor +5.5,2.3,4,1.3,Versicolor +6.5,2.8,4.6,1.5,Versicolor +5.7,2.8,4.5,1.3,Versicolor +6.3,3.3,4.7,1.6,Versicolor +4.9,2.4,3.3,1,Versicolor +6.6,2.9,4.6,1.3,Versicolor +5.2,2.7,3.9,1.4,Versicolor +5,2,3.5,1,Versicolor +5.9,3,4.2,1.5,Versicolor +6,2.2,4,1,Versicolor +6.1,2.9,4.7,1.4,Versicolor +5.6,2.9,3.6,1.3,Versicolor +6.7,3.1,4.4,1.4,Versicolor +5.6,3,4.5,1.5,Versicolor +5.8,2.7,4.1,1,Versicolor +6.2,2.2,4.5,1.5,Versicolor +5.6,2.5,3.9,1.1,Versicolor +5.9,3.2,4.8,1.8,Versicolor +6.1,2.8,4,1.3,Versicolor +6.3,2.5,4.9,1.5,Versicolor +6.1,2.8,4.7,1.2,Versicolor +6.4,2.9,4.3,1.3,Versicolor +6.6,3,4.4,1.4,Versicolor +6.8,2.8,4.8,1.4,Versicolor +6.7,3,5,1.7,Versicolor +6,2.9,4.5,1.5,Versicolor +5.7,2.6,3.5,1,Versicolor +5.5,2.4,3.8,1.1,Versicolor +5.5,2.4,3.7,1,Versicolor +5.8,2.7,3.9,1.2,Versicolor +6,2.7,5.1,1.6,Versicolor +5.4,3,4.5,1.5,Versicolor +6,3.4,4.5,1.6,Versicolor +6.7,3.1,4.7,1.5,Versicolor +6.3,2.3,4.4,1.3,Versicolor +5.6,3,4.1,1.3,Versicolor +5.5,2.5,4,1.3,Versicolor +5.5,2.6,4.4,1.2,Versicolor +6.1,3,4.6,1.4,Versicolor +5.8,2.6,4,1.2,Versicolor +5,2.3,3.3,1,Versicolor +5.6,2.7,4.2,1.3,Versicolor +5.7,3,4.2,1.2,Versicolor +5.7,2.9,4.2,1.3,Versicolor +6.2,2.9,4.3,1.3,Versicolor +5.1,2.5,3,1.1,Versicolor +5.7,2.8,4.1,1.3,Versicolor +6.3,3.3,6,2.5,Virginica +5.8,2.7,5.1,1.9,Virginica +7.1,3,5.9,2.1,Virginica +6.3,2.9,5.6,1.8,Virginica +6.5,3,5.8,2.2,Virginica +7.6,3,6.6,2.1,Virginica +4.9,2.5,4.5,1.7,Virginica +7.3,2.9,6.3,1.8,Virginica +6.7,2.5,5.8,1.8,Virginica +7.2,3.6,6.1,2.5,Virginica +6.5,3.2,5.1,2,Virginica +6.4,2.7,5.3,1.9,Virginica +6.8,3,5.5,2.1,Virginica +5.7,2.5,5,2,Virginica +5.8,2.8,5.1,2.4,Virginica +6.4,3.2,5.3,2.3,Virginica +6.5,3,5.5,1.8,Virginica +7.7,3.8,6.7,2.2,Virginica +7.7,2.6,6.9,2.3,Virginica +6,2.2,5,1.5,Virginica +6.9,3.2,5.7,2.3,Virginica +5.6,2.8,4.9,2,Virginica +7.7,2.8,6.7,2,Virginica +6.3,2.7,4.9,1.8,Virginica +6.7,3.3,5.7,2.1,Virginica +7.2,3.2,6,1.8,Virginica +6.2,2.8,4.8,1.8,Virginica +6.1,3,4.9,1.8,Virginica +6.4,2.8,5.6,2.1,Virginica +7.2,3,5.8,1.6,Virginica +7.4,2.8,6.1,1.9,Virginica +7.9,3.8,6.4,2,Virginica +6.4,2.8,5.6,2.2,Virginica +6.3,2.8,5.1,1.5,Virginica +6.1,2.6,5.6,1.4,Virginica +7.7,3,6.1,2.3,Virginica +6.3,3.4,5.6,2.4,Virginica +6.4,3.1,5.5,1.8,Virginica +6,3,4.8,1.8,Virginica +6.9,3.1,5.4,2.1,Virginica +6.7,3.1,5.6,2.4,Virginica +6.9,3.1,5.1,2.3,Virginica +5.8,2.7,5.1,1.9,Virginica +6.8,3.2,5.9,2.3,Virginica +6.7,3.3,5.7,2.5,Virginica +6.7,3,5.2,2.3,Virginica +6.3,2.5,5,1.9,Virginica +6.5,3,5.2,2,Virginica +6.2,3.4,5.4,2.3,Virginica +5.9,3,5.1,1.8,Virginica \ No newline at end of file diff --git a/docs/data/reddit.csv b/docs/data/reddit.csv new file mode 100644 index 000000000000..88f91e3df7db --- /dev/null +++ b/docs/data/reddit.csv @@ -0,0 +1,100 @@ +id,name,created_utc,updated_on,comment_karma,link_karma +1,truman48lamb_jasonbroken,1397113470,1536527864,0,0 +2,johnethen06_jasonbroken,1397113483,1536527864,0,0 +3,yaseinrez_jasonbroken,1397113483,1536527864,0,1 +4,Valve92_jasonbroken,1397113503,1536527864,0,0 +5,srbhuyan_jasonbroken,1397113506,1536527864,0,0 +6,taojianlong_jasonbroken,1397113510,1536527864,4,0 +7,YourPalGrant92_jasonbroken,1397113513,1536527864,0,0 +8,Lucki87_jasonbroken,1397113515,1536527864,0,0 +9,punkstock_jasonbroken,1397113517,1536527864,0,0 +10,duder_con_chile_jasonbroken,1397113519,1536527864,0,2 +11,IHaveBigBalls_jasonbroken,1397113520,1536527864,0,0 +12,Foggybanana_jasonbroken,1397113523,1536527864,0,0 +13,Thedrinkdriver_jasonbroken,1397113527,1536527864,-9,0 +14,littlemissd_jasonbroken,1397113530,1536527864,0,-3 +15,phonethaway_jasonbroken,1397113537,1536527864,0,0 +16,DreamingOfWinterfell_jasonbroken,1397113538,1536527864,0,0 +17,ssaig_jasonbroken,1397113544,1536527864,1,0 +18,divinetribe_jasonbroken,1397113549,1536527864,0,0 +19,fdbvfdssdgfds_jasonbroken,1397113552,1536527864,3,0 +20,hjtrsh54yh43_jasonbroken,1397113559,1536527864,-1,-1 +21,Dalin86_jasonbroken,1397113561,1536527864,0,0 +22,sgalex_jasonbroken,1397113561,1536527864,0,0 +23,beszhthw_jasonbroken,1397113566,1536527864,0,0 +24,WojkeN_jasonbroken,1397113572,1536527864,-8,0 +25,LixksHD_jasonbroken,1397113572,1536527864,0,0 +26,bradhrvf78_jasonbroken,1397113574,1536527864,0,0 +27,ravenfeathers_jasonbroken,1397113576,1536527864,0,0 +28,jayne101_jasonbroken,1397113583,1536527864,0,0 +29,jdennis6701_jasonbroken,1397113585,1536527864,0,0 +30,Puppy243_jasonbroken,1397113592,1536527864,0,0 +31,sissyt_jasonbroken,1397113609,1536527864,0,0 +32,fengye78_jasonbroken,1397113613,1536527864,0,0 +33,bigspender1988_jasonbroken,1397113614,1536527864,0,21 +34,bitdownworld_jasonbroken,1397113618,1536527864,0,0 +35,adhyufsdtha12_jasonbroken,1397113619,1536527864,0,0 +36,Haydenac_jasonbroken,1397113635,1536527864,0,0 +37,ihatewhoweare_jasonbroken,1397113636,1536527864,61,0 +38,HungDaddy69__jasonbroken,1397113641,1536527864,0,0 +39,FSUJohnny24_jasonbroken,1397113646,1536527864,0,0 +40,Toejimon_jasonbroken,1397113650,1536527864,0,0 +41,mine69flesh_jasonbroken,1397113651,1536527864,0,0 +42,brycentkt_jasonbroken,1397113653,1536527864,0,0 +43,hmmmitsbig,1397113655,1536527864,0,0 +77714,hockeyschtick,1137474000,1536497404,11104,451 +77715,kbmunkholm,1137474000,1536528267,0,0 +77716,dickb,1137588452,1536528267,0,0 +77717,stephenjcole,1137474000,1536528267,0,2 +77718,rosetree,1137474000,1536528267,0,0 +77719,benhawK,1138180921,1536528267,0,0 +77720,joenowak,1137474000,1536528268,0,0 +77721,constant,1137474000,1536528268,1,0 +77722,jpscott,1137474000,1536528268,0,1 +77723,meryn,1137474000,1536528268,0,2 +77724,momerath,1128916800,1536528268,2490,101 +77725,inuse,1137474000,1536528269,0,0 +77726,dubert11,1137474000,1536528269,38,59 +77727,CaliMark,1137474000,1536528269,0,0 +77728,Maniac,1137474000,1536528269,0,0 +77729,earlpearl,1137474000,1536528269,0,0 +77730,ghost,1137474000,1536497404,767,0 +77731,paulzg,1137474000,1536528270,0,0 +77732,rshawgo,1137474000,1536497404,707,6883 +77733,spage,1137474000,1536528270,0,0 +77734,HrothgarReborn,1137474000,1536528270,0,0 +77735,darknessvisible,1137474000,1536528270,26133,139 +77736,finleyt,1137714898,1536528270,0,0 +77737,Dalton,1137474000,1536528271,118,2 +77738,graemes,1137474000,1536528271,0,0 +77739,lettuce,1137780958,1536497404,4546,724 +77740,mudkicker,1137474000,1536528271,0,0 +77741,mydignet,1139649149,1536528271,0,0 +77742,markbo,1137474000,1536528271,0,0 +77743,mrfrostee,1137474000,1536528272,227,43 +77744,parappayo,1136350800,1536528272,53,164 +77745,danastasi,1137474000,1536528272,2335,146 +77747,AltherrWeb,1137474000,1536528272,1387,1605 +77748,dtpetty,1137474000,1536528273,0,0 +77749,jamesluke4,1137474000,1536528273,0,0 +77750,sankeld,1137474000,1536528273,9,45 +77751,iampivot,1139479524,1536497404,2640,31 +77752,mcaamano,1137474000,1536528273,0,0 +77753,wonsungi,1137596632,1536528273,0,0 +77754,naotakem,1137474000,1536528274,0,0 +77755,bis,1137474000,1536497404,2191,285 +77756,imeinzen,1137474000,1536528274,0,0 +77757,zrenneh,1137474000,1536528274,79,0 +77758,onclephilippe,1137474000,1536528274,0,0 +77759,Mokzaio415,1139422169,1536528274,0,0 +77761,-brisse,1137474000,1536528275,14,1 +77762,coolin86,1138303196,1536528275,40,7 +77763,Lunchy,1137599510,1536528275,65,0 +77764,jannemans,1137474000,1536528275,0,0 +77765,compostellas,1137474000,1536528276,6,0 +77766,genericbob,1137474000,1536528276,291,14 +77767,domlexch,1139482978,1536528276,0,0 +77768,TinheadNed,1139665457,1536497404,4434,103 +77769,patopurifik,1137474000,1536528276,0,0 +77770,PoPPo,1139057558,1536528276,0,0 +77771,tandrews,1137474000,1536528277,0,0 diff --git a/docs/getting-started/expressions.md b/docs/getting-started/expressions.md new file mode 100644 index 000000000000..692806d75de9 --- /dev/null +++ b/docs/getting-started/expressions.md @@ -0,0 +1,130 @@ +# Expressions + +`Expressions` are the core strength of `Polars`. The `expressions` offer a versatile structure that both solves easy queries and is easily extended to complex ones. Below we will cover the basic components that serve as building block (or in `Polars` terminology contexts) for all your queries: + +- `select` +- `filter` +- `with_columns` +- `group_by` + +To learn more about expressions and the context in which they operate, see the User Guide sections: [Contexts](../user-guide/concepts/contexts.md) and [Expressions](../user-guide/concepts/expressions.md). + +### Select statement + +To select a column we need to do two things. Define the `DataFrame` we want the data from. And second, select the data that we need. In the example below you see that we select `col('*')`. The asterisk stands for all columns. + +{{code_block('getting-started/expressions','select',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:setup" +print( + --8<-- "python/getting-started/expressions.py:select" +) +``` + +You can also specify the specific columns that you want to return. There are two ways to do this. The first option is to create a `list` of column names, as seen below. + +{{code_block('getting-started/expressions','select2',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:select2" +) +``` + +The second option is to specify each column within a `list` in the `select` statement. This option is shown below. + +{{code_block('getting-started/expressions','select3',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:select3" +) +``` + +If you want to exclude an entire column from your view, you can simply use `exclude` in your `select` statement. + +{{code_block('getting-started/expressions','exclude',['select'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:exclude" +) +``` + +### Filter + +The `filter` option allows us to create a subset of the `DataFrame`. We use the same `DataFrame` as earlier and we filter between two specified dates. + +{{code_block('getting-started/expressions','filter',['filter'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:filter" +) +``` + +With `filter` you can also create more complex filters that include multiple columns. + +{{code_block('getting-started/expressions','filter2',['filter'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:filter2" +) +``` + +### With_columns + +`with_columns` allows you to create new columns for your analyses. We create two new columns `e` and `b+42`. First we sum all values from column `b` and store the results in column `e`. After that we add `42` to the values of `b`. Creating a new column `b+42` to store these results. + +{{code_block('getting-started/expressions','with_columns',['with_columns'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:with_columns" +) +``` + +### Group by + +We will create a new `DataFrame` for the Group by functionality. This new `DataFrame` will include several 'groups' that we want to group by. + +{{code_block('getting-started/expressions','dataframe2',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:dataframe2" +print(df2) +``` + +{{code_block('getting-started/expressions','group_by',['group_by'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:group_by" +) +``` + +{{code_block('getting-started/expressions','group_by2',['group_by'])}} + +```python exec="on" result="text" session="getting-started/expressions" +print( + --8<-- "python/getting-started/expressions.py:group_by2" +) +``` + +### Combining operations + +Below are some examples on how to combine operations to create the `DataFrame` you require. + +{{code_block('getting-started/expressions','combine',['select','with_columns'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:combine" +``` + +{{code_block('getting-started/expressions','combine2',['select','with_columns'])}} + +```python exec="on" result="text" session="getting-started/expressions" +--8<-- "python/getting-started/expressions.py:combine2" +``` diff --git a/docs/getting-started/installation.md b/docs/getting-started/installation.md new file mode 100644 index 000000000000..b8b8d18441e6 --- /dev/null +++ b/docs/getting-started/installation.md @@ -0,0 +1,31 @@ +# Installation + +Polars is a library and installation is as simple as invoking the package manager of the corresponding programming language. + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` shell + cargo add polars + ``` + +## Importing + +To use the library import it into your project + +=== ":fontawesome-brands-python: Python" + + ``` python + import polars as pl + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` rust + use polars::prelude::*; + ``` diff --git a/docs/getting-started/intro.md b/docs/getting-started/intro.md new file mode 100644 index 000000000000..81d4ac110efc --- /dev/null +++ b/docs/getting-started/intro.md @@ -0,0 +1,16 @@ +# Introduction + +This getting started guide is written for new users of Polars. The goal is to provide a quick overview of the most common functionality. For a more detailed explanation, please go to the [User Guide](../user-guide/index.md) + +!!! rust "Rust Users Only" + + Due to historical reasons the eager API in Rust is outdated. In the future we would like to redesign it as a small wrapper around the lazy API (as is the design in Python / NodeJS). In the examples we will use the lazy API instead with `.lazy()` and `.collect()`. For now you can ignore these two functions. If you want to know more about the lazy and eager API go [here](../user-guide/concepts/lazy-vs-eager.md). + + To enable the Lazy API ensure you have the feature flag `lazy` configured when installing Polars + ``` + # Cargo.toml + [dependencies] + polars = { version = "x", features = ["lazy", ...]} + ``` + + Because of the ownership ruling in Rust we can not reuse the same `DataFrame` multiple times in the examples. For simplicity reasons we call `clone()` to overcome this issue. Note that this does not duplicate the data but just increments a pointer (`Arc`). diff --git a/docs/getting-started/joins.md b/docs/getting-started/joins.md new file mode 100644 index 000000000000..42d875d79144 --- /dev/null +++ b/docs/getting-started/joins.md @@ -0,0 +1,26 @@ +# Combining DataFrames + +There are two ways `DataFrame`s can be combined depending on the use case: join and concat. + +## Join + +Polars supports all types of join (e.g. left, right, inner, outer). Let's have a closer look on how to `join` two `DataFrames` into a single `DataFrame`. Our two `DataFrames` both have an 'id'-like column: `a` and `x`. We can use those columns to `join` the `DataFrames` in this example. + +{{code_block('getting-started/joins','join',['join'])}} + +```python exec="on" result="text" session="getting-started/joins" +--8<-- "python/getting-started/joins.py:setup" +--8<-- "python/getting-started/joins.py:join" +``` + +To see more examples with other types of joins, go the [User Guide](../user-guide/transformations/joins.md). + +## Concat + +We can also `concatenate` two `DataFrames`. Vertical concatenation will make the `DataFrame` longer. Horizontal concatenation will make the `DataFrame` wider. Below you can see the result of an horizontal concatenation of our two `DataFrames`. + +{{code_block('getting-started/joins','hstack',['hstack'])}} + +```python exec="on" result="text" session="getting-started/joins" +--8<-- "python/getting-started/joins.py:hstack" +``` diff --git a/docs/getting-started/reading-writing.md b/docs/getting-started/reading-writing.md new file mode 100644 index 000000000000..ad91be50f0f6 --- /dev/null +++ b/docs/getting-started/reading-writing.md @@ -0,0 +1,45 @@ +# Reading & writing + +Polars supports reading and writing to all common files (e.g. csv, json, parquet), cloud storage (S3, Azure Blob, BigQuery) and databases (e.g. postgres, mysql). In the following examples we will show how to operate on most common file formats. For the following dataframe + +{{code_block('getting-started/reading-writing','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:dataframe" +``` + +#### CSV + +Polars has its own fast implementation for csv reading with many flexible configuration options. + +{{code_block('getting-started/reading-writing','csv',['read_csv','write_csv'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:csv" +``` + +As we can see above, Polars made the datetimes a `string`. We can tell Polars to parse dates, when reading the csv, to ensure the date becomes a datetime. The example can be found below: + +{{code_block('getting-started/reading-writing','csv2',['read_csv'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:csv2" +``` + +#### JSON + +{{code_block('getting-started/reading-writing','json',['read_json','write_json'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:json" +``` + +#### Parquet + +{{code_block('getting-started/reading-writing','parquet',['read_parquet','write_parquet'])}} + +```python exec="on" result="text" session="getting-started/reading" +--8<-- "python/getting-started/reading-writing.py:parquet" +``` + +To see more examples and other data formats go to the [User Guide](../user-guide/io/csv.md), section IO. diff --git a/docs/getting-started/series-dataframes.md b/docs/getting-started/series-dataframes.md new file mode 100644 index 000000000000..07e05c194b93 --- /dev/null +++ b/docs/getting-started/series-dataframes.md @@ -0,0 +1,102 @@ +# Series & DataFrames + +The core base data structures provided by Polars are `Series` and `DataFrames`. + +## Series + +Series are a 1-dimensional data structure. Within a series all elements have the same data type (e.g. int, string). +The snippet below shows how to create a simple named `Series` object. In a later section of this getting started guide we will learn how to read data from external sources (e.g. files, database), for now lets keep it simple. + +{{code_block('getting-started/series-dataframes','series',['Series'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:series" +``` + +### Methods + +Although it is more common to work directly on a `DataFrame` object, `Series` implement a number of base methods which make it easy to perform transformations. Below are some examples of common operations you might want to perform. Note that these are for illustration purposes and only show a small subset of what is available. + +##### Aggregations + +`Series` out of the box supports all basic aggregations (e.g. min, max, mean, mode, ...). + +{{code_block('getting-started/series-dataframes','minmax',['min','max'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:minmax" +``` + +##### String + +There are a number of methods related to string operations in the `StringNamespace`. These only work on `Series` with the Datatype `Utf8`. + +{{code_block('getting-started/series-dataframes','string',['replace'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:string" +``` + +##### Datetime + +Similar to strings, there is a separate namespace for datetime related operations in the `DateLikeNameSpace`. These only work on `Series`with DataTypes related to dates. + +{{code_block('getting-started/series-dataframes','dt',['day'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:dt" +``` + +## DataFrame + +A `DataFrame` is a 2-dimensional data structure that is backed by a `Series`, and it could be seen as an abstraction of on collection (e.g. list) of `Series`. Operations that can be executed on `DataFrame` are very similar to what is done in a `SQL` like query. You can `GROUP BY`, `JOIN`, `PIVOT`, but also define custom functions. In the next pages we will cover how to perform these transformations. + +{{code_block('getting-started/series-dataframes','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:dataframe" +``` + +### Viewing data + +This part focuses on viewing data in a `DataFrame`. We will use the `DataFrame` from the previous example as a starting point. + +#### Head + +The `head` function shows by default the first 5 rows of a `DataFrame`. You can specify the number of rows you want to see (e.g. `df.head(10)`). + +{{code_block('getting-started/series-dataframes','head',['head'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:head" +``` + +#### Tail + +The `tail` function shows the last 5 rows of a `DataFrame`. You can also specify the number of rows you want to see, similar to `head`. + +{{code_block('getting-started/series-dataframes','tail',['tail'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:tail" +``` + +#### Sample + +If you want to get an impression of the data of your `DataFrame`, you can also use `sample`. With `sample` you get an _n_ number of random rows from the `DataFrame`. + +{{code_block('getting-started/series-dataframes','sample',['sample'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:sample" +``` + +#### Describe + +`Describe` returns summary statistics of your `DataFrame`. It will provide several quick statistics if possible. + +{{code_block('getting-started/series-dataframes','describe',['describe'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:describe" +``` diff --git a/docs/images/.gitignore b/docs/images/.gitignore new file mode 100644 index 000000000000..72e8ffc0db8a --- /dev/null +++ b/docs/images/.gitignore @@ -0,0 +1 @@ +* diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000000..2621ba4ee11d --- /dev/null +++ b/docs/index.md @@ -0,0 +1,71 @@ +--- +hide: + - navigation +--- + +# Polars + +![logo](https://raw.githubusercontent.com/pola-rs/polars-static/master/logos/polars_github_logo_rect_dark_name.svg) + +

Blazingly Fast DataFrame Library

+ + +Polars is a highly performant DataFrame library for manipulating structured data. The core is written in Rust, but the library is also available in Python. Its key features are: + +- **Fast**: Polars is written from the ground up, designed close to the machine and without external dependencies. +- **I/O**: First class support for all common data storage layers: local, cloud storage & databases. +- **Easy to use**: Write your queries the way they were intended. Polars, internally, will determine the most efficient way to execute using its query optimizer. +- **Out of Core**: Polars supports out of core data transformation with its streaming API. Allowing you to process your results without requiring all your data to be in memory at the same time +- **Parallel**: Polars fully utilises the power of your machine by dividing the workload among the available CPU cores without any additional configuration. +- **Vectorized Query Engine**: Polars uses [Apache Arrow](https://arrow.apache.org/), a columnar data format, to process your queries in a vectorized manner. It uses [SIMD](https://en.wikipedia.org/wiki/Single_instruction,_multiple_data) to optimize CPU usage. + +## About this guide + +The `Polars` user guide is intended to live alongside the API documentation. Its purpose is to explain (new) users how to use `Polars` and to provide meaningful examples. The guide is split into two parts: + +- [Getting Started](getting-started/intro.md): A 10 minute helicopter view of the library and its primary function. +- [User Guide](user-guide/index.md): A detailed explanation of how the library is setup and how to use it most effectively. + +If you are looking for details on a specific level / object, it is probably best to go the API documentation: [Python](https://pola-rs.github.io/polars/py-polars/html/reference/index.html) | [Rust](https://docs.rs/polars/latest/polars/). + +## Performance :rocket: :rocket: + +`Polars` is very fast, and in fact is one of the best performing solutions available. +See the results in h2oai's [db-benchmark](https://duckdblabs.github.io/db-benchmark/), revived by the DuckDB project. + +`Polars` [TPCH Benchmark results](https://www.pola.rs/benchmarks.html) are now available on the official website. + +## Example + +{{code_block('home/example','example',['scan_csv','filter','group_by','collect'])}} + +## Sponsors + +[](https://www.xomnia.com/)   [](https://www.jetbrains.com) + +## Community + +`Polars` has a very active community with frequent releases (approximately weekly). Below are some of the top contributors to the project: + +--8<-- "docs/people.md" + +## Contribute + +Thanks for taking the time to contribute! We appreciate all contributions, from reporting bugs to implementing new features. If you're unclear on how to proceed read our [contribution guide](https://github.com/pola-rs/polars/blob/main/CONTRIBUTING.md) or contact us on [discord](https://discord.com/invite/4UfP5cfBE7). + +## License + +This project is licensed under the terms of the MIT license. diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000000..2c317b06415b --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,9 @@ +pandas +pyarrow +graphviz +matplotlib + +mkdocs-material==9.2.5 +mkdocs-macros-plugin==1.0.4 +markdown-exec[ansi]==1.6.0 +PyGithub==1.59.1 diff --git a/docs/src/python/getting-started/expressions.py b/docs/src/python/getting-started/expressions.py new file mode 100644 index 000000000000..ea73e0819a90 --- /dev/null +++ b/docs/src/python/getting-started/expressions.py @@ -0,0 +1,91 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np +from datetime import datetime + +df = pl.DataFrame( + { + "a": np.arange(0, 8), + "b": np.random.rand(8), + "c": [ + datetime(2022, 12, 1), + datetime(2022, 12, 2), + datetime(2022, 12, 3), + datetime(2022, 12, 4), + datetime(2022, 12, 5), + datetime(2022, 12, 6), + datetime(2022, 12, 7), + datetime(2022, 12, 8), + ], + "d": [1, 2.0, np.NaN, np.NaN, 0, -5, -42, None], + } +) +# --8<-- [end:setup] + +# --8<-- [start:select] +df.select(pl.col("*")) +# --8<-- [end:select] + +# --8<-- [start:select2] +df.select(pl.col(["a", "b"])) +# --8<-- [end:select2] + +# --8<-- [start:select3] +df.select([pl.col("a"), pl.col("b")]).limit(3) +# --8<-- [end:select3] + +# --8<-- [start:exclude] +df.select([pl.exclude("a")]) +# --8<-- [end:exclude] + +# --8<-- [start:filter] +df.filter( + pl.col("c").is_between(datetime(2022, 12, 2), datetime(2022, 12, 8)), +) +# --8<-- [end:filter] + +# --8<-- [start:filter2] +df.filter((pl.col("a") <= 3) & (pl.col("d").is_not_nan())) +# --8<-- [end:filter2] + +# --8<-- [start:with_columns] +df.with_columns([pl.col("b").sum().alias("e"), (pl.col("b") + 42).alias("b+42")]) +# --8<-- [end:with_columns] + +# --8<-- [start:dataframe2] +df2 = pl.DataFrame( + { + "x": np.arange(0, 8), + "y": ["A", "A", "A", "B", "B", "C", "X", "X"], + } +) +# --8<-- [end:dataframe2] + +# --8<-- [start:group_by] +df2.group_by("y", maintain_order=True).count() +# --8<-- [end:group_by] + +# --8<-- [start:group_by2] +df2.group_by("y", maintain_order=True).agg( + [ + pl.col("*").count().alias("count"), + pl.col("*").sum().alias("sum"), + ] +) +# --8<-- [end:group_by2] + +# --8<-- [start:combine] +df_x = df.with_columns((pl.col("a") * pl.col("b")).alias("a * b")).select( + [pl.all().exclude(["c", "d"])] +) + +print(df_x) +# --8<-- [end:combine] + +# --8<-- [start:combine2] +df_y = df.with_columns([(pl.col("a") * pl.col("b")).alias("a * b")]).select( + [pl.all().exclude("d")] +) + +print(df_y) +# --8<-- [end:combine2] diff --git a/docs/src/python/getting-started/joins.py b/docs/src/python/getting-started/joins.py new file mode 100644 index 000000000000..e5a52416eef1 --- /dev/null +++ b/docs/src/python/getting-started/joins.py @@ -0,0 +1,29 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np + +# --8<-- [end:setup] + +# --8<-- [start:join] +df = pl.DataFrame( + { + "a": np.arange(0, 8), + "b": np.random.rand(8), + "d": [1, 2.0, np.NaN, np.NaN, 0, -5, -42, None], + } +) + +df2 = pl.DataFrame( + { + "x": np.arange(0, 8), + "y": ["A", "A", "A", "B", "B", "C", "X", "X"], + } +) +joined = df.join(df2, left_on="a", right_on="x") +print(joined) +# --8<-- [end:join] + +# --8<-- [start:hstack] +stacked = df.hstack(df2) +print(stacked) +# --8<-- [end:hstack] diff --git a/docs/src/python/getting-started/reading-writing.py b/docs/src/python/getting-started/reading-writing.py new file mode 100644 index 000000000000..dc8a54ebd18f --- /dev/null +++ b/docs/src/python/getting-started/reading-writing.py @@ -0,0 +1,41 @@ +# --8<-- [start:dataframe] +import polars as pl +from datetime import datetime + +df = pl.DataFrame( + { + "integer": [1, 2, 3], + "date": [ + datetime(2022, 1, 1), + datetime(2022, 1, 2), + datetime(2022, 1, 3), + ], + "float": [4.0, 5.0, 6.0], + } +) + +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:csv] +df.write_csv("docs/data/output.csv") +df_csv = pl.read_csv("docs/data/output.csv") +print(df_csv) +# --8<-- [end:csv] + +# --8<-- [start:csv2] +df_csv = pl.read_csv("docs/data/output.csv", try_parse_dates=True) +print(df_csv) +# --8<-- [end:csv2] + +# --8<-- [start:json] +df.write_json("docs/data/output.json") +df_json = pl.read_json("docs/data/output.json") +print(df_json) +# --8<-- [end:json] + +# --8<-- [start:parquet] +df.write_parquet("docs/data/output.parquet") +df_parquet = pl.read_parquet("docs/data/output.parquet") +print(df_parquet) +# --8<-- [end:parquet] diff --git a/docs/src/python/getting-started/series-dataframes.py b/docs/src/python/getting-started/series-dataframes.py new file mode 100644 index 000000000000..6f2fdf265c22 --- /dev/null +++ b/docs/src/python/getting-started/series-dataframes.py @@ -0,0 +1,64 @@ +# --8<-- [start:series] +import polars as pl + +s = pl.Series("a", [1, 2, 3, 4, 5]) +print(s) +# --8<-- [end:series] + +# --8<-- [start:minmax] +s = pl.Series("a", [1, 2, 3, 4, 5]) +print(s.min()) +print(s.max()) +# --8<-- [end:minmax] + +# --8<-- [start:string] +s = pl.Series("a", ["polar", "bear", "arctic", "polar fox", "polar bear"]) +s2 = s.str.replace("polar", "pola") +print(s2) +# --8<-- [end:string] + +# --8<-- [start:dt] +from datetime import date + +start = date(2001, 1, 1) +stop = date(2001, 1, 9) +s = pl.date_range(start, stop, interval="2d", eager=True) +s.dt.day() +print(s) +# --8<-- [end:dt] + +# --8<-- [start:dataframe] +from datetime import datetime + +df = pl.DataFrame( + { + "integer": [1, 2, 3, 4, 5], + "date": [ + datetime(2022, 1, 1), + datetime(2022, 1, 2), + datetime(2022, 1, 3), + datetime(2022, 1, 4), + datetime(2022, 1, 5), + ], + "float": [4.0, 5.0, 6.0, 7.0, 8.0], + } +) + +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:head] +print(df.head(3)) +# --8<-- [end:head] + +# --8<-- [start:tail] +print(df.tail(3)) +# --8<-- [end:tail] + +# --8<-- [start:sample] +print(df.sample(2)) +# --8<-- [end:sample] + +# --8<-- [start:describe] +print(df.describe()) +# --8<-- [end:describe] diff --git a/docs/src/python/home/example.py b/docs/src/python/home/example.py new file mode 100644 index 000000000000..5f675f4e82e4 --- /dev/null +++ b/docs/src/python/home/example.py @@ -0,0 +1,12 @@ +# --8<-- [start:example] +import polars as pl + +q = ( + pl.scan_csv("docs/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.all().sum()) +) + +df = q.collect() +# --8<-- [end:example] diff --git a/docs/src/python/user-guide/concepts/contexts.py b/docs/src/python/user-guide/concepts/contexts.py new file mode 100644 index 000000000000..ea3baf965b52 --- /dev/null +++ b/docs/src/python/user-guide/concepts/contexts.py @@ -0,0 +1,55 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np + +np.random.seed(12) +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", None], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "C", "B"], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:select] + +out = df.select( + pl.sum("nrs"), + pl.col("names").sort(), + pl.col("names").first().alias("first name"), + (pl.mean("nrs") * 10).alias("10xnrs"), +) +print(out) +# --8<-- [end:select] + +# --8<-- [start:filter] +out = df.filter(pl.col("nrs") > 2) +print(out) +# --8<-- [end:filter] + +# --8<-- [start:with_columns] + +df = df.with_columns( + pl.sum("nrs").alias("nrs_sum"), + pl.col("random").count().alias("count"), +) +print(df) +# --8<-- [end:with_columns] + + +# --8<-- [start:group_by] +out = df.group_by("groups").agg( + pl.sum("nrs"), # sum nrs by groups + pl.col("random").count().alias("count"), # count group members + # sum random where name != null + pl.col("random").filter(pl.col("names").is_not_null()).sum().suffix("_sum"), + pl.col("names").reverse().alias("reversed names"), +) +print(out) +# --8<-- [end:group_by] diff --git a/docs/src/python/user-guide/concepts/expressions.py b/docs/src/python/user-guide/concepts/expressions.py new file mode 100644 index 000000000000..83e6c4514c23 --- /dev/null +++ b/docs/src/python/user-guide/concepts/expressions.py @@ -0,0 +1,16 @@ +import polars as pl + +df = pl.DataFrame( + { + "foo": [1, 2, 3, None, 5], + "bar": ["foo", "ham", "spam", "egg", None], + } +) + +# --8<-- [start:example1] +pl.col("foo").sort().head(2) +# --8<-- [end:example1] + +# --8<-- [start:example2] +df.select(pl.col("foo").sort().head(2), pl.col("bar").filter(pl.col("foo") == 1).sum()) +# --8<-- [end:example2] diff --git a/docs/src/python/user-guide/concepts/lazy-vs-eager.py b/docs/src/python/user-guide/concepts/lazy-vs-eager.py new file mode 100644 index 000000000000..1327bac6357a --- /dev/null +++ b/docs/src/python/user-guide/concepts/lazy-vs-eager.py @@ -0,0 +1,20 @@ +import polars as pl + +# --8<-- [start:eager] + +df = pl.read_csv("docs/data/iris.csv") +df_small = df.filter(pl.col("sepal_length") > 5) +df_agg = df_small.group_by("species").agg(pl.col("sepal_width").mean()) +print(df_agg) +# --8<-- [end:eager] + +# --8<-- [start:lazy] +q = ( + pl.scan_csv("docs/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.col("sepal_width").mean()) +) + +df = q.collect() +# --8<-- [end:lazy] diff --git a/docs/src/python/user-guide/concepts/streaming.py b/docs/src/python/user-guide/concepts/streaming.py new file mode 100644 index 000000000000..955750bf6c30 --- /dev/null +++ b/docs/src/python/user-guide/concepts/streaming.py @@ -0,0 +1,12 @@ +import polars as pl + +# --8<-- [start:streaming] +q = ( + pl.scan_csv("docs/data/iris.csv") + .filter(pl.col("sepal_length") > 5) + .group_by("species") + .agg(pl.col("sepal_width").mean()) +) + +df = q.collect(streaming=True) +# --8<-- [end:streaming] diff --git a/docs/src/python/user-guide/expressions/aggregation.py b/docs/src/python/user-guide/expressions/aggregation.py new file mode 100644 index 000000000000..55a986164fbd --- /dev/null +++ b/docs/src/python/user-guide/expressions/aggregation.py @@ -0,0 +1,169 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import date + +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +url = "https://theunitedstates.io/congress-legislators/legislators-historical.csv" + +dtypes = { + "first_name": pl.Categorical, + "gender": pl.Categorical, + "type": pl.Categorical, + "state": pl.Categorical, + "party": pl.Categorical, +} + +dataset = pl.read_csv(url, dtypes=dtypes).with_columns( + pl.col("birthday").str.strptime(pl.Date, strict=False) +) +# --8<-- [end:dataframe] + +# --8<-- [start:basic] +q = ( + dataset.lazy() + .group_by("first_name") + .agg( + pl.count(), + pl.col("gender"), + pl.first("last_name"), + ) + .sort("count", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:basic] + +# --8<-- [start:conditional] +q = ( + dataset.lazy() + .group_by("state") + .agg( + (pl.col("party") == "Anti-Administration").sum().alias("anti"), + (pl.col("party") == "Pro-Administration").sum().alias("pro"), + ) + .sort("pro", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:conditional] + +# --8<-- [start:nested] +q = ( + dataset.lazy() + .group_by("state", "party") + .agg(pl.count("party").alias("count")) + .filter( + (pl.col("party") == "Anti-Administration") + | (pl.col("party") == "Pro-Administration") + ) + .sort("count", descending=True) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:nested] + + +# --8<-- [start:filter] +def compute_age() -> pl.Expr: + return date(2021, 1, 1).year - pl.col("birthday").dt.year() + + +def avg_birthday(gender: str) -> pl.Expr: + return ( + compute_age() + .filter(pl.col("gender") == gender) + .mean() + .alias(f"avg {gender} birthday") + ) + + +q = ( + dataset.lazy() + .group_by("state") + .agg( + avg_birthday("M"), + avg_birthday("F"), + (pl.col("gender") == "M").sum().alias("# male"), + (pl.col("gender") == "F").sum().alias("# female"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:filter] + + +# --8<-- [start:sort] +def get_person() -> pl.Expr: + return pl.col("first_name") + pl.lit(" ") + pl.col("last_name") + + +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort] + + +# --8<-- [start:sort2] +def get_person() -> pl.Expr: + return pl.col("first_name") + pl.lit(" ") + pl.col("last_name") + + +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort().first().alias("alphabetical_first"), + ) + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort2] + + +# --8<-- [start:sort3] +def get_person() -> pl.Expr: + return pl.col("first_name") + pl.lit(" ") + pl.col("last_name") + + +q = ( + dataset.lazy() + .sort("birthday", descending=True) + .group_by("state") + .agg( + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort().first().alias("alphabetical_first"), + pl.col("gender").sort_by("first_name").first().alias("gender"), + ) + .sort("state") + .limit(5) +) + +df = q.collect() +print(df) +# --8<-- [end:sort3] diff --git a/docs/src/python/user-guide/expressions/casting.py b/docs/src/python/user-guide/expressions/casting.py new file mode 100644 index 000000000000..7a57ac13656f --- /dev/null +++ b/docs/src/python/user-guide/expressions/casting.py @@ -0,0 +1,129 @@ +# --8<-- [start:setup] + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:dfnum] +df = pl.DataFrame( + { + "integers": [1, 2, 3, 4, 5], + "big_integers": [1, 10000002, 3, 10000004, 10000005], + "floats": [4.0, 5.0, 6.0, 7.0, 8.0], + "floats_with_decimal": [4.532, 5.5, 6.5, 7.5, 8.5], + } +) + +print(df) +# --8<-- [end:dfnum] + +# --8<-- [start:castnum] +out = df.select( + pl.col("integers").cast(pl.Float32).alias("integers_as_floats"), + pl.col("floats").cast(pl.Int32).alias("floats_as_integers"), + pl.col("floats_with_decimal") + .cast(pl.Int32) + .alias("floats_with_decimal_as_integers"), +) +print(out) +# --8<-- [end:castnum] + + +# --8<-- [start:downcast] +out = df.select( + pl.col("integers").cast(pl.Int16).alias("integers_smallfootprint"), + pl.col("floats").cast(pl.Float32).alias("floats_smallfootprint"), +) +print(out) +# --8<-- [end:downcast] + +# --8<-- [start:overflow] +try: + out = df.select(pl.col("big_integers").cast(pl.Int8)) + print(out) +except Exception as e: + print(e) +# --8<-- [end:overflow] + +# --8<-- [start:overflow2] +out = df.select(pl.col("big_integers").cast(pl.Int8, strict=False)) +print(out) +# --8<-- [end:overflow2] + + +# --8<-- [start:strings] +df = pl.DataFrame( + { + "integers": [1, 2, 3, 4, 5], + "float": [4.0, 5.03, 6.0, 7.0, 8.0], + "floats_as_string": ["4.0", "5.0", "6.0", "7.0", "8.0"], + } +) + +out = df.select( + pl.col("integers").cast(pl.Utf8), + pl.col("float").cast(pl.Utf8), + pl.col("floats_as_string").cast(pl.Float64), +) +print(out) +# --8<-- [end:strings] + + +# --8<-- [start:strings2] +df = pl.DataFrame({"strings_not_float": ["4.0", "not_a_number", "6.0", "7.0", "8.0"]}) +try: + out = df.select(pl.col("strings_not_float").cast(pl.Float64)) + print(out) +except Exception as e: + print(e) +# --8<-- [end:strings2] + +# --8<-- [start:bool] +df = pl.DataFrame( + { + "integers": [-1, 0, 2, 3, 4], + "floats": [0.0, 1.0, 2.0, 3.0, 4.0], + "bools": [True, False, True, False, True], + } +) + +out = df.select(pl.col("integers").cast(pl.Boolean), pl.col("floats").cast(pl.Boolean)) +print(out) +# --8<-- [end:bool] + +# --8<-- [start:dates] +from datetime import date, datetime + +df = pl.DataFrame( + { + "date": pl.date_range(date(2022, 1, 1), date(2022, 1, 5), eager=True), + "datetime": pl.datetime_range( + datetime(2022, 1, 1), datetime(2022, 1, 5), eager=True + ), + } +) + +out = df.select(pl.col("date").cast(pl.Int64), pl.col("datetime").cast(pl.Int64)) +print(out) +# --8<-- [end:dates] + +# --8<-- [start:dates2] +df = pl.DataFrame( + { + "date": pl.date_range(date(2022, 1, 1), date(2022, 1, 5), eager=True), + "string": [ + "2022-01-01", + "2022-01-02", + "2022-01-03", + "2022-01-04", + "2022-01-05", + ], + } +) + +out = df.select( + pl.col("date").dt.strftime("%Y-%m-%d"), + pl.col("string").str.strptime(pl.Datetime, "%Y-%m-%d"), +) +print(out) +# --8<-- [end:dates2] diff --git a/docs/src/python/user-guide/expressions/column-selections.py b/docs/src/python/user-guide/expressions/column-selections.py new file mode 100644 index 000000000000..88951eaee831 --- /dev/null +++ b/docs/src/python/user-guide/expressions/column-selections.py @@ -0,0 +1,91 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:selectors_df] +from datetime import date, datetime + +df = pl.DataFrame( + { + "id": [9, 4, 2], + "place": ["Mars", "Earth", "Saturn"], + "date": pl.date_range(date(2022, 1, 1), date(2022, 1, 3), "1d", eager=True), + "sales": [33.4, 2142134.1, 44.7], + "has_people": [False, True, False], + "logged_at": pl.datetime_range( + datetime(2022, 12, 1), datetime(2022, 12, 1, 0, 0, 2), "1s", eager=True + ), + } +).with_row_count("rn") +print(df) +# --8<-- [end:selectors_df] + +# --8<-- [start:all] +out = df.select(pl.col("*")) + +# Is equivalent to +out = df.select(pl.all()) +print(out) +# --8<-- [end:all] + +# --8<-- [start:exclude] +out = df.select(pl.col("*").exclude("logged_at", "rn")) +print(out) +# --8<-- [end:exclude] + +# --8<-- [start:expansion_by_names] +out = df.select(pl.col("date", "logged_at").dt.to_string("%Y-%h-%d")) +print(out) +# --8<-- [end:expansion_by_names] + +# --8<-- [start:expansion_by_regex] +out = df.select(pl.col("^.*(as|sa).*$")) +print(out) +# --8<-- [end:expansion_by_regex] + +# --8<-- [start:expansion_by_dtype] +out = df.select(pl.col(pl.Int64, pl.UInt32, pl.Boolean).n_unique()) +print(out) +# --8<-- [end:expansion_by_dtype] + +# --8<-- [start:selectors_intro] +import polars.selectors as cs + +out = df.select(cs.integer(), cs.string()) +print(out) +# --8<-- [end:selectors_intro] + +# --8<-- [start:selectors_diff] +out = df.select(cs.numeric() - cs.first()) +print(out) +# --8<-- [end:selectors_diff] + +# --8<-- [start:selectors_union] +out = df.select(cs.by_name("rn") | ~cs.numeric()) +print(out) +# --8<-- [end:selectors_union] + +# --8<-- [start:selectors_by_name] +out = df.select(cs.contains("rn"), cs.matches(".*_.*")) +print(out) +# --8<-- [end:selectors_by_name] + +# --8<-- [start:selectors_to_expr] +out = df.select(cs.temporal().as_expr().dt.to_string("%Y-%h-%d")) +print(out) +# --8<-- [end:selectors_to_expr] + +# --8<-- [start:selectors_is_selector_utility] +from polars.selectors import is_selector + +out = cs.temporal() +print(is_selector(out)) +# --8<-- [end:selectors_is_selector_utility] + +# --8<-- [start:selectors_colnames_utility] +from polars.selectors import expand_selector + +out = cs.temporal().as_expr().dt.to_string("%Y-%h-%d") +print(expand_selector(df, out)) +# --8<-- [end:selectors_colnames_utility] diff --git a/docs/src/python/user-guide/expressions/folds.py b/docs/src/python/user-guide/expressions/folds.py new file mode 100644 index 000000000000..803591b5b581 --- /dev/null +++ b/docs/src/python/user-guide/expressions/folds.py @@ -0,0 +1,50 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:mansum] +df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [10, 20, 30], + } +) + +out = df.select( + pl.fold(acc=pl.lit(0), function=lambda acc, x: acc + x, exprs=pl.all()).alias( + "sum" + ), +) +print(out) +# --8<-- [end:mansum] + +# --8<-- [start:conditional] +df = pl.DataFrame( + { + "a": [1, 2, 3], + "b": [0, 1, 2], + } +) + +out = df.filter( + pl.fold( + acc=pl.lit(True), + function=lambda acc, x: acc & x, + exprs=pl.col("*") > 1, + ) +) +print(out) +# --8<-- [end:conditional] + +# --8<-- [start:string] +df = pl.DataFrame( + { + "a": ["a", "b", "c"], + "b": [1, 2, 3], + } +) + +out = df.select(pl.concat_str(["a", "b"])) +print(out) +# --8<-- [end:string] diff --git a/docs/src/python/user-guide/expressions/functions.py b/docs/src/python/user-guide/expressions/functions.py new file mode 100644 index 000000000000..5f9bbd5bb1da --- /dev/null +++ b/docs/src/python/user-guide/expressions/functions.py @@ -0,0 +1,60 @@ +# --8<-- [start:setup] + +import polars as pl +import numpy as np + +np.random.seed(12) +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", "spam"], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "C", "B"], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:samename] +df_samename = df.select(pl.col("nrs") + 5) +print(df_samename) +# --8<-- [end:samename] + + +# --8<-- [start:samenametwice] +try: + df_samename2 = df.select(pl.col("nrs") + 5, pl.col("nrs") - 5) + print(df_samename2) +except Exception as e: + print(e) +# --8<-- [end:samenametwice] + +# --8<-- [start:samenamealias] +df_alias = df.select( + (pl.col("nrs") + 5).alias("nrs + 5"), + (pl.col("nrs") - 5).alias("nrs - 5"), +) +print(df_alias) +# --8<-- [end:samenamealias] + +# --8<-- [start:countunique] +df_alias = df.select( + pl.col("names").n_unique().alias("unique"), + pl.approx_n_unique("names").alias("unique_approx"), +) +print(df_alias) +# --8<-- [end:countunique] + +# --8<-- [start:conditional] +df_conditional = df.select( + pl.col("nrs"), + pl.when(pl.col("nrs") > 2) + .then(pl.lit(True)) + .otherwise(pl.lit(False)) + .alias("conditional"), +) +print(df_conditional) +# --8<-- [end:conditional] diff --git a/docs/src/python/user-guide/expressions/lists.py b/docs/src/python/user-guide/expressions/lists.py new file mode 100644 index 000000000000..d81dac154461 --- /dev/null +++ b/docs/src/python/user-guide/expressions/lists.py @@ -0,0 +1,111 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:weather_df] +weather = pl.DataFrame( + { + "station": ["Station " + str(x) for x in range(1, 6)], + "temperatures": [ + "20 5 5 E1 7 13 19 9 6 20", + "18 8 16 11 23 E2 8 E2 E2 E2 90 70 40", + "19 24 E9 16 6 12 10 22", + "E2 E0 15 7 8 10 E1 24 17 13 6", + "14 8 E0 16 22 24 E1", + ], + } +) +print(weather) +# --8<-- [end:weather_df] + +# --8<-- [start:string_to_list] +out = weather.with_columns(pl.col("temperatures").str.split(" ")) +print(out) +# --8<-- [end:string_to_list] + +# --8<-- [start:explode_to_atomic] +out = weather.with_columns(pl.col("temperatures").str.split(" ")).explode( + "temperatures" +) +print(out) +# --8<-- [end:explode_to_atomic] + +# --8<-- [start:list_ops] +out = weather.with_columns(pl.col("temperatures").str.split(" ")).with_columns( + pl.col("temperatures").list.head(3).alias("top3"), + pl.col("temperatures").list.slice(-3, 3).alias("bottom_3"), + pl.col("temperatures").list.lengths().alias("obs"), +) +print(out) +# --8<-- [end:list_ops] + + +# --8<-- [start:count_errors] +out = weather.with_columns( + pl.col("temperatures") + .str.split(" ") + .list.eval(pl.element().cast(pl.Int64, strict=False).is_null()) + .list.sum() + .alias("errors") +) +print(out) +# --8<-- [end:count_errors] + +# --8<-- [start:count_errors_regex] +out = weather.with_columns( + pl.col("temperatures") + .str.split(" ") + .list.eval(pl.element().str.contains("(?i)[a-z]")) + .list.sum() + .alias("errors") +) +print(out) +# --8<-- [end:count_errors_regex] + +# --8<-- [start:weather_by_day] +weather_by_day = pl.DataFrame( + { + "station": ["Station " + str(x) for x in range(1, 11)], + "day_1": [17, 11, 8, 22, 9, 21, 20, 8, 8, 17], + "day_2": [15, 11, 10, 8, 7, 14, 18, 21, 15, 13], + "day_3": [16, 15, 24, 24, 8, 23, 19, 23, 16, 10], + } +) +print(weather_by_day) +# --8<-- [end:weather_by_day] + +# --8<-- [start:weather_by_day_rank] +rank_pct = (pl.element().rank(descending=True) / pl.col("*").count()).round(2) + +out = weather_by_day.with_columns( + # create the list of homogeneous data + pl.concat_list(pl.all().exclude("station")).alias("all_temps") +).select( + # select all columns except the intermediate list + pl.all().exclude("all_temps"), + # compute the rank by calling `list.eval` + pl.col("all_temps").list.eval(rank_pct, parallel=True).alias("temps_rank"), +) + +print(out) +# --8<-- [end:weather_by_day_rank] + +# --8<-- [start:array_df] +array_df = pl.DataFrame( + [ + pl.Series("Array_1", [[1, 3], [2, 5]]), + pl.Series("Array_2", [[1, 7, 3], [8, 1, 0]]), + ], + schema={"Array_1": pl.Array(2, pl.Int64), "Array_2": pl.Array(3, pl.Int64)}, +) +print(array_df) +# --8<-- [end:array_df] + +# --8<-- [start:array_ops] +out = array_df.select( + pl.col("Array_1").arr.min().suffix("_min"), + pl.col("Array_2").arr.sum().suffix("_sum"), +) +print(out) +# --8<-- [end:array_ops] diff --git a/docs/src/python/user-guide/expressions/null.py b/docs/src/python/user-guide/expressions/null.py new file mode 100644 index 000000000000..4641773bbb85 --- /dev/null +++ b/docs/src/python/user-guide/expressions/null.py @@ -0,0 +1,88 @@ +# --8<-- [start:setup] +import polars as pl +import numpy as np + +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "value": [1, None], + }, +) +print(df) +# --8<-- [end:dataframe] + + +# --8<-- [start:count] +null_count_df = df.null_count() +print(null_count_df) +# --8<-- [end:count] + + +# --8<-- [start:isnull] +is_null_series = df.select( + pl.col("value").is_null(), +) +print(is_null_series) +# --8<-- [end:isnull] + + +# --8<-- [start:dataframe2] +df = pl.DataFrame( + { + "col1": [1, 2, 3], + "col2": [1, None, 3], + }, +) +print(df) +# --8<-- [end:dataframe2] + + +# --8<-- [start:fill] +fill_literal_df = ( + df.with_columns( + pl.col("col2").fill_null( + pl.lit(2), + ), + ), +) +print(fill_literal_df) +# --8<-- [end:fill] + +# --8<-- [start:fillstrategy] +fill_forward_df = df.with_columns( + pl.col("col2").fill_null(strategy="forward"), +) +print(fill_forward_df) +# --8<-- [end:fillstrategy] + +# --8<-- [start:fillexpr] +fill_median_df = df.with_columns( + pl.col("col2").fill_null(pl.median("col2")), +) +print(fill_median_df) +# --8<-- [end:fillexpr] + +# --8<-- [start:fillinterpolate] +fill_interpolation_df = df.with_columns( + pl.col("col2").interpolate(), +) +print(fill_interpolation_df) +# --8<-- [end:fillinterpolate] + +# --8<-- [start:nan] +nan_df = pl.DataFrame( + { + "value": [1.0, np.NaN, float("nan"), 3.0], + }, +) +print(nan_df) +# --8<-- [end:nan] + +# --8<-- [start:nanfill] +mean_nan_df = nan_df.with_columns( + pl.col("value").fill_nan(None).alias("value"), +).mean() +print(mean_nan_df) +# --8<-- [end:nanfill] diff --git a/docs/src/python/user-guide/expressions/numpy-example.py b/docs/src/python/user-guide/expressions/numpy-example.py new file mode 100644 index 000000000000..d3300591c4d6 --- /dev/null +++ b/docs/src/python/user-guide/expressions/numpy-example.py @@ -0,0 +1,7 @@ +import polars as pl +import numpy as np + +df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + +out = df.select(np.log(pl.all()).suffix("_log")) +print(out) diff --git a/docs/src/python/user-guide/expressions/operators.py b/docs/src/python/user-guide/expressions/operators.py new file mode 100644 index 000000000000..6f617487c81e --- /dev/null +++ b/docs/src/python/user-guide/expressions/operators.py @@ -0,0 +1,44 @@ +# --8<-- [start:setup] + +import polars as pl +import numpy as np + +np.random.seed(12) +# --8<-- [end:setup] + + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "nrs": [1, 2, 3, None, 5], + "names": ["foo", "ham", "spam", "egg", None], + "random": np.random.rand(5), + "groups": ["A", "A", "B", "C", "B"], + } +) +print(df) +# --8<-- [end:dataframe] + +# --8<-- [start:numerical] + +df_numerical = df.select( + (pl.col("nrs") + 5).alias("nrs + 5"), + (pl.col("nrs") - 5).alias("nrs - 5"), + (pl.col("nrs") * pl.col("random")).alias("nrs * random"), + (pl.col("nrs") / pl.col("random")).alias("nrs / random"), +) +print(df_numerical) + +# --8<-- [end:numerical] + +# --8<-- [start:logical] +df_logical = df.select( + (pl.col("nrs") > 1).alias("nrs > 1"), + (pl.col("random") <= 0.5).alias("random < .5"), + (pl.col("nrs") != 1).alias("nrs != 1"), + (pl.col("nrs") == 1).alias("nrs == 1"), + ((pl.col("random") <= 0.5) & (pl.col("nrs") > 1)).alias("and_expr"), # and + ((pl.col("random") <= 0.5) | (pl.col("nrs") > 1)).alias("or_expr"), # or +) +print(df_logical) +# --8<-- [end:logical] diff --git a/docs/src/python/user-guide/expressions/strings.py b/docs/src/python/user-guide/expressions/strings.py new file mode 100644 index 000000000000..9bec188f8930 --- /dev/null +++ b/docs/src/python/user-guide/expressions/strings.py @@ -0,0 +1,61 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:df] +df = pl.DataFrame({"animal": ["Crab", "cat and dog", "rab$bit", None]}) + +out = df.select( + pl.col("animal").str.lengths().alias("byte_count"), + pl.col("animal").str.n_chars().alias("letter_count"), +) +print(out) +# --8<-- [end:df] + +# --8<-- [start:existence] +out = df.select( + pl.col("animal"), + pl.col("animal").str.contains("cat|bit").alias("regex"), + pl.col("animal").str.contains("rab$", literal=True).alias("literal"), + pl.col("animal").str.starts_with("rab").alias("starts_with"), + pl.col("animal").str.ends_with("dog").alias("ends_with"), +) +print(out) +# --8<-- [end:existence] + +# --8<-- [start:extract] +df = pl.DataFrame( + { + "a": [ + "http://vote.com/ballon_dor?candidate=messi&ref=polars", + "http://vote.com/ballon_dor?candidat=jorginho&ref=polars", + "http://vote.com/ballon_dor?candidate=ronaldo&ref=polars", + ] + } +) +out = df.select( + pl.col("a").str.extract(r"candidate=(\w+)", group_index=1), +) +print(out) +# --8<-- [end:extract] + + +# --8<-- [start:extract_all] +df = pl.DataFrame({"foo": ["123 bla 45 asd", "xyz 678 910t"]}) +out = df.select( + pl.col("foo").str.extract_all(r"(\d+)").alias("extracted_nrs"), +) +print(out) +# --8<-- [end:extract_all] + + +# --8<-- [start:replace] +df = pl.DataFrame({"id": [1, 2], "text": ["123abc", "abc456"]}) +out = df.with_columns( + pl.col("text").str.replace(r"abc\b", "ABC"), + pl.col("text").str.replace_all("a", "-", literal=True).alias("text_replace_all"), +) +print(out) +# --8<-- [end:replace] diff --git a/docs/src/python/user-guide/expressions/structs.py b/docs/src/python/user-guide/expressions/structs.py new file mode 100644 index 000000000000..f209420a37ab --- /dev/null +++ b/docs/src/python/user-guide/expressions/structs.py @@ -0,0 +1,66 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:ratings_df] +ratings = pl.DataFrame( + { + "Movie": ["Cars", "IT", "ET", "Cars", "Up", "IT", "Cars", "ET", "Up", "ET"], + "Theatre": ["NE", "ME", "IL", "ND", "NE", "SD", "NE", "IL", "IL", "SD"], + "Avg_Rating": [4.5, 4.4, 4.6, 4.3, 4.8, 4.7, 4.7, 4.9, 4.7, 4.6], + "Count": [30, 27, 26, 29, 31, 28, 28, 26, 33, 26], + } +) +print(ratings) +# --8<-- [end:ratings_df] + +# --8<-- [start:state_value_counts] +out = ratings.select(pl.col("Theatre").value_counts(sort=True)) +print(out) +# --8<-- [end:state_value_counts] + +# --8<-- [start:struct_unnest] +out = ratings.select(pl.col("Theatre").value_counts(sort=True)).unnest("Theatre") +print(out) +# --8<-- [end:struct_unnest] + +# --8<-- [start:series_struct] +rating_Series = pl.Series( + "ratings", + [ + {"Movie": "Cars", "Theatre": "NE", "Avg_Rating": 4.5}, + {"Movie": "Toy Story", "Theatre": "ME", "Avg_Rating": 4.9}, + ], +) +print(rating_Series) +# --8<-- [end:series_struct] + +# --8<-- [start:series_struct_extract] +out = rating_Series.struct.field("Movie") +print(out) +# --8<-- [end:series_struct_extract] + +# --8<-- [start:series_struct_rename] +out = ( + rating_Series.to_frame() + .select(pl.col("ratings").struct.rename_fields(["Film", "State", "Value"])) + .unnest("ratings") +) +print(out) +# --8<-- [end:series_struct_rename] + +# --8<-- [start:struct_duplicates] +out = ratings.filter(pl.struct("Movie", "Theatre").is_duplicated()) +print(out) +# --8<-- [end:struct_duplicates] + +# --8<-- [start:struct_ranking] +out = ratings.with_columns( + pl.struct("Count", "Avg_Rating") + .rank("dense", descending=True) + .over("Movie", "Theatre") + .alias("Rank") +).filter(pl.struct("Movie", "Theatre").is_duplicated()) +print(out) +# --8<-- [end:struct_ranking] diff --git a/docs/src/python/user-guide/expressions/user-defined-functions.py b/docs/src/python/user-guide/expressions/user-defined-functions.py new file mode 100644 index 000000000000..89fa51420554 --- /dev/null +++ b/docs/src/python/user-guide/expressions/user-defined-functions.py @@ -0,0 +1,56 @@ +# --8<-- [start:setup] + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:dataframe] +df = pl.DataFrame( + { + "keys": ["a", "a", "b"], + "values": [10, 7, 1], + } +) + +out = df.group_by("keys", maintain_order=True).agg( + pl.col("values").map_batches(lambda s: s.shift()).alias("shift_map"), + pl.col("values").shift().alias("shift_expression"), +) +print(df) +# --8<-- [end:dataframe] + + +# --8<-- [start:apply] +out = df.group_by("keys", maintain_order=True).agg( + pl.col("values").map_elements(lambda s: s.shift()).alias("shift_map"), + pl.col("values").shift().alias("shift_expression"), +) +print(out) +# --8<-- [end:apply] + +# --8<-- [start:counter] +counter = 0 + + +def add_counter(val: int) -> int: + global counter + counter += 1 + return counter + val + + +out = df.select( + pl.col("values").map_elements(add_counter).alias("solution_apply"), + (pl.col("values") + pl.int_range(1, pl.count() + 1)).alias("solution_expr"), +) +print(out) +# --8<-- [end:counter] + +# --8<-- [start:combine] +out = df.select( + pl.struct(["keys", "values"]) + .map_elements(lambda x: len(x["keys"]) + x["values"]) + .alias("solution_apply"), + (pl.col("keys").str.lengths() + pl.col("values")).alias("solution_expr"), +) +print(out) +# --8<-- [end:combine] diff --git a/docs/src/python/user-guide/expressions/window.py b/docs/src/python/user-guide/expressions/window.py new file mode 100644 index 000000000000..bd2adda867f5 --- /dev/null +++ b/docs/src/python/user-guide/expressions/window.py @@ -0,0 +1,84 @@ +# --8<-- [start:pokemon] +import polars as pl + +# then let's load some csv data with information about pokemon +df = pl.read_csv( + "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv" +) +print(df.head()) +# --8<-- [end:pokemon] + + +# --8<-- [start:group_by] +out = df.select( + "Type 1", + "Type 2", + pl.col("Attack").mean().over("Type 1").alias("avg_attack_by_type"), + pl.col("Defense") + .mean() + .over(["Type 1", "Type 2"]) + .alias("avg_defense_by_type_combination"), + pl.col("Attack").mean().alias("avg_attack"), +) +print(out) +# --8<-- [end:group_by] + +# --8<-- [start:operations] +filtered = df.filter(pl.col("Type 2") == "Psychic").select( + "Name", + "Type 1", + "Speed", +) +print(filtered) +# --8<-- [end:operations] + +# --8<-- [start:sort] +out = filtered.with_columns( + pl.col(["Name", "Speed"]).sort_by("Speed", descending=True).over("Type 1"), +) +print(out) +# --8<-- [end:sort] + +# --8<-- [start:rules] +# aggregate and broadcast within a group +# output type: -> Int32 +pl.sum("foo").over("groups") + +# sum within a group and multiply with group elements +# output type: -> Int32 +(pl.col("x").sum() * pl.col("y")).over("groups") + +# sum within a group and multiply with group elements +# and aggregate the group to a list +# output type: -> List(Int32) +(pl.col("x").sum() * pl.col("y")).over("groups", mapping_strategy="join") + +# sum within a group and multiply with group elements +# and aggregate the group to a list +# then explode the list to multiple rows + +# This is the fastest method to do things over groups when the groups are sorted +(pl.col("x").sum() * pl.col("y")).over("groups", mapping_strategy="explode") +# --8<-- [end:rules] + +# --8<-- [start:examples] +out = df.sort("Type 1").select( + pl.col("Type 1").head(3).over("Type 1", mapping_strategy="explode"), + pl.col("Name") + .sort_by(pl.col("Speed"), descending=True) + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("fastest/group"), + pl.col("Name") + .sort_by(pl.col("Attack"), descending=True) + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("strongest/group"), + pl.col("Name") + .sort() + .head(3) + .over("Type 1", mapping_strategy="explode") + .alias("sorted_by_alphabet"), +) +print(out) +# --8<-- [end:examples] diff --git a/docs/src/python/user-guide/io/aws.py b/docs/src/python/user-guide/io/aws.py new file mode 100644 index 000000000000..c8bfa94941d2 --- /dev/null +++ b/docs/src/python/user-guide/io/aws.py @@ -0,0 +1,14 @@ +""" +# --8<-- [start:bucket] +import polars as pl +import pyarrow.parquet as pq +import s3fs + +fs = s3fs.S3FileSystem() +bucket = "" +path = "" + +dataset = pq.ParquetDataset(f"s3://{bucket}/{path}", filesystem=fs) +df = pl.from_arrow(dataset.read()) +# --8<-- [end:bucket] +""" diff --git a/docs/src/python/user-guide/io/bigquery.py b/docs/src/python/user-guide/io/bigquery.py new file mode 100644 index 000000000000..678ed70200b4 --- /dev/null +++ b/docs/src/python/user-guide/io/bigquery.py @@ -0,0 +1,38 @@ +""" +# --8<-- [start:read] +import polars as pl +from google.cloud import bigquery + +client = bigquery.Client() + +# Perform a query. +QUERY = ( + 'SELECT name FROM `bigquery-public-data.usa_names.usa_1910_2013` ' + 'WHERE state = "TX" ' + 'LIMIT 100') +query_job = client.query(QUERY) # API request +rows = query_job.result() # Waits for query to finish + +df = pl.from_arrow(rows.to_arrow()) +# --8<-- [end:read] + +# --8<-- [start:write] +from google.cloud import bigquery + +client = bigquery.Client() + +# Write dataframe to stream as parquet file; does not hit disk +with io.BytesIO() as stream: + df.write_parquet(stream) + stream.seek(0) + job = client.load_table_from_file( + stream, + destination='tablename', + project='projectname', + job_config=bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.PARQUET, + ), + ) +job.result() # Waits for the job to complete +# --8<-- [end:write] +""" diff --git a/docs/src/python/user-guide/io/csv.py b/docs/src/python/user-guide/io/csv.py new file mode 100644 index 000000000000..d4039a43ce35 --- /dev/null +++ b/docs/src/python/user-guide/io/csv.py @@ -0,0 +1,19 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_csv("docs/data/path.csv") +# --8<-- [end:read] +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_csv("docs/data/path.csv") +# --8<-- [end:write] + +# --8<-- [start:scan] +df = pl.scan_csv("docs/data/path.csv") +# --8<-- [end:scan] diff --git a/docs/src/python/user-guide/io/database.py b/docs/src/python/user-guide/io/database.py new file mode 100644 index 000000000000..97e8f659de73 --- /dev/null +++ b/docs/src/python/user-guide/io/database.py @@ -0,0 +1,32 @@ +""" +# --8<-- [start:read] +import polars as pl + +connection_uri = "postgres://username:password@server:port/database" +query = "SELECT * FROM foo" + +pl.read_database(query=query, connection_uri=connection_uri) +# --8<-- [end:read] + +# --8<-- [start:adbc] +connection_uri = "postgres://username:password@server:port/database" +query = "SELECT * FROM foo" + +pl.read_database(query=query, connection_uri=connection_uri, engine="adbc") +# --8<-- [end:adbc] + +# --8<-- [start:write] +connection_uri = "postgres://username:password@server:port/database" +df = pl.DataFrame({"foo": [1, 2, 3]}) + +df.write_database(table_name="records", connection_uri=connection_uri) +# --8<-- [end:write] + +# --8<-- [start:write_adbc] +connection_uri = "postgres://username:password@server:port/database" +df = pl.DataFrame({"foo": [1, 2, 3]}) + +df.write_database(table_name="records", connection_uri=connection_uri, engine="adbc") +# --8<-- [end:write_adbc] + +""" diff --git a/docs/src/python/user-guide/io/multiple.py b/docs/src/python/user-guide/io/multiple.py new file mode 100644 index 000000000000..f7500b6b6684 --- /dev/null +++ b/docs/src/python/user-guide/io/multiple.py @@ -0,0 +1,41 @@ +# --8<-- [start:create] +import polars as pl + +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "ham", "spam"]}) + +for i in range(5): + df.write_csv(f"docs/data/my_many_files_{i}.csv") +# --8<-- [end:create] + +# --8<-- [start:read] +df = pl.read_csv("docs/data/my_many_files_*.csv") +print(df) +# --8<-- [end:read] + +# --8<-- [start:creategraph] +import base64 + +pl.scan_csv("docs/data/my_many_files_*.csv").show_graph( + output_path="docs/images/multiple.png", show=False +) +with open("docs/images/multiple.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:creategraph] + +# --8<-- [start:graph] +pl.scan_csv("docs/data/my_many_files_*.csv").show_graph() +# --8<-- [end:graph] + +# --8<-- [start:glob] +import polars as pl +import glob + +queries = [] +for file in glob.glob("docs/data/my_many_files_*.csv"): + q = pl.scan_csv(file).group_by("bar").agg([pl.count(), pl.sum("foo")]) + queries.append(q) + +dataframes = pl.collect_all(queries) +print(dataframes) +# --8<-- [end:glob] diff --git a/docs/src/python/user-guide/io/parquet.py b/docs/src/python/user-guide/io/parquet.py new file mode 100644 index 000000000000..feba73df9a19 --- /dev/null +++ b/docs/src/python/user-guide/io/parquet.py @@ -0,0 +1,19 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +""" +# --8<-- [start:read] +df = pl.read_parquet("docs/data/path.parquet") +# --8<-- [end:read] +""" + +# --8<-- [start:write] +df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, "bak", "baz"]}) +df.write_parquet("docs/data/path.parquet") +# --8<-- [end:write] + +# --8<-- [start:scan] +df = pl.scan_parquet("docs/data/path.parquet") +# --8<-- [end:scan] diff --git a/docs/src/python/user-guide/lazy/execution.py b/docs/src/python/user-guide/lazy/execution.py new file mode 100644 index 000000000000..110fb0105500 --- /dev/null +++ b/docs/src/python/user-guide/lazy/execution.py @@ -0,0 +1,36 @@ +import polars as pl + +""" +# --8<-- [start:df] +q1 = ( + pl.scan_csv("docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:df] + +# --8<-- [start:collect] +q4 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .collect() +) +# --8<-- [end:collect] +# --8<-- [start:stream] +q5 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .collect(streaming=True) +) +# --8<-- [end:stream] +# --8<-- [start:partial] +q9 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) + .fetch(n_rows=int(100)) +) +# --8<-- [end:partial] +""" diff --git a/docs/src/python/user-guide/lazy/query_plan.py b/docs/src/python/user-guide/lazy/query_plan.py new file mode 100644 index 000000000000..ed2c3f4bac45 --- /dev/null +++ b/docs/src/python/user-guide/lazy/query_plan.py @@ -0,0 +1,48 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:plan] +q1 = ( + pl.scan_csv("docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:plan] + +# --8<-- [start:createplan] +import base64 + +q1.show_graph(optimized=False, show=False, output_path="docs/images/query_plan.png") +with open("docs/images/query_plan.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:createplan] + +""" +# --8<-- [start:showplan] +q1.show_graph(optimized=False) +# --8<-- [end:showplan] +""" + +# --8<-- [start:describe] +q1.explain(optimized=False) +# --8<-- [end:describe] + +# --8<-- [start:createplan2] +q1.show_graph(show=False, output_path="docs/images/query_plan_optimized.png") +with open("docs/images/query_plan_optimized.png", "rb") as f: + png = base64.b64encode(f.read()).decode() + print(f'') +# --8<-- [end:createplan2] + +""" +# --8<-- [start:show] +q1.show_graph() +# --8<-- [end:show] +""" + +# --8<-- [start:optimized] +q1.explain() +# --8<-- [end:optimized] diff --git a/docs/src/python/user-guide/lazy/schema.py b/docs/src/python/user-guide/lazy/schema.py new file mode 100644 index 000000000000..e621718307ee --- /dev/null +++ b/docs/src/python/user-guide/lazy/schema.py @@ -0,0 +1,38 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:schema] +q3 = pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy() + +print(q3.schema) +# --8<-- [end:schema] + +# --8<-- [start:typecheck] +pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy().with_columns( + pl.col("bar").round(0) +) +# --8<-- [end:typecheck] + +# --8<-- [start:lazyeager] +lazy_eager_query = ( + pl.DataFrame( + { + "id": ["a", "b", "c"], + "month": ["jan", "feb", "mar"], + "values": [0, 1, 2], + } + ) + .lazy() + .with_columns((2 * pl.col("values")).alias("double_values")) + .collect() + .pivot( + index="id", columns="month", values="double_values", aggregate_function="first" + ) + .lazy() + .filter(pl.col("mar").is_null()) + .collect() +) +print(lazy_eager_query) +# --8<-- [end:lazyeager] diff --git a/docs/src/python/user-guide/lazy/using.py b/docs/src/python/user-guide/lazy/using.py new file mode 100644 index 000000000000..1a10abb189d2 --- /dev/null +++ b/docs/src/python/user-guide/lazy/using.py @@ -0,0 +1,15 @@ +import polars as pl + +""" +# --8<-- [start:dataframe] +q1 = ( + pl.scan_csv(f"docs/data/reddit.csv") + .with_columns(pl.col("name").str.to_uppercase()) + .filter(pl.col("comment_karma") > 0) +) +# --8<-- [end:dataframe] + +# --8<-- [start:fromdf] +q3 = pl.DataFrame({"foo": ["a", "b", "c"], "bar": [0, 1, 2]}).lazy() +# --8<-- [end:fromdf] +""" diff --git a/docs/src/python/user-guide/misc/multiprocess.py b/docs/src/python/user-guide/misc/multiprocess.py new file mode 100644 index 000000000000..55aec52d6b9f --- /dev/null +++ b/docs/src/python/user-guide/misc/multiprocess.py @@ -0,0 +1,84 @@ +""" +# --8<-- [start:recommendation] +from multiprocessing import get_context + + +def my_fun(s): + print(s) + + +with get_context("spawn").Pool() as pool: + pool.map(my_fun, ["input1", "input2", ...]) + +# --8<-- [end:recommendation] + +# --8<-- [start:example1] +import multiprocessing +import polars as pl + + +def test_sub_process(df: pl.DataFrame, job_id): + df_filtered = df.filter(pl.col("a") > 0) + print(f"Filtered (job_id: {job_id})", df_filtered, sep="\n") + + +def create_dataset(): + return pl.DataFrame({"a": [0, 2, 3, 4, 5], "b": [0, 4, 5, 56, 4]}) + + +def setup(): + # some setup work + df = create_dataset() + df.write_parquet("/tmp/test.parquet") + + +def main(): + test_df = pl.read_parquet("/tmp/test.parquet") + + for i in range(0, 5): + proc = multiprocessing.get_context("spawn").Process( + target=test_sub_process, args=(test_df, i) + ) + proc.start() + proc.join() + + print(f"Executed sub process {i}") + + +if __name__ == "__main__": + setup() + main() + +# --8<-- [end:example1] +""" +# --8<-- [start:example2] +import multiprocessing +import polars as pl + + +def test_sub_process(df: pl.DataFrame, job_id): + df_filtered = df.filter(pl.col("a") > 0) + print(f"Filtered (job_id: {job_id})", df_filtered, sep="\n") + + +def create_dataset(): + return pl.DataFrame({"a": [0, 2, 3, 4, 5], "b": [0, 4, 5, 56, 4]}) + + +def main(): + test_df = create_dataset() + + for i in range(0, 5): + proc = multiprocessing.get_context("fork").Process( + target=test_sub_process, args=(test_df, i) + ) + proc.start() + proc.join() + + print(f"Executed sub process {i}") + + +if __name__ == "__main__": + main() + +# --8<-- [end:example2] diff --git a/docs/src/python/user-guide/sql/create.py b/docs/src/python/user-guide/sql/create.py new file mode 100644 index 000000000000..e26ffd0a31f1 --- /dev/null +++ b/docs/src/python/user-guide/sql/create.py @@ -0,0 +1,21 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:create] +data = {"name": ["Alice", "Bob", "Charlie", "David"], "age": [25, 30, 35, 40]} +df = pl.LazyFrame(data) + +ctx = pl.SQLContext(my_table=df, eager_execution=True) + +result = ctx.execute( + """ + CREATE TABLE older_people + AS + SELECT * FROM my_table WHERE age > 30 +""" +) + +print(ctx.execute("SELECT * FROM older_people")) +# --8<-- [end:create] diff --git a/docs/src/python/user-guide/sql/cte.py b/docs/src/python/user-guide/sql/cte.py new file mode 100644 index 000000000000..c44b906cf3ad --- /dev/null +++ b/docs/src/python/user-guide/sql/cte.py @@ -0,0 +1,24 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:cte] +ctx = pl.SQLContext() +df = pl.LazyFrame( + {"name": ["Alice", "Bob", "Charlie", "David"], "age": [25, 30, 35, 40]} +) +ctx.register("my_table", df) + +result = ctx.execute( + """ + WITH older_people AS ( + SELECT * FROM my_table WHERE age > 30 + ) + SELECT * FROM older_people WHERE STARTS_WITH(name,'C') +""", + eager=True, +) + +print(result) +# --8<-- [end:cte] diff --git a/docs/src/python/user-guide/sql/intro.py b/docs/src/python/user-guide/sql/intro.py new file mode 100644 index 000000000000..3b59ac9e70d1 --- /dev/null +++ b/docs/src/python/user-guide/sql/intro.py @@ -0,0 +1,100 @@ +# --8<-- [start:setup] +import os + +import pandas as pd +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:context] +ctx = pl.SQLContext() +# --8<-- [end:context] + +# --8<-- [start:register_context] +df = pl.DataFrame({"a": [1, 2, 3]}) +lf = pl.LazyFrame({"b": [4, 5, 6]}) + +# Register all dataframes in the global namespace: registers both df and lf +ctx = pl.SQLContext(register_globals=True) + +# Other option: register dataframe df as "df" and lazyframe lf as "lf" +ctx = pl.SQLContext(df=df, lf=lf) +# --8<-- [end:register_context] + +# --8<-- [start:register_pandas] +import pandas as pd + +df_pandas = pd.DataFrame({"c": [7, 8, 9]}) +ctx = pl.SQLContext(df_pandas=pl.from_pandas(df_pandas)) +# --8<-- [end:register_pandas] + +# --8<-- [start:execute] +# For local files use scan_csv instead +pokemon = pl.read_csv( + "https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv" +) +ctx = pl.SQLContext(register_globals=True, eager_execution=True) +df_small = ctx.execute("SELECT * from pokemon LIMIT 5") +print(df_small) +# --8<-- [end:execute] + +# --8<-- [start:prepare_multiple_sources] +with open("products_categories.json", "w") as temp_file: + json_data = """{"product_id": 1, "category": "Category 1"} +{"product_id": 2, "category": "Category 1"} +{"product_id": 3, "category": "Category 2"} +{"product_id": 4, "category": "Category 2"} +{"product_id": 5, "category": "Category 3"}""" + + temp_file.write(json_data) + +with open("products_masterdata.csv", "w") as temp_file: + csv_data = """product_id,product_name +1,Product A +2,Product B +3,Product C +4,Product D +5,Product E""" + + temp_file.write(csv_data) + +sales_data = pd.DataFrame( + { + "product_id": [1, 2, 3, 4, 5], + "sales": [100, 200, 150, 250, 300], + } +) +# --8<-- [end:prepare_multiple_sources] + +# --8<-- [start:execute_multiple_sources] +# Input data: +# products_masterdata.csv with schema {'product_id': Int64, 'product_name': Utf8} +# products_categories.json with schema {'product_id': Int64, 'category': Utf8} +# sales_data is a Pandas DataFrame with schema {'product_id': Int64, 'sales': Int64} + +ctx = pl.SQLContext( + products_masterdata=pl.scan_csv("products_masterdata.csv"), + products_categories=pl.scan_ndjson("products_categories.json"), + sales_data=pl.from_pandas(sales_data), + eager_execution=True, +) + +query = """ +SELECT + product_id, + product_name, + category, + sales +FROM + products_masterdata +LEFT JOIN products_categories USING (product_id) +LEFT JOIN sales_data USING (product_id) +""" + +print(ctx.execute(query)) +# --8<-- [end:execute_multiple_sources] + +# --8<-- [start:clean_multiple_sources] +os.remove("products_categories.json") +os.remove("products_masterdata.csv") +# --8<-- [end:clean_multiple_sources] diff --git a/docs/src/python/user-guide/sql/show.py b/docs/src/python/user-guide/sql/show.py new file mode 100644 index 000000000000..cedf425dc54b --- /dev/null +++ b/docs/src/python/user-guide/sql/show.py @@ -0,0 +1,26 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:show] +# Create some DataFrames and register them with the SQLContext +df1 = pl.LazyFrame( + { + "name": ["Alice", "Bob", "Charlie", "David"], + "age": [25, 30, 35, 40], + } +) +df2 = pl.LazyFrame( + { + "name": ["Ellen", "Frank", "Gina", "Henry"], + "age": [45, 50, 55, 60], + } +) +ctx = pl.SQLContext(mytable1=df1, mytable2=df2) + +tables = ctx.execute("SHOW TABLES", eager=True) + +print(tables) +# --8<-- [end:show] diff --git a/docs/src/python/user-guide/sql/sql_select.py b/docs/src/python/user-guide/sql/sql_select.py new file mode 100644 index 000000000000..1e040c739b99 --- /dev/null +++ b/docs/src/python/user-guide/sql/sql_select.py @@ -0,0 +1,106 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + + +# --8<-- [start:df] +df = pl.DataFrame( + { + "city": [ + "New York", + "Los Angeles", + "Chicago", + "Houston", + "Phoenix", + "Amsterdam", + ], + "country": ["USA", "USA", "USA", "USA", "USA", "Netherlands"], + "population": [8399000, 3997000, 2705000, 2320000, 1680000, 900000], + } +) + +ctx = pl.SQLContext(population=df, eager_execution=True) + +print(ctx.execute("SELECT * FROM population")) +# --8<-- [end:df] + +# --8<-- [start:group_by] +result = ctx.execute( + """ + SELECT country, AVG(population) as avg_population + FROM population + GROUP BY country + """ +) +print(result) +# --8<-- [end:group_by] + + +# --8<-- [start:orderby] +result = ctx.execute( + """ + SELECT city, population + FROM population + ORDER BY population + """ +) +print(result) +# --8<-- [end:orderby] + +# --8<-- [start:join] +income = pl.DataFrame( + { + "city": [ + "New York", + "Los Angeles", + "Chicago", + "Houston", + "Amsterdam", + "Rotterdam", + "Utrecht", + ], + "country": [ + "USA", + "USA", + "USA", + "USA", + "Netherlands", + "Netherlands", + "Netherlands", + ], + "income": [55000, 62000, 48000, 52000, 42000, 38000, 41000], + } +) +ctx.register_many(income=income) +result = ctx.execute( + """ + SELECT country, city, income, population + FROM population + LEFT JOIN income on population.city = income.city + """ +) +print(result) +# --8<-- [end:join] + + +# --8<-- [start:functions] +result = ctx.execute( + """ + SELECT city, population + FROM population + WHERE STARTS_WITH(country,'U') + """ +) +print(result) +# --8<-- [end:functions] + +# --8<-- [start:tablefunctions] +result = ctx.execute( + """ + SELECT * + FROM read_csv('docs/data/iris.csv') + """ +) +print(result) +# --8<-- [end:tablefunctions] diff --git a/docs/src/python/user-guide/transformations/concatenation.py b/docs/src/python/user-guide/transformations/concatenation.py new file mode 100644 index 000000000000..65b5c8239e83 --- /dev/null +++ b/docs/src/python/user-guide/transformations/concatenation.py @@ -0,0 +1,76 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import datetime + +# --8<-- [end:setup] + +# --8<-- [start:vertical] +df_v1 = pl.DataFrame( + { + "a": [1], + "b": [3], + } +) +df_v2 = pl.DataFrame( + { + "a": [2], + "b": [4], + } +) +df_vertical_concat = pl.concat( + [ + df_v1, + df_v2, + ], + how="vertical", +) +print(df_vertical_concat) +# --8<-- [end:vertical] + +# --8<-- [start:horizontal] +df_h1 = pl.DataFrame( + { + "l1": [1, 2], + "l2": [3, 4], + } +) +df_h2 = pl.DataFrame( + { + "r1": [5, 6], + "r2": [7, 8], + "r3": [9, 10], + } +) +df_horizontal_concat = pl.concat( + [ + df_h1, + df_h2, + ], + how="horizontal", +) +print(df_horizontal_concat) +# --8<-- [end:horizontal] + +# --8<-- [start:cross] +df_d1 = pl.DataFrame( + { + "a": [1], + "b": [3], + } +) +df_d2 = pl.DataFrame( + { + "a": [2], + "d": [4], + } +) + +df_diagonal_concat = pl.concat( + [ + df_d1, + df_d2, + ], + how="diagonal", +) +print(df_diagonal_concat) +# --8<-- [end:cross] diff --git a/docs/src/python/user-guide/transformations/joins.py b/docs/src/python/user-guide/transformations/joins.py new file mode 100644 index 000000000000..98828020820d --- /dev/null +++ b/docs/src/python/user-guide/transformations/joins.py @@ -0,0 +1,150 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import datetime + +# --8<-- [end:setup] + +# --8<-- [start:innerdf] +df_customers = pl.DataFrame( + { + "customer_id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + } +) +print(df_customers) +# --8<-- [end:innerdf] + +# --8<-- [start:innerdf2] +df_orders = pl.DataFrame( + { + "order_id": ["a", "b", "c"], + "customer_id": [1, 2, 2], + "amount": [100, 200, 300], + } +) +print(df_orders) +# --8<-- [end:innerdf2] + + +# --8<-- [start:inner] +df_inner_customer_join = df_customers.join(df_orders, on="customer_id", how="inner") +print(df_inner_customer_join) +# --8<-- [end:inner] + +# --8<-- [start:left] +df_left_join = df_customers.join(df_orders, on="customer_id", how="left") +print(df_left_join) +# --8<-- [end:left] + +# --8<-- [start:outer] +df_outer_join = df_customers.join(df_orders, on="customer_id", how="outer") +print(df_outer_join) +# --8<-- [end:outer] + +# --8<-- [start:df3] +df_colors = pl.DataFrame( + { + "color": ["red", "blue", "green"], + } +) +print(df_colors) +# --8<-- [end:df3] + +# --8<-- [start:df4] +df_sizes = pl.DataFrame( + { + "size": ["S", "M", "L"], + } +) +print(df_sizes) +# --8<-- [end:df4] + +# --8<-- [start:cross] +df_cross_join = df_colors.join(df_sizes, how="cross") +print(df_cross_join) +# --8<-- [end:cross] + +# --8<-- [start:df5] +df_cars = pl.DataFrame( + { + "id": ["a", "b", "c"], + "make": ["ford", "toyota", "bmw"], + } +) +print(df_cars) +# --8<-- [end:df5] + +# --8<-- [start:df6] +df_repairs = pl.DataFrame( + { + "id": ["c", "c"], + "cost": [100, 200], + } +) +print(df_repairs) +# --8<-- [end:df6] + +# --8<-- [start:inner2] +df_inner_join = df_cars.join(df_repairs, on="id", how="inner") +print(df_inner_join) +# --8<-- [end:inner2] + +# --8<-- [start:semi] +df_semi_join = df_cars.join(df_repairs, on="id", how="semi") +print(df_semi_join) +# --8<-- [end:semi] + +# --8<-- [start:anti] +df_anti_join = df_cars.join(df_repairs, on="id", how="anti") +print(df_anti_join) +# --8<-- [end:anti] + +# --8<-- [start:df7] +df_trades = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 1, 0), + datetime(2020, 1, 1, 9, 3, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "B", "C"], + "trade": [101, 299, 301, 500], + } +) +print(df_trades) +# --8<-- [end:df7] + +# --8<-- [start:df8] +df_quotes = pl.DataFrame( + { + "time": [ + datetime(2020, 1, 1, 9, 0, 0), + datetime(2020, 1, 1, 9, 2, 0), + datetime(2020, 1, 1, 9, 4, 0), + datetime(2020, 1, 1, 9, 6, 0), + ], + "stock": ["A", "B", "C", "A"], + "quote": [100, 300, 501, 102], + } +) + +print(df_quotes) +# --8<-- [end:df8] + +# --8<-- [start:asofpre] +df_trades = df_trades.sort("time") +df_quotes = df_quotes.sort("time") # Set column as sorted +# --8<-- [end:asofpre] + +# --8<-- [start:asof] +df_asof_join = df_trades.join_asof(df_quotes, on="time", by="stock") +print(df_asof_join) +# --8<-- [end:asof] + +# --8<-- [start:asof2] +df_asof_tolerance_join = df_trades.join_asof( + df_quotes, on="time", by="stock", tolerance="1m" +) +print(df_asof_tolerance_join) +# --8<-- [end:asof2] diff --git a/docs/src/python/user-guide/transformations/melt.py b/docs/src/python/user-guide/transformations/melt.py new file mode 100644 index 000000000000..e9bf53a96ec7 --- /dev/null +++ b/docs/src/python/user-guide/transformations/melt.py @@ -0,0 +1,18 @@ +# --8<-- [start:df] +import polars as pl + +df = pl.DataFrame( + { + "A": ["a", "b", "a"], + "B": [1, 3, 5], + "C": [10, 11, 12], + "D": [2, 4, 6], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:melt] +out = df.melt(id_vars=["A", "B"], value_vars=["C", "D"]) +print(out) +# --8<-- [end:melt] diff --git a/docs/src/python/user-guide/transformations/pivot.py b/docs/src/python/user-guide/transformations/pivot.py new file mode 100644 index 000000000000..d80b26ee0c34 --- /dev/null +++ b/docs/src/python/user-guide/transformations/pivot.py @@ -0,0 +1,31 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.DataFrame( + { + "foo": ["A", "A", "B", "B", "C"], + "N": [1, 2, 2, 4, 2], + "bar": ["k", "l", "m", "n", "o"], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:eager] +out = df.pivot(index="foo", columns="bar", values="N", aggregate_function="first") +print(out) +# --8<-- [end:eager] + +# --8<-- [start:lazy] +q = ( + df.lazy() + .collect() + .pivot(index="foo", columns="bar", values="N", aggregate_function="first") + .lazy() +) +out = q.collect() +print(out) +# --8<-- [end:lazy] diff --git a/docs/src/python/user-guide/transformations/time-series/filter.py b/docs/src/python/user-guide/transformations/time-series/filter.py new file mode 100644 index 000000000000..6a2a28e44f8c --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/filter.py @@ -0,0 +1,30 @@ +# --8<-- [start:df] +import polars as pl +from datetime import datetime + +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=True) +print(df) +# --8<-- [end:df] + +# --8<-- [start:filter] +filtered_df = df.filter( + pl.col("Date") == datetime(1995, 10, 16), +) +print(filtered_df) +# --8<-- [end:filter] + +# --8<-- [start:range] +filtered_range_df = df.filter( + pl.col("Date").is_between(datetime(1995, 7, 1), datetime(1995, 11, 1)), +) +print(filtered_range_df) +# --8<-- [end:range] + +# --8<-- [start:negative] +ts = pl.Series(["-1300-05-23", "-1400-03-02"]).str.strptime(pl.Date) + +negative_dates_df = pl.DataFrame({"ts": ts, "values": [3, 4]}) + +negative_dates_filtered_df = negative_dates_df.filter(pl.col("ts").dt.year() < -1300) +print(negative_dates_filtered_df) +# --8<-- [end:negative] diff --git a/docs/src/python/user-guide/transformations/time-series/parsing.py b/docs/src/python/user-guide/transformations/time-series/parsing.py new file mode 100644 index 000000000000..0e49df5495a0 --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/parsing.py @@ -0,0 +1,43 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=True) +print(df) +# --8<-- [end:df] + + +# --8<-- [start:cast] +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=False) + +df = df.with_columns(pl.col("Date").str.strptime(pl.Date, format="%Y-%m-%d")) +print(df) +# --8<-- [end:cast] + + +# --8<-- [start:df3] +df_with_year = df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:df3] + +# --8<-- [start:extract] +df_with_year = df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:extract] + +# --8<-- [start:mixed] +data = [ + "2021-03-27T00:00:00+0100", + "2021-03-28T00:00:00+0100", + "2021-03-29T00:00:00+0200", + "2021-03-30T00:00:00+0200", +] +mixed_parsed = ( + pl.Series(data) + .str.strptime(pl.Datetime, format="%Y-%m-%dT%H:%M:%S%z") + .dt.convert_time_zone("Europe/Brussels") +) +print(mixed_parsed) +# --8<-- [end:mixed] diff --git a/docs/src/python/user-guide/transformations/time-series/resampling.py b/docs/src/python/user-guide/transformations/time-series/resampling.py new file mode 100644 index 000000000000..80a7b2597a67 --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/resampling.py @@ -0,0 +1,36 @@ +# --8<-- [start:setup] +from datetime import datetime + +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.DataFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "groups": ["a", "a", "a", "b", "b", "a", "a"], + "values": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + } +) +print(df) +# --8<-- [end:df] + +# --8<-- [start:upsample] +out1 = df.upsample(time_column="time", every="15m").fill_null(strategy="forward") +print(out1) +# --8<-- [end:upsample] + +# --8<-- [start:upsample2] +out2 = ( + df.upsample(time_column="time", every="15m") + .interpolate() + .fill_null(strategy="forward") +) +print(out2) +# --8<-- [end:upsample2] diff --git a/docs/src/python/user-guide/transformations/time-series/rolling.py b/docs/src/python/user-guide/transformations/time-series/rolling.py new file mode 100644 index 000000000000..16f751523ade --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/rolling.py @@ -0,0 +1,75 @@ +# --8<-- [start:setup] +import polars as pl +from datetime import date, datetime + +# --8<-- [end:setup] + +# --8<-- [start:df] +df = pl.read_csv("docs/data/apple_stock.csv", try_parse_dates=True) +df = df.sort("Date") +print(df) +# --8<-- [end:df] + +# --8<-- [start:group_by] +annual_average_df = df.group_by_dynamic("Date", every="1y").agg(pl.col("Close").mean()) + +df_with_year = annual_average_df.with_columns(pl.col("Date").dt.year().alias("year")) +print(df_with_year) +# --8<-- [end:group_by] + +# --8<-- [start:group_by_dyn] +df = ( + pl.date_range( + start=date(2021, 1, 1), + end=date(2021, 12, 31), + interval="1d", + eager=True, + ) + .alias("time") + .to_frame() +) + +out = ( + df.group_by_dynamic("time", every="1mo", period="1mo", closed="left") + .agg( + [ + pl.col("time").cumcount().reverse().head(3).alias("day/eom"), + ((pl.col("time") - pl.col("time").first()).last().dt.days() + 1).alias( + "days_in_month" + ), + ] + ) + .explode("day/eom") +) +print(out) +# --8<-- [end:group_by_dyn] + +# --8<-- [start:group_by_roll] +df = pl.DataFrame( + { + "time": pl.datetime_range( + start=datetime(2021, 12, 16), + end=datetime(2021, 12, 16, 3), + interval="30m", + eager=True, + ), + "groups": ["a", "a", "a", "b", "b", "a", "a"], + } +) +print(df) +# --8<-- [end:group_by_roll] + +# --8<-- [start:group_by_dyn2] +out = df.group_by_dynamic( + "time", + every="1h", + closed="both", + by="groups", + include_boundaries=True, +).agg( + [ + pl.count(), + ] +) +print(out) +# --8<-- [end:group_by_dyn2] diff --git a/docs/src/python/user-guide/transformations/time-series/timezones.py b/docs/src/python/user-guide/transformations/time-series/timezones.py new file mode 100644 index 000000000000..13234a9d8e30 --- /dev/null +++ b/docs/src/python/user-guide/transformations/time-series/timezones.py @@ -0,0 +1,27 @@ +# --8<-- [start:setup] +import polars as pl + +# --8<-- [end:setup] + +# --8<-- [start:example] +ts = ["2021-03-27 03:00", "2021-03-28 03:00"] +tz_naive = pl.Series("tz_naive", ts).str.strptime(pl.Datetime) +tz_aware = tz_naive.dt.replace_time_zone("UTC").rename("tz_aware") +time_zones_df = pl.DataFrame([tz_naive, tz_aware]) +print(time_zones_df) +# --8<-- [end:example] + +# --8<-- [start:example2] +time_zones_operations = time_zones_df.select( + [ + pl.col("tz_aware") + .dt.replace_time_zone("Europe/Brussels") + .alias("replace time zone"), + pl.col("tz_aware") + .dt.convert_time_zone("Asia/Kathmandu") + .alias("convert time zone"), + pl.col("tz_aware").dt.replace_time_zone(None).alias("unset time zone"), + ] +) +print(time_zones_operations) +# --8<-- [end:example2] diff --git a/docs/src/rust/getting-started/expressions.rs b/docs/src/rust/getting-started/expressions.rs new file mode 100644 index 000000000000..e8d031ebd1f7 --- /dev/null +++ b/docs/src/rust/getting-started/expressions.rs @@ -0,0 +1,144 @@ +use chrono::prelude::*; +use polars::prelude::*; +use rand::Rng; + +fn main() -> Result<(), Box> { + let mut rng = rand::thread_rng(); + + let df: DataFrame = df!("a" => 0..8, + "b"=> (0..8).map(|_| rng.gen::()).collect::>(), + "c"=> [ + NaiveDate::from_ymd_opt(2022, 12, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 4).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 6).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 7).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 12, 8).unwrap().and_hms_opt(0, 0, 0).unwrap(), + ], + "d"=> [Some(1.0), Some(2.0), None, None, Some(0.0), Some(-5.0), Some(-42.), None] + ) + .expect("should not fail"); + + // --8<-- [start:select] + let out = df.clone().lazy().select([col("*")]).collect()?; + println!("{}", out); + // --8<-- [end:select] + + // --8<-- [start:select2] + let out = df.clone().lazy().select([col("a"), col("b")]).collect()?; + println!("{}", out); + // --8<-- [end:select2] + + // --8<-- [start:select3] + let out = df + .clone() + .lazy() + .select([col("a"), col("b")]) + .limit(3) + .collect()?; + println!("{}", out); + // --8<-- [end:select3] + + // --8<-- [start:exclude] + let out = df + .clone() + .lazy() + .select([col("*").exclude(["a"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:exclude] + + // --8<-- [start:filter] + let start_date = NaiveDate::from_ymd_opt(2022, 12, 2) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let end_date = NaiveDate::from_ymd_opt(2022, 12, 8) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap(); + let out = df + .clone() + .lazy() + .filter( + col("c") + .gt_eq(lit(start_date)) + .and(col("c").lt_eq(lit(end_date))), + ) + .collect()?; + println!("{}", out); + // --8<-- [end:filter] + + // --8<-- [start:filter2] + let out = df + .clone() + .lazy() + .filter(col("a").lt_eq(3).and(col("d").is_not_null())) + .collect()?; + println!("{}", out); + // --8<-- [end:filter2] + + // --8<-- [start:with_columns] + let out = df + .clone() + .lazy() + .with_columns([ + col("b").sum().alias("e"), + (col("b") + lit(42)).alias("b+42"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:with_columns] + + // --8<-- [start:dataframe2] + let df2: DataFrame = df!("x" => 0..8, + "y"=> &["A", "A", "A", "B", "B", "C", "X", "X"], + ) + .expect("should not fail"); + println!("{}", df2); + // --8<-- [end:dataframe2] + + // --8<-- [start:group_by] + let out = df2 + .clone() + .lazy() + .group_by(["y"]) + .agg([count()]) + .collect()?; + println!("{}", out); + // --8<-- [end:group_by] + + // --8<-- [start:group_by2] + let out = df2 + .clone() + .lazy() + .group_by(["y"]) + .agg([col("*").count().alias("count"), col("*").sum().alias("sum")]) + .collect()?; + println!("{}", out); + // --8<-- [end:group_by2] + + // --8<-- [start:combine] + let out = df + .clone() + .lazy() + .with_columns([(col("a") * col("b")).alias("a * b")]) + .select([col("*").exclude(["c", "d"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:combine] + + // --8<-- [start:combine2] + let out = df + .clone() + .lazy() + .with_columns([(col("a") * col("b")).alias("a * b")]) + .select([col("*").exclude(["d"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:combine2] + + Ok(()) +} diff --git a/docs/src/rust/getting-started/joins.rs b/docs/src/rust/getting-started/joins.rs new file mode 100644 index 000000000000..1f583dc0e4f9 --- /dev/null +++ b/docs/src/rust/getting-started/joins.rs @@ -0,0 +1,29 @@ +use polars::prelude::*; + + +fn main() -> Result<(), Box>{ + + + // --8<-- [start:join] + use rand::Rng; + let mut rng = rand::thread_rng(); + + let df: DataFrame = df!("a" => 0..8, + "b"=> (0..8).map(|_| rng.gen::()).collect::>(), + "d"=> [Some(1.0), Some(2.0), None, None, Some(0.0), Some(-5.0), Some(-42.), None] + ).expect("should not fail"); + let df2: DataFrame = df!("x" => 0..8, + "y"=> &["A", "A", "A", "B", "B", "C", "X", "X"], + ).expect("should not fail"); + let joined = df.join(&df2,["a"],["x"],JoinType::Left,None)?; + println!("{}",joined); + // --8<-- [end:join] + + // --8<-- [start:hstack] + let stacked = df.hstack(df2.get_columns())?; + println!("{}",stacked); + // --8<-- [end:hstack] + + Ok(()) + +} diff --git a/docs/src/rust/getting-started/reading-writing.rs b/docs/src/rust/getting-started/reading-writing.rs new file mode 100644 index 000000000000..4fe035d34f82 --- /dev/null +++ b/docs/src/rust/getting-started/reading-writing.rs @@ -0,0 +1,67 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use chrono::prelude::*; + use std::fs::File; + + let mut df: DataFrame = df!( + "integer" => &[1, 2, 3], + "date" => &[ + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), + ], + "float" => &[4.0, 5.0, 6.0] + ) + .expect("should not fail"); + println!("{}", df); + // --8<-- [end:dataframe] + + // --8<-- [start:csv] + let mut file = File::create("docs/data/output.csv").expect("could not create file"); + CsvWriter::new(&mut file) + .has_header(true) + .with_delimiter(b',') + .finish(&mut df); + let df_csv = CsvReader::from_path("docs/data/output.csv")? + .infer_schema(None) + .has_header(true) + .finish()?; + println!("{}", df_csv); + // --8<-- [end:csv] + + // --8<-- [start:csv2] + let mut file = File::create("docs/data/output.csv").expect("could not create file"); + CsvWriter::new(&mut file) + .has_header(true) + .with_delimiter(b',') + .finish(&mut df); + let df_csv = CsvReader::from_path("docs/data/output.csv")? + .infer_schema(None) + .has_header(true) + .with_parse_dates(true) + .finish()?; + println!("{}", df_csv); + // --8<-- [end:csv2] + + // --8<-- [start:json] + let mut file = File::create("docs/data/output.json").expect("could not create file"); + JsonWriter::new(&mut file).finish(&mut df); + let mut f = File::open("docs/data/output.json")?; + let df_json = JsonReader::new(f) + .with_json_format(JsonFormat::JsonLines) + .finish()?; + println!("{}", df_json); + // --8<-- [end:json] + + // --8<-- [start:parquet] + let mut file = File::create("docs/data/output.parquet").expect("could not create file"); + ParquetWriter::new(&mut file).finish(&mut df); + let mut f = File::open("docs/data/output.parquet")?; + let df_parquet = ParquetReader::new(f).finish()?; + println!("{}", df_parquet); + // --8<-- [end:parquet] + + Ok(()) +} diff --git a/docs/src/rust/getting-started/series-dataframes.rs b/docs/src/rust/getting-started/series-dataframes.rs new file mode 100644 index 000000000000..09b45d705bac --- /dev/null +++ b/docs/src/rust/getting-started/series-dataframes.rs @@ -0,0 +1,59 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:series] + use polars::prelude::*; + + let s = Series::new("a", [1, 2, 3, 4, 5]); + println!("{}", s); + // --8<-- [end:series] + + // --8<-- [start:minmax] + let s = Series::new("a", [1, 2, 3, 4, 5]); + // The use of generics is necessary for the type system + println!("{}", s.min::().unwrap()); + println!("{}", s.max::().unwrap()); + // --8<-- [end:minmax] + + // --8<-- [start:string] + // This operation is not directly available on the Series object yet, only on the DataFrame + // --8<-- [end:string] + + // --8<-- [start:dt] + // This operation is not directly available on the Series object yet, only on the DataFrame + // --8<-- [end:dt] + + // --8<-- [start:dataframe] + use chrono::prelude::*; + + let df: DataFrame = df!( + "integer" => &[1, 2, 3, 4, 5], + "date" => &[ + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 2).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 4).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap() + ], + "float" => &[4.0, 5.0, 6.0, 7.0, 8.0], + ) + .unwrap(); + + println!("{}", df); + // --8<-- [end:dataframe] + + // --8<-- [start:head] + println!("{}", df.head(Some(3))); + // --8<-- [end:head] + + // --8<-- [start:tail] + println!("{}", df.tail(Some(3))); + // --8<-- [end:tail] + + // --8<-- [start:sample] + println!("{}", df.sample_n(2, false, true, None)?); + // --8<-- [end:sample] + + // --8<-- [start:describe] + println!("{:?}", df.describe(None)); + // --8<-- [end:describe] + Ok(()) +} diff --git a/docs/src/rust/home/example.rs b/docs/src/rust/home/example.rs new file mode 100644 index 000000000000..00cf7de67bfb --- /dev/null +++ b/docs/src/rust/home/example.rs @@ -0,0 +1,16 @@ +fn main() -> Result<(), Box> { + // --8<-- [start:example] + use polars::prelude::*; + + let q = LazyCsvReader::new("docs/data/iris.csv") + .has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("*").sum()]); + + let df = q.collect(); + // --8<-- [end:example] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/concepts/contexts.rs b/docs/src/rust/user-guide/concepts/contexts.rs new file mode 100644 index 000000000000..b911faa8fd6d --- /dev/null +++ b/docs/src/rust/user-guide/concepts/contexts.rs @@ -0,0 +1,69 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use rand::{thread_rng, Rng}; + + let mut arr = [0f64; 5]; + thread_rng().fill(&mut arr); + + let df = df! ( + "nrs" => &[Some(1), Some(2), Some(3), None, Some(5)], + "names" => &[Some("foo"), Some("ham"), Some("spam"), Some("eggs"), None], + "random" => &arr, + "groups" => &["A", "A", "B", "C", "B"], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:select] + let out = df + .clone() + .lazy() + .select([ + sum("nrs"), + col("names").sort(false), + col("names").first().alias("first name"), + (mean("nrs") * lit(10)).alias("10xnrs"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:select] + + // --8<-- [start:filter] + let out = df.clone().lazy().filter(col("nrs").gt(lit(2))).collect()?; + println!("{}", out); + // --8<-- [end:filter] + + // --8<-- [start:with_columns] + let out = df + .clone() + .lazy() + .with_columns([ + sum("nrs").alias("nrs_sum"), + col("random").count().alias("count"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:with_columns] + + // --8<-- [start:group_by] + let out = df + .lazy() + .group_by([col("groups")]) + .agg([ + sum("nrs"), // sum nrs by groups + col("random").count().alias("count"), // count group members + // sum random where name != null + col("random") + .filter(col("names").is_not_null()) + .sum() + .suffix("_sum"), + col("names").reverse().alias("reversed names"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:group_by] + Ok(()) +} diff --git a/docs/src/rust/user-guide/concepts/expressions.rs b/docs/src/rust/user-guide/concepts/expressions.rs new file mode 100644 index 000000000000..9c76fc6642e8 --- /dev/null +++ b/docs/src/rust/user-guide/concepts/expressions.rs @@ -0,0 +1,24 @@ +use polars::prelude::*; +use rand::Rng; +use chrono::prelude::*; + +fn main() -> Result<(), Box>{ + + let df = df! ( + "foo" => &[Some(1), Some(2), Some(3), None, Some(5)], + "bar" => &[Some("foo"), Some("ham"), Some("spam"), Some("egg"), None], + )?; + + // --8<-- [start:example1] + df.column("foo")?.sort(false).head(Some(2)); + // --8<-- [end:example1] + + // --8<-- [start:example2] + df.clone().lazy().select([ + col("foo").sort(Default::default()).head(Some(2)), + col("bar").filter(col("foo").eq(lit(1))).sum(), + ]).collect()?; + // --8<-- [end:example2] + + Ok(()) +} \ No newline at end of file diff --git a/docs/src/rust/user-guide/concepts/lazy-vs-eager.rs b/docs/src/rust/user-guide/concepts/lazy-vs-eager.rs new file mode 100644 index 000000000000..910235fbbf65 --- /dev/null +++ b/docs/src/rust/user-guide/concepts/lazy-vs-eager.rs @@ -0,0 +1,30 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:eager] + let df = CsvReader::from_path("docs/data/iris.csv") + .unwrap() + .finish() + .unwrap(); + let mask = df.column("sepal_length")?.f64()?.gt(5.0); + let df_small = df.filter(&mask)?; + let df_agg = df_small + .group_by(["species"])? + .select(["sepal_width"]) + .mean()?; + println!("{}", df_agg); + // --8<-- [end:eager] + + // --8<-- [start:lazy] + let q = LazyCsvReader::new("docs/data/iris.csv") + .has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + let df = q.collect()?; + println!("{}", df); + // --8<-- [end:lazy] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/concepts/streaming.rs b/docs/src/rust/user-guide/concepts/streaming.rs new file mode 100644 index 000000000000..f00b5e92ca99 --- /dev/null +++ b/docs/src/rust/user-guide/concepts/streaming.rs @@ -0,0 +1,19 @@ +use chrono::prelude::*; +use polars::prelude::*; +use rand::Rng; + +fn main() -> Result<(), Box> { + // --8<-- [start:streaming] + let q = LazyCsvReader::new("docs/data/iris.csv") + .has_header(true) + .finish()? + .filter(col("sepal_length").gt(lit(5))) + .group_by(vec![col("species")]) + .agg([col("sepal_width").mean()]); + + let df = q.with_streaming(true).collect()?; + println!("{}", df); + // --8<-- [end:streaming] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/aggregation.rs b/docs/src/rust/user-guide/expressions/aggregation.rs new file mode 100644 index 000000000000..205ec2f01bf7 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/aggregation.rs @@ -0,0 +1,204 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use reqwest::blocking::Client; + use std::io::Cursor; + + let url = "https://theunitedstates.io/congress-legislators/legislators-historical.csv"; + + let mut schema = Schema::new(); + schema.with_column("first_name".to_string(), DataType::Categorical(None)); + schema.with_column("gender".to_string(), DataType::Categorical(None)); + schema.with_column("type".to_string(), DataType::Categorical(None)); + schema.with_column("state".to_string(), DataType::Categorical(None)); + schema.with_column("party".to_string(), DataType::Categorical(None)); + schema.with_column("birthday".to_string(), DataType::Date); + + let data: Vec = Client::new().get(url).send()?.text()?.bytes().collect(); + + let dataset = CsvReader::new(Cursor::new(data)) + .has_header(true) + .with_dtypes(Some(&schema)) + .with_parse_dates(true) + .finish()?; + + println!("{}", &dataset); + // --8<-- [end:dataframe] + + // --8<-- [start:basic] + let df = dataset + .clone() + .lazy() + .group_by(["first_name"]) + .agg([count(), col("gender").list(), col("last_name").first()]) + .sort( + "count", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:basic] + + // --8<-- [start:conditional] + let df = dataset + .clone() + .lazy() + .group_by(["state"]) + .agg([ + (col("party").eq(lit("Anti-Administration"))) + .sum() + .alias("anti"), + (col("party").eq(lit("Pro-Administration"))) + .sum() + .alias("pro"), + ]) + .sort( + "pro", + SortOptions { + descending: true, + nulls_last: false, + }, + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:conditional] + + // --8<-- [start:nested] + let df = dataset + .clone() + .lazy() + .group_by(["state", "party"]) + .agg([col("party").count().alias("count")]) + .filter( + col("party") + .eq(lit("Anti-Administration")) + .or(col("party").eq(lit("Pro-Administration"))), + ) + .sort( + "count", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:nested] + + // --8<-- [start:filter] + fn compute_age() -> Expr { + lit(2022) - col("birthday").dt().year() + } + + fn avg_birthday(gender: &str) -> Expr { + compute_age() + .filter(col("gender").eq(lit(gender))) + .mean() + .alias(&format!("avg {} birthday", gender)) + } + + let df = dataset + .clone() + .lazy() + .group_by(["state"]) + .agg([ + avg_birthday("M"), + avg_birthday("F"), + (col("gender").eq(lit("M"))).sum().alias("# male"), + (col("gender").eq(lit("F"))).sum().alias("# female"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:filter] + + // --8<-- [start:sort] + fn get_person() -> Expr { + col("first_name") + lit(" ") + col("last_name") + } + + let df = dataset + .clone() + .lazy() + .sort( + "birthday", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .group_by(["state"]) + .agg([ + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort] + + // --8<-- [start:sort2] + let df = dataset + .clone() + .lazy() + .sort( + "birthday", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .group_by(["state"]) + .agg([ + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort(false).first().alias("alphabetical_first"), + ]) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort2] + + // --8<-- [start:sort3] + let df = dataset + .clone() + .lazy() + .sort( + "birthday", + SortOptions { + descending: true, + nulls_last: true, + }, + ) + .group_by(["state"]) + .agg([ + get_person().first().alias("youngest"), + get_person().last().alias("oldest"), + get_person().sort(false).first().alias("alphabetical_first"), + col("gender") + .sort_by(["first_name"], [false]) + .first() + .alias("gender"), + ]) + .sort("state", SortOptions::default()) + .limit(5) + .collect()?; + + println!("{}", df); + // --8<-- [end:sort3] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/casting.rs b/docs/src/rust/user-guide/expressions/casting.rs new file mode 100644 index 000000000000..2dda1e185215 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/casting.rs @@ -0,0 +1,201 @@ +// --8<-- [start:setup] +use polars::lazy::dsl::StrptimeOptions; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:dfnum] + let df = df! ( + "integers"=> &[1, 2, 3, 4, 5], + "big_integers"=> &[1, 10000002, 3, 10000004, 10000005], + "floats"=> &[4.0, 5.0, 6.0, 7.0, 8.0], + "floats_with_decimal"=> &[4.532, 5.5, 6.5, 7.5, 8.5], + )?; + + println!("{}", &df); + // --8<-- [end:dfnum] + + // --8<-- [start:castnum] + let out = df + .clone() + .lazy() + .select([ + col("integers") + .cast(DataType::Float32) + .alias("integers_as_floats"), + col("floats") + .cast(DataType::Int32) + .alias("floats_as_integers"), + col("floats_with_decimal") + .cast(DataType::Int32) + .alias("floats_with_decimal_as_integers"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:castnum] + + // --8<-- [start:downcast] + let out = df + .clone() + .lazy() + .select([ + col("integers") + .cast(DataType::Int16) + .alias("integers_smallfootprint"), + col("floats") + .cast(DataType::Float32) + .alias("floats_smallfootprint"), + ]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:downcast] + + // --8<-- [start:overflow] + + let out = df + .clone() + .lazy() + .select([col("big_integers").strict_cast(DataType::Int8)]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:overflow] + + // --8<-- [start:overflow2] + let out = df + .clone() + .lazy() + .select([col("big_integers").cast(DataType::Int8)]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:overflow2] + + // --8<-- [start:strings] + + let df = df! ( + "integers" => &[1, 2, 3, 4, 5], + "float" => &[4.0, 5.03, 6.0, 7.0, 8.0], + "floats_as_string" => &["4.0", "5.0", "6.0", "7.0", "8.0"], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("integers").cast(DataType::Utf8), + col("float").cast(DataType::Utf8), + col("floats_as_string").cast(DataType::Float64), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:strings] + + // --8<-- [start:strings2] + + let df = df! ("strings_not_float"=> ["4.0", "not_a_number", "6.0", "7.0", "8.0"])?; + + let out = df + .clone() + .lazy() + .select([col("strings_not_float").cast(DataType::Float64)]) + .collect(); + match out { + Ok(out) => println!("{}", &out), + Err(e) => println!("{:?}", e), + }; + // --8<-- [end:strings2] + + // --8<-- [start:bool] + + let df = df! ( + "integers"=> &[-1, 0, 2, 3, 4], + "floats"=> &[0.0, 1.0, 2.0, 3.0, 4.0], + "bools"=> &[true, false, true, false, true], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("integers").cast(DataType::Boolean), + col("floats").cast(DataType::Boolean), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:bool] + + // --8<-- [start:dates] + + use chrono::prelude::*; + use polars::time::*; + + let df = df! ( + "date" => date_range( + "date", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None + )?.cast(&DataType::Date)?, + "datetime" => datetime_range( + "datetime", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, + None + )?, + )?; + + let out = df + .clone() + .lazy() + .select([ + col("date").cast(DataType::Int64), + col("datetime").cast(DataType::Int64), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:dates] + + // --8<-- [start:dates2] + + let df = df! ( + "date" => date_range("date", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 5).unwrap().and_hms_opt(0, 0, 0).unwrap(), Duration::parse("1d"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + "string" => &[ + "2022-01-01", + "2022-01-02", + "2022-01-03", + "2022-01-04", + "2022-01-05", + ], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("date").dt().strftime("%Y-%m-%d"), + col("string").str().strptime( + DataType::Datetime(TimeUnit::Microseconds, None), + StrptimeOptions::default(), + ), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:dates2] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/column-selections.rs b/docs/src/rust/user-guide/expressions/column-selections.rs new file mode 100644 index 000000000000..105cc6f102df --- /dev/null +++ b/docs/src/rust/user-guide/expressions/column-selections.rs @@ -0,0 +1,99 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:selectors_df] + + use chrono::prelude::*; + use polars::time::*; + + let df = df!( + "id" => &[9, 4, 2], + "place" => &["Mars", "Earth", "Saturn"], + "date" => date_range("date", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 3).unwrap().and_hms_opt(0, 0, 0).unwrap(), Duration::parse("1d"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + "sales" => &[33.4, 2142134.1, 44.7], + "has_people" => &[false, true, false], + "logged_at" => date_range("logged_at", + NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), NaiveDate::from_ymd_opt(2022, 1, 1).unwrap().and_hms_opt(0, 0, 2).unwrap(), Duration::parse("1s"),ClosedWindow::Both, TimeUnit::Milliseconds, None)?, + )? + .with_row_count("rn", None)?; + println!("{}", &df); + // --8<-- [end:selectors_df] + + // --8<-- [start:all] + let out = df.clone().lazy().select([col("*")]).collect()?; + + // Is equivalent to + let out = df.clone().lazy().select([all()]).collect()?; + println!("{}", &out); + // --8<-- [end:all] + + // --8<-- [start:exclude] + let out = df + .clone() + .lazy() + .select([col("*").exclude(["logged_at", "rn"])]) + .collect()?; + println!("{}", &out); + // --8<-- [end:exclude] + + // --8<-- [start:expansion_by_names] + let out = df + .clone() + .lazy() + .select([cols(["date", "logged_at"]).dt().to_string("%Y-%h-%d")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:expansion_by_names] + + // --8<-- [start:expansion_by_regex] + let out = df.clone().lazy().select([col("^.*(as|sa).*$")]).collect()?; + println!("{}", &out); + // --8<-- [end:expansion_by_regex] + + // --8<-- [start:expansion_by_dtype] + let out = df + .clone() + .lazy() + .select([dtype_cols([DataType::Int64, DataType::UInt32, DataType::Boolean]).n_unique()]) + .collect()?; + // gives different result than python as the id col is i32 in rust + println!("{}", &out); + // --8<-- [end:expansion_by_dtype] + + // --8<-- [start:selectors_intro] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_intro] + + // --8<-- [start:selectors_diff] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_diff] + + // --8<-- [start:selectors_union] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_union] + + // --8<-- [start:selectors_by_name] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/1059 + // --8<-- [end:selectors_by_name] + + // --8<-- [start:selectors_to_expr] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_to_expr] + + // --8<-- [start:selectors_is_selector_utility] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_is_selector_utility] + + // --8<-- [start:selectors_colnames_utility] + // Not available in Rust, refer the following link + // https://github.com/pola-rs/polars/issues/10594 + // --8<-- [end:selectors_colnames_utility] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/folds.rs b/docs/src/rust/user-guide/expressions/folds.rs new file mode 100644 index 000000000000..b851557f8e37 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/folds.rs @@ -0,0 +1,49 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + + // --8<-- [start:mansum] + let df = df!( + "a" => &[1, 2, 3], + "b" => &[10, 20, 30], + )?; + + let out = df + .lazy() + .select([fold_exprs(lit(0), |acc, x| Ok(Some(acc + x)), [col("*")]).alias("sum")]) + .collect()?; + println!("{}", out); + // --8<-- [end:mansum] + + // --8<-- [start:conditional] + let df = df!( + "a" => &[1, 2, 3], + "b" => &[0, 1, 2], + )?; + + let out = df + .lazy() + .filter(fold_exprs( + lit(true), + |acc, x| Some(acc.bitand(&x)), + [col("*").gt(1)], + )) + .collect()?; + println!("{}", out); + // --8<-- [end:conditional] + + // --8<-- [start:string] + let df = df!( + "a" => &["a", "b", "c"], + "b" => &[1, 2, 3], + )?; + + let out = df + .lazy() + .select([concat_str([col("a"), col("b")], "")]) + .collect()?; + println!("{:?}", out); + // --8<-- [end:string] + + Ok(()) +} \ No newline at end of file diff --git a/docs/src/rust/user-guide/expressions/functions.rs b/docs/src/rust/user-guide/expressions/functions.rs new file mode 100644 index 000000000000..490809b75557 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/functions.rs @@ -0,0 +1,79 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use rand::{thread_rng, Rng}; + + let mut arr = [0f64; 5]; + thread_rng().fill(&mut arr); + + let df = df! ( + "nrs" => &[Some(1), Some(2), Some(3), None, Some(5)], + "names" => &["foo", "ham", "spam", "egg", "spam"], + "random" => &arr, + "groups" => &["A", "A", "B", "C", "B"], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:samename] + let df_samename = df.clone().lazy().select([col("nrs") + lit(5)]).collect()?; + println!("{}", &df_samename); + // --8<-- [end:samename] + + // --8<-- [start:samenametwice] + let df_samename2 = df + .clone() + .lazy() + .select([col("nrs") + lit(5), col("nrs") - lit(5)]) + .collect(); + match df_samename2 { + Ok(df) => println!("{}", &df), + Err(e) => println!("{:?}", &e), + }; + // --8<-- [end:samenametwice] + + // --8<-- [start:samenamealias] + let df_alias = df + .clone() + .lazy() + .select([ + (col("nrs") + lit(5)).alias("nrs + 5"), + (col("nrs") - lit(5)).alias("nrs - 5"), + ]) + .collect()?; + println!("{}", &df_alias); + // --8<-- [end:samenamealias] + + // --8<-- [start:countunique] + let df_alias = df + .clone() + .lazy() + .select([ + col("names").n_unique().alias("unique"), + // Following query shows there isn't anything in Rust API + // https://docs.rs/polars/latest/polars/?search=approx_n_unique + // col("names").approx_n_unique().alias("unique_approx"), + ]) + .collect()?; + println!("{}", &df_alias); + // --8<-- [end:countunique] + + // --8<-- [start:conditional] + let df_conditional = df + .clone() + .lazy() + .select([ + col("nrs"), + when(col("nrs").gt(2)) + .then(lit(true)) + .otherwise(lit(false)) + .alias("conditional"), + ]) + .collect()?; + println!("{}", &df_conditional); + // --8<-- [end:conditional] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/lists.rs b/docs/src/rust/user-guide/expressions/lists.rs new file mode 100644 index 000000000000..257649e0cc7d --- /dev/null +++ b/docs/src/rust/user-guide/expressions/lists.rs @@ -0,0 +1,162 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] +fn main() -> Result<(), Box> { + // --8<-- [start:weather_df] + let stns: Vec = (1..6).map(|i| format!("Station {i}")).collect(); + let weather = df!( + "station"=> &stns, + "temperatures"=> &[ + "20 5 5 E1 7 13 19 9 6 20", + "18 8 16 11 23 E2 8 E2 E2 E2 90 70 40", + "19 24 E9 16 6 12 10 22", + "E2 E0 15 7 8 10 E1 24 17 13 6", + "14 8 E0 16 22 24 E1", + ], + )?; + println!("{}", &weather); + // --8<-- [end:weather_df] + + // --8<-- [start:string_to_list] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures").str().split(lit(" "))]) + .collect()?; + println!("{}", &out); + // --8<-- [end:string_to_list] + + // --8<-- [start:explode_to_atomic] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures").str().split(lit(" "))]) + .explode(["temperatures"]) + .collect()?; + println!("{}", &out); + // --8<-- [end:explode_to_atomic] + + // --8<-- [start:list_ops] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures").str().split(lit(" "))]) + .with_columns([ + col("temperatures").list().head(lit(3)).alias("top3"), + col("temperatures") + .list() + .slice(lit(-3), lit(3)) + .alias("bottom_3"), + col("temperatures").list().lengths().alias("obs"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:list_ops] + + // --8<-- [start:count_errors] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures") + .str() + .split(lit(" ")) + .list() + .eval(col("").cast(DataType::Int64).is_null(), false) + .list() + .sum() + .alias("errors")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:count_errors] + + // --8<-- [start:count_errors_regex] + let out = weather + .clone() + .lazy() + .with_columns([col("temperatures") + .str() + .split(lit(" ")) + .list() + .eval(col("").str().contains(lit("(?i)[a-z]"), false), false) + .list() + .sum() + .alias("errors")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:count_errors_regex] + + // --8<-- [start:weather_by_day] + let stns: Vec = (1..11).map(|i| format!("Station {i}")).collect(); + let weather_by_day = df!( + "station" => &stns, + "day_1" => &[17, 11, 8, 22, 9, 21, 20, 8, 8, 17], + "day_2" => &[15, 11, 10, 8, 7, 14, 18, 21, 15, 13], + "day_3" => &[16, 15, 24, 24, 8, 23, 19, 23, 16, 10], + )?; + println!("{}", &weather_by_day); + // --8<-- [end:weather_by_day] + + // --8<-- [start:weather_by_day_rank] + let rank_pct = (col("") + .rank( + RankOptions { + method: RankMethod::Average, + descending: true, + }, + None, + ) + .cast(DataType::Float32) + / col("*").count().cast(DataType::Float32)) + .round(2); + + let out = weather_by_day + .clone() + .lazy() + .with_columns( + // create the list of homogeneous data + [concat_list([all().exclude(["station"])])?.alias("all_temps")], + ) + .select( + // select all columns except the intermediate list + [ + all().exclude(["all_temps"]), + // compute the rank by calling `list.eval` + col("all_temps") + .list() + .eval(rank_pct, true) + .alias("temps_rank"), + ], + ) + .collect()?; + + println!("{}", &out); + // --8<-- [end:weather_by_day_rank] + + // --8<-- [start:array_df] + let mut col1: ListPrimitiveChunkedBuilder = + ListPrimitiveChunkedBuilder::new("Array_1", 8, 8, DataType::Int32); + col1.append_slice(&[1, 3]); + col1.append_slice(&[2, 5]); + let mut col2: ListPrimitiveChunkedBuilder = + ListPrimitiveChunkedBuilder::new("Array_2", 8, 8, DataType::Int32); + col2.append_slice(&[1, 7, 3]); + col2.append_slice(&[8, 1, 0]); + let array_df = DataFrame::new([col1.finish(), col2.finish()].into())?; + + println!("{}", &array_df); + // --8<-- [end:array_df] + + // --8<-- [start:array_ops] + let out = array_df + .clone() + .lazy() + .select([ + col("Array_1").list().min().suffix("_min"), + col("Array_2").list().sum().suffix("_sum"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:array_ops] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/null.rs b/docs/src/rust/user-guide/expressions/null.rs new file mode 100644 index 000000000000..8d78310cb0a9 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/null.rs @@ -0,0 +1,89 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + + let df = df! ( + "value" => &[Some(1), None], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:count] + let null_count_df = df.null_count(); + println!("{}", &null_count_df); + // --8<-- [end:count] + + // --8<-- [start:isnull] + let is_null_series = df + .clone() + .lazy() + .select([col("value").is_null()]) + .collect()?; + println!("{}", &is_null_series); + // --8<-- [end:isnull] + + // --8<-- [start:dataframe2] + let df = df!( + "col1" => &[Some(1), Some(2), Some(3)], + "col2" => &[Some(1), None, Some(3)], + + )?; + println!("{}", &df); + // --8<-- [end:dataframe2] + + // --8<-- [start:fill] + let fill_literal_df = df + .clone() + .lazy() + .with_columns([col("col2").fill_null(lit(2))]) + .collect()?; + println!("{}", &fill_literal_df); + // --8<-- [end:fill] + + // --8<-- [start:fillstrategy] + let fill_forward_df = df + .clone() + .lazy() + .with_columns([col("col2").forward_fill(None)]) + .collect()?; + println!("{}", &fill_forward_df); + // --8<-- [end:fillstrategy] + + // --8<-- [start:fillexpr] + let fill_median_df = df + .clone() + .lazy() + .with_columns([col("col2").fill_null(median("col2"))]) + .collect()?; + println!("{}", &fill_median_df); + // --8<-- [end:fillexpr] + + // --8<-- [start:fillinterpolate] + let fill_interpolation_df = df + .clone() + .lazy() + .with_columns([col("col2").interpolate(InterpolationMethod::Linear)]) + .collect()?; + println!("{}", &fill_interpolation_df); + // --8<-- [end:fillinterpolate] + + // --8<-- [start:nan] + let nan_df = df!( + "value" => [1.0, f64::NAN, f64::NAN, 3.0], + )?; + println!("{}", &nan_df); + // --8<-- [end:nan] + + // --8<-- [start:nanfill] + let mean_nan_df = nan_df + .clone() + .lazy() + .with_columns([col("value").fill_nan(lit(NULL)).alias("value")]) + .mean() + .collect()?; + println!("{}", &mean_nan_df); + // --8<-- [end:nanfill] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/operators.rs b/docs/src/rust/user-guide/expressions/operators.rs new file mode 100644 index 000000000000..868d301c2182 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/operators.rs @@ -0,0 +1,54 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + use rand::{thread_rng, Rng}; + + let mut arr = [0f64; 5]; + thread_rng().fill(&mut arr); + + let df = df! ( + "nrs" => &[Some(1), Some(2), Some(3), None, Some(5)], + "names" => &[Some("foo"), Some("ham"), Some("spam"), Some("eggs"), None], + "random" => &arr, + "groups" => &["A", "A", "B", "C", "B"], + )?; + + println!("{}", &df); + // --8<-- [end:dataframe] + + // --8<-- [start:numerical] + let df_numerical = df + .clone() + .lazy() + .select([ + (col("nrs") + lit(5)).alias("nrs + 5"), + (col("nrs") - lit(5)).alias("nrs - 5"), + (col("nrs") * col("random")).alias("nrs * random"), + (col("nrs") / col("random")).alias("nrs / random"), + ]) + .collect()?; + println!("{}", &df_numerical); + // --8<-- [end:numerical] + + // --8<-- [start:logical] + let df_logical = df + .clone() + .lazy() + .select([ + col("nrs").gt(1).alias("nrs > 1"), + col("random").lt_eq(0.5).alias("random < .5"), + col("nrs").neq(1).alias("nrs != 1"), + col("nrs").eq(1).alias("nrs == 1"), + (col("random").lt_eq(0.5)) + .and(col("nrs").gt(1)) + .alias("and_expr"), // and + (col("random").lt_eq(0.5)) + .or(col("nrs").gt(1)) + .alias("or_expr"), // or + ]) + .collect()?; + println!("{}", &df_logical); + // --8<-- [end:logical] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/strings.rs b/docs/src/rust/user-guide/expressions/strings.rs new file mode 100644 index 000000000000..f3020e4fa2ce --- /dev/null +++ b/docs/src/rust/user-guide/expressions/strings.rs @@ -0,0 +1,93 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df! ( + "animal" => &[Some("Crab"), Some("cat and dog"), Some("rab$bit"), None], + )?; + + let out = df + .clone() + .lazy() + .select([ + col("animal").str().lengths().alias("byte_count"), + col("animal").str().n_chars().alias("letter_count"), + ]) + .collect()?; + + println!("{}", &out); + // --8<-- [end:df] + + // --8<-- [start:existence] + let out = df + .clone() + .lazy() + .select([ + col("animal"), + col("animal") + .str() + .contains(lit("cat|bit"), false) + .alias("regex"), + col("animal") + .str() + .contains_literal(lit("rab$")) + .alias("literal"), + col("animal") + .str() + .starts_with(lit("rab")) + .alias("starts_with"), + col("animal").str().ends_with(lit("dog")).alias("ends_with"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:existence] + + // --8<-- [start:extract] + let df = df!( + "a" => &[ + "http://vote.com/ballon_dor?candidate=messi&ref=polars", + "http://vote.com/ballon_dor?candidat=jorginho&ref=polars", + "http://vote.com/ballon_dor?candidate=ronaldo&ref=polars", + ] + )?; + let out = df + .clone() + .lazy() + .select([col("a").str().extract(r"candidate=(\w+)", 1)]) + .collect()?; + println!("{}", &out); + // --8<-- [end:extract] + + // --8<-- [start:extract_all] + let df = df!("foo"=> &["123 bla 45 asd", "xyz 678 910t"])?; + let out = df + .clone() + .lazy() + .select([col("foo") + .str() + .extract_all(lit(r"(\d+)")) + .alias("extracted_nrs")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:extract_all] + + // --8<-- [start:replace] + let df = df!("id"=> &[1, 2], "text"=> &["123abc", "abc456"])?; + let out = df + .clone() + .lazy() + .with_columns([ + col("text").str().replace(lit(r"abc\b"), lit("ABC"), false), + col("text") + .str() + .replace_all(lit("a"), lit("-"), false) + .alias("text_replace_all"), + ]) + .collect()?; + println!("{}", &out); + // --8<-- [end:replace] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/structs.rs b/docs/src/rust/user-guide/expressions/structs.rs new file mode 100644 index 000000000000..662e264222a6 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/structs.rs @@ -0,0 +1,99 @@ +// --8<-- [start:setup] +use polars::{lazy::dsl::count, prelude::*}; +// --8<-- [end:setup] +fn main() -> Result<(), Box> { + // --8<-- [start:ratings_df] + let ratings = df!( + "Movie"=> &["Cars", "IT", "ET", "Cars", "Up", "IT", "Cars", "ET", "Up", "ET"], + "Theatre"=> &["NE", "ME", "IL", "ND", "NE", "SD", "NE", "IL", "IL", "SD"], + "Avg_Rating"=> &[4.5, 4.4, 4.6, 4.3, 4.8, 4.7, 4.7, 4.9, 4.7, 4.6], + "Count"=> &[30, 27, 26, 29, 31, 28, 28, 26, 33, 26], + + )?; + println!("{}", &ratings); + // --8<-- [end:ratings_df] + + // --8<-- [start:state_value_counts] + let out = ratings + .clone() + .lazy() + .select([col("Theatre").value_counts(true, true)]) + .collect()?; + println!("{}", &out); + // --8<-- [end:state_value_counts] + + // --8<-- [start:struct_unnest] + let out = ratings + .clone() + .lazy() + .select([col("Theatre").value_counts(true, true)]) + .unnest(["Theatre"]) + .collect()?; + println!("{}", &out); + // --8<-- [end:struct_unnest] + + // --8<-- [start:series_struct] + // Don't think we can make it the same way in rust, but this works + let rating_series = df!( + "Movie" => &["Cars", "Toy Story"], + "Theatre" => &["NE", "ME"], + "Avg_Rating" => &[4.5, 4.9], + )? + .into_struct("ratings") + .into_series(); + println!("{}", &rating_series); + // // --8<-- [end:series_struct] + + // --8<-- [start:series_struct_extract] + let out = rating_series.struct_()?.field_by_name("Movie")?; + println!("{}", &out); + // --8<-- [end:series_struct_extract] + + // --8<-- [start:series_struct_rename] + let out = DataFrame::new([rating_series].into())? + .lazy() + .select([col("ratings") + .struct_() + .rename_fields(["Film".into(), "State".into(), "Value".into()].to_vec())]) + .unnest(["ratings"]) + .collect()?; + + println!("{}", &out); + // --8<-- [end:series_struct_rename] + + // --8<-- [start:struct_duplicates] + let out = ratings + .clone() + .lazy() + // .filter(as_struct(&[col("Movie"), col("Theatre")]).is_duplicated()) + // Error: .is_duplicated() not available if you try that + // https://github.com/pola-rs/polars/issues/3803 + .filter(count().over([col("Movie"), col("Theatre")]).gt(lit(1))) + .collect()?; + println!("{}", &out); + // --8<-- [end:struct_duplicates] + + // --8<-- [start:struct_ranking] + let out = ratings + .clone() + .lazy() + .with_columns([as_struct(&[col("Count"), col("Avg_Rating")]) + .rank( + RankOptions { + method: RankMethod::Dense, + descending: false, + }, + None, + ) + .over([col("Movie"), col("Theatre")]) + .alias("Rank")]) + // .filter(as_struct(&[col("Movie"), col("Theatre")]).is_duplicated()) + // Error: .is_duplicated() not available if you try that + // https://github.com/pola-rs/polars/issues/3803 + .filter(count().over([col("Movie"), col("Theatre")]).gt(lit(1))) + .collect()?; + println!("{}", &out); + // --8<-- [end:struct_ranking] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/user-defined-functions.rs b/docs/src/rust/user-guide/expressions/user-defined-functions.rs new file mode 100644 index 000000000000..7cbe1605f3e3 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/user-defined-functions.rs @@ -0,0 +1,84 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:dataframe] + let df = df!( + "keys" => &["a", "a", "b"], + "values" => &[10, 7, 1], + )?; + + let out = df + .lazy() + .group_by(["keys"]) + .agg([ + col("values") + .map(|s| Ok(s.shift(1)), GetOutput::default()) + .alias("shift_map"), + col("values").shift(1).alias("shift_expression"), + ]) + .collect()?; + + println!("{}", out); + // --8<-- [end:dataframe] + + // --8<-- [start:apply] + let out = df + .clone() + .lazy() + .group_by([col("keys")]) + .agg([ + col("values") + .apply(|s| Ok(s.shift(1)), GetOutput::default()) + .alias("shift_map"), + col("values").shift(1).alias("shift_expression"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:apply] + + // --8<-- [start:counter] + + // --8<-- [end:counter] + + // --8<-- [start:combine] + let out = df + .lazy() + .select([ + // pack to struct to get access to multiple fields in a custom `apply/map` + as_struct(&[col("keys"), col("values")]) + // we will compute the len(a) + b + .apply( + |s| { + // downcast to struct + let ca = s.struct_()?; + + // get the fields as Series + let s_a = &ca.fields()[0]; + let s_b = &ca.fields()[1]; + + // downcast the `Series` to their known type + let ca_a = s_a.utf8()?; + let ca_b = s_b.i32()?; + + // iterate both `ChunkedArrays` + let out: Int32Chunked = ca_a + .into_iter() + .zip(ca_b) + .map(|(opt_a, opt_b)| match (opt_a, opt_b) { + (Some(a), Some(b)) => Some(a.len() as i32 + b), + _ => None, + }) + .collect(); + + Ok(out.into_series()) + }, + GetOutput::from_type(DataType::Int32), + ) + .alias("solution_apply"), + (col("keys").str().count_match(".") + col("values")).alias("solution_expr"), + ]) + .collect()?; + println!("{}", out); + // --8<-- [end:combine] + Ok(()) +} diff --git a/docs/src/rust/user-guide/expressions/window.rs b/docs/src/rust/user-guide/expressions/window.rs new file mode 100644 index 000000000000..2fcc32cdc309 --- /dev/null +++ b/docs/src/rust/user-guide/expressions/window.rs @@ -0,0 +1,131 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box> { + // --8<-- [start:pokemon] + use polars::prelude::*; + use reqwest::blocking::Client; + + let data: Vec = Client::new() + .get("https://gist.githubusercontent.com/ritchie46/cac6b337ea52281aa23c049250a4ff03/raw/89a957ff3919d90e6ef2d34235e6bf22304f3366/pokemon.csv") + .send()? + .text()? + .bytes() + .collect(); + + let df = CsvReader::new(std::io::Cursor::new(data)) + .has_header(true) + .finish()?; + + println!("{}", df); + // --8<-- [end:pokemon] + + // --8<-- [start:group_by] + let out = df + .clone() + .lazy() + .select([ + col("Type 1"), + col("Type 2"), + col("Attack") + .mean() + .over(["Type 1"]) + .alias("avg_attack_by_type"), + col("Defense") + .mean() + .over(["Type 1", "Type 2"]) + .alias("avg_defense_by_type_combination"), + col("Attack").mean().alias("avg_attack"), + ]) + .collect()?; + + println!("{}", out); + // --8<-- [end:group_by] + + // --8<-- [start:operations] + let filtered = df + .clone() + .lazy() + .filter(col("Type 2").eq(lit("Psychic"))) + .select([col("Name"), col("Type 1"), col("Speed")]) + .collect()?; + + println!("{}", filtered); + // --8<-- [end:operations] + + // --8<-- [start:sort] + let out = filtered + .lazy() + .with_columns([cols(["Name", "Speed"]) + .sort_by(["Speed"], [true]) + .over(["Type 1"])]) + .collect()?; + println!("{}", out); + // --8<-- [end:sort] + + // --8<-- [start:rules] + // aggregate and broadcast within a group + // output type: -> i32 + sum("foo").over([col("groups")]) + // sum within a group and multiply with group elements + // output type: -> i32 + (col("x").sum() * col("y")) + .over([col("groups")]) + .alias("x1") + // sum within a group and multiply with group elements + // and aggregate the group to a list + // output type: -> ChunkedArray + (col("x").sum() * col("y")) + .list() + .over([col("groups")]) + .alias("x2") + // note that it will require an explicit `list()` call + // sum within a group and multiply with group elements + // and aggregate the group to a list + // the flatten call explodes that list + + // This is the fastest method to do things over groups when the groups are sorted + (col("x").sum() * col("y")) + .list() + .over([col("groups")]) + .flatten() + .alias("x3"); + // --8<-- [end:rules] + + // --8<-- [start:examples] + let out = df + .clone() + .lazy() + .select([ + col("Type 1") + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten(), + col("Name") + .sort_by(["Speed"], [true]) + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten() + .alias("fastest/group"), + col("Name") + .sort_by(["Attack"], [true]) + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten() + .alias("strongest/group"), + col("Name") + .sort(false) + .head(Some(3)) + .list() + .over(["Type 1"]) + .flatten() + .alias("sorted_by_alphabet"), + ]) + .collect()?; + println!("{:?}", out); + // --8<-- [end:examples] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/io/aws.rs b/docs/src/rust/user-guide/io/aws.rs new file mode 100644 index 000000000000..0a1924d9d294 --- /dev/null +++ b/docs/src/rust/user-guide/io/aws.rs @@ -0,0 +1,32 @@ +""" +# --8<-- [start:bucket] +use aws_sdk_s3::Region; + +use aws_config::meta::region::RegionProviderChain; +use aws_sdk_s3::Client; +use std::borrow::Cow; + +use polars::prelude::*; + +#[tokio::main] +async fn main() { + let bucket = ""; + let path = ""; + + let config = aws_config::from_env().load().await; + let client = Client::new(&config); + + let req = client.get_object().bucket(bucket).key(path); + + let res = req.clone().send().await.unwrap(); + let bytes = res.body.collect().await.unwrap(); + let bytes = bytes.into_bytes(); + + let cursor = std::io::Cursor::new(bytes); + + let df = CsvReader::new(cursor).finish().unwrap(); + + println!("{:?}", df); +} +# --8<-- [end:bucket] +""" diff --git a/docs/src/rust/user-guide/io/csv.rs b/docs/src/rust/user-guide/io/csv.rs new file mode 100644 index 000000000000..7c56d813e626 --- /dev/null +++ b/docs/src/rust/user-guide/io/csv.rs @@ -0,0 +1,29 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box>{ + + """ + // --8<-- [start:read] + use polars::prelude::*; + + let df = CsvReader::from_path("docs/data/path.csv").unwrap().finish().unwrap(); + // --8<-- [end:read] + """ + + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/data/path.csv").unwrap(); + CsvWriter::new(&mut file).finish(&mut df).unwrap(); + // --8<-- [end:write] + + // --8<-- [start:scan] + let df = LazyCsvReader::new("./test.csv").finish().unwrap(); + // --8<-- [end:scan] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/io/json-file.rs b/docs/src/rust/user-guide/io/json-file.rs new file mode 100644 index 000000000000..ab4df729c955 --- /dev/null +++ b/docs/src/rust/user-guide/io/json-file.rs @@ -0,0 +1,47 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box>{ + + """ + // --8<-- [start:read] + use polars::prelude::*; + + let mut file = std::fs::File::open("docs/data/path.json").unwrap(); + let df = JsonReader::new(&mut file).finish().unwrap(); + // --8<-- [end:read] + + + // --8<-- [start:readnd] + let mut file = std::fs::File::open("docs/data/path.json").unwrap(); + let df = JsonLineReader::new(&mut file).finish().unwrap(); + // --8<-- [end:readnd] + """ + + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/data/path.json").unwrap(); + + // json + JsonWriter::new(&mut file) + .with_json_format(JsonFormat::Json) + .finish(&mut df) + .unwrap(); + + // ndjson + JsonWriter::new(&mut file) + .with_json_format(JsonFormat::JsonLines) + .finish(&mut df) + .unwrap(); + // --8<-- [end:write] + + // --8<-- [start:scan] + let df = LazyJsonLineReader::new("docs/data/path.json".to_string()).finish().unwrap(); + // --8<-- [end:scan] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/io/parquet.rs b/docs/src/rust/user-guide/io/parquet.rs new file mode 100644 index 000000000000..f3469ffd4e2c --- /dev/null +++ b/docs/src/rust/user-guide/io/parquet.rs @@ -0,0 +1,30 @@ +use polars::prelude::*; + +fn main() -> Result<(), Box>{ + + """ + // --8<-- [start:read] + let mut file = std::fs::File::open("docs/data/path.parquet").unwrap(); + + let df = ParquetReader::new(&mut file).finish().unwrap(); + // --8<-- [end:read] + """ + + // --8<-- [start:write] + let mut df = df!( + "foo" => &[1, 2, 3], + "bar" => &[None, Some("bak"), Some("baz")], + ) + .unwrap(); + + let mut file = std::fs::File::create("docs/data/path.parquet").unwrap(); + ParquetWriter::new(&mut file).finish(&mut df).unwrap(); + // --8<-- [end:write] + + // --8<-- [start:scan] + let args = ScanArgsParquet::default(); + let df = LazyFrame::scan_parquet("./file.parquet",args).unwrap(); + // --8<-- [end:scan] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/concatenation.rs b/docs/src/rust/user-guide/transformations/concatenation.rs new file mode 100644 index 000000000000..ecb9dba877a6 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/concatenation.rs @@ -0,0 +1,49 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:vertical] + let df_v1 = df!( + "a"=> &[1], + "b"=> &[3], + )?; + let df_v2 = df!( + "a"=> &[2], + "b"=> &[4], + )?; + let df_vertical_concat = concat( + [df_v1.clone().lazy(), df_v2.clone().lazy()], + UnionArgs::default(), + )? + .collect()?; + println!("{}", &df_vertical_concat); + // --8<-- [end:vertical] + + // --8<-- [start:horizontal] + let df_h1 = df!( + "l1"=> &[1, 2], + "l2"=> &[3, 4], + )?; + let df_h2 = df!( + "r1"=> &[5, 6], + "r2"=> &[7, 8], + "r3"=> &[9, 10], + )?; + let df_horizontal_concat = polars::functions::hor_concat_df(&[df_h1, df_h2])?; + println!("{}", &df_horizontal_concat); + // --8<-- [end:horizontal] + + // --8<-- [start:cross] + let df_d1 = df!( + "a"=> &[1], + "b"=> &[3], + )?; + let df_d2 = df!( + "a"=> &[2], + "d"=> &[4],)?; + let df_diagonal_concat = polars::functions::diag_concat_df(&[df_d1, df_d2])?; + println!("{}", &df_diagonal_concat); + // --8<-- [end:cross] + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/joins.rs b/docs/src/rust/user-guide/transformations/joins.rs new file mode 100644 index 000000000000..aa444c5d9a1a --- /dev/null +++ b/docs/src/rust/user-guide/transformations/joins.rs @@ -0,0 +1,205 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:innerdf] + let df_customers = df! ( + + "customer_id" => &[1, 2, 3], + "name" => &["Alice", "Bob", "Charlie"], + )?; + + println!("{}", &df_customers); + // --8<-- [end:innerdf] + + // --8<-- [start:innerdf2] + let df_orders = df!( + "order_id"=> &["a", "b", "c"], + "customer_id"=> &[1, 2, 2], + "amount"=> &[100, 200, 300], + )?; + println!("{}", &df_orders); + // --8<-- [end:innerdf2] + + // --8<-- [start:inner] + let df_inner_customer_join = df_customers + .clone() + .lazy() + .join( + df_orders.clone().lazy(), + [col("customer_id")], + [col("customer_id")], + JoinArgs::new(JoinType::Inner), + ) + .collect()?; + println!("{}", &df_inner_customer_join); + // --8<-- [end:inner] + + // --8<-- [start:left] + let df_left_join = df_customers + .clone() + .lazy() + .join( + df_orders.clone().lazy(), + [col("customer_id")], + [col("customer_id")], + JoinArgs::new(JoinType::Left), + ) + .collect()?; + println!("{}", &df_left_join); + // --8<-- [end:left] + + // --8<-- [start:outer] + let df_outer_join = df_customers + .clone() + .lazy() + .join( + df_orders.clone().lazy(), + [col("customer_id")], + [col("customer_id")], + JoinArgs::new(JoinType::Outer), + ) + .collect()?; + println!("{}", &df_outer_join); + // --8<-- [end:outer] + + // --8<-- [start:df3] + let df_colors = df!( + "color"=> &["red", "blue", "green"], + )?; + println!("{}", &df_colors); + // --8<-- [end:df3] + + // --8<-- [start:df4] + let df_sizes = df!( + "size"=> &["S", "M", "L"], + )?; + println!("{}", &df_sizes); + // --8<-- [end:df4] + + // --8<-- [start:cross] + let df_cross_join = df_colors + .clone() + .lazy() + .cross_join(df_sizes.clone().lazy()) + .collect()?; + println!("{}", &df_cross_join); + // --8<-- [end:cross] + + // --8<-- [start:df5] + let df_cars = df!( + "id"=> &["a", "b", "c"], + "make"=> &["ford", "toyota", "bmw"], + )?; + println!("{}", &df_cars); + // --8<-- [end:df5] + + // --8<-- [start:df6] + let df_repairs = df!( + "id"=> &["c", "c"], + "cost"=> &[100, 200], + )?; + println!("{}", &df_repairs); + // --8<-- [end:df6] + + // --8<-- [start:inner2] + let df_inner_join = df_cars + .clone() + .lazy() + .inner_join(df_repairs.clone().lazy(), col("id"), col("id")) + .collect()?; + println!("{}", &df_inner_join); + // --8<-- [end:inner2] + + // --8<-- [start:semi] + let df_semi_join = df_cars + .clone() + .lazy() + .join( + df_repairs.clone().lazy(), + [col("id")], + [col("id")], + JoinArgs::new(JoinType::Semi), + ) + .collect()?; + println!("{}", &df_semi_join); + // --8<-- [end:semi] + + // --8<-- [start:anti] + let df_anti_join = df_cars + .clone() + .lazy() + .join( + df_repairs.clone().lazy(), + [col("id")], + [col("id")], + JoinArgs::new(JoinType::Anti), + ) + .collect()?; + println!("{}", &df_anti_join); + // --8<-- [end:anti] + + // --8<-- [start:df7] + use chrono::prelude::*; + let df_trades = df!( + "time"=> &[ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 1, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 3, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock"=> &["A", "B", "B", "C"], + "trade"=> &[101, 299, 301, 500], + )?; + println!("{}", &df_trades); + // --8<-- [end:df7] + + // --8<-- [start:df8] + let df_quotes = df!( + "time"=> &[ + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 2, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 4, 0).unwrap(), + NaiveDate::from_ymd_opt(2020, 1, 1).unwrap().and_hms_opt(9, 6, 0).unwrap(), + ], + "stock"=> &["A", "B", "C", "A"], + "quote"=> &[100, 300, 501, 102], + )?; + + println!("{}", &df_quotes); + // --8<-- [end:df8] + + // --8<-- [start:asofpre] + let df_trades = df_trades.sort(["time"], false, true).unwrap(); + let df_quotes = df_quotes.sort(["time"], false, true).unwrap(); + // --8<-- [end:asofpre] + + // --8<-- [start:asof] + let df_asof_join = df_trades.join_asof_by( + &df_quotes, + "time", + "time", + ["stock"], + ["stock"], + AsofStrategy::Backward, + None, + )?; + println!("{}", &df_asof_join); + // --8<-- [end:asof] + + // --8<-- [start:asof2] + let df_asof_tolerance_join = df_trades.join_asof_by( + &df_quotes, + "time", + "time", + ["stock"], + ["stock"], + AsofStrategy::Backward, + Some(AnyValue::Duration(60000, TimeUnit::Milliseconds)), + )?; + println!("{}", &df_asof_tolerance_join); + // --8<-- [end:asof2] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/melt.rs b/docs/src/rust/user-guide/transformations/melt.rs new file mode 100644 index 000000000000..ff797423d293 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/melt.rs @@ -0,0 +1,21 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "A"=> &["a", "b", "a"], + "B"=> &[1, 3, 5], + "C"=> &[10, 11, 12], + "D"=> &[2, 4, 6], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:melt] + let out = df.melt(["A", "B"], ["C", "D"])?; + println!("{}", &out); + // --8<-- [end:melt] + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/pivot.rs b/docs/src/rust/user-guide/transformations/pivot.rs new file mode 100644 index 000000000000..e632f095f31b --- /dev/null +++ b/docs/src/rust/user-guide/transformations/pivot.rs @@ -0,0 +1,28 @@ +// --8<-- [start:setup] +use polars::prelude::{pivot::pivot, *}; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "foo"=> ["A", "A", "B", "B", "C"], + "N"=> [1, 2, 2, 4, 2], + "bar"=> ["k", "l", "m", "n", "o"], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:eager] + let out = pivot(&df, ["N"], ["foo"], ["bar"], false, None, None)?; + println!("{}", &out); + // --8<-- [end:eager] + + // --8<-- [start:lazy] + let q = df.lazy(); + let q2 = pivot(&q.collect()?, ["N"], ["foo"], ["bar"], false, None, None)?.lazy(); + let out = q2.collect()?; + println!("{}", &out); + // --8<-- [end:lazy] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/filter.rs b/docs/src/rust/user-guide/transformations/time-series/filter.rs new file mode 100644 index 000000000000..6e5b2175b81c --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/filter.rs @@ -0,0 +1,61 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::lazy::dsl::StrptimeOptions; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(true) + .finish() + .unwrap(); + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:filter] + let filtered_df = df + .clone() + .lazy() + .filter(col("Date").eq(lit(NaiveDate::from_ymd_opt(1995, 10, 16).unwrap()))) + .collect()?; + println!("{}", &filtered_df); + // --8<-- [end:filter] + + // --8<-- [start:range] + let filtered_range_df = df + .clone() + .lazy() + .filter( + col("Date") + .gt(lit(NaiveDate::from_ymd_opt(1995, 7, 1).unwrap())) + .and(col("Date").lt(lit(NaiveDate::from_ymd_opt(1995, 11, 1).unwrap()))), + ) + .collect()?; + println!("{}", &filtered_range_df); + // --8<-- [end:range] + + // --8<-- [start:negative] + let negative_dates_df = df!( + "ts"=> &["-1300-05-23", "-1400-03-02"], + "values"=> &[3, 4])? + .lazy() + .with_column( + col("ts") + .str() + .strptime(DataType::Date, StrptimeOptions::default()), + ) + .collect()?; + + let negative_dates_filtered_df = negative_dates_df + .clone() + .lazy() + .filter(col("ts").dt().year().lt(-1300)) + .collect()?; + println!("{}", &negative_dates_filtered_df); + // --8<-- [end:negative] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/parsing.rs b/docs/src/rust/user-guide/transformations/time-series/parsing.rs new file mode 100644 index 000000000000..275ed0bf0e6a --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/parsing.rs @@ -0,0 +1,75 @@ +// --8<-- [start:setup] +use polars::io::prelude::*; +use polars::lazy::dsl::StrptimeOptions; +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(true) + .finish() + .unwrap(); + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:cast] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(false) + .finish() + .unwrap(); + let df = df + .clone() + .lazy() + .with_columns([col("Date") + .str() + .strptime(DataType::Date, StrptimeOptions::default())]) + .collect()?; + println!("{}", &df); + // --8<-- [end:cast] + + // --8<-- [start:df3] + let df_with_year = df + .clone() + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:df3] + + // --8<-- [start:extract] + let df_with_year = df + .clone() + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:extract] + + // --8<-- [start:mixed] + let data = [ + "2021-03-27T00:00:00+0100", + "2021-03-28T00:00:00+0100", + "2021-03-29T00:00:00+0200", + "2021-03-30T00:00:00+0200", + ]; + let q = col("date") + .str() + .strptime( + DataType::Datetime(TimeUnit::Microseconds, None), + StrptimeOptions { + format: Some("%Y-%m-%dT%H:%M:%S%z".to_string()), + ..Default::default() + }, + ) + .dt() + .convert_time_zone("Europe/Brussels".to_string()); + let mixed_parsed = df!("date" => &data)?.lazy().select([q]).collect()?; + + println!("{}", &mixed_parsed); + // --8<-- [end:mixed] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/resampling.rs b/docs/src/rust/user-guide/transformations/time-series/resampling.rs new file mode 100644 index 000000000000..60888c264e12 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/resampling.rs @@ -0,0 +1,43 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::prelude::*; +use polars::time::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = df!( + "time" => date_range( + "time", + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(3, 0, 0).unwrap(), + Duration::parse("30m"), + ClosedWindow::Both, + TimeUnit::Milliseconds, None)?, + "groups" => &["a", "a", "a", "b", "b", "a", "a"], + "values" => &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + )?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:upsample] + let out1 = df + .clone() + .upsample::<[String; 0]>([], "time", Duration::parse("15m"), Duration::parse("0"))? + .fill_null(FillNullStrategy::Forward(None))?; + println!("{}", &out1); + // --8<-- [end:upsample] + + // --8<-- [start:upsample2] + let out2 = df + .clone() + .upsample::<[String; 0]>([], "time", Duration::parse("15m"), Duration::parse("0"))? + .lazy() + .with_columns([col("values").interpolate(InterpolationMethod::Linear)]) + .collect()? + .fill_null(FillNullStrategy::Forward(None))?; + println!("{}", &out2); + // --8<-- [end:upsample2] + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/src/rust/user-guide/transformations/time-series/rolling.rs new file mode 100644 index 000000000000..6458eb69bdfc --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/rolling.rs @@ -0,0 +1,130 @@ +// --8<-- [start:setup] +use chrono::prelude::*; +use polars::io::prelude::*; +use polars::lazy::dsl::GetOutput; +use polars::prelude::*; +use polars::time::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:df] + let df = CsvReader::from_path("docs/data/apple_stock.csv") + .unwrap() + .with_try_parse_dates(true) + .finish() + .unwrap() + .sort(["Date"], false, true)?; + println!("{}", &df); + // --8<-- [end:df] + + // --8<-- [start:group_by] + let annual_average_df = df + .clone() + .lazy() + .groupby_dynamic( + col("Date"), + [], + DynamicGroupOptions { + every: Duration::parse("1y"), + period: Duration::parse("1y"), + offset: Duration::parse("0"), + ..Default::default() + }, + ) + .agg([col("Close").mean()]) + .collect()?; + + let df_with_year = annual_average_df + .lazy() + .with_columns([col("Date").dt().year().alias("year")]) + .collect()?; + println!("{}", &df_with_year); + // --8<-- [end:group_by] + + // --8<-- [start:group_by_dyn] + let df = df!( + "time" => date_range( + "time", + NaiveDate::from_ymd_opt(2021, 1, 1).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 31).unwrap().and_hms_opt(0, 0, 0).unwrap(), + Duration::parse("1d"), + ClosedWindow::Both, + TimeUnit::Milliseconds, None)?.cast(&DataType::Date)?)?; + + let out = df + .clone() + .lazy() + .groupby_dynamic( + col("time"), + [], + DynamicGroupOptions { + every: Duration::parse("1mo"), + period: Duration::parse("1mo"), + offset: Duration::parse("0"), + closed_window: ClosedWindow::Left, + ..Default::default() + }, + ) + .agg([ + col("time") + .cumcount(true) // python example has false + .reverse() + .head(Some(3)) + .alias("day/eom"), + ((col("time").last() - col("time").first()).map( + // had to use map as .duration().days() is not available + |s| { + Ok(Some( + s.duration()? + .into_iter() + .map(|d| d.map(|v| v / 1000 / 24 / 60 / 60)) + .collect::() + .into_series(), + )) + }, + GetOutput::from_type(DataType::Int64), + ) + lit(1)) + .alias("days_in_month"), + ]) + .explode([col("day/eom")]) + .collect()?; + println!("{}", &out); + // --8<-- [end:group_by_dyn] + + // --8<-- [start:group_by_roll] + let df = df!( + "time" => date_range( + "time", + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(0, 0, 0).unwrap(), + NaiveDate::from_ymd_opt(2021, 12, 16).unwrap().and_hms_opt(3, 0, 0).unwrap(), + Duration::parse("30m"), + ClosedWindow::Both, + TimeUnit::Milliseconds, None)?, + "groups"=> ["a", "a", "a", "b", "b", "a", "a"], + )?; + println!("{}", &df); + // --8<-- [end:group_by_roll] + + // --8<-- [start:group_by_dyn2] + let out = df + .clone() + .lazy() + .groupby_dynamic( + col("time"), + [col("groups")], + DynamicGroupOptions { + every: Duration::parse("1h"), + period: Duration::parse("1h"), + offset: Duration::parse("0"), + include_boundaries: true, + closed_window: ClosedWindow::Both, + ..Default::default() + }, + ) + .agg([count()]) + .collect()?; + println!("{}", &out); + // --8<-- [end:group_by_dyn2] + + Ok(()) +} diff --git a/docs/src/rust/user-guide/transformations/time-series/timezones.rs b/docs/src/rust/user-guide/transformations/time-series/timezones.rs new file mode 100644 index 000000000000..20f818954667 --- /dev/null +++ b/docs/src/rust/user-guide/transformations/time-series/timezones.rs @@ -0,0 +1,46 @@ +// --8<-- [start:setup] +use polars::prelude::*; +// --8<-- [end:setup] + +fn main() -> Result<(), Box> { + // --8<-- [start:example] + let ts = ["2021-03-27 03:00", "2021-03-28 03:00"]; + let tz_naive = Series::new("tz_naive", &ts); + let time_zones_df = DataFrame::new(vec![tz_naive])? + .lazy() + .select([col("tz_naive").str().strptime( + DataType::Datetime(TimeUnit::Milliseconds, None), + StrptimeOptions::default(), + )]) + .with_columns([col("tz_naive") + .dt() + .replace_time_zone(Some("UTC".to_string()), None) + .alias("tz_aware")]) + .collect()?; + + println!("{}", &time_zones_df); + // --8<-- [end:example] + + // --8<-- [start:example2] + let time_zones_operations = time_zones_df + .lazy() + .select([ + col("tz_aware") + .dt() + .replace_time_zone(Some("Europe/Brussels".to_string()), None) + .alias("replace time zone"), + col("tz_aware") + .dt() + .convert_time_zone("Asia/Kathmandu".to_string()) + .alias("convert time zone"), + col("tz_aware") + .dt() + .replace_time_zone(None, None) + .alias("unset time zone"), + ]) + .collect()?; + println!("{}", &time_zones_operations); + // --8<-- [end:example2] + + Ok(()) +} diff --git a/docs/user-guide/concepts/contexts.md b/docs/user-guide/concepts/contexts.md new file mode 100644 index 000000000000..604ff311ca63 --- /dev/null +++ b/docs/user-guide/concepts/contexts.md @@ -0,0 +1,64 @@ +# Contexts + +Polars has developed its own Domain Specific Language (DSL) for transforming data. The language is very easy to use and allows for complex queries that remain human readable. The two core components of the language are Contexts and Expressions, the latter we will cover in the next section. + +A context, as implied by the name, refers to the context in which an expression needs to be evaluated. There are three main contexts [^1]: + +1. Selection: `df.select([..])`, `df.with_columns([..])` +1. Filtering: `df.filter()` +1. Group by / Aggregation: `df.group_by(..).agg([..])` + +The examples below are performed on the following `DataFrame`: + +{{code_block('user-guide/concepts/contexts','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:setup" +--8<-- "python/user-guide/concepts/contexts.py:dataframe" +``` + +## Select + +In the `select` context the selection applies expressions over columns. The expressions in this context must produce `Series` that are all the same length or have a length of 1. + +A `Series` of a length of 1 will be broadcasted to match the height of the `DataFrame`. Note that a select may produce new columns that are aggregations, combinations of expressions, or literals. + +{{code_block('user-guide/concepts/contexts','select',['select'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:select" +``` + +As you can see from the query the `select` context is very powerful and allows you to perform arbitrary expressions independent (and in parallel) of each other. + +Similarly to the `select` statement there is the `with_columns` statement which also is an entrance to the selection context. The main difference is that `with_columns` retains the original columns and adds new ones while `select` drops the original columns. + +{{code_block('user-guide/concepts/contexts','with_columns',['with_columns'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:with_columns" +``` + +## Filter + +In the `filter` context you filter the existing dataframe based on arbitrary expression which evaluates to the `Boolean` data type. + +{{code_block('user-guide/concepts/contexts','filter',['filter'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:filter" +``` + +## Group by / aggregation + +In the `group_by` context, expressions work on groups and thus may yield results of any length (a group may have many members). + +{{code_block('user-guide/concepts/contexts','group_by',['group_by'])}} + +```python exec="on" result="text" session="user-guide/contexts" +--8<-- "python/user-guide/concepts/contexts.py:group_by" +``` + +As you can see from the result all expressions are applied to the group defined by the `group_by` context. Besides the standard `group_by`, `group_by_dynamic`, and `group_by_rolling` are also entrances to the group by context. + +[^1]: There are additional List and SQL contexts which are covered later in this guide. But for simplicity, we leave them out of scope for now. diff --git a/docs/user-guide/concepts/data-structures.md b/docs/user-guide/concepts/data-structures.md new file mode 100644 index 000000000000..1825f8bbc892 --- /dev/null +++ b/docs/user-guide/concepts/data-structures.md @@ -0,0 +1,68 @@ +# Data structures + +The core base data structures provided by Polars are `Series` and `DataFrames`. + +## Series + +Series are a 1-dimensional data structure. Within a series all elements have the same [Data Type](data-types.md) . +The snippet below shows how to create a simple named `Series` object. + +{{code_block('getting-started/series-dataframes','series',['Series'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:series" +``` + +## DataFrame + +A `DataFrame` is a 2-dimensional data structure that is backed by a `Series`, and it can be seen as an abstraction of a collection (e.g. list) of `Series`. Operations that can be executed on a `DataFrame` are very similar to what is done in a `SQL` like query. You can `GROUP BY`, `JOIN`, `PIVOT`, but also define custom functions. + +{{code_block('getting-started/series-dataframes','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:dataframe" +``` + +### Viewing data + +This part focuses on viewing data in a `DataFrame`. We will use the `DataFrame` from the previous example as a starting point. + +#### Head + +The `head` function shows by default the first 5 rows of a `DataFrame`. You can specify the number of rows you want to see (e.g. `df.head(10)`). + +{{code_block('getting-started/series-dataframes','head',['head'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:head" +``` + +#### Tail + +The `tail` function shows the last 5 rows of a `DataFrame`. You can also specify the number of rows you want to see, similar to `head`. + +{{code_block('getting-started/series-dataframes','tail',['tail'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:tail" +``` + +#### Sample + +If you want to get an impression of the data of your `DataFrame`, you can also use `sample`. With `sample` you get an _n_ number of random rows from the `DataFrame`. + +{{code_block('getting-started/series-dataframes','sample',['sample'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:sample" +``` + +#### Describe + +`Describe` returns summary statistics of your `DataFrame`. It will provide several quick statistics if possible. + +{{code_block('getting-started/series-dataframes','describe',['describe'])}} + +```python exec="on" result="text" session="getting-started/series" +--8<-- "python/getting-started/series-dataframes.py:describe" +``` diff --git a/docs/user-guide/concepts/data-types.md b/docs/user-guide/concepts/data-types.md new file mode 100644 index 000000000000..c63c9b4a37f7 --- /dev/null +++ b/docs/user-guide/concepts/data-types.md @@ -0,0 +1,31 @@ +# Data types + +`Polars` is entirely based on `Arrow` data types and backed by `Arrow` memory arrays. This makes data processing +cache-efficient and well-supported for Inter Process Communication. Most data types follow the exact implementation +from `Arrow`, with the exception of `Utf8` (this is actually `LargeUtf8`), `Categorical`, and `Object` (support is limited). The data types are: + +| Group | Type | Details | +| -------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------- | +| Numeric | `Int8` | 8-bit signed integer. | +| | `Int16` | 16-bit signed integer. | +| | `Int32` | 32-bit signed integer. | +| | `Int64` | 64-bit signed integer. | +| | `UInt8` | 8-bit unsigned integer. | +| | `UInt16` | 16-bit unsigned integer. | +| | `UInt32` | 32-bit unsigned integer. | +| | `UInt64` | 64-bit unsigned integer. | +| | `Float32` | 32-bit floating point. | +| | `Float64` | 64-bit floating point. | +| Nested | `Struct` | A struct array is represented as a `Vec` and is useful to pack multiple/heterogenous values in a single column. | +| | `List` | A list array contains a child array containing the list values and an offset array. (this is actually `Arrow` `LargeList` internally). | +| Temporal | `Date` | Date representation, internally represented as days since UNIX epoch encoded by a 32-bit signed integer. | +| | `Datetime` | Datetime representation, internally represented as microseconds since UNIX epoch encoded by a 64-bit signed integer. | +| | `Duration` | A timedelta type, internally represented as microseconds. Created when subtracting `Date/Datetime`. | +| | `Time` | Time representation, internally represented as nanoseconds since midnight. | +| Other | `Boolean` | Boolean type effectively bit packed. | +| | `Utf8` | String data (this is actually `Arrow` `LargeUtf8` internally). | +| | `Binary` | Store data as bytes. | +| | `Object` | A limited supported data type that can be any value. | +| | `Categorical` | A categorical encoding of a set of strings. | + +To learn more about the internal representation of these data types, check the [`Arrow` columnar format](https://arrow.apache.org/docs/format/Columnar.html). diff --git a/docs/user-guide/concepts/expressions.md b/docs/user-guide/concepts/expressions.md new file mode 100644 index 000000000000..b276c494a4a3 --- /dev/null +++ b/docs/user-guide/concepts/expressions.md @@ -0,0 +1,49 @@ +# Expressions + +`Polars` has a powerful concept called expressions that is central to its very fast performance. + +Expressions are at the core of many data science operations: + +- taking a sample of rows from a column +- multiplying values in a column +- extracting a column of years from dates +- convert a column of strings to lowercase +- and so on! + +However, expressions are also used within other operations: + +- taking the mean of a group in a `group_by` operation +- calculating the size of groups in a `group_by` operation +- taking the sum horizontally across columns + +`Polars` performs these core data transformations very quickly by: + +- automatic query optimization on each expression +- automatic parallelization of expressions on many columns + +Polars expressions are a mapping from a series to a series (or mathematically `Fn(Series) -> Series`). As expressions have a `Series` as an input and a `Series` as an output then it is straightforward to do a sequence of expressions (similar to method chaining in `Pandas`). + +## Examples + +The following is an expression: + +{{code_block('user-guide/concepts/expressions','example1',['col','sort','head'])}} + +The snippet above says: + +1. Select column "foo" +1. Then sort the column (not in reversed order) +1. Then take the first two values of the sorted output + +The power of expressions is that every expression produces a new expression, and that they +can be _piped_ together. You can run an expression by passing them to one of `Polars` execution contexts. + +Here we run two expressions by running `df.select`: + +{{code_block('user-guide/concepts/expressions','example2',['select'])}} + +All expressions are run in parallel, meaning that separate `Polars` expressions are **embarrassingly parallel**. Note that within an expression there may be more parallelization going on. + +## Conclusion + +This is the tip of the iceberg in terms of possible expressions. There are a ton more, and they can be combined in a variety of ways. This page is intended to get you familiar with the concept of expressions, in the section on [expressions](../expressions/operators.md) we will dive deeper. diff --git a/docs/user-guide/concepts/lazy-vs-eager.md b/docs/user-guide/concepts/lazy-vs-eager.md new file mode 100644 index 000000000000..1b84a0272aa5 --- /dev/null +++ b/docs/user-guide/concepts/lazy-vs-eager.md @@ -0,0 +1,28 @@ +# Lazy / eager API + +`Polars` supports two modes of operation: lazy and eager. In the eager API the query is executed immediately while in the lazy API the query is only evaluated once it is 'needed'. Deferring the execution to the last minute can have significant performance advantages that is why the Lazy API is preferred in most cases. Let us demonstrate this with an example: + +{{code_block('user-guide/concepts/lazy-vs-eager','eager',['read_csv'])}} + +In this example we use the eager API to: + +1. Read the iris [dataset](https://archive.ics.uci.edu/ml/datasets/iris). +1. Filter the dataset based on sepal length +1. Calculate the mean of the sepal width per species + +Every step is executed immediately returning the intermediate results. This can be very wasteful as we might do work or load extra data that is not being used. If we instead used the lazy API and waited on execution until all the steps are defined then the query planner could perform various optimizations. In this case: + +- Predicate pushdown: Apply filters as early as possible while reading the dataset, thus only reading rows with sepal length greater than 5. +- Projection pushdown: Select only the columns that are needed while reading the dataset, thus removing the need to load additional columns (e.g. petal length & petal width) + +{{code_block('user-guide/concepts/lazy-vs-eager','lazy',['scan_csv'])}} + +These will significantly lower the load on memory & CPU thus allowing you to fit bigger datasets in memory and process faster. Once the query is defined you call `collect` to inform `Polars` that you want to execute it. In the section on Lazy API we will go into more details on its implementation. + +!!! info "Eager API" + + In many cases the eager API is actually calling the lazy API under the hood and immediately collecting the result. This has the benefit that within the query itself optimization(s) made by the query planner can still take place. + +### When to use which + +In general the lazy API should be preferred unless you are either interested in the intermediate results or are doing exploratory work and don't know yet what your query is going to look like. diff --git a/docs/user-guide/concepts/streaming.md b/docs/user-guide/concepts/streaming.md new file mode 100644 index 000000000000..e52e28bf2cfe --- /dev/null +++ b/docs/user-guide/concepts/streaming.md @@ -0,0 +1,21 @@ +# Streaming API + +One additional benefit of the lazy API is that it allows queries to be executed in a streaming manner. Instead of processing the data all-at-once `Polars` can execute the query in batches allowing you to process datasets that are larger-than-memory. + +To tell Polars we want to execute a query in streaming mode we pass the `streaming=True` argument to `collect` + +{{code_block('user-guide/concepts/streaming','streaming',['collect'])}} + +## When is streaming available? + +Streaming is still in development. We can ask Polars to execute any lazy query in streaming mode. However, not all lazy operations support streaming. If there is an operation for which streaming is not supported Polars will run the query in non-streaming mode. + +Streaming is supported for many operations including: + +- `filter`,`slice`,`head`,`tail` +- `with_columns`,`select` +- `group_by` +- `join` +- `sort` +- `explode`,`melt` +- `scan_csv`,`scan_parquet`,`scan_ipc` diff --git a/docs/user-guide/expressions/aggregation.md b/docs/user-guide/expressions/aggregation.md new file mode 100644 index 000000000000..6b5fb8bcaf48 --- /dev/null +++ b/docs/user-guide/expressions/aggregation.md @@ -0,0 +1,122 @@ +# Aggregation + +`Polars` implements a powerful syntax defined not only in its lazy API, but also in its eager API. Let's take a look at what that means. + +We can start with the simple [US congress `dataset`](https://github.com/unitedstates/congress-legislators). + +{{code_block('user-guide/expressions/aggregation','dataframe',['DataFrame','Categorical'])}} + +#### Basic aggregations + +You can easily combine different aggregations by adding multiple expressions in a +`list`. There is no upper bound on the number of aggregations you can do, and you can +make any combination you want. In the snippet below we do the following aggregations: + +Per GROUP `"first_name"` we + +- count the number of rows in the group: + - short form: `pl.count("party")` + - full form: `pl.col("party").count()` +- aggregate the gender values groups: + - full form: `pl.col("gender")` +- get the first value of column `"last_name"` in the group: + - short form: `pl.first("last_name")` (not available in Rust) + - full form: `pl.col("last_name").first()` + +Besides the aggregation, we immediately sort the result and limit to the top `5` so that +we have a nice summary overview. + +{{code_block('user-guide/expressions/aggregation','basic',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:setup" +--8<-- "python/user-guide/expressions/aggregation.py:dataframe" +--8<-- "python/user-guide/expressions/aggregation.py:basic" +``` + +#### Conditionals + +It's that easy! Let's turn it up a notch. Let's say we want to know how +many delegates of a "state" are "Pro" or "Anti" administration. We could directly query +that in the aggregation without the need of a `lambda` or grooming the `DataFrame`. + +{{code_block('user-guide/expressions/aggregation','conditional',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:conditional" +``` + +Similarly, this could also be done with a nested GROUP BY, but that doesn't help show off some of these nice features. 😉 + +{{code_block('user-guide/expressions/aggregation','nested',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:nested" +``` + +#### Filtering + +We can also filter the groups. Let's say we want to compute a mean per group, but we +don't want to include all values from that group, and we also don't want to filter the +rows from the `DataFrame` (because we need those rows for another aggregation). + +In the example below we show how this can be done. + +!!! note + + Note that we can make `Python` functions for clarity. These functions don't cost us anything. That is because we only create `Polars` expressions, we don't apply a custom function over a `Series` during runtime of the query. Of course, you can make functions that return expressions in Rust, too. + +{{code_block('user-guide/expressions/aggregation','filter',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:filter" +``` + +#### Sorting + +It's common to see a `DataFrame` being sorted for the sole purpose of managing the ordering during a GROUP BY operation. Let's say that we want to get the names of the oldest and youngest politicians per state. We could SORT and GROUP BY. + +{{code_block('user-guide/expressions/aggregation','sort',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort" +``` + +However, **if** we also want to sort the names alphabetically, this breaks. Luckily we can sort in a `group_by` context separate from the `DataFrame`. + +{{code_block('user-guide/expressions/aggregation','sort2',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort2" +``` + +We can even sort by another column in the `group_by` context. If we want to know if the alphabetically sorted name is male or female we could add: `pl.col("gender").sort_by("first_name").first().alias("gender")` + +{{code_block('user-guide/expressions/aggregation','sort3',['group_by'])}} + +```python exec="on" result="text" session="user-guide/expressions" +--8<-- "python/user-guide/expressions/aggregation.py:sort3" +``` + +### Do not kill parallelization + +!!! warning "Python Users Only" + + The following section is specific to `Python`, and doesn't apply to `Rust`. Within `Rust`, blocks and closures (lambdas) can, and will, be executed concurrently. + +We have all heard that `Python` is slow, and does "not scale." Besides the overhead of +running "slow" bytecode, `Python` has to remain within the constraints of the Global +Interpreter Lock (GIL). This means that if you were to use a `lambda` or a custom `Python` +function to apply during a parallelized phase, `Polars` speed is capped running `Python` +code preventing any multiple threads from executing the function. + +This all feels terribly limiting, especially because we often need those `lambda` functions in a +`.group_by()` step, for example. This approach is still supported by `Polars`, but +keeping in mind bytecode **and** the GIL costs have to be paid. It is recommended to try to solve your queries using the expression syntax before moving to `lambdas`. If you want to learn more about using `lambdas`, go to the [user defined functions section](./user-defined-functions.md). + +### Conclusion + +In the examples above we've seen that we can do a lot by combining expressions. By doing so we delay the use of custom `Python` functions that slow down the queries (by the slow nature of Python AND the GIL). + +If we are missing a type expression let us know by opening a +[feature request](https://github.com/pola-rs/polars/issues/new/choose)! diff --git a/docs/user-guide/expressions/casting.md b/docs/user-guide/expressions/casting.md new file mode 100644 index 000000000000..cb06699fa2ed --- /dev/null +++ b/docs/user-guide/expressions/casting.md @@ -0,0 +1,100 @@ +# Casting + +Casting converts the underlying [`DataType`](../concepts/data-types.md) of a column to a new one. Polars uses Arrow to manage the data in memory and relies on the compute kernels in the [rust implementation](https://github.com/jorgecarleitao/arrow2) to do the conversion. Casting is available with the `cast()` method. + +The `cast` method includes a `strict` parameter that determines how Polars behaves when it encounters a value that can't be converted from the source `DataType` to the target `DataType`. By default, `strict=True`, which means that Polars will throw an error to notify the user of the failed conversion and provide details on the values that couldn't be cast. On the other hand, if `strict=False`, any values that can't be converted to the target `DataType` will be quietly converted to `null`. + +## Numerics + +Let's take a look at the following `DataFrame` which contains both integers and floating point numbers. + +{{code_block('user-guide/expressions/casting','dfnum',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:setup" +--8<-- "python/user-guide/expressions/casting.py:dfnum" +``` + +To perform casting operations between floats and integers, or vice versa, we can invoke the `cast()` function. + +{{code_block('user-guide/expressions/casting','castnum',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:castnum" +``` + +Note that in the case of decimal values these are rounded downwards when casting to an integer. + +##### Downcast + +Reducing the memory footprint is also achievable by modifying the number of bits allocated to an element. As an illustration, the code below demonstrates how casting from `Int64` to `Int16` and from `Float64` to `Float32` can be used to lower memory usage. + +{{code_block('user-guide/expressions/casting','downcast',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:downcast" +``` + +#### Overflow + +When performing downcasting, it is crucial to ensure that the chosen number of bits (such as 64, 32, or 16) is sufficient to accommodate the largest and smallest numbers in the column. For example, using a 32-bit signed integer (`Int32`) allows handling integers within the range of -2147483648 to +2147483647, while using `Int8` covers integers between -128 to 127. Attempting to cast to a `DataType` that is too small will result in a `ComputeError` thrown by Polars, as the operation is not supported. + +{{code_block('user-guide/expressions/casting','overflow',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:overflow" +``` + +You can set the `strict` parameter to `False`, this converts values that are overflowing to null values. + +{{code_block('user-guide/expressions/casting','overflow2',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:overflow2" +``` + +## Strings + +Strings can be casted to numerical data types and vice versa: + +{{code_block('user-guide/expressions/casting','strings',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:strings" +``` + +In case the column contains a non-numerical value, Polars will throw a `ComputeError` detailing the conversion error. Setting `strict=False` will convert the non float value to `null`. + +{{code_block('user-guide/expressions/casting','strings2',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:strings2" +``` + +## Booleans + +Booleans can be expressed as either 1 (`True`) or 0 (`False`). It's possible to perform casting operations between a numerical `DataType` and a boolean, and vice versa. However, keep in mind that casting from a string (`Utf8`) to a boolean is not permitted. + +{{code_block('user-guide/expressions/casting','bool',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:bool" +``` + +## Dates + +Temporal data types such as `Date` or `Datetime` are represented as the number of days (`Date`) and microseconds (`Datetime`) since epoch. Therefore, casting between the numerical types and the temporal data types is allowed. + +{{code_block('user-guide/expressions/casting','dates',['cast'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:dates" +``` + +To perform casting operations between strings and `Dates`/`Datetimes`, `strftime` and `strptime` are utilized. Polars adopts the [chrono format syntax](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) for when formatting. It's worth noting that `strptime` features additional options that support timezone functionality. Refer to the API documentation for further information. + +{{code_block('user-guide/expressions/casting','dates2',['strftime','strptime'])}} + +```python exec="on" result="text" session="user-guide/cast" +--8<-- "python/user-guide/expressions/casting.py:dates2" +``` diff --git a/docs/user-guide/expressions/column-selections.md b/docs/user-guide/expressions/column-selections.md new file mode 100644 index 000000000000..0f6b1a82f018 --- /dev/null +++ b/docs/user-guide/expressions/column-selections.md @@ -0,0 +1,134 @@ +# Column selections + +Let's create a dataset to use in this section: + +{{code_block('user-guide/expressions/column-selections','selectors_df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:setup" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_df" +``` + +## Expression expansion + +As we've seen in the previous section, we can select specific columns using the `pl.col` method. It can also select multiple columns - both as a means of convenience, and to _expand_ the expression. + +This kind of convenience feature isn't just decorative or syntactic sugar. It allows for a very powerful application of [DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself) principles in your code: a single expression that specifies multiple columns expands into a list of expressions (depending on the DataFrame schema), resulting in being able to select multiple columns + run computation on them! + +### Select all, or all but some + +We can select all columns in the `DataFrame` object by providing the argument `*`: + +{{code_block('user-guide/expressions/column-selections', 'all',['all'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:all" +``` + +Often, we don't just want to include all columns, but include all _while_ excluding a few. This can be done easily as well: + +{{code_block('user-guide/expressions/column-selections','exclude',['exclude'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:exclude" +``` + +### By multiple strings + +Specifying multiple strings allows expressions to _expand_ to all matching columns: + +{{code_block('user-guide/expressions/column-selections','expansion_by_names',['dt_to_string'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:expansion_by_names" +``` + +### By regular expressions + +Multiple column selection is possible by regular expressions also, by making sure to wrap the regex by `^` and `$` to let `pl.col` know that a regex selection is expected: + +{{code_block('user-guide/expressions/column-selections','expansion_by_regex',[''])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:expansion_by_regex" +``` + +### By data type + +`pl.col` can select multiple columns using Polars data types: + +{{code_block('user-guide/expressions/column-selections','expansion_by_dtype',['n_unique'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:expansion_by_dtype" +``` + +## Using `selectors` + +Polars also allows for the use of intuitive selections for columns based on their name, `dtype` or other properties; and this is built on top of existing functionality outlined in `col` used above. It is recommended to use them by importing and aliasing `polars.selectors` as `cs`. + +### By `dtype` + +To select just the integer and string columns, we can do: + +{{code_block('user-guide/expressions/column-selections','selectors_intro',['selectors'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_intro" +``` + +### Applying set operations + +These _selectors_ also allow for set based selection operations. For instance, to select the **numeric** columns **except** the **first** column that indicates row numbers: + +{{code_block('user-guide/expressions/column-selections','selectors_diff',['cs_first', 'cs_numeric'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_diff" +``` + +We can also select the row number by name **and** any **non**-numeric columns: + +{{code_block('user-guide/expressions/column-selections','selectors_union',['cs_by_name', 'cs_numeric'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_union" +``` + +### By patterns and substrings + +_Selectors_ can also be matched by substring and regex patterns: + +{{code_block('user-guide/expressions/column-selections','selectors_by_name',['cs_contains', 'cs_matches'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_by_name" +``` + +### Converting to expressions + +What if we want to apply a specific operation on the selected columns (i.e. get back to representing them as **expressions** to operate upon)? We can simply convert them using `as_expr` and then proceed as normal: + +{{code_block('user-guide/expressions/column-selections','selectors_to_expr',['cs_temporal'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_to_expr" +``` + +### Debugging `selectors` + +Polars also provides two helpful utility functions to aid with using selectors: `is_selector` and `selector_column_names`: + +{{code_block('user-guide/expressions/column-selections','selectors_is_selector_utility',['is_selector'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_is_selector_utility" +``` + +To predetermine the column names that are selected, which is especially useful for a LazyFrame object: + +{{code_block('user-guide/expressions/column-selections','selectors_colnames_utility',['selector_column_names'])}} + +```python exec="on" result="text" session="user-guide/column-selections" +--8<-- "python/user-guide/expressions/column-selections.py:selectors_colnames_utility" +``` diff --git a/docs/user-guide/expressions/folds.md b/docs/user-guide/expressions/folds.md new file mode 100644 index 000000000000..2339f8f114e5 --- /dev/null +++ b/docs/user-guide/expressions/folds.md @@ -0,0 +1,43 @@ +# Folds + +`Polars` provides expressions/methods for horizontal aggregations like `sum`,`min`, `mean`, +etc. However, when you need a more complex aggregation the default methods `Polars` supplies may not be sufficient. That's when `folds` come in handy. + +The `fold` expression operates on columns for maximum speed. It utilizes the data layout very efficiently and often has vectorized execution. + +### Manual sum + +Let's start with an example by implementing the `sum` operation ourselves, with a `fold`. + +{{code_block('user-guide/expressions/folds','mansum',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:setup" +--8<-- "python/user-guide/expressions/folds.py:mansum" +``` + +The snippet above recursively applies the function `f(acc, x) -> acc` to an accumulator `acc` and a new column `x`. The function operates on columns individually and can take advantage of cache efficiency and vectorization. + +### Conditional + +In the case where you'd want to apply a condition/predicate on all columns in a `DataFrame` a `fold` operation can be a very concise way to express this. + +{{code_block('user-guide/expressions/folds','conditional',['fold'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:conditional" +``` + +In the snippet we filter all rows where **each** column value is `> 1`. + +### Folds and string data + +Folds could be used to concatenate string data. However, due to the materialization of intermediate columns, this operation will have squared complexity. + +Therefore, we recommend using the `concat_str` expression for this. + +{{code_block('user-guide/expressions/folds','string',['concat_str'])}} + +```python exec="on" result="text" session="user-guide/folds" +--8<-- "python/user-guide/expressions/folds.py:string" +``` diff --git a/docs/user-guide/expressions/functions.md b/docs/user-guide/expressions/functions.md new file mode 100644 index 000000000000..fde219cb25dd --- /dev/null +++ b/docs/user-guide/expressions/functions.md @@ -0,0 +1,65 @@ +# Functions + +`Polars` expressions have a large number of built in functions. These allow you to create complex queries without the need for [user defined functions](user-defined-functions.md). There are too many to go through here, but we will cover some of the more popular use cases. If you want to view all the functions go to the API Reference for your programming language. + +In the examples below we will use the following `DataFrame`: + +{{code_block('user-guide/expressions/functions','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:setup" +--8<-- "python/user-guide/expressions/functions.py:dataframe" +``` + +## Column naming + +By default if you perform an expression it will keep the same name as the original column. In the example below we perform an expression on the `nrs` column. Note that the output `DataFrame` still has the same name. + +{{code_block('user-guide/expressions/functions','samename',[])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:samename" +``` + +This might get problematic in the case you use the same column multiple times in your expression as the output columns will get duplicated. For example, the following query will fail. + +{{code_block('user-guide/expressions/functions','samenametwice',[])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:samenametwice" +``` + +You can change the output name of an expression by using the `alias` function + +{{code_block('user-guide/expressions/functions','samenamealias',['alias'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:samenamealias" +``` + +In case of multiple columns for example when using `all()` or `col(*)` you can apply a mapping function `map_alias` to change the original column name into something else. In case you want to add a suffix (`suffix()`) or prefix (`prefix()`) these are also built in. + +=== ":fontawesome-brands-python: Python" +[:material-api: `prefix`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.prefix.html) +[:material-api: `suffix`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.suffix.html) +[:material-api: `map_alias`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.map_alias.html) + +## Count unique values + +There are two ways to count unique values in `Polars`: an exact methodology and an approximation. The approximation uses the [HyperLogLog++](https://en.wikipedia.org/wiki/HyperLogLog) algorithm to approximate the cardinality and is especially useful for very large datasets where an approximation is good enough. + +{{code_block('user-guide/expressions/functions','countunique',['n_unique','approx_n_unique'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:countunique" +``` + +## Conditionals + +`Polars` supports if-else like conditions in expressions with the `when`, `then`, `otherwise` syntax. The predicate is placed in the `when` clause and when this evaluates to `true` the `then` expression is applied otherwise the `otherwise` expression is applied (row-wise). + +{{code_block('user-guide/expressions/functions','conditional',['when'])}} + +```python exec="on" result="text" session="user-guide/functions" +--8<-- "python/user-guide/expressions/functions.py:conditional" +``` diff --git a/docs/user-guide/expressions/lists.md b/docs/user-guide/expressions/lists.md new file mode 100644 index 000000000000..b7c508f11b90 --- /dev/null +++ b/docs/user-guide/expressions/lists.md @@ -0,0 +1,119 @@ +# Lists and Arrays + +`Polars` has first-class support for `List` columns: that is, columns where each row is a list of homogeneous elements, of varying lengths. `Polars` also has an `Array` datatype, which is analogous to `numpy`'s `ndarray` objects, where the length is identical across rows. + +Note: this is different from Python's `list` object, where the elements can be of any type. Polars can store these within columns, but as a generic `Object` datatype that doesn't have the special list manipulation features that we're about to discuss. + +## Powerful `List` manipulation + +Let's say we had the following data from different weather stations across a state. When the weather station is unable to get a result, an error code is recorded instead of the actual temperature at that time. + +{{code_block('user-guide/expressions/lists','weather_df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:setup" +--8<-- "python/user-guide/expressions/lists.py:weather_df" +``` + +### Creating a `List` column + +For the `weather` `DataFrame` created above, it's very likely we need to run some analysis on the temperatures that are captured by each station. To make this happen, we need to first be able to get individual temperature measurements. This is done by: + +{{code_block('user-guide/expressions/lists','string_to_list',['str.split'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:string_to_list" +``` + +One way we could go post this would be to convert each temperature measurement into its own row: + +{{code_block('user-guide/expressions/lists','explode_to_atomic',['DataFrame.explode'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:explode_to_atomic" +``` + +However, in Polars, we often do not need to do this to operate on the `List` elements. + +### Operating on `List` columns + +Polars provides several standard operations on `List` columns. If we want the first three measurements, we can do a `head(3)`. The last three can be obtained via a `tail(3)`, or alternately, via `slice` (negative indexing is supported). We can also identify the number of observations via `lengths`. Let's see them in action: + +{{code_block('user-guide/expressions/lists','list_ops',['Expr.List'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:list_ops" +``` + +!!! warning "`arr` then, `list` now" + + If you find references to the `arr` API on Stackoverflow or other sources, just replace `arr` with `list`, this was the old accessor for the `List` datatype. `arr` now refers to the newly introduced `Array` datatype (see below). + +### Element-wise computation within `List`s + +If we need to identify the stations that are giving the most number of errors from the starting `DataFrame`, we need to: + +1. Parse the string input as a `List` of string values (already done). +2. Identify those strings that can be converted to numbers. +3. Identify the number of non-numeric values (i.e. `null` values) in the list, by row. +4. Rename this output as `errors` so that we can easily identify the stations. + +The third step requires a casting (or alternately, a regex pattern search) operation to be perform on each element of the list. We can do this using by applying the operation on each element by first referencing them in the `pl.element()` context, and then calling a suitable Polars expression on them. Let's see how: + +{{code_block('user-guide/expressions/lists','count_errors',['Expr.List', 'element'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:count_errors" +``` + +What if we chose the regex route (i.e. recognizing the presence of _any_ alphabetical character?) + +{{code_block('user-guide/expressions/lists','count_errors_regex',['str.contains'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:count_errors_regex" +``` + +If you're unfamiliar with the `(?i)`, it's a good time to look at the documentation for the `str.contains` function in Polars! The rust regex crate provides a lot of additional regex flags that might come in handy. + +## Row-wise computations + +This context is ideal for computing in row orientation. + +We can apply **any** Polars operations on the elements of the list with the `list.eval` (`list().eval` in Rust) expression! These expressions run entirely on Polars' query engine and can run in parallel, so will be well optimized. Let's say we have another set of weather data across three days, for different stations: + +{{code_block('user-guide/expressions/lists','weather_by_day',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:weather_by_day" +``` + +Let's do something interesting, where we calculate the percentage rank of the temperatures by day, measured across stations. Pandas allows you to compute the percentages of the `rank` values. `Polars` doesn't provide a special function to do this directly, but because expressions are so versatile we can create our own percentage rank expression for highest temperature. Let's try that! + +{{code_block('user-guide/expressions/lists','weather_by_day_rank',['list.eval'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:weather_by_day_rank" +``` + +## Polars `Array`s + +`Array`s are a new data type that was recently introduced, and are still pretty nascent in features that it offers. The major difference between a `List` and an `Array` is that the latter is limited to having the same number of elements per row, while a `List` can have a variable number of elements. Both still require that each element's data type is the same. + +We can define `Array` columns in this manner: + +{{code_block('user-guide/expressions/lists','array_df',['Array'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:array_df" +``` + +Basic operations are available on it: + +{{code_block('user-guide/expressions/lists','array_ops',['arr'])}} + +```python exec="on" result="text" session="user-guide/lists" +--8<-- "python/user-guide/expressions/lists.py:array_ops" +``` + +Polars `Array`s are still being actively developed, so this section will likely change in the future. diff --git a/docs/user-guide/expressions/null.md b/docs/user-guide/expressions/null.md new file mode 100644 index 000000000000..5ded317ac2b5 --- /dev/null +++ b/docs/user-guide/expressions/null.md @@ -0,0 +1,140 @@ +# Missing data + +This page sets out how missing data is represented in `Polars` and how missing data can be filled. + +## `null` and `NaN` values + +Each column in a `DataFrame` (or equivalently a `Series`) is an Arrow array or a collection of Arrow arrays [based on the Apache Arrow format](https://arrow.apache.org/docs/format/Columnar.html#null-count). Missing data is represented in Arrow and `Polars` with a `null` value. This `null` missing value applies for all data types including numerical values. + +`Polars` also allows `NotaNumber` or `NaN` values for float columns. These `NaN` values are considered to be a type of floating point data rather than missing data. We discuss `NaN` values separately below. + +You can manually define a missing value with the python `None` value: + +{{code_block('user-guide/expressions/null','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:setup" +--8<-- "python/user-guide/expressions/null.py:dataframe" +``` + +!!! info + + In `Pandas` the value for missing data depends on the dtype of the column. In `Polars` missing data is always represented as a `null` value. + +## Missing data metadata + +Each Arrow array used by `Polars` stores two kinds of metadata related to missing data. This metadata allows `Polars` to quickly show how many missing values there are and which values are missing. + +The first piece of metadata is the `null_count` - this is the number of rows with `null` values in the column: + +{{code_block('user-guide/expressions/null','count',['null_count'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:count" +``` + +The `null_count` method can be called on a `DataFrame`, a column from a `DataFrame` or a `Series`. The `null_count` method is a cheap operation as `null_count` is already calculated for the underlying Arrow array. + +The second piece of metadata is an array called a _validity bitmap_ that indicates whether each data value is valid or missing. +The validity bitmap is memory efficient as it is bit encoded - each value is either a 0 or a 1. This bit encoding means the memory overhead per array is only (array length / 8) bytes. The validity bitmap is used by the `is_null` method in `Polars`. + +You can return a `Series` based on the validity bitmap for a column in a `DataFrame` or a `Series` with the `is_null` method: + +{{code_block('user-guide/expressions/null','isnull',['is_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:isnull" +``` + +The `is_null` method is a cheap operation that does not require scanning the full column for `null` values. This is because the validity bitmap already exists and can be returned as a Boolean array. + +## Filling missing data + +Missing data in a `Series` can be filled with the `fill_null` method. You have to specify how you want the `fill_null` method to fill the missing data. The main ways to do this are filling with: + +- a literal such as 0 or "0" +- a strategy such as filling forwards +- an expression such as replacing with values from another column +- interpolation + +We illustrate each way to fill nulls by defining a simple `DataFrame` with a missing value in `col2`: + +{{code_block('user-guide/expressions/null','dataframe2',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:dataframe2" +``` + +### Fill with specified literal value + +We can fill the missing data with a specified literal value with `pl.lit`: + +{{code_block('user-guide/expressions/null','fill',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fill" +``` + +### Fill with a strategy + +We can fill the missing data with a strategy such as filling forward: + +{{code_block('user-guide/expressions/null','fillstrategy',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fillstrategy" +``` + +You can find other fill strategies in the API docs. + +### Fill with an expression + +For more flexibility we can fill the missing data with an expression. For example, +to fill nulls with the median value from that column: + +{{code_block('user-guide/expressions/null','fillexpr',['fill_null'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fillexpr" +``` + +In this case the column is cast from integer to float because the median is a float statistic. + +### Fill with interpolation + +In addition, we can fill nulls with interpolation (without using the `fill_null` function): + +{{code_block('user-guide/expressions/null','fillinterpolate',['interpolate'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:fillinterpolate" +``` + +## `NotaNumber` or `NaN` values + +Missing data in a `Series` has a `null` value. However, you can use `NotaNumber` or `NaN` values in columns with float datatypes. These `NaN` values can be created from Numpy's `np.nan` or the native python `float('nan')`: + +{{code_block('user-guide/expressions/null','nan',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:nan" +``` + +!!! info + + In `Pandas` by default a `NaN` value in an integer column causes the column to be cast to float. This does not happen in `Polars` - instead an exception is raised. + +`NaN` values are considered to be a type of floating point data and are **not considered to be missing data** in `Polars`. This means: + +- `NaN` values are **not** counted with the `null_count` method +- `NaN` values are filled when you use `fill_nan` method but are **not** filled with the `fill_null` method + +`Polars` has `is_nan` and `fill_nan` methods which work in a similar way to the `is_null` and `fill_null` methods. The underlying Arrow arrays do not have a pre-computed validity bitmask for `NaN` values so this has to be computed for the `is_nan` method. + +One further difference between `null` and `NaN` values is that taking the `mean` of a column with `null` values excludes the `null` values from the calculation but with `NaN` values taking the mean results in a `NaN`. This behaviour can be avoided by replacing the `NaN` values with `null` values; + +{{code_block('user-guide/expressions/null','nanfill',['fill_nan'])}} + +```python exec="on" result="text" session="user-guide/null" +--8<-- "python/user-guide/expressions/null.py:nanfill" +``` diff --git a/docs/user-guide/expressions/numpy.md b/docs/user-guide/expressions/numpy.md new file mode 100644 index 000000000000..6449ffd634bf --- /dev/null +++ b/docs/user-guide/expressions/numpy.md @@ -0,0 +1,22 @@ +# Numpy + +`Polars` expressions support `NumPy` [ufuncs](https://numpy.org/doc/stable/reference/ufuncs.html). See [here](https://numpy.org/doc/stable/reference/ufuncs.html#available-ufuncs) +for a list on all supported numpy functions. + +This means that if a function is not provided by `Polars`, we can use `NumPy` and we still have fast columnar operation through the `NumPy` API. + +### Example + +{{code_block('user-guide/expressions/numpy-example',api_functions=['DataFrame','np.log'])}} + +```python exec="on" result="text" session="user-guide/numpy" +--8<-- "python/user-guide/expressions/numpy-example.py" +``` + +### Interoperability + +Polars `Series` have support for NumPy universal functions (ufuncs). Element-wise functions such as `np.exp()`, `np.cos()`, `np.div()`, etc. all work with almost zero overhead. + +However, as a Polars-specific remark: missing values are a separate bitmask and are not visible by NumPy. This can lead to a window function or a `np.convolve()` giving flawed or incomplete results. + +Convert a Polars `Series` to a NumPy array with the `.to_numpy()` method. Missing values will be replaced by `np.nan` during the conversion. If the `Series` does not include missing values, or those values are not desired anymore, the `.view()` method can be used instead, providing a zero-copy NumPy array of the data. diff --git a/docs/user-guide/expressions/operators.md b/docs/user-guide/expressions/operators.md new file mode 100644 index 000000000000..24cb4e6834b8 --- /dev/null +++ b/docs/user-guide/expressions/operators.md @@ -0,0 +1,30 @@ +# Basic operators + +This section describes how to use basic operators (e.g. addition, subtraction) in conjunction with Expressions. We will provide various examples using different themes in the context of the following dataframe. + +!!! note Operator Overloading + + In Rust and Python it is possible to use the operators directly (as in `+ - * / < > `) as the language allows operator overloading. For instance, the operator `+` translates to the `.add()` method. You can choose the one you prefer. + +{{code_block('user-guide/expressions/operators','dataframe',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/operators" +--8<-- "python/user-guide/expressions/operators.py:setup" +--8<-- "python/user-guide/expressions/operators.py:dataframe" +``` + +### Numerical + +{{code_block('user-guide/expressions/operators','numerical',['operators'])}} + +```python exec="on" result="text" session="user-guide/operators" +--8<-- "python/user-guide/expressions/operators.py:numerical" +``` + +### Logical + +{{code_block('user-guide/expressions/operators','logical',['operators'])}} + +```python exec="on" result="text" session="user-guide/operators" +--8<-- "python/user-guide/expressions/operators.py:logical" +``` diff --git a/docs/user-guide/expressions/strings.md b/docs/user-guide/expressions/strings.md new file mode 100644 index 000000000000..ccb06de30f20 --- /dev/null +++ b/docs/user-guide/expressions/strings.md @@ -0,0 +1,62 @@ +# Strings + +The following section discusses operations performed on `Utf8` strings, which are a frequently used `DataType` when working with `DataFrames`. However, processing strings can often be inefficient due to their unpredictable memory size, causing the CPU to access many random memory locations. To address this issue, Polars utilizes `Arrow` as its backend, which stores all strings in a contiguous block of memory. As a result, string traversal is cache-optimal and predictable for the CPU. + +String processing functions are available in the `str` namespace. + +##### Accessing the string namespace + +The `str` namespace can be accessed through the `.str` attribute of a column with `Utf8` data type. In the following example, we create a column named `animal` and compute the length of each element in the column in terms of the number of bytes and the number of characters. If you are working with ASCII text, then the results of these two computations will be the same, and using `lengths` is recommended since it is faster. + +{{code_block('user-guide/expressions/strings','df',['lengths','n_chars'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:setup" +--8<-- "python/user-guide/expressions/strings.py:df" +``` + +#### String parsing + +`Polars` offers multiple methods for checking and parsing elements of a string. Firstly, we can use the `contains` method to check whether a given pattern exists within a substring. Subsequently, we can extract these patterns and replace them using other methods, which will be demonstrated in upcoming examples. + +##### Check for existence of a pattern + +To check for the presence of a pattern within a string, we can use the contains method. The `contains` method accepts either a regular substring or a regex pattern, depending on the value of the `literal` parameter. If the pattern we're searching for is a simple substring located either at the beginning or end of the string, we can alternatively use the `starts_with` and `ends_with` functions. + +{{code_block('user-guide/expressions/strings','existence',['str.contains', 'starts_with','ends_with'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:existence" +``` + +##### Extract a pattern + +The `extract` method allows us to extract a pattern from a specified string. This method takes a regex pattern containing one or more capture groups, which are defined by parentheses `()` in the pattern. The group index indicates which capture group to output. + +{{code_block('user-guide/expressions/strings','extract',['extract'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:extract" +``` + +To extract all occurrences of a pattern within a string, we can use the `extract_all` method. In the example below, we extract all numbers from a string using the regex pattern `(\d+)`, which matches one or more digits. The resulting output of the `extract_all` method is a list containing all instances of the matched pattern within the string. + +{{code_block('user-guide/expressions/strings','extract_all',['extract_all'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:extract_all" +``` + +##### Replace a pattern + +We have discussed two methods for pattern matching and extraction thus far, and now we will explore how to replace a pattern within a string. Similar to `extract` and `extract_all`, Polars provides the `replace` and `replace_all` methods for this purpose. In the example below we replace one match of `abc` at the end of a word (`\b`) by `ABC` and we replace all occurrence of `a` with `-`. + +{{code_block('user-guide/expressions/strings','replace',['replace','replace_all'])}} + +```python exec="on" result="text" session="user-guide/strings" +--8<-- "python/user-guide/expressions/strings.py:replace" +``` + +#### API documentation + +In addition to the examples covered above, Polars offers various other string manipulation methods for tasks such as formatting, stripping, splitting, and more. To explore these additional methods, you can go to the API documentation of your chosen programming language for Polars. diff --git a/docs/user-guide/expressions/structs.md b/docs/user-guide/expressions/structs.md new file mode 100644 index 000000000000..9973e61d4c68 --- /dev/null +++ b/docs/user-guide/expressions/structs.md @@ -0,0 +1,99 @@ +# The Struct datatype + +Polars `Struct`s are the idiomatic way of working with multiple columns. It is also a free operation i.e. moving columns into `Struct`s does not copy any data! + +For this section, let's start with a `DataFrame` that captures the average rating of a few movies across some states in the U.S.: + +{{code_block('user-guide/expressions/structs','ratings_df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:setup" +--8<-- "python/user-guide/expressions/structs.py:ratings_df" +``` + +## Encountering the `Struct` type + +A common operation that will lead to a `Struct` column is the ever so popular `value_counts` function that is commonly used in exploratory data analysis. Checking the number of times a state appears the data will be done as so: + +{{code_block('user-guide/expressions/structs','state_value_counts',['value_counts'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:state_value_counts" +``` + +Quite unexpected an output, especially if coming from tools that do not have such a data type. We're not in peril though, to get back to a more familiar output, all we need to do is `unnest` the `Struct` column into its constituent columns: + +{{code_block('user-guide/expressions/structs','struct_unnest',['unnest'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_unnest" +``` + +!!! note "Why `value_counts` returns a `Struct`" + + Polars expressions always have a `Fn(Series) -> Series` signature and `Struct` is thus the data type that allows us to provide multiple columns as input/ouput of an expression. In other words, all expressions have to return a `Series` object, and `Struct` allows us to stay consistent with that requirement. + +## Structs as `dict`s + +Polars will interpret a `dict` sent to the `Series` constructor as a `Struct`: + +{{code_block('user-guide/expressions/structs','series_struct',['Series'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct" +``` + +!!! note "Constructing `Series` objects" + + Note that `Series` here was constructed with the `name` of the series in the begninng, followed by the `values`. Providing the latter first + is considered an anti-pattern in Polars, and must be avoided. + +### Extracting individual values of a `Struct` + +Let's say that we needed to obtain just the `movie` value in the `Series` that we created above. We can use the `field` method to do so: + +{{code_block('user-guide/expressions/structs','series_struct_extract',['field'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct_extract" +``` + +### Renaming individual keys of a `Struct` + +What if we need to rename individual `field`s of a `Struct` column? We first convert the `rating_Series` object to a `DataFrame` so that we can view the changes easily, and then use the `rename_fields` method: + +{{code_block('user-guide/expressions/structs','series_struct_rename',['rename_fields'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:series_struct_rename" +``` + +## Practical use-cases of `Struct` columns + +### Identifying duplicate rows + +Let's get back to the `ratings` data. We want to identify cases where there are duplicates at a `Movie` and `Theatre` level. This is where the `Struct` datatype shines: + +{{code_block('user-guide/expressions/structs','struct_duplicates',['is_duplicated', 'struct'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_duplicates" +``` + +We can identify the unique cases at this level also with `is_unique`! + +### Multi-column ranking + +Suppose, given that we know there are duplicates, we want to choose which rank gets a higher priority. We define _Count_ of ratings to be more important than the actual `Avg_Rating` themselves, and only use it to break a tie. We can then do: + +{{code_block('user-guide/expressions/structs','struct_ranking',['is_duplicated', 'struct'])}} + +```python exec="on" result="text" session="user-guide/structs" +--8<-- "python/user-guide/expressions/structs.py:struct_ranking" +``` + +That's a pretty complex set of requirements done very elegantly in Polars! + +### Using multi-column apply + +This was discussed in the previous section on _User Defined Functions_. diff --git a/docs/user-guide/expressions/user-defined-functions.md b/docs/user-guide/expressions/user-defined-functions.md new file mode 100644 index 000000000000..dd83cb13c382 --- /dev/null +++ b/docs/user-guide/expressions/user-defined-functions.md @@ -0,0 +1,187 @@ +# User-defined functions + +!!! warning "Not updated for Python Polars `0.19.0`" + + This section of the user guide still needs to be updated for the latest Polars release. + +You should be convinced by now that Polars expressions are so powerful and flexible that there is much less need for custom Python functions +than in other libraries. + +Still, you need to have the power to be able to pass an expression's state to a third party library or apply your black box function +over data in Polars. + +For this we provide the following expressions: + +- `map` +- `apply` + +## To `map` or to `apply`. + +These functions have an important distinction in how they operate and consequently what data they will pass to the user. + +A `map` passes the `Series` backed by the `expression` as is. + +`map` follows the same rules in both the `select` and the `group_by` context, this will +mean that the `Series` represents a column in a `DataFrame`. Note that in the `group_by` context, that column is not yet +aggregated! + +Use cases for `map` are for instance passing the `Series` in an expression to a third party library. Below we show how +we could use `map` to pass an expression column to a neural network model. + +=== ":fontawesome-brands-python: Python" +[:material-api: `map`](https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.map.html) + +```python +df.with_columns([ + pl.col("features").map(lambda s: MyNeuralNetwork.forward(s.to_numpy())).alias("activations") +]) +``` + +=== ":fontawesome-brands-rust: Rust" + +```rust +df.with_columns([ + col("features").map(|s| Ok(my_nn.forward(s))).alias("activations") +]) +``` + +Use cases for `map` in the `group_by` context are slim. They are only used for performance reasons, but can quite easily lead to incorrect results. Let me explain why. + +{{code_block('user-guide/expressions/user-defined-functions','dataframe',['map'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:setup" +--8<-- "python/user-guide/expressions/user-defined-functions.py:dataframe" +``` + +In the snippet above we group by the `"keys"` column. That means we have the following groups: + +```c +"a" -> [10, 7] +"b" -> [1] +``` + +If we would then apply a `shift` operation to the right, we'd expect: + +```c +"a" -> [null, 10] +"b" -> [null] +``` + +Now, let's print and see what we've got. + +```python +print(out) +``` + +``` +shape: (2, 3) +┌──────┬────────────┬──────────────────┐ +│ keys ┆ shift_map ┆ shift_expression │ +│ --- ┆ --- ┆ --- │ +│ str ┆ list[i64] ┆ list[i64] │ +╞══════╪════════════╪══════════════════╡ +│ a ┆ [null, 10] ┆ [null, 10] │ +│ b ┆ [7] ┆ [null] │ +└──────┴────────────┴──────────────────┘ +``` + +Ouch.. we clearly get the wrong results here. Group `"b"` even got a value from group `"a"` 😵. + +This went horribly wrong, because the `map` applies the function before we aggregate! So that means the whole column `[10, 7, 1`\] got shifted to `[null, 10, 7]` and was then aggregated. + +So my advice is to never use `map` in the `group_by` context unless you know you need it and know what you are doing. + +## To `apply` + +Luckily we can fix previous example with `apply`. `apply` works on the smallest logical elements for that operation. + +That is: + +- `select context` -> single elements +- `group by context` -> single groups + +So with `apply` we should be able to fix our example: + +{{code_block('user-guide/expressions/user-defined-functions','apply',['apply'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:apply" +``` + +And observe, a valid result! 🎉 + +## `apply` in the `select` context + +In the `select` context, the `apply` expression passes elements of the column to the python function. + +_Note that you are now running Python, this will be slow._ + +Let's go through some examples to see what to expect. We will continue with the `DataFrame` we defined at the start of +this section and show an example with the `apply` function and a counter example where we use the expression API to +achieve the same goals. + +### Adding a counter + +In this example we create a global `counter` and then add the integer `1` to the global state at every element processed. +Every iteration the result of the increment will be added to the element value. + +> Note, this example isn't provided in Rust. The reason is that the global `counter` value would lead to data races when this apply is evaluated in parallel. It would be possible to wrap it in a `Mutex` to protect the variable, but that would be obscuring the point of the example. This is a case where the Python Global Interpreter Lock's performance tradeoff provides some safety guarantees. + +{{code_block('user-guide/expressions/user-defined-functions','counter',['apply'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:counter" +``` + +### Combining multiple column values + +If we want to have access to values of different columns in a single `apply` function call, we can create `struct` data +type. This data type collects those columns as fields in the `struct`. So if we'd create a struct from the columns +`"keys"` and `"values"`, we would get the following struct elements: + +```python +[ + {"keys": "a", "values": 10}, + {"keys": "a", "values": 7}, + {"keys": "b", "values": 1}, +] +``` + +In Python, those would be passed as `dict` to the calling python function and can thus be indexed by `field: str`. In rust, you'll get a `Series` with the `Struct` type. The fields of the struct can then be indexed and downcast. + +{{code_block('user-guide/expressions/user-defined-functions','combine',['apply','struct'])}} + +```python exec="on" result="text" session="user-guide/udf" +--8<-- "python/user-guide/expressions/user-defined-functions.py:combine" +``` + +`Structs` are covered in detail in the next section. + +### Return types? + +Custom python functions are black boxes for polars. We really don't know what kind of black arts you are doing, so we have +to infer and try our best to understand what you meant. + +As a user it helps to understand what we do to better utilize custom functions. + +The data type is automatically inferred. We do that by waiting for the first non-null value. That value will then be used +to determine the type of the `Series`. + +The mapping of python types to polars data types is as follows: + +- `int` -> `Int64` +- `float` -> `Float64` +- `bool` -> `Boolean` +- `str` -> `Utf8` +- `list[tp]` -> `List[tp]` (where the inner type is inferred with the same rules) +- `dict[str, [tp]]` -> `struct` +- `Any` -> `object` (Prevent this at all times) + +Rust types map as follows: + +- `i32` or `i64` -> `Int64` +- `f32` or `f64` -> `Float64` +- `bool` -> `Boolean` +- `String` or `str` -> `Utf8` +- `Vec` -> `List[tp]` (where the inner type is inferred with the same rules) diff --git a/docs/user-guide/expressions/window.md b/docs/user-guide/expressions/window.md new file mode 100644 index 000000000000..7ea426ccb1b9 --- /dev/null +++ b/docs/user-guide/expressions/window.md @@ -0,0 +1,91 @@ +# Window functions + +Window functions are expressions with superpowers. They allow you to perform aggregations on groups in the +`select` context. Let's get a feel for what that means. First we create a dataset. The dataset loaded in the +snippet below contains information about pokemon: + +{{code_block('user-guide/expressions/window','pokemon',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:pokemon" +``` + +## Group by aggregations in selection + +Below we show how to use window functions to group over different columns and perform an aggregation on them. +Doing so allows us to use multiple group by operations in parallel, using a single query. The results of the aggregation +are projected back to the original rows. Therefore, a window function will almost always lead to a `DataFrame` with the same size as the original. + +We will discuss later the cases where a window function can change the numbers of rows in a `DataFrame`. + +Note how we call `.over("Type 1")` and `.over(["Type 1", "Type 2"])`. Using window functions we can aggregate over different groups in a single `select` call! Note that, in Rust, the type of the argument to `over()` must be a collection, so even when you're only using one column, you must provided it in an array. + +The best part is, this won't cost you anything. The computed groups are cached and shared between different `window` expressions. + +{{code_block('user-guide/expressions/window','group_by',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:group_by" +``` + +## Operations per group + +Window functions can do more than aggregation. They can also be viewed as an operation within a group. If, for instance, you +want to `sort` the values within a `group`, you can write `col("value").sort().over("group")` and voilà! We sorted by group! + +Let's filter out some rows to make this more clear. + +{{code_block('user-guide/expressions/window','operations',['filter'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:operations" +``` + +Observe that the group `Water` of column `Type 1` is not contiguous. There are two rows of `Grass` in between. Also note +that each pokemon within a group are sorted by `Speed` in `ascending` order. Unfortunately, for this example we want them sorted in +`descending` speed order. Luckily with window functions this is easy to accomplish. + +{{code_block('user-guide/expressions/window','sort',['over'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:sort" +``` + +`Polars` keeps track of each group's location and maps the expressions to the proper row locations. This will also work over different groups in a single `select`. + +The power of window expressions is that you often don't need a `group_by -> explode` combination, but you can put the logic in a single expression. It also makes the API cleaner. If properly used a: + +- `group_by` -> marks that groups are aggregated and we expect a `DataFrame` of size `n_groups` +- `over` -> marks that we want to compute something within a group, and doesn't modify the original size of the `DataFrame` except in specific cases + +## Map the expression result to the DataFrame rows + +In cases where the expression results in multiple values per group, the Window function has 3 strategies for linking the values back to the `DataFrame` rows: + +- `mapping_strategy = 'group_to_rows'` -> each value is assigned back to one row. The number of values returned should match the number of rows. + +- `mapping_strategy = 'join'` -> the values are imploded in a list, and the list is repeated on all rows. This can be memory intensive. + +- `mapping_strategy = 'explode'` -> the values are exploded to new rows. This operation changes the number of rows. + +## Window expression rules + +The evaluations of window expressions are as follows (assuming we apply it to a `pl.Int32` column): + +{{code_block('user-guide/expressions/window','rules',['over'])}} + +## More examples + +For more exercise, below are some window functions for us to compute: + +- sort all pokemon by type +- select the first `3` pokemon per type as `"Type 1"` +- sort the pokemon within a type by speed in descending order and select the first `3` as `"fastest/group"` +- sort the pokemon within a type by attack in descending order and select the first `3` as `"strongest/group"` +- sort the pokemon within a type by name and select the first `3` as `"sorted_by_alphabet"` + +{{code_block('user-guide/expressions/window','examples',['over','implode'])}} + +```python exec="on" result="text" session="user-guide/window" +--8<-- "python/user-guide/expressions/window.py:examples" +``` diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md new file mode 100644 index 000000000000..8fb27a98c743 --- /dev/null +++ b/docs/user-guide/index.md @@ -0,0 +1,31 @@ +# Introduction + +This User Guide is an introduction to the [`Polars` DataFrame library](https://github.com/pola-rs/polars). Its goal is to introduce you to `Polars` by going through examples and comparing it to other +solutions. Some design choices are introduced here. The guide will also introduce you to optimal usage of `Polars`. + +Even though `Polars` is completely written in [`Rust`](https://www.rust-lang.org/) (no runtime overhead!) and uses [`Arrow`](https://arrow.apache.org/) -- the +[native arrow2 `Rust` implementation](https://github.com/jorgecarleitao/arrow2) -- as its foundation, the examples presented in this guide will be mostly using its higher-level language +bindings. Higher-level bindings only serve as a thin wrapper for functionality implemented in the core library. + +For [`Pandas`](https://pandas.pydata.org/) users, our [Python package](https://pypi.org/project/polars/) will offer the easiest way to get started with `Polars`. + +### Philosophy + +The goal of `Polars` is to provide a lightning fast `DataFrame` library that: + +- Utilizes all available cores on your machine. +- Optimizes queries to reduce unneeded work/memory allocations. +- Handles datasets much larger than your available RAM. +- Has an API that is consistent and predictable. +- Has a strict schema (data-types should be known before running the query). + +Polars is written in Rust which gives it C/C++ performance and allows it to fully control performance critical parts +in a query engine. + +As such `Polars` goes to great lengths to: + +- Reduce redundant copies. +- Traverse memory cache efficiently. +- Minimize contention in parallelism. +- Process data in chunks. +- Reuse memory allocations. diff --git a/docs/user-guide/installation.md b/docs/user-guide/installation.md new file mode 100644 index 000000000000..1f990779d09d --- /dev/null +++ b/docs/user-guide/installation.md @@ -0,0 +1,174 @@ +# Installation + +Polars is a library and installation is as simple as invoking the package manager of the corresponding programming language. + +=== ":fontawesome-brands-python: Python" + + ``` bash + pip install polars + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` shell + cargo add polars -F lazy + + # Or Cargo.toml + [dependencies] + polars = { version = "x", features = ["lazy", ...]} + ``` + +## Importing + +To use the library import it into your project + +=== ":fontawesome-brands-python: Python" + + ``` python + import polars as pl + ``` + +=== ":fontawesome-brands-rust: Rust" + + ``` rust + use polars::prelude::*; + ``` + +## Feature Flags + +By using the above command you install the core of `Polars` onto your system. However depending on your use case you might want to install the optional dependencies as well. These are made optional to minimize the footprint. The flags are different depending on the programming language. Throughout the user guide we will mention when a functionality is used that requires an additional dependency. + +### Python + +```text +# For example +pip install polars[numpy, fsspec] +``` + +| Tag | Description | +| ---------- | ------------------------------------------------------------------------------------------------------------------------------------- | +| all | Install all optional dependencies (all of the following) | +| pandas | Install with Pandas for converting data to and from Pandas Dataframes/Series | +| numpy | Install with numpy for converting data to and from numpy arrays | +| pyarrow | Reading data formats using PyArrow | +| fsspec | Support for reading from remote file systems | +| connectorx | Support for reading from SQL databases | +| xlsx2csv | Support for reading from Excel files | +| deltalake | Support for reading from Delta Lake Tables | +| timezone | Timezone support, only needed if 1. you are on Python < 3.9 and/or 2. you are on Windows, otherwise no dependencies will be installed | + +### Rust + +```toml +# Cargo.toml +[dependencies] +polars = { version = "0.26.1", features = ["lazy", "temporal", "describe", "json", "parquet", "dtype-datetime"] } +``` + +The opt-in features are: + +- Additional data types: + - `dtype-date` + - `dtype-datetime` + - `dtype-time` + - `dtype-duration` + - `dtype-i8` + - `dtype-i16` + - `dtype-u8` + - `dtype-u16` + - `dtype-categorical` + - `dtype-struct` +- `performant` - Longer compile times more fast paths. +- `lazy` - Lazy API + - `lazy_regex` - Use regexes in [column selection](crate::lazy::dsl::col) + - `dot_diagram` - Create dot diagrams from lazy logical plans. +- `sql` - Pass SQL queries to polars. +- `streaming` - Be able to process datasets that are larger than RAM. +- `random` - Generate arrays with randomly sampled values +- `ndarray`- Convert from `DataFrame` to `ndarray` +- `temporal` - Conversions between [Chrono](https://docs.rs/chrono/) and Polars for temporal data types +- `timezones` - Activate timezone support. +- `strings` - Extra string utilities for `Utf8Chunked` + - `string_justify` - `zfill`, `ljust`, `rjust` + - `string_from_radix` - `parse_int` +- `object` - Support for generic ChunkedArrays called `ObjectChunked` (generic over `T`). + These are downcastable from Series through the [Any](https://doc.rust-lang.org/std/any/index.html) trait. +- Performance related: + - `nightly` - Several nightly only features such as SIMD and specialization. + - `performant` - more fast paths, slower compile times. + - `bigidx` - Activate this feature if you expect >> 2^32 rows. This has not been needed by anyone. + This allows polars to scale up way beyond that by using `u64` as an index. + Polars will be a bit slower with this feature activated as many data structures + are less cache efficient. + - `cse` - Activate common subplan elimination optimization +- IO related: + + - `serde` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. + Can be used for JSON and more serde supported serialization formats. + - `serde-lazy` - Support for [serde](https://crates.io/crates/serde) serialization and deserialization. + Can be used for JSON and more serde supported serialization formats. + + - `parquet` - Read Apache Parquet format + - `json` - JSON serialization + - `ipc` - Arrow's IPC format serialization + - `decompress` - Automatically infer compression of csvs and decompress them. + Supported compressions: + - zip + - gzip + +- `DataFrame` operations: + - `dynamic_group_by` - Group by based on a time window instead of predefined keys. + Also activates rolling window group by operations. + - `sort_multiple` - Allow sorting a `DataFrame` on multiple columns + - `rows` - Create `DataFrame` from rows and extract rows from `DataFrames`. + And activates `pivot` and `transpose` operations + - `join_asof` - Join ASOF, to join on nearest keys instead of exact equality match. + - `cross_join` - Create the cartesian product of two DataFrames. + - `semi_anti_join` - SEMI and ANTI joins. + - `group_by_list` - Allow group by operation on keys of type List. + - `row_hash` - Utility to hash DataFrame rows to UInt64Chunked + - `diagonal_concat` - Concat diagonally thereby combining different schemas. + - `horizontal_concat` - Concat horizontally and extend with null values if lengths don't match + - `dataframe_arithmetic` - Arithmetic on (Dataframe and DataFrames) and (DataFrame on Series) + - `partition_by` - Split into multiple DataFrames partitioned by groups. +- `Series`/`Expression` operations: + - `is_in` - [Check for membership in `Series`](crate::chunked_array::ops::IsIn) + - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip) + - `round_series` - round underlying float types of `Series`. + - `repeat_by` - [Repeat element in an Array N times, where N is given by another array. + - `is_first_distinct` - Check if element is first unique value. + - `is_last_distinct` - Check if element is last unique value. + - `checked_arithmetic` - checked arithmetic/ returning `None` on invalid operations. + - `dot_product` - Dot/inner product on Series and Expressions. + - `concat_str` - Concat string data in linear time. + - `reinterpret` - Utility to reinterpret bits to signed/unsigned + - `take_opt_iter` - Take from a Series with `Iterator>` + - `mode` - [Return the most occurring value(s)](crate::chunked_array::ops::ChunkUnique::mode) + - `cum_agg` - cumsum, cummin, cummax aggregation. + - `rolling_window` - rolling window functions, like rolling_mean + - `interpolate` [interpolate None values](crate::chunked_array::ops::Interpolate) + - `extract_jsonpath` - [Run jsonpath queries on Utf8Chunked](https://goessner.net/articles/JsonPath/) + - `list` - List utils. + - `list_take` take sublist by multiple indices + - `rank` - Ranking algorithms. + - `moment` - kurtosis and skew statistics + - `ewma` - Exponential moving average windows + - `abs` - Get absolute values of Series + - `arange` - Range operation on Series + - `product` - Compute the product of a Series. + - `diff` - `diff` operation. + - `pct_change` - Compute change percentages. + - `unique_counts` - Count unique values in expressions. + - `log` - Logarithms for `Series`. + - `list_to_struct` - Convert `List` to `Struct` dtypes. + - `list_count` - Count elements in lists. + - `list_eval` - Apply expressions over list elements. + - `cumulative_eval` - Apply expressions over cumulatively increasing windows. + - `arg_where` - Get indices where condition holds. + - `search_sorted` - Find indices where elements should be inserted to maintain order. + - `date_offset` Add an offset to dates that take months and leap years into account. + - `trigonometry` Trigonometric functions. + - `sign` Compute the element-wise sign of a Series. + - `propagate_nans` NaN propagating min/max aggregations. +- `DataFrame` pretty printing + - `fmt` - Activate DataFrame formatting diff --git a/docs/user-guide/io/aws.md b/docs/user-guide/io/aws.md new file mode 100644 index 000000000000..e19efc74b580 --- /dev/null +++ b/docs/user-guide/io/aws.md @@ -0,0 +1,20 @@ +# AWS + +--8<-- "docs/_build/snippets/under_construction.md" + +To read from or write to an AWS bucket, additional dependencies are needed in Rust: + +=== ":fontawesome-brands-rust: Rust" + +```shell +$ cargo add aws_sdk_s3 aws_config tokio --features tokio/full +``` + +In the next few snippets we'll demonstrate interacting with a `Parquet` file +located on an AWS bucket. + +## Read + +Load a `.parquet` file using: + +{{code_block('user-guide/io/aws','bucket',['from_arrow'])}} diff --git a/docs/user-guide/io/bigquery.md b/docs/user-guide/io/bigquery.md new file mode 100644 index 000000000000..21287cd448d2 --- /dev/null +++ b/docs/user-guide/io/bigquery.md @@ -0,0 +1,19 @@ +# Google BigQuery + +To read or write from GBQ, additional dependencies are needed: + +=== ":fontawesome-brands-python: Python" + +```shell +$ pip install google-cloud-bigquery +``` + +## Read + +We can load a query into a `DataFrame` like this: + +{{code_block('user-guide/io/bigquery','read',['from_arrow'])}} + +## Write + +{{code_block('user-guide/io/bigquery','write',[])}} diff --git a/docs/user-guide/io/csv.md b/docs/user-guide/io/csv.md new file mode 100644 index 000000000000..eeb209dfb34e --- /dev/null +++ b/docs/user-guide/io/csv.md @@ -0,0 +1,21 @@ +# CSV + +## Read & write + +Reading a CSV file should look familiar: + +{{code_block('user-guide/io/csv','read',['read_csv'])}} + +Writing a CSV file is similar with the `write_csv` function: + +{{code_block('user-guide/io/csv','write',['write_csv'])}} + +## Scan + +`Polars` allows you to _scan_ a CSV input. Scanning delays the actual parsing of the +file and instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/csv','scan',['scan_csv'])}} + +If you want to know why this is desirable, you can read more about these `Polars` +optimizations [here](../concepts/lazy-vs-eager.md). diff --git a/docs/user-guide/io/database.md b/docs/user-guide/io/database.md new file mode 100644 index 000000000000..4444e7be799e --- /dev/null +++ b/docs/user-guide/io/database.md @@ -0,0 +1,70 @@ +# Databases + +## Read from a database + +We can read from a database with Polars using the `pl.read_database` function. To use this function you need an SQL query string and a connection string called a `connection_uri`. + +For example, the following snippet shows the general patterns for reading all columns from the `foo` table in a Postgres database: + +{{code_block('user-guide/io/database','read',['read_database_connectorx'])}} + +### Engines + +Polars doesn't manage connections and data transfer from databases by itself. Instead external libraries (known as _engines_) handle this. At present Polars can use two engines to read from databases: + +- [ConnectorX](https://github.com/sfu-db/connector-x) and +- [ADBC](https://arrow.apache.org/docs/format/ADBC.html) + +#### ConnectorX + +ConnectorX is the default engine and [supports numerous databases](https://github.com/sfu-db/connector-x#sources) including Postgres, Mysql, SQL Server and Redshift. ConnectorX is written in Rust and stores data in Arrow format to allow for zero-copy to Polars. + +To read from one of the supported databases with `ConnectorX` you need to activate the additional dependency `ConnectorX` when installing Polars or install it manually with + +```shell +$ pip install connectorx +``` + +#### ADBC + +ADBC (Arrow Database Connectivity) is an engine supported by the Apache Arrow project. ADBC aims to be both an API standard for connecting to databases and libraries implementing this standard in a range of languages. + +It is still early days for ADBC so support for different databases is still limited. At present drivers for ADBC are only available for [Postgres and SQLite](https://arrow.apache.org/adbc/0.1.0/driver/cpp/index.html). To install ADBC you need to install the driver for your database. For example to install the driver for SQLite you run + +```shell +$ pip install adbc-driver-sqlite +``` + +As ADBC is not the default engine you must specify the engine as an argument to `pl.read_database` + +{{code_block('user-guide/io/database','adbc',['read_database'])}} + +## Write to a database + +We can write to a database with Polars using the `pl.write_database` function. + +### Engines + +As with reading from a database above Polars uses an _engine_ to write to a database. The currently supported engines are: + +- [SQLAlchemy](https://www.sqlalchemy.org/) and +- Arrow Database Connectivity (ADBC) + +#### SQLAlchemy + +With the default engine SQLAlchemy you can write to any database supported by SQLAlchemy. To use this engine you need to install SQLAlchemy and Pandas + +```shell +$ pip install SQLAlchemy pandas +``` + +In this example, we write the `DataFrame` to a table called `records` in the database + +{{code_block('user-guide/io/database','write',['write_database'])}} + +In the SQLAlchemy approach Polars converts the `DataFrame` to a Pandas `DataFrame` backed by PyArrow and then uses SQLAlchemy methods on a Pandas `DataFrame` to write to the database. + +#### ADBC + +As with reading from a database you can also use ADBC to write to a SQLite or Posgres database. As shown above you need to install the appropriate ADBC driver for your database. +{{code_block('user-guide/io/database','write_adbc',['write_database'])}} diff --git a/docs/user-guide/io/json_file.md b/docs/user-guide/io/json_file.md new file mode 100644 index 000000000000..352904829c7b --- /dev/null +++ b/docs/user-guide/io/json_file.md @@ -0,0 +1,26 @@ +# JSON files + +## Read & write + +### JSON + +Reading a JSON file should look familiar: + +{{code_block('user-guide/io/json-file','read',['read_json'])}} + +### Newline Delimited JSON + +JSON objects that are delimited by newlines can be read into polars in a much more performant way than standard json. + +{{code_block('user-guide/io/json-file','readnd',['read_ndjson'])}} + +## Write + +{{code_block('user-guide/io/json-file','write',['write_json','write_ndjson'])}} + +## Scan + +`Polars` allows you to _scan_ a JSON input **only for newline delimited json**. Scanning delays the actual parsing of the +file and instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/json-file','scan',['scan_ndjson'])}} diff --git a/docs/user-guide/io/multiple.md b/docs/user-guide/io/multiple.md new file mode 100644 index 000000000000..c5a66b03940f --- /dev/null +++ b/docs/user-guide/io/multiple.md @@ -0,0 +1,40 @@ +## Dealing with multiple files. + +Polars can deal with multiple files differently depending on your needs and memory strain. + +Let's create some files to give us some context: + +{{code_block('user-guide/io/multiple','create',['write_csv'])}} + +## Reading into a single `DataFrame` + +To read multiple files into a single `DataFrame`, we can use globbing patterns: + +{{code_block('user-guide/io/multiple','read',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:create" +--8<-- "python/user-guide/io/multiple.py:read" +``` + +To see how this works we can take a look at the query plan. Below we see that all files are read separately and +concatenated into a single `DataFrame`. `Polars` will try to parallelize the reading. + +{{code_block('user-guide/io/multiple','graph',['show_graph'])}} + +```python exec="on" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:creategraph" +``` + +## Reading and processing in parallel + +If your files don't have to be in a single table you can also build a query plan for each file and execute them in parallel +on the `Polars` thread pool. + +All query plan execution is embarrassingly parallel and doesn't require any communication. + +{{code_block('user-guide/io/multiple','glob',['scan_csv'])}} + +```python exec="on" result="text" session="user-guide/io/multiple" +--8<-- "python/user-guide/io/multiple.py:glob" +``` diff --git a/docs/user-guide/io/parquet.md b/docs/user-guide/io/parquet.md new file mode 100644 index 000000000000..71a5399bb393 --- /dev/null +++ b/docs/user-guide/io/parquet.md @@ -0,0 +1,24 @@ +# Parquet + +Loading or writing [`Parquet` files](https://parquet.apache.org/) is lightning fast. +`Pandas` uses [`PyArrow`](https://arrow.apache.org/docs/python/) -`Python` bindings +exposed by `Arrow`- to load `Parquet` files into memory, but it has to copy that data into +`Pandas` memory. With `Polars` there is no extra cost due to +copying as we read `Parquet` directly into `Arrow` memory and _keep it there_. + +## Read + +{{code_block('user-guide/io/parquet','read',['read_parquet'])}} + +## Write + +{{code_block('user-guide/io/parquet','write',['write_parquet'])}} + +## Scan + +`Polars` allows you to _scan_ a `Parquet` input. Scanning delays the actual parsing of the +file and instead returns a lazy computation holder called a `LazyFrame`. + +{{code_block('user-guide/io/parquet','scan',['scan_parquet'])}} + +If you want to know why this is desirable, you can read more about those `Polars` optimizations [here](../concepts/lazy-vs-eager.md). diff --git a/docs/user-guide/lazy/execution.md b/docs/user-guide/lazy/execution.md new file mode 100644 index 000000000000..975f52a0ac4a --- /dev/null +++ b/docs/user-guide/lazy/execution.md @@ -0,0 +1,79 @@ +# Query execution + +Our example query on the Reddit dataset is: + +{{code_block('user-guide/lazy/execution','df',['scan_csv'])}} + +If we were to run the code above on the Reddit CSV the query would not be evaluated. Instead Polars takes each line of code, adds it to the internal query graph and optimizes the query graph. + +When we execute the code Polars executes the optimized query graph by default. + +### Execution on the full dataset + +We can execute our query on the full dataset by calling the `.collect` method on the query. + +{{code_block('user-guide/lazy/execution','collect',['scan_csv','collect'])}} + +```text +shape: (14_029, 6) +┌─────────┬───────────────────────────┬─────────────┬────────────┬───────────────┬────────────┐ +│ id ┆ name ┆ created_utc ┆ updated_on ┆ comment_karma ┆ link_karma │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═════════╪═══════════════════════════╪═════════════╪════════════╪═══════════════╪════════════╡ +│ 6 ┆ TAOJIANLONG_JASONBROKEN ┆ 1397113510 ┆ 1536527864 ┆ 4 ┆ 0 │ +│ 17 ┆ SSAIG_JASONBROKEN ┆ 1397113544 ┆ 1536527864 ┆ 1 ┆ 0 │ +│ 19 ┆ FDBVFDSSDGFDS_JASONBROKEN ┆ 1397113552 ┆ 1536527864 ┆ 3 ┆ 0 │ +│ 37 ┆ IHATEWHOWEARE_JASONBROKEN ┆ 1397113636 ┆ 1536527864 ┆ 61 ┆ 0 │ +│ … ┆ … ┆ … ┆ … ┆ … ┆ … │ +│ 1229384 ┆ DSFOX ┆ 1163177415 ┆ 1536497412 ┆ 44411 ┆ 7917 │ +│ 1229459 ┆ NEOCARTY ┆ 1163177859 ┆ 1536533090 ┆ 40 ┆ 0 │ +│ 1229587 ┆ TEHSMA ┆ 1163178847 ┆ 1536497412 ┆ 14794 ┆ 5707 │ +│ 1229621 ┆ JEREMYLOW ┆ 1163179075 ┆ 1536497412 ┆ 411 ┆ 1063 │ +└─────────┴───────────────────────────┴─────────────┴────────────┴───────────────┴────────────┘ +``` + +Above we see that from the 10 million rows there are 14,029 rows that match our predicate. + +With the default `collect` method Polars processes all of your data as one batch. This means that all the data has to fit into your available memory at the point of peak memory usage in your query. + +!!! warning "Reusing `LazyFrame` objects" + + Remember that `LazyFrame`s are query plans i.e. a promise on computation and is not guaranteed to cache common subplans. This means that every time you reuse it in separate downstream queries after it is defined, it is computed all over again. If you define an operation on a `LazyFrame` that doesn't maintain row order (such as a `group_by`), then the order will also change every time it is run. To avoid this, use `maintain_order=True` arguments for such operations. + +### Execution on larger-than-memory data + +If your data requires more memory than you have available Polars may be able to process the data in batches using _streaming_ mode. To use streaming mode you simply pass the `streaming=True` argument to `collect` + +{{code_block('user-guide/lazy/execution','stream',['scan_csv','collect'])}} + +We look at [streaming in more detail here](streaming.md). + +### Execution on a partial dataset + +While you're writing, optimizing or checking your query on a large dataset, querying all available data may lead to a slow development process. + +You can instead execute the query with the `.fetch` method. The `.fetch` method takes a parameter `n_rows` and tries to 'fetch' that number of rows at the data source. The number of rows cannot be guaranteed, however, as the lazy API does not count how many rows there are at each stage of the query. + +Here we "fetch" 100 rows from the source file and apply the predicates. + +{{code_block('user-guide/lazy/execution','partial',['scan_csv','collect','fetch'])}} + +```text +shape: (27, 6) +┌───────┬───────────────────────────┬─────────────┬────────────┬───────────────┬────────────┐ +│ id ┆ name ┆ created_utc ┆ updated_on ┆ comment_karma ┆ link_karma │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 ┆ i64 │ +╞═══════╪═══════════════════════════╪═════════════╪════════════╪═══════════════╪════════════╡ +│ 6 ┆ TAOJIANLONG_JASONBROKEN ┆ 1397113510 ┆ 1536527864 ┆ 4 ┆ 0 │ +│ 17 ┆ SSAIG_JASONBROKEN ┆ 1397113544 ┆ 1536527864 ┆ 1 ┆ 0 │ +│ 19 ┆ FDBVFDSSDGFDS_JASONBROKEN ┆ 1397113552 ┆ 1536527864 ┆ 3 ┆ 0 │ +│ 37 ┆ IHATEWHOWEARE_JASONBROKEN ┆ 1397113636 ┆ 1536527864 ┆ 61 ┆ 0 │ +│ … ┆ … ┆ … ┆ … ┆ … ┆ … │ +│ 77763 ┆ LUNCHY ┆ 1137599510 ┆ 1536528275 ┆ 65 ┆ 0 │ +│ 77765 ┆ COMPOSTELLAS ┆ 1137474000 ┆ 1536528276 ┆ 6 ┆ 0 │ +│ 77766 ┆ GENERICBOB ┆ 1137474000 ┆ 1536528276 ┆ 291 ┆ 14 │ +│ 77768 ┆ TINHEADNED ┆ 1139665457 ┆ 1536497404 ┆ 4434 ┆ 103 │ +└───────┴───────────────────────────┴─────────────┴────────────┴───────────────┴────────────┘ +``` diff --git a/docs/user-guide/lazy/optimizations.md b/docs/user-guide/lazy/optimizations.md new file mode 100644 index 000000000000..576413833a3a --- /dev/null +++ b/docs/user-guide/lazy/optimizations.md @@ -0,0 +1,17 @@ +# Optimizations + +If you use `Polars`' lazy API, `Polars` will run several optimizations on your query. Some of them are executed up front, +others are determined just in time as the materialized data comes in. + +Here is a non-complete overview of optimizations done by polars, what they do and how often they run. + +| Optimization | Explanation | runs | +| -------------------------- | ------------------------------------------------------------------------------------------------------------ | ----------------------------- | +| Predicate pushdown | Applies filters as early as possible/ at scan level. | 1 time | +| Projection pushdown | Select only the columns that are needed at the scan level. | 1 time | +| Slice pushdown | Only load the required slice from the scan level. Don't materialize sliced outputs (e.g. join.head(10)). | 1 time | +| Common subplan elimination | Cache subtrees/file scans that are used by multiple subtrees in the query plan. | 1 time | +| Simplify expressions | Various optimizations, such as constant folding and replacing expensive operations with faster alternatives. | until fixed point | +| Join ordering | Estimates the branches of joins that should be executed first in order to reduce memory pressure. | 1 time | +| Type coercion | Coerce types such that operations succeed and run on minimal required memory. | until fixed point | +| Cardinality estimation | Estimates cardinality in order to determine optimal group by strategy. | 0/n times; dependent on query | diff --git a/docs/user-guide/lazy/query_plan.md b/docs/user-guide/lazy/query_plan.md new file mode 100644 index 000000000000..bb57a74168de --- /dev/null +++ b/docs/user-guide/lazy/query_plan.md @@ -0,0 +1,96 @@ +# Query plan + +For any lazy query `Polars` has both: + +- a non-optimized plan with the set of steps code as we provided it and +- an optimized plan with changes made by the query optimizer + +We can understand both the non-optimized and optimized query plans with visualization and by printing them as text. + +
+```python exec="on" result="text" session="user-guide/lazy/query_plan" +--8<-- "python/user-guide/lazy/query_plan.py:setup" +``` +
+ +Below we consider the following query: + +{{code_block('user-guide/lazy/query_plan','plan',[])}} + +```python exec="on" session="user-guide/lazy/query_plan" +--8<-- "python/user-guide/lazy/query_plan.py:plan" +``` + +## Non-optimized query plan + +### Graphviz visualization + +First we visualise the non-optimized plan by setting `optimized=False`. + +{{code_block('user-guide/lazy/query_plan','showplan',['show_graph'])}} + +```python exec="on" session="user-guide/lazy/query_plan" +--8<-- "python/user-guide/lazy/query_plan.py:createplan" +``` + +The query plan visualization should be read from bottom to top. In the visualization: + +- each box corresponds to a stage in the query plan +- the `sigma` stands for `SELECTION` and indicates any filter conditions +- the `pi` stands for `PROJECTION` and indicates choosing a subset of columns + +### Printed query plan + +We can also print the non-optimized plan with `explain(optimized=False)` + +{{code_block('user-guide/lazy/query_plan','describe',['explain'])}} + +```python exec="on" session="user-guide/lazy/query_plan" +--8<-- "python/user-guide/lazy/query_plan.py:describe" +``` + +```text +FILTER [(col("comment_karma")) > (0)] FROM WITH_COLUMNS: + [col("name").str.uppercase()] + + CSV SCAN data/reddit.csv + PROJECT */6 COLUMNS +``` + +The printed plan should also be read from bottom to top. This non-optimized plan is roughly equal to: + +- read from the `data/reddit.csv` file +- read all 6 columns (where the * wildcard in PROJECT \*/6 COLUMNS means take all columns) +- transform the `name` column to uppercase +- apply a filter on the `comment_karma` column + +## Optimized query plan + +Now we visualize the optimized plan with `show_graph`. + +{{code_block('user-guide/lazy/query_plan','show',['show_graph'])}} + +```python exec="on" session="user-guide/lazy/query_plan" +--8<-- "python/user-guide/lazy/query_plan.py:createplan2" +``` + +We can also print the optimized plan with `explain` + +{{code_block('user-guide/lazy/query_plan','optimized',['explain'])}} + +```text + WITH_COLUMNS: + [col("name").str.uppercase()] + + CSV SCAN data/reddit.csv + PROJECT */6 COLUMNS + SELECTION: [(col("comment_karma")) > (0)] +``` + +The optimized plan is to: + +- read the data from the Reddit CSV +- apply the filter on the `comment_karma` column while the CSV is being read line-by-line +- transform the `name` column to uppercase + +In this case the query optimizer has identified that the `filter` can be applied while the CSV is read from disk rather than reading the whole file into memory and then applying the filter. This optimization is called _Predicate Pushdown_. diff --git a/docs/user-guide/lazy/schemas.md b/docs/user-guide/lazy/schemas.md new file mode 100644 index 000000000000..77d2be54b722 --- /dev/null +++ b/docs/user-guide/lazy/schemas.md @@ -0,0 +1,60 @@ +# Schema + +The schema of a Polars `DataFrame` or `LazyFrame` sets out the names of the columns and their datatypes. You can see the schema with the `.schema` method on a `DataFrame` or `LazyFrame` + +{{code_block('user-guide/lazy/schema','schema',['DataFrame','lazy'])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:setup" +--8<-- "python/user-guide/lazy/schema.py:schema" +``` + +The schema plays an important role in the lazy API. + +## Type checking in the lazy API + +One advantage of the lazy API is that Polars will check the schema before any data is processed. This check happens when you execute your lazy query. + +We see how this works in the following simple example where we call the `.round` expression on the integer `bar` column. + +{{code_block('user-guide/lazy/schema','typecheck',['lazy','with_columns'])}} + +The `.round` expression is only valid for columns with a floating point dtype. Calling `.round` on an integer column means the operation will raise an `InvalidOperationError` when we evaluate the query with `collect`. This schema check happens before the data is processed when we call `collect`. + +`python exec="on" result="text" session="user-guide/lazy/schemas"` + +If we executed this query in eager mode the error would only be found once the data had been processed in all earlier steps. + +When we execute a lazy query Polars checks for any potential `InvalidOperationError` before the time-consuming step of actually processing the data in the pipeline. + +## The lazy API must know the schema + +In the lazy API the Polars query optimizer must be able to infer the schema at every step of a query plan. This means that operations where the schema is not knowable in advance cannot be used with the lazy API. + +The classic example of an operation where the schema is not knowable in advance is a `.pivot` operation. In a `.pivot` the new column names come from data in one of the columns. As these column names cannot be known in advance a `.pivot` is not available in the lazy API. + +## Dealing with operations not available in the lazy API + +If your pipeline includes an operation that is not available in the lazy API it is normally best to: + +- run the pipeline in lazy mode up until that point +- execute the pipeline with `.collect` to materialize a `DataFrame` +- do the non-lazy operation on the `DataFrame` +- convert the output back to a `LazyFrame` with `.lazy` and continue in lazy mode + +We show how to deal with a non-lazy operation in this example where we: + +- create a simple `DataFrame` +- convert it to a `LazyFrame` with `.lazy` +- do a transformation using `.with_columns` +- execute the query before the pivot with `.collect` to get a `DataFrame` +- do the `.pivot` on the `DataFrame` +- convert back in lazy mode +- do a `.filter` +- finish by executing the query with `.collect` to get a `DataFrame` + +{{code_block('user-guide/lazy/schema','lazyeager',['collect','pivot','filter'])}} + +```python exec="on" result="text" session="user-guide/lazy/schemas" +--8<-- "python/user-guide/lazy/schema.py:lazyeager" +``` diff --git a/docs/user-guide/lazy/streaming.md b/docs/user-guide/lazy/streaming.md new file mode 100644 index 000000000000..3f9d268443ca --- /dev/null +++ b/docs/user-guide/lazy/streaming.md @@ -0,0 +1,3 @@ +# Streaming + +--8<-- "docs/_build/snippets/under_construction.md" diff --git a/docs/user-guide/lazy/using.md b/docs/user-guide/lazy/using.md new file mode 100644 index 000000000000..d777557da550 --- /dev/null +++ b/docs/user-guide/lazy/using.md @@ -0,0 +1,37 @@ +# Usage + +With the lazy API, Polars doesn't run each query line-by-line but instead processes the full query end-to-end. To get the most out of Polars it is important that you use the lazy API because: + +- the lazy API allows Polars to apply automatic query optimization with the query optimizer +- the lazy API allows you to work with larger than memory datasets using streaming +- the lazy API can catch schema errors before processing the data + +Here we see how to use the lazy API starting from either a file or an existing `DataFrame`. + +## Using the lazy API from a file + +In the ideal case we would use the lazy API right from a file as the query optimizer may help us to reduce the amount of data we read from the file. + +We create a lazy query from the Reddit CSV data and apply some transformations. + +By starting the query with `pl.scan_csv` we are using the lazy API. + +{{code_block('user-guide/lazy/using','dataframe',['scan_csv','with_columns','filter','col'])}} + +A `pl.scan_` function is available for a number of file types including CSV, IPC, Parquet and JSON. + +In this query we tell Polars that we want to: + +- load data from the Reddit CSV file +- convert the `name` column to uppercase +- apply a filter to the `comment_karma` column + +The lazy query will not be executed at this point. See this page on [executing lazy queries](execution.md) for more on running lazy queries. + +## Using the lazy API from a `DataFrame` + +An alternative way to access the lazy API is to call `.lazy` on a `DataFrame` that has already been created in memory. + +{{code_block('user-guide/lazy/using','fromdf',['lazy'])}} + +By calling `.lazy` we convert the `DataFrame` to a `LazyFrame`. diff --git a/docs/user-guide/migration/pandas.md b/docs/user-guide/migration/pandas.md new file mode 100644 index 000000000000..d781ae290f96 --- /dev/null +++ b/docs/user-guide/migration/pandas.md @@ -0,0 +1,328 @@ +# Coming from Pandas + +Here we set out the key points that anyone who has experience with `Pandas` and wants to +try `Polars` should know. We include both differences in the concepts the libraries are +built on and differences in how you should write `Polars` code compared to `Pandas` +code. + +## Differences in concepts between `Polars` and `Pandas` + +### `Polars` does not have a multi-index/index + +`Pandas` gives a label to each row with an index. `Polars` does not use an index and +each row is indexed by its integer position in the table. + +Polars aims to have predictable results and readable queries, as such we think an index does not help us reach that +objective. We believe the semantics of a query should not change by the state of an index or a `reset_index` call. + +In Polars a DataFrame will always be a 2D table with heterogeneous data-types. The data-types may have nesting, but the +table itself will not. +Operations like resampling will be done by specialized functions or methods that act like 'verbs' on a table explicitly +stating the columns that that 'verb' operates on. As such, it is our conviction that not having indices make things simpler, +more explicit, more readable and less error-prone. + +Note that an 'index' data structure as known in databases will be used by polars as an optimization technique. + +### `Polars` uses Apache Arrow arrays to represent data in memory while `Pandas` uses `Numpy` arrays + +`Polars` represents data in memory with Arrow arrays while `Pandas` represents data in +memory with `Numpy` arrays. Apache Arrow is an emerging standard for in-memory columnar +analytics that can accelerate data load times, reduce memory usage and accelerate +calculations. + +`Polars` can convert data to `Numpy` format with the `to_numpy` method. + +### `Polars` has more support for parallel operations than `Pandas` + +`Polars` exploits the strong support for concurrency in Rust to run many operations in +parallel. While some operations in `Pandas` are multi-threaded the core of the library +is single-threaded and an additional library such as `Dask` must be used to parallelize +operations. + +### `Polars` can lazily evaluate queries and apply query optimization + +Eager evaluation is when code is evaluated as soon as you run the code. Lazy evaluation +is when running a line of code means that the underlying logic is added to a query plan +rather than being evaluated. + +`Polars` supports eager evaluation and lazy evaluation whereas `Pandas` only supports +eager evaluation. The lazy evaluation mode is powerful because `Polars` carries out +automatic query optimization when it examines the query plan and looks for ways to +accelerate the query or reduce memory usage. + +`Dask` also supports lazy evaluation when it generates a query plan. However, `Dask` +does not carry out query optimization on the query plan. + +## Key syntax differences + +Users coming from `Pandas` generally need to know one thing... + +``` +polars != pandas +``` + +If your `Polars` code looks like it could be `Pandas` code, it might run, but it likely +runs slower than it should. + +Let's go through some typical `Pandas` code and see how we might rewrite it in `Polars`. + +### Selecting data + +As there is no index in `Polars` there is no `.loc` or `iloc` method in `Polars` - and +there is also no `SettingWithCopyWarning` in `Polars`. + +However, the best way to select data in `Polars` is to use the expression API. For +example, if you want to select a column in `Pandas` you can do one of the following: + +```python +df['a'] +df.loc[:,'a'] +``` + +but in `Polars` you would use the `.select` method: + +```python +df.select('a') +``` + +If you want to select rows based on the values then in `Polars` you use the `.filter` +method: + +```python +df.filter(pl.col('a') < 10) +``` + +As noted in the section on expressions below, `Polars` can run operations in `.select` +and `filter` in parallel and `Polars` can carry out query optimization on the full set +of data selection criteria. + +### Be lazy + +Working in lazy evaluation mode is straightforward and should be your default in +`Polars` as the lazy mode allows `Polars` to do query optimization. + +We can run in lazy mode by either using an implicitly lazy function (such as `scan_csv`) +or explicitly using the `lazy` method. + +Take the following simple example where we read a CSV file from disk and do a group by. +The CSV file has numerous columns but we just want to do a group by on one of the id +columns (`id1`) and then sum by a value column (`v1`). In `Pandas` this would be: + +```python +df = pd.read_csv(csv_file, usecols=['id1','v1']) +grouped_df = df.loc[:,['id1','v1']].groupby('id1').sum('v1') +``` + +In `Polars` you can build this query in lazy mode with query optimization and evaluate +it by replacing the eager `Pandas` function `read_csv` with the implicitly lazy `Polars` +function `scan_csv`: + +```python +df = pl.scan_csv(csv_file) +grouped_df = df.group_by('id1').agg(pl.col('v1').sum()).collect() +``` + +`Polars` optimizes this query by identifying that only the `id1` and `v1` columns are +relevant and so will only read these columns from the CSV. By calling the `.collect` +method at the end of the second line we instruct `Polars` to eagerly evaluate the query. + +If you do want to run this query in eager mode you can just replace `scan_csv` with +`read_csv` in the `Polars` code. + +Read more about working with lazy evaluation in the +[lazy API](../lazy/using.md) section. + +### Express yourself + +A typical `Pandas` script consists of multiple data transformations that are executed +sequentially. However, in `Polars` these transformations can be executed in parallel +using expressions. + +#### Column assignment + +We have a dataframe `df` with a column called `value`. We want to add two new columns, a +column called `tenXValue` where the `value` column is multiplied by 10 and a column +called `hundredXValue` where the `value` column is multiplied by 100. + +In `Pandas` this would be: + +```python +df["tenXValue"] = df["value"] * 10 +df["hundredXValue"] = df["value"] * 100 +``` + +These column assignments are executed sequentially. + +In `Polars` we add columns to `df` using the `.with_columns` method and name them with +the `.alias` method: + +```python +df.with_columns( + (pl.col("value") * 10).alias("tenXValue"), + (pl.col("value") * 100).alias("hundredXValue"), +) +``` + +These column assignments are executed in parallel. + +#### Column assignment based on predicate + +In this case we have a dataframe `df` with columns `a`,`b` and `c`. We want to re-assign +the values in column `a` based on a condition. When the value in column `c` is equal to +2 then we replace the value in `a` with the value in `b`. + +In `Pandas` this would be: + +```python +df.loc[df["c"] == 2, "a"] = df.loc[df["c"] == 2, "b"] +``` + +while in `Polars` this would be: + +```python +df.with_columns( + pl.when(pl.col("c") == 2) + .then(pl.col("b")) + .otherwise(pl.col("a")).alias("a") +) +``` + +The `Polars` way is pure in that the original `DataFrame` is not modified. The `mask` is +also not computed twice as in `Pandas` (you could prevent this in `Pandas`, but that +would require setting a temporary variable). + +Additionally `Polars` can compute every branch of an `if -> then -> otherwise` in +parallel. This is valuable, when the branches get more expensive to compute. + +#### Filtering + +We want to filter the dataframe `df` with housing data based on some criteria. + +In `Pandas` you filter the dataframe by passing Boolean expressions to the `loc` method: + +```python +df.loc[(df['sqft_living'] > 2500) & (df['price'] < 300000)] +``` + +while in `Polars` you call the `filter` method: + +```python +df.filter( + (pl.col("m2_living") > 2500) & (pl.col("price") < 300000) +) +``` + +The query optimizer in `Polars` can also detect if you write multiple filters separately +and combine them into a single filter in the optimized plan. + +## `Pandas` transform + +The `Pandas` documentation demonstrates an operation on a group by called `transform`. In +this case we have a dataframe `df` and we want a new column showing the number of rows +in each group. + +In `Pandas` we have: + +```python +df = pd.DataFrame({ + "type": ["m", "n", "o", "m", "m", "n", "n"], + "c": [1, 1, 1, 2, 2, 2, 2], +}) + +df["size"] = df.groupby("c")["type"].transform(len) +``` + +Here `Pandas` does a group by on `"c"`, takes column `"type"`, computes the group length +and then joins the result back to the original `DataFrame` producing: + +``` + c type size +0 1 m 3 +1 1 n 3 +2 1 o 3 +3 2 m 4 +4 2 m 4 +5 2 n 4 +6 2 n 4 +``` + +In `Polars` the same can be achieved with `window` functions: + +```python +df.select( + pl.all(), + pl.col("type").count().over("c").alias("size") +) +``` + +``` +shape: (7, 3) +┌─────┬──────┬──────┐ +│ c ┆ type ┆ size │ +│ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ u32 │ +╞═════╪══════╪══════╡ +│ 1 ┆ m ┆ 3 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 1 ┆ n ┆ 3 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 1 ┆ o ┆ 3 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 │ +└─────┴──────┴──────┘ +``` + +Because we can store the whole operation in a single expression, we can combine several +`window` functions and even combine different groups! + +`Polars` will cache window expressions that are applied over the same group, so storing +them in a single `select` is both convenient **and** optimal. In the following example +we look at a case where we are calculating group statistics over `"c"` twice: + +```python +df.select( + pl.all(), + pl.col("c").count().over("c").alias("size"), + pl.col("c").sum().over("type").alias("sum"), + pl.col("c").reverse().over("c").flatten().alias("reverse_type") +) +``` + +``` +shape: (7, 5) +┌─────┬──────┬──────┬─────┬──────────────┐ +│ c ┆ type ┆ size ┆ sum ┆ reverse_type │ +│ --- ┆ --- ┆ --- ┆ --- ┆ --- │ +│ i64 ┆ str ┆ u32 ┆ i64 ┆ i64 │ +╞═════╪══════╪══════╪═════╪══════════════╡ +│ 1 ┆ m ┆ 3 ┆ 5 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 1 ┆ n ┆ 3 ┆ 5 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 1 ┆ o ┆ 3 ┆ 1 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 ┆ 5 ┆ 2 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ m ┆ 4 ┆ 5 ┆ 1 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 ┆ 5 ┆ 1 │ +├╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌┤ +│ 2 ┆ n ┆ 4 ┆ 5 ┆ 1 │ +└─────┴──────┴──────┴─────┴──────────────┘ +``` + +## Missing data + +`Pandas` uses `NaN` and/or `None` values to indicate missing values depending on the dtype of the column. In addition the behaviour in `Pandas` varies depending on whether the default dtypes or optional nullable arrays are used. In `Polars` missing data corresponds to a `null` value for all data types. + +For float columns `Polars` permits the use of `NaN` values. These `NaN` values are not considered to be missing data but instead a special floating point value. + +In `Pandas` an integer column with missing values is cast to be a float column with `NaN` values for the missing values (unless using optional nullable integer dtypes). In `Polars` any missing values in an integer column are simply `null` values and the column remains an integer column. + +See the [missing data](../expressions/null.md) section for more details. diff --git a/docs/user-guide/migration/spark.md b/docs/user-guide/migration/spark.md new file mode 100644 index 000000000000..ea1a41abbd71 --- /dev/null +++ b/docs/user-guide/migration/spark.md @@ -0,0 +1,158 @@ +# Coming from Apache Spark + +## Column-based API vs. Row-based API + +Whereas the `Spark` `DataFrame` is analogous to a collection of rows, a `Polars` `DataFrame` is closer to a collection of columns. This means that you can combine columns in `Polars` in ways that are not possible in `Spark`, because `Spark` preserves the relationship of the data in each row. + +Consider this sample dataset: + +```python +import polars as pl + +df = pl.DataFrame({ + "foo": ["a", "b", "c", "d", "d"], + "bar": [1, 2, 3, 4, 5], +}) + +dfs = spark.createDataFrame( + [ + ("a", 1), + ("b", 2), + ("c", 3), + ("d", 4), + ("d", 5), + ], + schema=["foo", "bar"], +) +``` + +### Example 1: Combining `head` and `sum` + +In `Polars` you can write something like this: + +```python +df.select( + pl.col("foo").sort().head(2), + pl.col("bar").filter(pl.col("foo") == "d").sum() +) +``` + +Output: + +``` +shape: (2, 2) +┌─────┬─────┐ +│ foo ┆ bar │ +│ --- ┆ --- │ +│ str ┆ i64 │ +╞═════╪═════╡ +│ a ┆ 9 │ +├╌╌╌╌╌┼╌╌╌╌╌┤ +│ b ┆ 9 │ +└─────┴─────┘ +``` + +The expressions on columns `foo` and `bar` are completely independent. Since the expression on `bar` returns a single value, that value is repeated for each value output by the expression on `foo`. But `a` and `b` have no relation to the data that produced the sum of `9`. + +To do something similar in `Spark`, you'd need to compute the sum separately and provide it as a literal: + +```python +from pyspark.sql.functions import col, sum, lit + +bar_sum = ( + dfs + .where(col("foo") == "d") + .groupBy() + .agg(sum(col("bar"))) + .take(1)[0][0] +) + +( + dfs + .orderBy("foo") + .limit(2) + .withColumn("bar", lit(bar_sum)) + .show() +) +``` + +Output: + +``` ++---+---+ +|foo|bar| ++---+---+ +| a| 9| +| b| 9| ++---+---+ +``` + +### Example 2: Combining Two `head`s + +In `Polars` you can combine two different `head` expressions on the same DataFrame, provided that they return the same number of values. + +```python +df.select( + pl.col("foo").sort().head(2), + pl.col("bar").sort(descending=True).head(2), +) +``` + +Output: + +``` +shape: (3, 2) +┌─────┬─────┐ +│ foo ┆ bar │ +│ --- ┆ --- │ +│ str ┆ i64 │ +╞═════╪═════╡ +│ a ┆ 5 │ +├╌╌╌╌╌┼╌╌╌╌╌┤ +│ b ┆ 4 │ +└─────┴─────┘ +``` + +Again, the two `head` expressions here are completely independent, and the pairing of `a` to `5` and `b` to `4` results purely from the juxtaposition of the two columns output by the expressions. + +To accomplish something similar in `Spark`, you would need to generate an artificial key that enables you to join the values in this way. + +```python +from pyspark.sql import Window +from pyspark.sql.functions import row_number + +foo_dfs = ( + dfs + .withColumn( + "rownum", + row_number().over(Window.orderBy("foo")) + ) +) + +bar_dfs = ( + dfs + .withColumn( + "rownum", + row_number().over(Window.orderBy(col("bar").desc())) + ) +) + +( + foo_dfs.alias("foo") + .join(bar_dfs.alias("bar"), on="rownum") + .select("foo.foo", "bar.bar") + .limit(2) + .show() +) +``` + +Output: + +``` ++---+---+ +|foo|bar| ++---+---+ +| a| 5| +| b| 4| ++---+---+ +``` diff --git a/docs/user-guide/misc/alternatives.md b/docs/user-guide/misc/alternatives.md new file mode 100644 index 000000000000..a5544e7db354 --- /dev/null +++ b/docs/user-guide/misc/alternatives.md @@ -0,0 +1,66 @@ +# Alternatives + +These are some tools that share similar functionality to what polars does. + +- Pandas + + A very versatile tool for small data. Read [10 things I hate about pandas](https://wesmckinney.com/blog/apache-arrow-pandas-internals/) + written by the author himself. Polars has solved all those 10 things. + Polars is a versatile tool for small and large data with a more predictable, less ambiguous, and stricter API. + +- Pandas the API + + The API of pandas was designed for in memory data. This makes it a poor fit for performant analysis on large data + (read anything that does not fit into RAM). Any tool that tries to distribute that API will likely have a + suboptimal query plan compared to plans that follow from a declarative API like SQL or Polars' API. + +- Dask + + Parallelizes existing single-threaded libraries like `NumPy` and `Pandas`. As a consumer of those libraries Dask + therefore has less control over low level performance and semantics. + Those libraries are treated like a black box. + On a single machine the parallelization effort can also be seriously stalled by pandas strings. + Pandas strings, by default, are stored as python objects in + numpy arrays meaning that any operation on them is GIL bound and therefore single threaded. This can be circumvented + by multi-processing but has a non-trivial cost. + +- Modin + + Similar to Dask + +- Vaex + + Vaexs method of out-of-core analysis is memory mapping files. This works until it doesn't. For instance parquet + or csv files first need to be read and converted to a file format that can be memory mapped. Another downside is + that the OS determines when pages will be swapped. Operations that need a full data shuffle, such as + sorts, have terrible performance on memory mapped data. + Polars' out of core processing is not based on memory mapping, but on streaming data in batches (and spilling to disk + if needed), we control which data must be hold in memory, not the OS, meaning that we don't have unexpected IO stalls. + +- DuckDB + + Polars and DuckDB have many similarities. DuckDB is focused on providing an in-process OLAP Sqlite alternative, + Polars is focused on providing a scalable `DataFrame` interface to many languages. Those different front-ends lead to + different optimization strategies and different algorithm prioritization. The interoperability between both is zero-copy. + See more: https://duckdb.org/docs/guides/python/polars + +- Spark + + Spark is designed for distributed workloads and uses the JVM. The setup for spark is complicated and the startup-time + is slow. On a single machine Polars has much better performance characteristics. If you need to process TB's of data + Spark is a better choice. + +- CuDF + + GPU's and CuDF are fast! + However, GPU's are not readily available and expensive in production. The amount of memory available on a GPU + is often a fraction of the available RAM. + This (and out-of-core) processing means that Polars can handle much larger data-sets. + Next to that Polars can be close in [performance to CuDF](https://zakopilo.hatenablog.jp/entry/2023/02/04/220552). + CuDF doesn't optimize your query, so is not uncommon that on ETL jobs Polars will be faster because it can elide + unneeded work and materializations. + +- Any + + Polars is written in Rust. This gives it strong safety, performance and concurrency guarantees. + Polars is written in a modular manner. Parts of Polars can be used in other query programs and can be added as a library. diff --git a/docs/user-guide/misc/contributing.md b/docs/user-guide/misc/contributing.md new file mode 100644 index 000000000000..abd4d4d229be --- /dev/null +++ b/docs/user-guide/misc/contributing.md @@ -0,0 +1,11 @@ +# Contributing + +See the [`CONTRIBUTING.md`](https://github.com/pola-rs/polars/blob/master/CONTRIBUTING.md) if you would like to contribute to the `Polars` project. + +If you're new to this we recommend starting out with contributing examples to the Python API documentation. The Python API docs are generated from the docstrings of the Python wrapper located in `polars/py-polars`. + +Here is an example [commit](https://github.com/pola-rs/polars/pull/3567/commits/5db9e335f3f2777dd1d6f80df765c6bca8f307b0) that adds a docstring. + +If you spot any gaps in this User Guide you can submit fixes to the [`pola-rs/polars`](https://github.com/pola-rs/polars) repo. + +Happy hunting! diff --git a/docs/user-guide/misc/multiprocessing.md b/docs/user-guide/misc/multiprocessing.md new file mode 100644 index 000000000000..4973da8c0155 --- /dev/null +++ b/docs/user-guide/misc/multiprocessing.md @@ -0,0 +1,104 @@ +# Multiprocessing + +TLDR: if you find that using Python's built-in `multiprocessing` module together with Polars results in a Polars error about multiprocessing methods, you should make sure you are using `spawn`, not `fork`, as the starting method: + +{{code_block('user-guide/misc/multiprocess','recommendation',[])}} + +## When not to use multiprocessing + +Before we dive into the details, it is important to emphasize that Polars has been built from the start to use all your CPU cores. +It does this by executing computations which can be done in parallel in separate threads. +For example, requesting two expressions in a `select` statement can be done in parallel, with the results only being combined at the end. +Another example is aggregating a value within groups using `group_by().agg()`, each group can be evaluated separately. +It is very unlikely that the `multiprocessing` module can improve your code performance in these cases. + +See [the optimizations section](../lazy/optimizations.md) for more optimizations. + +## When to use multiprocessing + +Although Polars is multithreaded, other libraries may be single-threaded. +When the other library is the bottleneck, and the problem at hand is parallelizable, it makes sense to use multiprocessing to gain a speed up. + +## The problem with the default multiprocessing config + +### Summary + +The [Python multiprocessing documentation](https://docs.python.org/3/library/multiprocessing.html) lists the three methods to create a process pool: + +1. spawn +1. fork +1. forkserver + +The description of fork is (as of 2022-10-15): + +> The parent process uses os.fork() to fork the Python interpreter. The child process, when it begins, is effectively identical to the parent process. All resources of the parent are inherited by the child process. Note that safely forking a multithreaded process is problematic. + +> Available on Unix only. The default on Unix. + +The short summary is: Polars is multithreaded as to provide strong performance out-of-the-box. +Thus, it cannot be combined with `fork`. +If you are on Unix (Linux, BSD, etc), you are using `fork`, unless you explicitly override it. + +The reason you may not have encountered this before is that pure Python code, and most Python libraries, are (mostly) single threaded. +Alternatively, you are on Windows or MacOS, on which `fork` is not even available as a method (for MacOS it was up to Python 3.7). + +Thus one should use `spawn`, or `forkserver`, instead. `spawn` is available on all platforms and the safest choice, and hence the recommended method. + +### Example + +The problem with `fork` is in the copying of the parent's process. +Consider the example below, which is a slightly modified example posted on the [Polars issue tracker](https://github.com/pola-rs/polars/issues/3144): + +{{code_block('user-guide/misc/multiprocess','example1',[])}} + +Using `fork` as the method, instead of `spawn`, will cause a dead lock. +Please note: Polars will not even start and raise the error on multiprocessing method being set wrong, but if the check had not been there, the deadlock would exist. + +The fork method is equivalent to calling `os.fork()`, which is a system call as defined in [the POSIX standard](https://pubs.opengroup.org/onlinepubs/9699919799/functions/fork.html): + +> A process shall be created with a single thread. If a multi-threaded process calls fork(), the new process shall contain a replica of the calling thread and its entire address space, possibly including the states of mutexes and other resources. Consequently, to avoid errors, the child process may only execute async-signal-safe operations until such time as one of the exec functions is called. + +In contrast, `spawn` will create a completely new fresh Python interpreter, and not inherit the state of mutexes. + +So what happens in the code example? +For reading the file with `pl.read_parquet` the file has to be locked. +Then `os.fork()` is called, copying the state of the parent process, including mutexes. +Thus all child processes will copy the file lock in an acquired state, leaving them hanging indefinitely waiting for the file lock to be released, which never happens. + +What makes debugging these issues tricky is that `fork` can work. +Change the example to not having the call to `pl.read_parquet`: + +{{code_block('user-guide/misc/multiprocess','example2',[])}} + +This works fine. +Therefore debugging these issues in larger code bases, i.e. not the small toy examples here, can be a real pain, as a seemingly unrelated change can break your multiprocessing code. +In general, one should therefore never use the `fork` start method with multithreaded libraries unless there are very specific requirements that cannot be met otherwise. + +### Pro's and cons of fork + +Based on the example, you may think, why is `fork` available in Python to start with? + +First, probably because of historical reasons: `spawn` was added to Python in version 3.4, whilst `fork` has been part of Python from the 2.x series. + +Second, there are several limitations for `spawn` and `forkserver` that do not apply to `fork`, in particular all arguments should be pickable. +See the [Python multiprocessing docs](https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods) for more information. + +Third, because it is faster to create new processes compared to `spawn`, as `spawn` is effectively `fork` + creating a brand new Python process without the locks by calling [execv](https://pubs.opengroup.org/onlinepubs/9699919799/functions/exec.html). +Hence the warning in the Python docs that it is slower: there is more overhead to `spawn`. +However, in almost all cases, one would like to use multiple processes to speed up computations that take multiple minutes or even hours, meaning the overhead is negligible in the grand scheme of things. +And more importantly, it actually works in combination with multithreaded libraries. + +Fourth, `spawn` starts a new process, and therefore it requires code to be importable, in contrast to `fork`. +In particular, this means that when using `spawn` the relevant code should not be in the global scope, such as in Jupyter notebooks or in plain scripts. +Hence in the examples above, we define functions where we spawn within, and run those functions from a `__main__` clause. +This is not an issue for typical projects, but during quick experimentation in notebooks it could fail. + +## References + +1. https://docs.python.org/3/library/multiprocessing.html + +1. https://pythonspeed.com/articles/python-multiprocessing/ + +1. https://pubs.opengroup.org/onlinepubs/9699919799/functions/fork.html + +1. https://bnikolic.co.uk/blog/python/parallelism/2019/11/13/python-forkserver-preload.html diff --git a/docs/user-guide/misc/reference-guides.md b/docs/user-guide/misc/reference-guides.md new file mode 100644 index 000000000000..c0e082d08447 --- /dev/null +++ b/docs/user-guide/misc/reference-guides.md @@ -0,0 +1,6 @@ +# Reference guides + +The api documentations with details on function / object signatures can be found here: + +- [Python](https://pola-rs.github.io/polars/py-polars/html/reference/index.html) +- [Rust](https://docs.rs/polars/latest/polars/) diff --git a/docs/user-guide/sql/create.md b/docs/user-guide/sql/create.md new file mode 100644 index 000000000000..a5a1922b7f23 --- /dev/null +++ b/docs/user-guide/sql/create.md @@ -0,0 +1,28 @@ +# CREATE + +In Polars, the `SQLContext` provides a way to execute SQL statements against `LazyFrames` and `DataFrames` using SQL syntax. One of the SQL statements that can be executed using `SQLContext` is the `CREATE TABLE` statement, which is used to create a new table. + +The syntax for the `CREATE TABLE` statement in Polars is as follows: + +``` +CREATE TABLE table_name +AS +SELECT ... +``` + +In this syntax, `table_name` is the name of the new table that will be created, and `SELECT ...` is a SELECT statement that defines the data that will be inserted into the table. + +Here's an example of how to use the `CREATE TABLE` statement in Polars: + +{{code_block('user-guide/sql/create','create',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/create.py:setup" +--8<-- "python/user-guide/sql/create.py:create" +``` + +In this example, we use the `execute()` method of the `SQLContext` to execute a `CREATE TABLE` statement that creates a new table called `older_people` based on a SELECT statement that selects all rows from the `my_table` DataFrame where the `age` column is greater than 30. + +!!! note Result + + Note that the result of a `CREATE TABLE` statement is not the table itself. The table is registered in the `SQLContext`. In case you want to turn the table back to a `DataFrame` you can use a `SELECT * FROM ...` statement diff --git a/docs/user-guide/sql/cte.md b/docs/user-guide/sql/cte.md new file mode 100644 index 000000000000..1129f6d19230 --- /dev/null +++ b/docs/user-guide/sql/cte.md @@ -0,0 +1,27 @@ +# Common Table Expressions + +Common Table Expressions (CTEs) are a feature of SQL that allow you to define a temporary named result set that can be referenced within a SQL statement. CTEs provide a way to break down complex SQL queries into smaller, more manageable pieces, making them easier to read, write, and maintain. + +A CTE is defined using the `WITH` keyword followed by a comma-separated list of subqueries, each of which defines a named result set that can be used in subsequent queries. The syntax for a CTE is as follows: + +``` +WITH cte_name AS ( + subquery +) +SELECT ... +``` + +In this syntax, `cte_name` is the name of the CTE, and `subquery` is the subquery that defines the result set. The CTE can then be referenced in subsequent queries as if it were a table or view. + +CTEs are particularly useful when working with complex queries that involve multiple levels of subqueries, as they allow you to break down the query into smaller, more manageable pieces that are easier to understand and debug. Additionally, CTEs can help improve query performance by allowing the database to optimize and cache the results of subqueries, reducing the number of times they need to be executed. + +Polars supports Common Table Expressions (CTEs) using the WITH clause in SQL syntax. Below is an example + +{{code_block('user-guide/sql/cte','cte',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/cte" +--8<-- "python/user-guide/sql/cte.py:setup" +--8<-- "python/user-guide/sql/cte.py:cte" +``` + +In this example, we use the `execute()` method of the `SQLContext` to execute a SQL query that includes a CTE. The CTE selects all rows from the `my_table` LazyFrame where the `age` column is greater than 30 and gives it the alias `older_people`. We then execute a second SQL query that selects all rows from the `older_people` CTE where the `name` column starts with the letter 'C'. diff --git a/docs/user-guide/sql/intro.md b/docs/user-guide/sql/intro.md new file mode 100644 index 000000000000..815231e3d59c --- /dev/null +++ b/docs/user-guide/sql/intro.md @@ -0,0 +1,106 @@ +# Introduction + +While Polars does support writing queries in SQL, it's recommended that users familiarize themselves with the [expression syntax](../concepts/expressions.md) for more readable and expressive code. As a primarily DataFrame library, new features will typically be added to the expression API first. However, if you already have an existing SQL codebase or prefer to use SQL, Polars also offers support for SQL queries. + +!!! note Execution + + In Polars, there is no separate SQL engine because Polars translates SQL queries into [expressions](../concepts/expressions.md), which are then executed using its built-in execution engine. This approach ensures that Polars maintains its performance and scalability advantages as a native DataFrame library while still providing users with the ability to work with SQL queries. + +## Context + +Polars uses the `SQLContext` to manage SQL queries . The context contains a dictionary mapping `DataFrames` and `LazyFrames` names to their corresponding datasets[^1]. The example below starts a `SQLContext`: + +{{code_block('user-guide/sql/intro','context',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:setup" +--8<-- "python/user-guide/sql/intro.py:context" +``` + +## Register Dataframes + +There are 2 ways to register DataFrames in the `SQLContext`: + +- register all `LazyFrames` and `DataFrames` in the global namespace +- register them one by one + +{{code_block('user-guide/sql/intro','register_context',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:register_context" +``` + +We can also register Pandas DataFrames by converting them to Polars first. + +{{code_block('user-guide/sql/intro','register_pandas',['SQLContext'])}} + +```python exec="on" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:register_pandas" +``` + +!!! note Pandas + + Converting a Pandas DataFrame backed by Numpy to Polars triggers a conversion to the Arrow format. This conversion has a computation cost. Converting a Pandas DataFrame backed by Arrow on the other hand will be free or almost free. + +Once the `SQLContext` is initialized, we can register additional Dataframes or unregister existing Dataframes with: + +- `register` +- `register_globals` +- `register_many` +- `unregister` + +## Execute queries and collect results + +SQL queries are always executed in lazy mode to benefit from lazy optimizations, so we have 2 options to collect the result: + +- Set the parameter `eager_execution` to True in `SQLContext`. With this parameter, Polars will automatically collect SQL results +- Set the parameter `eager` to True when executing a query with `execute`, or collect the result with `collect`. + +We execute SQL queries by calling `execute` on a `SQLContext`. + +{{code_block('user-guide/sql/intro','execute',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:execute" +``` + +## Execute queries from multiple sources + +SQL queries can be executed just as easily from multiple sources. +In the example below, we register : + +- a CSV file loaded lazily +- a NDJSON file loaded lazily +- a Pandas DataFrame + +And we join them together with SQL. +Lazy reading allows to only load the necessary rows and columns from the files. + +In the same way, it's possible to register cloud datalakes (S3, Azure Data Lake). A PyArrow dataset can point to the datalake, then Polars can read it with `scan_pyarrow_dataset`. + +{{code_block('user-guide/sql/intro','execute_multiple_sources',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql" +--8<-- "python/user-guide/sql/intro.py:prepare_multiple_sources" +--8<-- "python/user-guide/sql/intro.py:execute_multiple_sources" +--8<-- "python/user-guide/sql/intro.py:clean_multiple_sources" +``` + +[^1]: Additionally it also tracks the [common table expressions](./cte.md) as well. + +## Compatibility + +Polars does not support the full SQL language, in Polars you are allowed to: + +- Write a `CREATE` statements `CREATE TABLE xxx AS ...` +- Write a `SELECT` statements with all generic elements (`GROUP BY`, `WHERE`,`ORDER`,`LIMIT`,`JOIN`, ...) +- Write Common Table Expressions (CTE's) (`WITH tablename AS`) +- Show an overview of all tables `SHOW TABLES` + +The following is not yet supported: + +- `INSERT`, `UPDATE` or `DELETE` statements +- Table aliasing (e.g. `SELECT p.Name from pokemon AS p`) +- Meta queries such as `ANALYZE`, `EXPLAIN` + +In the upcoming sections we will cover each of the statements in more details. diff --git a/docs/user-guide/sql/select.md b/docs/user-guide/sql/select.md new file mode 100644 index 000000000000..1c643895dec7 --- /dev/null +++ b/docs/user-guide/sql/select.md @@ -0,0 +1,72 @@ +# SELECT + +In Polars SQL, the `SELECT` statement is used to retrieve data from a table into a `DataFrame`. The basic syntax of a `SELECT` statement in Polars SQL is as follows: + +```sql +SELECT column1, column2, ... +FROM table_name; +``` + +Here, `column1`, `column2`, etc. are the columns that you want to select from the table. You can also use the wildcard `*` to select all columns. `table_name` is the name of the table or that you want to retrieve data from. In the sections below we will cover some of the more common SELECT variants + +{{code_block('user-guide/sql/sql_select','df',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/sql_select.py:setup" +--8<-- "python/user-guide/sql/sql_select.py:df" +``` + +### GROUP BY + +The `GROUP BY` statement is used to group rows in a table by one or more columns and compute aggregate functions on each group. + +{{code_block('user-guide/sql/sql_select','group_by',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/sql_select.py:group_by" +``` + +### ORDER BY + +The `ORDER BY` statement is used to sort the result set of a query by one or more columns in ascending or descending order. + +{{code_block('user-guide/sql/sql_select','orderby',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/sql_select.py:orderby" +``` + +### JOIN + +{{code_block('user-guide/sql/sql_select','join',['SQLregister_many','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/sql_select.py:join" +``` + +### Functions + +Polars provides a wide range of SQL functions, including: + +- Mathematical functions: `ABS`, `EXP`, `LOG`, `ASIN`, `ACOS`, `ATAN`, etc. +- String functions: `LOWER`, `UPPER`, `LTRIM`, `RTRIM`, `STARTS_WITH`,`ENDS_WITH`. +- Aggregation functions: `SUM`, `AVG`, `MIN`, `MAX`, `COUNT`, `STDDEV`, `FIRST` etc. +- Array functions: `EXPLODE`, `UNNEST`,`ARRAY_SUM`,`ARRAY_REVERSE`, etc. + +For a full list of supported functions go the [API documentation](https://docs.rs/polars-sql/latest/src/polars_sql/keywords.rs.html). The example below demonstrates how to use a function in a query + +{{code_block('user-guide/sql/sql_select','functions',['SQLquery'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/sql_select.py:functions" +``` + +### Table Functions + +In the examples earlier we first generated a DataFrame which we registered in the `SQLContext`. Polars also support directly reading from CSV, Parquet, JSON and IPC in your SQL query using table functions `read_xxx`. + +{{code_block('user-guide/sql/sql_select','tablefunctions',['SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/select" +--8<-- "python/user-guide/sql/sql_select.py:tablefunctions" +``` diff --git a/docs/user-guide/sql/show.md b/docs/user-guide/sql/show.md new file mode 100644 index 000000000000..70453ebcb6dd --- /dev/null +++ b/docs/user-guide/sql/show.md @@ -0,0 +1,22 @@ +# SHOW TABLES + +In Polars, the `SHOW TABLES` statement is used to list all the tables that have been registered in the current `SQLContext`. When you register a DataFrame with the `SQLContext`, you give it a name that can be used to refer to the DataFrame in subsequent SQL statements. The `SHOW TABLES` statement allows you to see a list of all the registered tables, along with their names. + +The syntax for the `SHOW TABLES` statement in Polars is as follows: + +``` +SHOW TABLES +``` + +Here's an example of how to use the `SHOW TABLES` statement in Polars: + +{{code_block('user-guide/sql/show','show',['SQLregister','SQLexecute'])}} + +```python exec="on" result="text" session="user-guide/sql/show" +--8<-- "python/user-guide/sql/show.py:setup" +--8<-- "python/user-guide/sql/show.py:show" +``` + +In this example, we create two DataFrames and register them with the `SQLContext` using different names. We then execute a `SHOW TABLES` statement using the `execute()` method of the `SQLContext` object, which returns a DataFrame containing a list of all the registered tables and their names. The resulting DataFrame is then printed using the `print()` function. + +Note that the `SHOW TABLES` statement only lists tables that have been registered with the current `SQLContext`. If you register a DataFrame with a different `SQLContext` or in a different Python session, it will not appear in the list of tables returned by `SHOW TABLES`. diff --git a/docs/user-guide/transformations/concatenation.md b/docs/user-guide/transformations/concatenation.md new file mode 100644 index 000000000000..8deff923acee --- /dev/null +++ b/docs/user-guide/transformations/concatenation.md @@ -0,0 +1,51 @@ +# Concatenation + +There are a number of ways to concatenate data from separate DataFrames: + +- two dataframes with **the same columns** can be **vertically** concatenated to make a **longer** dataframe +- two dataframes with the **same number of rows** and **non-overlapping columns** can be **horizontally** concatenated to make a **wider** dataframe +- two dataframes with **different numbers of rows and columns** can be **diagonally** concatenated to make a dataframe which might be longer and/ or wider. Where column names overlap values will be vertically concatenated. Where column names do not overlap new rows and columns will be added. Missing values will be set as `null` + +## Vertical concatenation - getting longer + +In a vertical concatenation you combine all of the rows from a list of `DataFrames` into a single longer `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','vertical',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:setup" +--8<-- "python/user-guide/transformations/concatenation.py:vertical" +``` + +Vertical concatenation fails when the dataframes do not have the same column names. + +## Horizontal concatenation - getting wider + +In a horizontal concatenation you combine all of the columns from a list of `DataFrames` into a single wider `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','horizontal',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:horizontal" +``` + +Horizontal concatenation fails when dataframes have overlapping columns or a different number of rows. + +## Diagonal concatenation - getting longer, wider and `null`ier + +In a diagonal concatenation you combine all of the row and columns from a list of `DataFrames` into a single longer and/or wider `DataFrame`. + +{{code_block('user-guide/transformations/concatenation','cross',['concat'])}} + +```python exec="on" result="text" session="user-guide/transformations/concatenation" +--8<-- "python/user-guide/transformations/concatenation.py:cross" +``` + +Diagonal concatenation generates nulls when the column names do not overlap. + +When the dataframe shapes do not match and we have an overlapping semantic key then [we can join the dataframes](joins.md) instead of concatenating them. + +## Rechunking + +Before a concatenation we have two dataframes `df1` and `df2`. Each column in `df1` and `df2` is in one or more chunks in memory. By default, during concatenation the chunks in each column are copied to a single new chunk - this is known as **rechunking**. Rechunking is an expensive operation, but is often worth it because future operations will be faster. +If you do not want Polars to rechunk the concatenated `DataFrame` you specify `rechunk = False` when doing the concatenation. diff --git a/docs/user-guide/transformations/joins.md b/docs/user-guide/transformations/joins.md new file mode 100644 index 000000000000..ad233cf060fb --- /dev/null +++ b/docs/user-guide/transformations/joins.md @@ -0,0 +1,183 @@ +# Joins + +## Join strategies + +`Polars` supports the following join strategies by specifying the `strategy` argument: + +| Strategy | Description | +| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `inner` | Returns row with matching keys in _both_ frames. Non-matching rows in either the left or right frame are discarded. | +| `left` | Returns all rows in the left dataframe, whether or not a match in the right-frame is found. Non-matching rows have their right columns null-filled. | +| `outer` | Returns all rows from both the left and right dataframe. If no match is found in one frame, columns from the other frame are null-filled. | +| `cross` | Returns the Cartesian product of all rows from the left frame with all rows from the right frame. Duplicates rows are retained; the table length of `A` cross-joined with `B` is always `len(A) × len(B)`. | +| `asof` | A left-join in which the match is performed on the _nearest_ key rather than on equal keys. | +| `semi` | Returns all rows from the left frame in which the join key is also present in the right frame. | +| `anti` | Returns all rows from the left frame in which the join key is _not_ present in the right frame. | + +### Inner join + +An `inner` join produces a `DataFrame` that contains only the rows where the join key exists in both `DataFrames`. Let's take for example the following two `DataFrames`: + +{{code_block('user-guide/transformations/joins','innerdf',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:setup" +--8<-- "python/user-guide/transformations/joins.py:innerdf" +``` + +

+ +{{code_block('user-guide/transformations/joins','innerdf2',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:innerdf2" +``` + +To get a `DataFrame` with the orders and their associated customer we can do an `inner` join on the `customer_id` column: + +{{code_block('user-guide/transformations/joins','inner',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:inner" +``` + +### Left join + +The `left` join produces a `DataFrame` that contains all the rows from the left `DataFrame` and only the rows from the right `DataFrame` where the join key exists in the left `DataFrame`. If we now take the example from above and want to have a `DataFrame` with all the customers and their associated orders (regardless of whether they have placed an order or not) we can do a `left` join: + +{{code_block('user-guide/transformations/joins','left',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:left" +``` + +Notice, that the fields for the customer with the `customer_id` of `3` are null, as there are no orders for this customer. + +### Outer join + +The `outer` join produces a `DataFrame` that contains all the rows from both `DataFrames`. Columns are null, if the join key does not exist in the source `DataFrame`. Doing an `outer` join on the two `DataFrames` from above produces a similar `DataFrame` to the `left` join: + +{{code_block('user-guide/transformations/joins','outer',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:outer" +``` + +### Cross join + +A `cross` join is a cartesian product of the two `DataFrames`. This means that every row in the left `DataFrame` is joined with every row in the right `DataFrame`. The `cross` join is useful for creating a `DataFrame` with all possible combinations of the columns in two `DataFrames`. Let's take for example the following two `DataFrames`. + +{{code_block('user-guide/transformations/joins','df3',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df3" +``` + +

+ +{{code_block('user-guide/transformations/joins','df4',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df4" +``` + +We can now create a `DataFrame` containing all possible combinations of the colors and sizes with a `cross` join: + +{{code_block('user-guide/transformations/joins','cross',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:cross" +``` + +
+ +The `inner`, `left`, `outer` and `cross` join strategies are standard amongst dataframe libraries. We provide more details on the less familiar `semi`, `anti` and `asof` join strategies below. + +### Semi join + +The `semi` join returns all rows from the left frame in which the join key is also present in the right frame. Consider the following scenario: a car rental company has a `DataFrame` showing the cars that it owns with each car having a unique `id`. + +{{code_block('user-guide/transformations/joins','df5',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df5" +``` + +The company has another `DataFrame` showing each repair job carried out on a vehicle. + +{{code_block('user-guide/transformations/joins','df6',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df6" +``` + +You want to answer this question: which of the cars have had repairs carried out? + +An inner join does not answer this question directly as it produces a `DataFrame` with multiple rows for each car that has had multiple repair jobs: + +{{code_block('user-guide/transformations/joins','inner2',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:inner2" +``` + +However, a semi join produces a single row for each car that has had a repair job carried out. + +{{code_block('user-guide/transformations/joins','semi',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:semi" +``` + +### Anti join + +Continuing this example, an alternative question might be: which of the cars have **not** had a repair job carried out? An anti join produces a `DataFrame` showing all the cars from `df_cars` where the `id` is not present in the `df_repairs` `DataFrame`. + +{{code_block('user-guide/transformations/joins','anti',['join'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:anti" +``` + +### Asof join + +An `asof` join is like a left join except that we match on nearest key rather than equal keys. +In `Polars` we can do an asof join with the `join` method and specifying `strategy="asof"`. However, for more flexibility we can use the `join_asof` method. + +Consider the following scenario: a stock market broker has a `DataFrame` called `df_trades` showing transactions it has made for different stocks. + +{{code_block('user-guide/transformations/joins','df7',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df7" +``` + +The broker has another `DataFrame` called `df_quotes` showing prices it has quoted for these stocks. + +{{code_block('user-guide/transformations/joins','df8',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:df8" +``` + +You want to produce a `DataFrame` showing for each trade the most recent quote provided _before_ the trade. You do this with `join_asof` (using the default `strategy = "backward"`). +To avoid joining between trades on one stock with a quote on another you must specify an exact preliminary join on the stock column with `by="stock"`. + +{{code_block('user-guide/transformations/joins','asof',['join_asof'])}} + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:asofpre" +--8<-- "python/user-guide/transformations/joins.py:asof" +``` + +If you want to make sure that only quotes within a certain time range are joined to the trades you can specify the `tolerance` argument. In this case we want to make sure that the last preceding quote is within 1 minute of the trade so we set `tolerance = "1m"`. + +=== ":fontawesome-brands-python: Python" + +```python +--8<-- "python/user-guide/transformations/joins.py:asof2" +``` + +```python exec="on" result="text" session="user-guide/transformations/joins" +--8<-- "python/user-guide/transformations/joins.py:asof2" +``` diff --git a/docs/user-guide/transformations/melt.md b/docs/user-guide/transformations/melt.md new file mode 100644 index 000000000000..3e6efe35723e --- /dev/null +++ b/docs/user-guide/transformations/melt.md @@ -0,0 +1,21 @@ +# Melts + +Melt operations unpivot a DataFrame from wide format to long format + +## Dataset + +{{code_block('user-guide/transformations/melt','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/melt" +--8<-- "python/user-guide/transformations/melt.py:df" +``` + +## Eager + lazy + +`Eager` and `lazy` have the same API. + +{{code_block('user-guide/transformations/melt','melt',['melt'])}} + +```python exec="on" result="text" session="user-guide/transformations/melt" +--8<-- "python/user-guide/transformations/melt.py:melt" +``` diff --git a/docs/user-guide/transformations/pivot.md b/docs/user-guide/transformations/pivot.md new file mode 100644 index 000000000000..9850dbed0330 --- /dev/null +++ b/docs/user-guide/transformations/pivot.md @@ -0,0 +1,46 @@ +# Pivots + +Pivot a column in a `DataFrame` and perform one of the following aggregations: + +- first +- sum +- min +- max +- mean +- median + +The pivot operation consists of a group by one, or multiple columns (these will be the +new y-axis), the column that will be pivoted (this will be the new x-axis) and an +aggregation. + +## Dataset + +{{code_block('user-guide/transformations/pivot','df',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:setup" +--8<-- "python/user-guide/transformations/pivot.py:df" +``` + +## Eager + +{{code_block('user-guide/transformations/pivot','eager',['pivot'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:eager" +``` + +## Lazy + +A polars `LazyFrame` always need to know the schema of a computation statically (before collecting the query). +As a pivot's output schema depends on the data, and it is therefore impossible to determine the schema without +running the query. + +Polars could have abstracted this fact for you just like Spark does, but we don't want you to shoot yourself in the foot +with a shotgun. The cost should be clear upfront. + +{{code_block('user-guide/transformations/pivot','lazy',['pivot'])}} + +```python exec="on" result="text" session="user-guide/transformations/pivot" +--8<-- "python/user-guide/transformations/pivot.py:lazy" +``` diff --git a/docs/user-guide/transformations/time-series/filter.md b/docs/user-guide/transformations/time-series/filter.md new file mode 100644 index 000000000000..326969c34e11 --- /dev/null +++ b/docs/user-guide/transformations/time-series/filter.md @@ -0,0 +1,48 @@ +# Filtering + +Filtering date columns works in the same way as with other types of columns using the `.filter` method. + +Polars uses Python's native `datetime`, `date` and `timedelta` for equality comparisons between the datatypes `pl.Datetime`, `pl.Date` and `pl.Duration`. + +In the following example we use a time series of Apple stock prices. + +{{code_block('user-guide/transformations/time-series/filter','df',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:df" +``` + +## Filtering by single dates + +We can filter by a single date by casting the desired date string to a `Date` object +in a filter expression: + +{{code_block('user-guide/transformations/time-series/filter','filter',['filter'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:filter" +``` + +Note we are using the lowercase `datetime` method rather than the uppercase `Datetime` data type. + +## Filtering by a date range + +We can filter by a range of dates using the `is_between` method in a filter expression with the start and end dates: + +{{code_block('user-guide/transformations/time-series/filter','range',['filter','is_between'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:range" +``` + +## Filtering with negative dates + +Say you are working with an archeologist and are dealing in negative dates. +Polars can parse and store them just fine, but the Python `datetime` library +does not. So for filtering, you should use attributes in the `.dt` namespace: + +{{code_block('user-guide/transformations/time-series/filter','negative',['strptime'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/filter" +--8<-- "python/user-guide/transformations/time-series/filter.py:negative" +``` diff --git a/docs/user-guide/transformations/time-series/parsing.md b/docs/user-guide/transformations/time-series/parsing.md new file mode 100644 index 000000000000..a31095d07434 --- /dev/null +++ b/docs/user-guide/transformations/time-series/parsing.md @@ -0,0 +1,58 @@ +# Parsing + +Polars has native support for parsing time series data and doing more sophisticated operations such as temporal grouping and resampling. + +## Datatypes + +`Polars` has the following datetime datatypes: + +- `Date`: Date representation e.g. 2014-07-08. It is internally represented as days since UNIX epoch encoded by a 32-bit signed integer. +- `Datetime`: Datetime representation e.g. 2014-07-08 07:00:00. It is internally represented as a 64 bit integer since the Unix epoch and can have different units such as ns, us, ms. +- `Duration`: A time delta type that is created when subtracting `Date/Datetime`. Similar to `timedelta` in python. +- `Time`: Time representation, internally represented as nanoseconds since midnight. + +## Parsing dates from a file + +When loading from a CSV file `Polars` attempts to parse dates and times if the `try_parse_dates` flag is set to `True`: + +{{code_block('user-guide/transformations/time-series/parsing','df',['read_csv'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:setup" +--8<-- "python/user-guide/transformations/time-series/parsing.py:df" +``` + +On the other hand binary formats such as parquet have a schema that is respected by `Polars`. + +## Casting strings to dates + +You can also cast a column of datetimes encoded as strings to a datetime type. You do this by calling the string `str.strptime` method and passing the format of the date string: + +{{code_block('user-guide/transformations/time-series/parsing','cast',['read_csv','strptime'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:cast" +``` + +[The strptime date formats can be found here.](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). + +## Extracting date features from a date column + +You can extract data features such as the year or day from a date column using the `.dt` namespace on a date column: + +{{code_block('user-guide/transformations/time-series/parsing','extract',['year'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:extract" +``` + +## Mixed offsets + +If you have mixed offsets (say, due to crossing daylight saving time), +then you can use `utc=True` and then convert to your time zone: + +{{code_block('user-guide/transformations/time-series/parsing','mixed',['strptime','convert_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/parsing" +--8<-- "python/user-guide/transformations/time-series/parsing.py:mixed" +``` diff --git a/docs/user-guide/transformations/time-series/resampling.md b/docs/user-guide/transformations/time-series/resampling.md new file mode 100644 index 000000000000..63ad583a9bec --- /dev/null +++ b/docs/user-guide/transformations/time-series/resampling.md @@ -0,0 +1,42 @@ +# Resampling + +We can resample by either: + +- upsampling (moving data to a higher frequency) +- downsampling (moving data to a lower frequency) +- combinations of these e.g. first upsample and then downsample + +## Downsampling to a lower frequency + +`Polars` views downsampling as a special case of the **group_by** operation and you can do this with `group_by_dynamic` and `group_by_rolling` - [see the temporal group by page for examples](rolling.md). + +## Upsampling to a higher frequency + +Let's go through an example where we generate data at 30 minute intervals: + +{{code_block('user-guide/transformations/time-series/resampling','df',['DataFrame','date_range'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:setup" +--8<-- "python/user-guide/transformations/time-series/resampling.py:df" +``` + +Upsampling can be done by defining the new sampling interval. By upsampling we are adding in extra rows where we do not have data. As such upsampling by itself gives a DataFrame with nulls. These nulls can then be filled with a fill strategy or interpolation. + +### Upsampling strategies + +In this example we upsample from the original 30 minutes to 15 minutes and then use a `forward` strategy to replace the nulls with the previous non-null value: + +{{code_block('user-guide/transformations/time-series/resampling','upsample',['upsample'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:upsample" +``` + +In this example we instead fill the nulls by linear interpolation: + +{{code_block('user-guide/transformations/time-series/resampling','upsample2',['upsample','interpolate','fill_null'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/resampling" +--8<-- "python/user-guide/transformations/time-series/resampling.py:upsample2" +``` diff --git a/docs/user-guide/transformations/time-series/rolling.md b/docs/user-guide/transformations/time-series/rolling.md new file mode 100644 index 000000000000..a88373caada2 --- /dev/null +++ b/docs/user-guide/transformations/time-series/rolling.md @@ -0,0 +1,148 @@ +# Grouping + +## Grouping by fixed windows + +We can calculate temporal statistics using `group_by_dynamic` to group rows into days/months/years etc. + +### Annual average example + +In following simple example we calculate the annual average closing price of Apple stock prices. We first load the data from CSV: + +{{code_block('user-guide/transformations/time-series/rolling','df',['upsample'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:setup" +--8<-- "python/user-guide/transformations/time-series/rolling.py:df" +``` + +!!! info + + The dates are sorted in ascending order - if they are not sorted in this way the `group_by_dynamic` output will not be correct! + +To get the annual average closing price we tell `group_by_dynamic` that we want to: + +- group by the `Date` column on an annual (`1y`) basis +- take the mean values of the `Close` column for each year: + +{{code_block('user-guide/transformations/time-series/rolling','group_by',['group_by_dynamic'])}} + +The annual average closing price is then: + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by" +``` + +### Parameters for `group_by_dynamic` + +A dynamic window is defined by a: + +- **every**: indicates the interval of the window +- **period**: indicates the duration of the window +- **offset**: can be used to offset the start of the windows + +The value for `every` sets how often the groups start. The time period values are flexible - for example we could take: + +- the average over 2 year intervals by replacing `1y` with `2y` +- the average over 18 month periods by replacing `1y` with `1y6mo` + +We can also use the `period` parameter to set how long the time period for each group is. For example, if we set the `every` parameter to be `1y` and the `period` parameter to be `2y` then we would get groups at one year intervals where each groups spanned two years. + +If the `period` parameter is not specified then it is set equal to the `every` parameter so that if the `every` parameter is set to be `1y` then each group spans `1y` as well. + +Because _**every**_ does not have to be equal to _**period**_, we can create many groups in a very flexible way. They may overlap +or leave boundaries between them. + +Let's see how the windows for some parameter combinations would look. Let's start out boring. 🥱 + +- every: 1 day -> `"1d"` +- period: 1 day -> `"1d"` + +```text +this creates adjacent windows of the same size +|--| + |--| + |--| +``` + +- every: 1 day -> `"1d"` +- period: 2 days -> `"2d"` + +```text +these windows have an overlap of 1 day +|----| + |----| + |----| +``` + +- every: 2 days -> `"2d"` +- period: 1 day -> `"1d"` + +```text +this would leave gaps between the windows +data points that in these gaps will not be a member of any group +|--| + |--| + |--| +``` + +#### `truncate` + +The `truncate` parameter is a Boolean variable that determines what datetime value is associated with each group in the output. In the example above the first data point is on 23rd February 1981. If `truncate = True` (the default) then the date for the first year in the annual average is 1st January 1981. However, if `truncate = False` then the date for the first year in the annual average is the date of the first data point on 23rd February 1981. Note that `truncate` only affects what's shown in the +`Date` column and does not affect the window boundaries. + +### Using expressions in `group_by_dynamic` + +We aren't restricted to using simple aggregations like `mean` in a group by operation - we can use the full range of expressions available in Polars. + +In the snippet below we create a `date range` with every **day** (`"1d"`) in 2021 and turn this into a `DataFrame`. + +Then in the `group_by_dynamic` we create dynamic windows that start every **month** (`"1mo"`) and have a window length of `1` month. The values that match these dynamic windows are then assigned to that group and can be aggregated with the powerful expression API. + +Below we show an example where we use **group_by_dynamic** to compute: + +- the number of days until the end of the month +- the number of days in a month + +{{code_block('user-guide/transformations/time-series/rolling','group_by_dyn',['group_by_dynamic','explode','date_range'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_dyn" +``` + +## Grouping by rolling windows + +The rolling group by, `group_by_rolling`, is another entrance to the `group_by` context. But different from the `group_by_dynamic` the windows are +not fixed by a parameter `every` and `period`. In a rolling group by, the windows are not fixed at all! They are determined +by the values in the `index_column`. + +So imagine having a time column with the values `{2021-01-06, 2021-01-10}` and a `period="5d"` this would create the following +windows: + +```text +2021-01-01 2021-01-06 + |----------| + + 2021-01-05 2021-01-10 + |----------| +``` + +Because the windows of a rolling group by are always determined by the values in the `DataFrame` column, the number of +groups is always equal to the original `DataFrame`. + +## Combining group by operations + +Rolling and dynamic group by operations can be combined with normal group by operations. + +Below is an example with a dynamic group by. + +{{code_block('user-guide/transformations/time-series/rolling','group_by_roll',['DataFrame'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_roll" +``` + +{{code_block('user-guide/transformations/time-series/rolling','group_by_dyn2',['group_by_dynamic'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/rolling" +--8<-- "python/user-guide/transformations/time-series/rolling.py:group_by_dyn2" +``` diff --git a/docs/user-guide/transformations/time-series/timezones.md b/docs/user-guide/transformations/time-series/timezones.md new file mode 100644 index 000000000000..48f6870e8b20 --- /dev/null +++ b/docs/user-guide/transformations/time-series/timezones.md @@ -0,0 +1,46 @@ +--- +hide: + - toc +--- + +# Time zones + +!!! quote "Tom Scott" + + You really should never, ever deal with time zones if you can help it. + +The `Datetime` datatype can have a time zone associated with it. +Examples of valid time zones are: + +- `None`: no time zone, also known as "time zone naive"; +- `UTC`: Coordinated Universal Time; +- `Asia/Kathmandu`: time zone in "area/location" format. + See the [list of tz database time zones](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) + to see what's available; +- `+01:00`: fixed offsets. May be useful when parsing, but you almost certainly want the "Area/Location" + format above instead as it will deal with irregularities such as DST (Daylight Saving Time) for you. + +Note that, because a `Datetime` can only have a single time zone, it is +impossible to have a column with multiple time zones. If you are parsing data +with multiple offsets, you may want to pass `utc=True` to convert +them all to a common time zone (`UTC`), see [parsing dates and times](parsing.md). + +The main methods for setting and converting between time zones are: + +- `dt.convert_time_zone`: convert from one time zone to another; +- `dt.replace_time_zone`: set/unset/change time zone; + +Let's look at some examples of common operations: + +{{code_block('user-guide/transformations/time-series/timezones','example',['strptime','replace_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/timezones" +--8<-- "python/user-guide/transformations/time-series/timezones.py:setup" +--8<-- "python/user-guide/transformations/time-series/timezones.py:example" +``` + +{{code_block('user-guide/transformations/time-series/timezones','example2',['convert_time_zone','replace_time_zone'])}} + +```python exec="on" result="text" session="user-guide/transformations/ts/timezones" +--8<-- "python/user-guide/transformations/time-series/timezones.py:example2" +``` diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000000..65e961b13225 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,163 @@ +# https://www.mkdocs.org/user-guide/configuration/ + +# Project information +site_name: Polars documentation +site_url: https://pola-rs.github.io/polars +repo_url: https://github.com/pola-rs/polars +repo_name: pola-rs/polars + +# Documentation layout +nav: + - Home: index.md + - Getting started: + - getting-started/intro.md + - getting-started/installation.md + - getting-started/series-dataframes.md + - getting-started/reading-writing.md + - getting-started/expressions.md + - getting-started/joins.md + - User guide: + - user-guide/index.md + - user-guide/installation.md + - Concepts: + - user-guide/concepts/data-types.md + - user-guide/concepts/data-structures.md + - user-guide/concepts/contexts.md + - user-guide/concepts/expressions.md + - user-guide/concepts/lazy-vs-eager.md + - user-guide/concepts/streaming.md + - Expressions: + - user-guide/expressions/operators.md + - user-guide/expressions/column-selections.md + - user-guide/expressions/functions.md + - user-guide/expressions/casting.md + - user-guide/expressions/strings.md + - user-guide/expressions/aggregation.md + - user-guide/expressions/null.md + - user-guide/expressions/window.md + - user-guide/expressions/folds.md + - user-guide/expressions/lists.md + - user-guide/expressions/user-defined-functions.md + - user-guide/expressions/structs.md + - user-guide/expressions/numpy.md + - Transformations: + - user-guide/transformations/joins.md + - user-guide/transformations/concatenation.md + - user-guide/transformations/pivot.md + - user-guide/transformations/melt.md + - Time series: + - user-guide/transformations/time-series/parsing.md + - user-guide/transformations/time-series/filter.md + - user-guide/transformations/time-series/rolling.md + - user-guide/transformations/time-series/resampling.md + - user-guide/transformations/time-series/timezones.md + - Lazy API: + - user-guide/lazy/using.md + - user-guide/lazy/optimizations.md + - user-guide/lazy/schemas.md + - user-guide/lazy/query_plan.md + - user-guide/lazy/execution.md + - user-guide/lazy/streaming.md + - IO: + - user-guide/io/csv.md + - user-guide/io/parquet.md + - user-guide/io/json_file.md + - user-guide/io/multiple.md + - user-guide/io/database.md + - user-guide/io/aws.md + - user-guide/io/bigquery.md + - SQL: + - user-guide/sql/intro.md + - user-guide/sql/show.md + - user-guide/sql/select.md + - user-guide/sql/create.md + - user-guide/sql/cte.md + - Migrating: + - user-guide/migration/pandas.md + - user-guide/migration/spark.md + - Misc: + - user-guide/misc/multiprocessing.md + - user-guide/misc/alternatives.md + - user-guide/misc/reference-guides.md + - user-guide/misc/contributing.md +not_in_nav: | + /_build/ + people.md +validation: + links: + # Allow an absolute link to the features page for our code snippets + absolute_links: ignore + +# Build directories +theme: + name: material + locale: en + custom_dir: docs/_build/overrides + palette: + # Palette toggle for light mode + - media: "(prefers-color-scheme: light)" + scheme: default + toggle: + icon: material/brightness-7 + name: Switch to dark mode + # Palette toggle for dark mode + - media: "(prefers-color-scheme: dark)" + scheme: slate + toggle: + icon: material/brightness-4 + name: Switch to light mode + logo: _build/assets/logo.png + features: + - navigation.tracking + - navigation.instant + - navigation.tabs + - navigation.tabs.sticky + - navigation.footer + - content.tabs.link + icon: + repo: fontawesome/brands/github + +extra_css: + - _build/css/extra.css +extra: + consent: + title: Cookie consent + description: >- + We use cookies to recognize your repeated visits and preferences, as well + as to measure the effectiveness of our documentation and whether users + find what they're searching for. With your consent, you're helping us to + make our documentation better. + analytics: + provider: google + property: G-LKNVFWD3T5 + +# Preview controls +# TODO: Fix warnings and turn on strict mode +strict: false + +# Formatting options +markdown_extensions: + - admonition + - pymdownx.details + - attr_list + - pymdownx.emoji: + emoji_index: !!python/name:materialx.emoji.twemoji + emoji_generator: !!python/name:materialx.emoji.to_svg + - pymdownx.superfences + - pymdownx.tabbed: + alternate_style: true + - pymdownx.snippets: + base_path: ['.','docs/src/'] + check_paths: true + dedent_subsections: true + - footnotes + +hooks: + - docs/_build/scripts/people.py + +plugins: + - search: + lang: en + - markdown-exec + - macros: + module_name: docs/_build/scripts/macro diff --git a/py-polars/Cargo.lock b/py-polars/Cargo.lock index 38b9d1e834e3..45dd9b735978 100644 --- a/py-polars/Cargo.lock +++ b/py-polars/Cargo.lock @@ -96,41 +96,6 @@ dependencies = [ "serde", ] -[[package]] -name = "arrow2" -version = "0.17.4" -source = "git+https://github.com/jorgecarleitao/arrow2?rev=7c93e358fc400bf3c0c0219c22eefc6b38fc2d12#7c93e358fc400bf3c0c0219c22eefc6b38fc2d12" -dependencies = [ - "ahash", - "arrow-format", - "avro-schema", - "base64", - "bytemuck", - "chrono", - "chrono-tz", - "dyn-clone", - "either", - "ethnum", - "fallible-streaming-iterator", - "foreign_vec", - "futures", - "getrandom", - "hash_hasher", - "hashbrown 0.14.0", - "lexical-core", - "lz4", - "multiversion", - "num-traits", - "parquet2", - "regex", - "regex-syntax", - "rustc_version", - "simdutf8", - "streaming-iterator", - "strength_reduce", - "zstd", -] - [[package]] name = "async-stream" version = "0.3.5" @@ -150,7 +115,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -161,7 +126,7 @@ checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -195,9 +160,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.3" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "bitflags" @@ -248,35 +213,35 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.13.0" +version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "bytemuck" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.4.1" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdde5c9cd29ebd706ce1b35600920a33550e402fc998a2e53ad3b42c3c47a192" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" [[package]] name = "cargo-lock" @@ -308,14 +273,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.27" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f56b4c72906975ca04becb8a30e102dfecddd0c06181e3e95ddc444be28881f8" +checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-targets", ] @@ -518,7 +485,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -636,7 +603,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -717,12 +684,6 @@ dependencies = [ "serde", ] -[[package]] -name = "hash_hasher" -version = "2.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74721d007512d0cb3338cd20f0654ac913920061a4c4d0d8708edb3f2a698c0c" - [[package]] name = "hashbrown" version = "0.13.2" @@ -961,9 +922,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" [[package]] name = "libflate" @@ -997,6 +958,16 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libloading" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d580318f95776505201b28cf98eb1fa5e4be3b689633ba6a3e6cd880ff22d8cb" +dependencies = [ + "cfg-if", + "windows-sys", +] + [[package]] name = "libm" version = "0.2.7" @@ -1005,9 +976,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "libmimalloc-sys" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25d058a81af0d1c22d7a1c948576bee6d673f7af3c0f35564abd6c81122f513d" +checksum = "3979b5c37ece694f1f5e51e7ecc871fdb0f517ed04ee45f88d15d6d553cb9664" dependencies = [ "cc", "libc", @@ -1083,9 +1054,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.1" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f478948fd84d9f8e86967bf432640e46adfb5a4bd4f14ef7e864ab38220534ae" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "memmap2" @@ -1107,9 +1078,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.38" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "972e5f23f6716f62665760b0f4cbf592576a80c7b879ba9beaafc0e558894127" +checksum = "fa01922b5ea280a911e323e4d2fd24b7fe5cc4042e0d2cda3c40775cdc4bdc9c" dependencies = [ "libmimalloc-sys", ] @@ -1157,6 +1128,37 @@ dependencies = [ "target-features", ] +[[package]] +name = "nano-arrow" +version = "0.1.0" +dependencies = [ + "ahash", + "arrow-format", + "avro-schema", + "base64", + "bytemuck", + "chrono", + "chrono-tz", + "dyn-clone", + "either", + "ethnum", + "fallible-streaming-iterator", + "foreign_vec", + "futures", + "getrandom", + "hashbrown 0.14.0", + "lexical-core", + "lz4", + "multiversion", + "num-traits", + "parquet2", + "rustc_version", + "simdutf8", + "streaming-iterator", + "strength_reduce", + "zstd", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -1381,7 +1383,7 @@ dependencies = [ [[package]] name = "polars" -version = "0.32.0" +version = "0.33.2" dependencies = [ "getrandom", "polars-core", @@ -1395,7 +1397,7 @@ dependencies = [ [[package]] name = "polars-algo" -version = "0.32.0" +version = "0.33.2" dependencies = [ "polars-core", "polars-lazy", @@ -1404,15 +1406,15 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", "atoi", "chrono", "chrono-tz", "ethnum", "hashbrown 0.14.0", "multiversion", + "nano-arrow", "num-traits", "polars-error", "serde", @@ -1422,10 +1424,9 @@ dependencies = [ [[package]] name = "polars-core" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", "bitflags 2.4.0", "chrono", "chrono-tz", @@ -1434,6 +1435,7 @@ dependencies = [ "hashbrown 0.14.0", "indexmap", "itoap", + "nano-arrow", "ndarray", "num-traits", "once_cell", @@ -1455,29 +1457,38 @@ dependencies = [ [[package]] name = "polars-error" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", + "nano-arrow", "regex", "thiserror", ] +[[package]] +name = "polars-ffi" +version = "0.33.2" +dependencies = [ + "nano-arrow", + "polars-core", +] + [[package]] name = "polars-io" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", "bytes", "chrono", "chrono-tz", "fast-float", "flate2", "home", + "itoa", "lexical", "lexical-core", "memchr", "memmap2", + "nano-arrow", "num-traits", "once_cell", "polars-arrow", @@ -1488,6 +1499,7 @@ dependencies = [ "polars-utils", "rayon", "regex", + "ryu", "serde", "serde_json", "simd-json", @@ -1496,23 +1508,27 @@ dependencies = [ [[package]] name = "polars-json" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", + "chrono", "fallible-streaming-iterator", "hashbrown 0.14.0", "indexmap", + "itoa", + "nano-arrow", "num-traits", "polars-arrow", "polars-error", "polars-utils", + "ryu", "simd-json", + "streaming-iterator", ] [[package]] name = "polars-lazy" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", "bitflags 2.4.0", @@ -1535,10 +1551,9 @@ dependencies = [ [[package]] name = "polars-ops" -version = "0.32.0" +version = "0.33.2" dependencies = [ "argminmax", - "arrow2", "base64", "chrono", "chrono-tz", @@ -1547,6 +1562,7 @@ dependencies = [ "indexmap", "jsonpath_lib", "memchr", + "nano-arrow", "polars-arrow", "polars-core", "polars-json", @@ -1560,7 +1576,7 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.32.0" +version = "0.33.2" dependencies = [ "crossbeam-channel", "crossbeam-queue", @@ -1581,16 +1597,18 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", - "arrow2", "chrono", "chrono-tz", "ciborium", + "libloading", + "nano-arrow", "once_cell", "polars-arrow", "polars-core", + "polars-ffi", "polars-io", "polars-ops", "polars-time", @@ -1606,16 +1624,16 @@ dependencies = [ [[package]] name = "polars-row" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", + "nano-arrow", "polars-error", "polars-utils", ] [[package]] name = "polars-sql" -version = "0.32.0" +version = "0.33.2" dependencies = [ "polars-arrow", "polars-core", @@ -1628,12 +1646,12 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.32.0" +version = "0.33.2" dependencies = [ - "arrow2", "atoi", "chrono", "chrono-tz", + "nano-arrow", "now", "once_cell", "polars-arrow", @@ -1647,7 +1665,7 @@ dependencies = [ [[package]] name = "polars-utils" -version = "0.32.0" +version = "0.33.2" dependencies = [ "ahash", "bytemuck", @@ -1669,16 +1687,16 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" dependencies = [ "unicode-ident", ] [[package]] name = "py-polars" -version = "0.19.2" +version = "0.19.3" dependencies = [ "ahash", "built", @@ -1697,6 +1715,7 @@ dependencies = [ "polars-error", "polars-lazy", "polars-ops", + "polars-plan", "pyo3", "pyo3-built", "serde_json", @@ -1859,9 +1878,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -1871,9 +1890,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", @@ -1957,14 +1976,14 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "indexmap", "itoa", @@ -2013,11 +2032,12 @@ dependencies = [ [[package]] name = "simd-json" -version = "0.10.6" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de7f1293f0e4e11d52e588766fe9de8caa2857ff63809d40de83245452ca7c5c" +checksum = "80ea1dfc2c400965867fc4ddd6f502572be2de2074b39f90984ed15fbdbdd8eb" dependencies = [ "ahash", + "getrandom", "halfbrown", "lexical-core", "once_cell", @@ -2137,7 +2157,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -2153,9 +2173,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.29" +version = "2.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" +checksum = "91e02e55d62894af2a08aca894c6577281f76769ba47c94d5756bec8ac6e7373" dependencies = [ "proc-macro2", "quote", @@ -2164,9 +2184,9 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.29.9" +version = "0.29.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d0e9cc2273cc8d31377bdd638d72e3ac3e5607b18621062b169d02787f1bab" +checksum = "0a18d114d420ada3a891e6bc8e96a2023402203296a47cdd65083377dad18ba5" dependencies = [ "cfg-if", "core-foundation-sys", @@ -2190,22 +2210,22 @@ checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" [[package]] name = "thiserror" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", ] [[package]] @@ -2225,9 +2245,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "toml" -version = "0.7.6" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" +checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ "serde", "serde_spanned", @@ -2246,9 +2266,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.19.14" +version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap", "serde", @@ -2265,9 +2285,9 @@ checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" @@ -2352,7 +2372,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", "wasm-bindgen-shared", ] @@ -2374,7 +2394,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.36", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2493,9 +2513,9 @@ dependencies = [ [[package]] name = "xxhash-rust" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "735a71d46c4d68d71d4b24d03fdc2b98e38cea81730595801db779c04fe80d70" +checksum = "9828b178da53440fa9c766a3d2f73f7cf5d0ac1fe3980c1e5018d899fd19e07b" [[package]] name = "zstd" diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index cc0d0ced39e5..318e5979979e 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-polars" -version = "0.19.2" +version = "0.19.3" edition = "2021" [lib] @@ -13,6 +13,7 @@ polars-core = { path = "../crates/polars-core", default-features = false, featur polars-error = { path = "../crates/polars-error" } polars-lazy = { path = "../crates/polars-lazy", default-features = false, features = ["python"] } polars-ops = { path = "../crates/polars-ops", default-features = false } +polars-plan = { path = "../crates/polars-plan", default-features = false } ahash = "0.8" ciborium = "0.2" @@ -51,8 +52,8 @@ features = [ "fmt", "horizontal_concat", "interpolate", - "is_first", - "is_last", + "is_first_distinct", + "is_last_distinct", "is_unique", "lazy", "list_eval", @@ -139,6 +140,7 @@ list_any_all = ["polars/list_any_all"] cutqcut = ["polars/cutqcut"] rle = ["polars/rle"] extract_groups = ["polars/extract_groups"] +ffi_plugin = ["polars-plan/ffi_plugin"] all = [ "dtype-i8", @@ -185,6 +187,7 @@ all = [ "rle", "extract_groups", "polars-ops/convert_index", + "ffi_plugin", ] # we cannot conditionally activate simd diff --git a/py-polars/Makefile b/py-polars/Makefile index f5fd7c4404cd..74336116b0d0 100644 --- a/py-polars/Makefile +++ b/py-polars/Makefile @@ -76,6 +76,7 @@ test: .venv build ## Run fast unittests .PHONY: doctest doctest: .venv build ## Run doctests $(VENV_BIN)/python tests/docs/run_doctest.py + $(VENV_BIN)/pytest tests/docs/test_user_guide.py -m docs .PHONY: test-all test-all: .venv build ## Run all tests diff --git a/py-polars/docs/source/_static/css/custom.css b/py-polars/docs/source/_static/css/custom.css index 029703f6ec91..9cdd3b3591d8 100644 --- a/py-polars/docs/source/_static/css/custom.css +++ b/py-polars/docs/source/_static/css/custom.css @@ -43,3 +43,7 @@ div.bd-sidebar-secondary { label.sidebar-toggle.secondary-toggle { display: none !important; } + +a:visited { + color: var(--pst-color-link); +} diff --git a/py-polars/docs/source/conf.py b/py-polars/docs/source/conf.py index 2e3f1b34290a..e4f3a71c516d 100644 --- a/py-polars/docs/source/conf.py +++ b/py-polars/docs/source/conf.py @@ -106,7 +106,7 @@ "external_links": [ { "name": "User Guide", - "url": f"{web_root}/polars-book/user-guide/index.html", + "url": f"{web_root}/polars/user-guide/index.html", }, { "name": "Powered by Xomnia", diff --git a/py-polars/docs/source/reference/expressions/boolean.rst b/py-polars/docs/source/reference/expressions/boolean.rst index c08465ce4bb4..73c68917d515 100644 --- a/py-polars/docs/source/reference/expressions/boolean.rst +++ b/py-polars/docs/source/reference/expressions/boolean.rst @@ -12,9 +12,11 @@ Boolean Expr.is_duplicated Expr.is_finite Expr.is_first + Expr.is_first_distinct Expr.is_in Expr.is_infinite Expr.is_last + Expr.is_last_distinct Expr.is_nan Expr.is_not Expr.is_not_nan diff --git a/py-polars/docs/source/reference/io.rst b/py-polars/docs/source/reference/io.rst index a5bf2f91916a..9b0b91335c09 100644 --- a/py-polars/docs/source/reference/io.rst +++ b/py-polars/docs/source/reference/io.rst @@ -66,14 +66,22 @@ AVRO read_avro DataFrame.write_avro -Excel -~~~~~ +Spreadsheet +~~~~~~~~~~~ .. autosummary:: :toctree: api/ read_excel + read_ods DataFrame.write_excel +Apache Iceberg +~~~~~~~~~~~~~~ +.. autosummary:: + :toctree: api/ + + scan_iceberg + Delta Lake ~~~~~~~~~~ .. autosummary:: diff --git a/py-polars/docs/source/reference/series/descriptive.rst b/py-polars/docs/source/reference/series/descriptive.rst index d30868055c3e..6ec39e326b9f 100644 --- a/py-polars/docs/source/reference/series/descriptive.rst +++ b/py-polars/docs/source/reference/series/descriptive.rst @@ -15,11 +15,13 @@ Descriptive Series.is_empty Series.is_finite Series.is_first + Series.is_first_distinct Series.is_float Series.is_in Series.is_infinite Series.is_integer Series.is_last + Series.is_last_distinct Series.is_nan Series.is_not_nan Series.is_not_null diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index f203591e5112..8500a34fd0de 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -67,6 +67,7 @@ DuplicateError, InvalidOperationError, NoDataError, + OutOfBoundsError, PolarsPanicError, SchemaError, SchemaFieldNotFoundError, @@ -167,10 +168,12 @@ read_ipc_stream, read_json, read_ndjson, + read_ods, read_parquet, read_parquet_schema, scan_csv, scan_delta, + scan_iceberg, scan_ipc, scan_ndjson, scan_parquet, @@ -201,6 +204,7 @@ "DuplicateError", "InvalidOperationError", "NoDataError", + "OutOfBoundsError", "PolarsPanicError", "SchemaError", "SchemaFieldNotFoundError", @@ -261,10 +265,12 @@ "read_ipc_stream", "read_json", "read_ndjson", + "read_ods", "read_parquet", "read_parquet_schema", "scan_csv", "scan_delta", + "scan_iceberg", "scan_ipc", "scan_ndjson", "scan_parquet", diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 476d32151308..02e03bdc5fb9 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -55,7 +55,7 @@ from polars.exceptions import NoRowsReturnedError, TooManyRowsReturnedError from polars.functions import col, lit from polars.io._utils import _is_glob_pattern, _is_local_file -from polars.io.excel._write_utils import ( +from polars.io.spreadsheet._write_utils import ( _unpack_multi_column_dict, _xl_apply_conditional_formats, _xl_inject_sparklines, @@ -1853,10 +1853,11 @@ def item(self, row: int | None = None, column: int | str | None = None) -> Any: if row is None and column is None: if self.shape != (1, 1): raise ValueError( - f"can only call `.item()` if the dataframe is of shape (1, 1), or if" - f" explicit row/col values are provided; frame has shape {self.shape!r}" + "can only call `.item()` if the dataframe is of shape (1, 1)," + " or if explicit row/col values are provided;" + f" frame has shape {self.shape!r}" ) - return self._df.select_at_idx(0).get_idx(0) + return self._df.select_at_idx(0).get_index(0) elif row is None or column is None: raise ValueError("cannot call `.item()` with only one of `row` or `column`") @@ -1868,7 +1869,7 @@ def item(self, row: int | None = None, column: int | str | None = None) -> Any: ) if s is None: raise IndexError(f"column index {column!r} is out of bounds") - return s.get_idx(row) + return s.get_index_signed(row) def to_arrow(self) -> pa.Table: """ @@ -2528,7 +2529,7 @@ def write_csv( ``Float64`` datatypes. null_value A string representing null values (defaulting to the empty string). - quote_style : {'necessary', 'always', 'non_numeric'} + quote_style : {'necessary', 'always', 'non_numeric', 'never'} Determines the quoting strategy used. - necessary (default): This puts quotes around fields only when necessary. They are necessary when fields contain a quote, @@ -2537,6 +2538,8 @@ def write_csv( (which is indistinguishable from a record with one empty field). This is the default. - always: This puts quotes around every field. Always. + - never: This never puts quotes around fields, even if that results in + invalid CSV data (e.g.: by not quoting strings containing the separator). - non_numeric: This puts quotes around all fields that are non-numeric. Namely, when writing a field that does not parse as a valid float or integer, then quotes will be used even if they aren`t strictly @@ -5251,47 +5254,16 @@ def group_by_dynamic( Group based on a time value (or index value of type Int32, Int64). Time windows are calculated and rows are assigned to windows. Different from a - normal group by is that a row can be member of multiple groups. The time/index - window could be seen as a rolling window, with a window size determined by - dates/times/values instead of slots in the DataFrame. + normal group by is that a row can be member of multiple groups. + By default, the windows look like: - A window is defined by: + - [start, start + period) + - [start + every, start + every + period) + - [start + 2*every, start + 2*every + period) + - ... - - every: interval of the window - - period: length of the window - - offset: offset of the window - - The `every`, `period` and `offset` arguments are created with - the following string language: - - - 1ns (1 nanosecond) - - 1us (1 microsecond) - - 1ms (1 millisecond) - - 1s (1 second) - - 1m (1 minute) - - 1h (1 hour) - - 1d (1 calendar day) - - 1w (1 calendar week) - - 1mo (1 calendar month) - - 1q (1 calendar quarter) - - 1y (1 calendar year) - - 1i (1 index count) - - Or combine them: - "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds - - Suffix with `"_saturating"` to indicate that dates too large for - their month should saturate at the largest date (e.g. 2022-02-29 -> 2022-02-28) - instead of erroring. - - By "calendar day", we mean the corresponding time on the next day (which may - not be 24 hours, due to daylight savings). Similarly for "calendar week", - "calendar month", "calendar quarter", and "calendar year". - - In case of a group_by_dynamic on an integer column, the windows are defined by: - - - "1i" # length 1 - - "10i" # length 10 + where `start` is determined by `start_by`, `offset`, and `every` (see parameter + descriptions below). .. warning:: The index column must be sorted in ascending order. If `by` is passed, then @@ -5311,10 +5283,10 @@ def group_by_dynamic( every interval of the window period - length of the window, if None it is equal to 'every' + length of the window, if None it will equal 'every' offset - offset of the window if None and period is None it will be equal to negative - `every` + offset of the window, only takes effect if `start_by` is ``'window'``. + Defaults to negative `every`. truncate truncate the time value to the window lower bound include_boundaries @@ -5328,7 +5300,8 @@ def group_by_dynamic( start_by : {'window', 'datapoint', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'} The strategy to determine the start of the first window by. - * 'window': Truncate the start of the window with the 'every' argument. + * 'window': Start by taking the earliest timestamp, truncating it with + `every`, and then adding `offset`. Note that weekly windows start on Monday. * 'datapoint': Start from the first encountered data point. * a day of the week (only takes effect if `every` contains ``'w'``): @@ -5351,25 +5324,61 @@ def group_by_dynamic( of which will be sorted by `index_column` (but note that if `by` columns are passed, it will only be sorted within each `by` group). + See Also + -------- + group_by_rolling + Notes ----- - If you're coming from pandas, then + 1) If you're coming from pandas, then + + .. code-block:: python + + # polars + df.group_by_dynamic("ts", every="1d").agg(pl.col("value").sum()) - .. code-block:: python + is equivalent to - # polars - df.group_by_dynamic("ts", every="1d").agg(pl.col("value").sum()) + .. code-block:: python - is equivalent to + # pandas + df.set_index("ts").resample("D")["value"].sum().reset_index() - .. code-block:: python + though note that, unlike pandas, polars doesn't add extra rows for empty + windows. If you need `index_column` to be evenly spaced, then please combine + with :func:`DataFrame.upsample`. - # pandas - df.set_index("ts").resample("D")["value"].sum().reset_index() + 2) The `every`, `period` and `offset` arguments are created with + the following string language: - though note that, unlike pandas, polars doesn't add extra rows for empty - windows. If you need `index_column` to be evenly spaced, then please combine - with :func:`DataFrame.upsample`. + - 1ns (1 nanosecond) + - 1us (1 microsecond) + - 1ms (1 millisecond) + - 1s (1 second) + - 1m (1 minute) + - 1h (1 hour) + - 1d (1 calendar day) + - 1w (1 calendar week) + - 1mo (1 calendar month) + - 1q (1 calendar quarter) + - 1y (1 calendar year) + - 1i (1 index count) + + Or combine them: + "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds + + Suffix with `"_saturating"` to indicate that dates too large for + their month should saturate at the largest date (e.g. 2022-02-29 -> 2022-02-28) + instead of erroring. + + By "calendar day", we mean the corresponding time on the next day (which may + not be 24 hours, due to daylight savings). Similarly for "calendar week", + "calendar month", "calendar quarter", and "calendar year". + + In case of a group_by_dynamic on an integer column, the windows are defined by: + + - "1i" # length 1 + - "10i" # length 10 Examples -------- @@ -5545,12 +5554,13 @@ def group_by_dynamic( ... closed="right", ... ).agg(pl.col("A").alias("A_agg_list")) ... ) - shape: (3, 4) + shape: (4, 4) ┌─────────────────┬─────────────────┬─────┬─────────────────┐ │ _lower_boundary ┆ _upper_boundary ┆ idx ┆ A_agg_list │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ list[str] │ ╞═════════════════╪═════════════════╪═════╪═════════════════╡ + │ -2 ┆ 1 ┆ -2 ┆ ["A", "A"] │ │ 0 ┆ 3 ┆ 0 ┆ ["A", "B", "B"] │ │ 2 ┆ 5 ┆ 2 ┆ ["B", "B", "C"] │ │ 4 ┆ 7 ┆ 4 ┆ ["C"] │ @@ -8696,8 +8706,8 @@ def sample( set to False (default), the order of the returned rows will be neither stable nor fully random. seed - Seed for the random number generator. If set to None (default), a random - seed is generated using the ``random`` module. + Seed for the random number generator. If set to None (default), a + random seed is generated for each sample operation. Examples -------- @@ -9765,7 +9775,8 @@ def groupby( """ Start a group by operation. - Alias for :func:`DataFrame.group_by`. + .. deprecated:: 0.19.0 + This method has been renamed to :func:`DataFrame.group_by`. Parameters ---------- @@ -9806,7 +9817,8 @@ def groupby_rolling( """ Create rolling groups based on a time, Int32, or Int64 column. - Alias for :func:`DataFrame.group_by_rolling`. + .. deprecated:: 0.19.0 + This method has been renamed to :func:`DataFrame.group_by_rolling`. Parameters ---------- @@ -9862,7 +9874,8 @@ def groupby_dynamic( """ Group based on a time value (or index value of type Int32, Int64). - Alias for :func:`DataFrame.group_by_rolling`. + .. deprecated:: 0.19.0 + This method has been renamed to :func:`DataFrame.group_by_dynamic`. Parameters ---------- @@ -9878,10 +9891,10 @@ def groupby_dynamic( every interval of the window period - length of the window, if None it is equal to 'every' + length of the window, if None it will equal 'every' offset - offset of the window if None and period is None it will be equal to negative - `every` + offset of the window, only takes effect if `start_by` is ``'window'``. + Defaults to negative `every`. truncate truncate the time value to the window lower bound include_boundaries @@ -9895,7 +9908,8 @@ def groupby_dynamic( start_by : {'window', 'datapoint', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'} The strategy to determine the start of the first window by. - * 'window': Truncate the start of the window with the 'every' argument. + * 'window': Start by taking the earliest timestamp, truncating it with + `every`, and then adding `offset`. Note that weekly windows start on Monday. * 'datapoint': Start from the first encountered data point. * a day of the week (only takes effect if `every` contains ``'w'``): diff --git a/py-polars/polars/dependencies.py b/py-polars/polars/dependencies.py index 715c3bec3eed..bf853d1d2cc4 100644 --- a/py-polars/polars/dependencies.py +++ b/py-polars/polars/dependencies.py @@ -11,11 +11,13 @@ _DATAFRAME_API_COMPAT_AVAILABLE = True _DELTALAKE_AVAILABLE = True _FSSPEC_AVAILABLE = True +_GEVENT_AVAILABLE = True _HYPOTHESIS_AVAILABLE = True _NUMPY_AVAILABLE = True _PANDAS_AVAILABLE = True _PYARROW_AVAILABLE = True _PYDANTIC_AVAILABLE = True +_PYICEBERG_AVAILABLE = True _ZONEINFO_AVAILABLE = True @@ -155,11 +157,13 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: import dataframe_api_compat import deltalake import fsspec + import gevent import hypothesis import numpy import pandas import pyarrow import pydantic + import pyiceberg if sys.version_info >= (3, 9): import zoneinfo @@ -184,11 +188,13 @@ def _lazy_import(module_name: str) -> tuple[ModuleType, bool]: pandas, _PANDAS_AVAILABLE = _lazy_import("pandas") pyarrow, _PYARROW_AVAILABLE = _lazy_import("pyarrow") pydantic, _PYDANTIC_AVAILABLE = _lazy_import("pydantic") + pyiceberg, _PYICEBERG_AVAILABLE = _lazy_import("pyiceberg") zoneinfo, _ZONEINFO_AVAILABLE = ( _lazy_import("zoneinfo") if sys.version_info >= (3, 9) else _lazy_import("backports.zoneinfo") ) + gevent, _GEVENT_AVAILABLE = _lazy_import("gevent") @lru_cache(maxsize=None) @@ -228,9 +234,11 @@ def _check_for_pydantic(obj: Any) -> bool: "dataframe_api_compat", "deltalake", "fsspec", + "gevent", "numpy", "pandas", "pydantic", + "pyiceberg", "pyarrow", "zoneinfo", # lazy utilities @@ -241,7 +249,9 @@ def _check_for_pydantic(obj: Any) -> bool: "_LazyModule", # exported flags/guards "_DELTALAKE_AVAILABLE", + "_PYICEBERG_AVAILABLE", "_FSSPEC_AVAILABLE", + "_GEVENT_AVAILABLE", "_HYPOTHESIS_AVAILABLE", "_NUMPY_AVAILABLE", "_PANDAS_AVAILABLE", diff --git a/py-polars/polars/exceptions.py b/py-polars/polars/exceptions.py index 795cb6bd6863..963c5379cbb5 100644 --- a/py-polars/polars/exceptions.py +++ b/py-polars/polars/exceptions.py @@ -6,6 +6,7 @@ DuplicateError, InvalidOperationError, NoDataError, + OutOfBoundsError, PolarsPanicError, SchemaError, SchemaFieldNotFoundError, @@ -35,6 +36,12 @@ class InvalidOperationError(Exception): # type: ignore[no-redef] class NoDataError(Exception): # type: ignore[no-redef] """Exception raised when an operation can not be performed on an empty data structure.""" # noqa: W505 + class OutOfBoundsError(Exception): # type: ignore[no-redef] + """Exception raised when the given index is out of bounds.""" + + class PolarsPanicError(Exception): # type: ignore[no-redef] + """Exception raised when an unexpected state causes a panic in the underlying Rust library.""" # noqa: W505 + class SchemaError(Exception): # type: ignore[no-redef] """Exception raised when trying to combine data structures with mismatched schemas.""" # noqa: W505 @@ -50,9 +57,6 @@ class StringCacheMismatchError(Exception): # type: ignore[no-redef] class StructFieldNotFoundError(Exception): # type: ignore[no-redef] """Exception raised when a specified schema field is not found.""" - class PolarsPanicError(Exception): # type: ignore[no-redef] - """Exception raised when an unexpected state causes a panic in the underlying Rust library.""" # noqa: W505 - class ChronoFormatWarning(Warning): """ @@ -78,6 +82,10 @@ class NoRowsReturnedError(RowsError): """Exception raised when no rows are returned, but at least one row is expected.""" +class ParameterCollisionError(RuntimeError): + """Exception raised when the same parameter occurs multiple times.""" + + class PolarsInefficientMapWarning(Warning): """Warning raised when a potentially slow `apply` operation is performed.""" @@ -103,6 +111,7 @@ class UnsuitableSQLError(ValueError): "InvalidOperationError", "NoDataError", "NoRowsReturnedError", + "OutOfBoundsError", "PolarsInefficientMapWarning", "PolarsPanicError", "RowsError", diff --git a/py-polars/polars/expr/binary.py b/py-polars/polars/expr/binary.py index b83c4d95fa11..ed8a80d20c40 100644 --- a/py-polars/polars/expr/binary.py +++ b/py-polars/polars/expr/binary.py @@ -2,11 +2,12 @@ from typing import TYPE_CHECKING +from polars.utils._parse_expr_input import parse_as_expression from polars.utils._wrap import wrap_expr if TYPE_CHECKING: from polars import Expr - from polars.type_aliases import TransferEncoding + from polars.type_aliases import IntoExpr, TransferEncoding class ExprBinaryNameSpace: @@ -17,7 +18,7 @@ class ExprBinaryNameSpace: def __init__(self, expr: Expr): self._pyexpr = expr._pyexpr - def contains(self, literal: bytes) -> Expr: + def contains(self, literal: IntoExpr) -> Expr: r""" Check if binaries in Series contain a binary substring. @@ -42,29 +43,29 @@ def contains(self, literal: bytes) -> Expr: ... { ... "name": ["black", "yellow", "blue"], ... "code": [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"], + ... "lit": [b"\x00", b"\xff\x00", b"\xff\xff"], ... } ... ) >>> colors.select( ... "name", - ... pl.col("code").bin.encode("hex").alias("code_encoded_hex"), - ... pl.col("code").bin.contains(b"\xff").alias("contains_ff"), - ... pl.col("code").bin.starts_with(b"\xff").alias("starts_with_ff"), - ... pl.col("code").bin.ends_with(b"\xff").alias("ends_with_ff"), + ... pl.col("code").bin.contains(b"\xff").alias("contains_with_lit"), + ... pl.col("code").bin.contains(pl.col("lit")).alias("contains_with_expr"), ... ) - shape: (3, 5) - ┌────────┬──────────────────┬─────────────┬────────────────┬──────────────┐ - │ name ┆ code_encoded_hex ┆ contains_ff ┆ starts_with_ff ┆ ends_with_ff │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ str ┆ bool ┆ bool ┆ bool │ - ╞════════╪══════════════════╪═════════════╪════════════════╪══════════════╡ - │ black ┆ 000000 ┆ false ┆ false ┆ false │ - │ yellow ┆ ffff00 ┆ true ┆ true ┆ false │ - │ blue ┆ 0000ff ┆ true ┆ false ┆ true │ - └────────┴──────────────────┴─────────────┴────────────────┴──────────────┘ + shape: (3, 3) + ┌────────┬───────────────────┬────────────────────┐ + │ name ┆ contains_with_lit ┆ contains_with_expr │ + │ --- ┆ --- ┆ --- │ + │ str ┆ bool ┆ bool │ + ╞════════╪═══════════════════╪════════════════════╡ + │ black ┆ false ┆ true │ + │ yellow ┆ true ┆ true │ + │ blue ┆ true ┆ false │ + └────────┴───────────────────┴────────────────────┘ """ + literal = parse_as_expression(literal, str_as_lit=True) return wrap_expr(self._pyexpr.bin_contains(literal)) - def ends_with(self, suffix: bytes) -> Expr: + def ends_with(self, suffix: IntoExpr) -> Expr: r""" Check if string values end with a binary substring. @@ -89,29 +90,29 @@ def ends_with(self, suffix: bytes) -> Expr: ... { ... "name": ["black", "yellow", "blue"], ... "code": [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"], + ... "suffix": [b"\x00", b"\xff\x00", b"\x00\x00"], ... } ... ) >>> colors.select( ... "name", - ... pl.col("code").bin.encode("hex").alias("code_encoded_hex"), - ... pl.col("code").bin.contains(b"\xff").alias("contains_ff"), - ... pl.col("code").bin.starts_with(b"\xff").alias("starts_with_ff"), - ... pl.col("code").bin.ends_with(b"\xff").alias("ends_with_ff"), + ... pl.col("code").bin.ends_with(b"\xff").alias("ends_with_lit"), + ... pl.col("code").bin.ends_with(pl.col("suffix")).alias("ends_with_expr"), ... ) - shape: (3, 5) - ┌────────┬──────────────────┬─────────────┬────────────────┬──────────────┐ - │ name ┆ code_encoded_hex ┆ contains_ff ┆ starts_with_ff ┆ ends_with_ff │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ str ┆ bool ┆ bool ┆ bool │ - ╞════════╪══════════════════╪═════════════╪════════════════╪══════════════╡ - │ black ┆ 000000 ┆ false ┆ false ┆ false │ - │ yellow ┆ ffff00 ┆ true ┆ true ┆ false │ - │ blue ┆ 0000ff ┆ true ┆ false ┆ true │ - └────────┴──────────────────┴─────────────┴────────────────┴──────────────┘ + shape: (3, 3) + ┌────────┬───────────────┬────────────────┐ + │ name ┆ ends_with_lit ┆ ends_with_expr │ + │ --- ┆ --- ┆ --- │ + │ str ┆ bool ┆ bool │ + ╞════════╪═══════════════╪════════════════╡ + │ black ┆ false ┆ true │ + │ yellow ┆ false ┆ true │ + │ blue ┆ true ┆ false │ + └────────┴───────────────┴────────────────┘ """ + suffix = parse_as_expression(suffix, str_as_lit=True) return wrap_expr(self._pyexpr.bin_ends_with(suffix)) - def starts_with(self, prefix: bytes) -> Expr: + def starts_with(self, prefix: IntoExpr) -> Expr: r""" Check if values start with a binary substring. @@ -136,26 +137,28 @@ def starts_with(self, prefix: bytes) -> Expr: ... { ... "name": ["black", "yellow", "blue"], ... "code": [b"\x00\x00\x00", b"\xff\xff\x00", b"\x00\x00\xff"], + ... "prefix": [b"\x00", b"\xff\x00", b"\x00\x00"], ... } ... ) >>> colors.select( ... "name", - ... pl.col("code").bin.encode("hex").alias("code_encoded_hex"), - ... pl.col("code").bin.contains(b"\xff").alias("contains_ff"), - ... pl.col("code").bin.starts_with(b"\xff").alias("starts_with_ff"), - ... pl.col("code").bin.ends_with(b"\xff").alias("ends_with_ff"), + ... pl.col("code").bin.starts_with(b"\xff").alias("starts_with_lit"), + ... pl.col("code") + ... .bin.starts_with(pl.col("prefix")) + ... .alias("starts_with_expr"), ... ) - shape: (3, 5) - ┌────────┬──────────────────┬─────────────┬────────────────┬──────────────┐ - │ name ┆ code_encoded_hex ┆ contains_ff ┆ starts_with_ff ┆ ends_with_ff │ - │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ - │ str ┆ str ┆ bool ┆ bool ┆ bool │ - ╞════════╪══════════════════╪═════════════╪════════════════╪══════════════╡ - │ black ┆ 000000 ┆ false ┆ false ┆ false │ - │ yellow ┆ ffff00 ┆ true ┆ true ┆ false │ - │ blue ┆ 0000ff ┆ true ┆ false ┆ true │ - └────────┴──────────────────┴─────────────┴────────────────┴──────────────┘ + shape: (3, 3) + ┌────────┬─────────────────┬──────────────────┐ + │ name ┆ starts_with_lit ┆ starts_with_expr │ + │ --- ┆ --- ┆ --- │ + │ str ┆ bool ┆ bool │ + ╞════════╪═════════════════╪══════════════════╡ + │ black ┆ false ┆ true │ + │ yellow ┆ true ┆ false │ + │ blue ┆ false ┆ true │ + └────────┴─────────────────┴──────────────────┘ """ + prefix = parse_as_expression(prefix, str_as_lit=True) return wrap_expr(self._pyexpr.bin_starts_with(prefix)) def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> Expr: diff --git a/py-polars/polars/expr/datetime.py b/py-polars/polars/expr/datetime.py index a519e3259842..d12b4682e8f3 100644 --- a/py-polars/polars/expr/datetime.py +++ b/py-polars/polars/expr/datetime.py @@ -28,7 +28,7 @@ def __init__(self, expr: Expr): def truncate( self, - every: str | timedelta, + every: str | timedelta | Expr, offset: str | timedelta | None = None, *, use_earliest: bool | None = None, @@ -221,12 +221,17 @@ def truncate( ambiguous = rename_use_earliest_to_ambiguous(use_earliest, ambiguous) if not isinstance(ambiguous, pl.Expr): ambiguous = F.lit(ambiguous) + + if not isinstance(every, pl.Expr): + every = _timedelta_to_pl_duration(every) + every = parse_as_expression(every, str_as_lit=True) + if offset is None: offset = "0ns" return wrap_expr( self._pyexpr.dt_truncate( - _timedelta_to_pl_duration(every), + every, _timedelta_to_pl_duration(offset), ambiguous._pyexpr, ) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 8dea34ae6984..f42cf5fff2ef 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -195,7 +195,10 @@ def __ne__(self, other: Any) -> Self: # type: ignore[override] return self._from_pyexpr(self._pyexpr.neq(self._to_expr(other)._pyexpr)) def __neg__(self) -> Expr: - return F.lit(0) - self + neg_expr = F.lit(0) - self + if (name := self.meta.output_name(raise_if_undetermined=False)) is not None: + neg_expr = neg_expr.alias(name) + return neg_expr def __or__(self, other: Expr | int | bool) -> Self: return self._from_pyexpr(self._pyexpr._or(self._to_pyexpr(other))) @@ -204,7 +207,10 @@ def __ror__(self, other: Any) -> Self: return self._from_pyexpr(self._to_pyexpr(other)._or(self._pyexpr)) def __pos__(self) -> Expr: - return F.lit(0) + self + pos_expr = F.lit(0) + self + if (name := self.meta.output_name(raise_if_undetermined=False)) is not None: + pos_expr = pos_expr.alias(name) + return pos_expr def __pow__(self, power: int | float | Series | Expr) -> Self: return self.pow(power) @@ -2116,9 +2122,7 @@ def arg_min(self) -> Self: """ return self._from_pyexpr(self._pyexpr.arg_min()) - def search_sorted( - self, element: Expr | int | float | Series, side: SearchSortedSide = "any" - ) -> Self: + def search_sorted(self, element: IntoExpr, side: SearchSortedSide = "any") -> Self: """ Find indices where elements should be inserted to maintain order. @@ -3193,9 +3197,9 @@ def is_unique(self) -> Self: """ return self._from_pyexpr(self._pyexpr.is_unique()) - def is_first(self) -> Self: + def is_first_distinct(self) -> Self: """ - Get a mask of the first unique value. + Return a boolean mask indicating the first occurrence of each distinct value. Returns ------- @@ -3204,31 +3208,27 @@ def is_first(self) -> Self: Examples -------- - >>> df = pl.DataFrame( - ... { - ... "num": [1, 2, 3, 1, 5], - ... } - ... ) - >>> df.with_columns(pl.col("num").is_first().alias("is_first")) + >>> df = pl.DataFrame({"a": [1, 1, 2, 3, 2]}) + >>> df.with_columns(pl.col("a").is_first_distinct().alias("first")) shape: (5, 2) - ┌─────┬──────────┐ - │ num ┆ is_first │ - │ --- ┆ --- │ - │ i64 ┆ bool │ - ╞═════╪══════════╡ - │ 1 ┆ true │ - │ 2 ┆ true │ - │ 3 ┆ true │ - │ 1 ┆ false │ - │ 5 ┆ true │ - └─────┴──────────┘ + ┌─────┬───────┐ + │ a ┆ first │ + │ --- ┆ --- │ + │ i64 ┆ bool │ + ╞═════╪═══════╡ + │ 1 ┆ true │ + │ 1 ┆ false │ + │ 2 ┆ true │ + │ 3 ┆ true │ + │ 2 ┆ false │ + └─────┴───────┘ """ - return self._from_pyexpr(self._pyexpr.is_first()) + return self._from_pyexpr(self._pyexpr.is_first_distinct()) - def is_last(self) -> Self: + def is_last_distinct(self) -> Self: """ - Get a mask of the last unique value. + Return a boolean mask indicating the last occurrence of each distinct value. Returns ------- @@ -3237,31 +3237,32 @@ def is_last(self) -> Self: Examples -------- - >>> df = pl.DataFrame( - ... { - ... "num": [1, 2, 3, 1, 5], - ... } - ... ) - >>> df.with_columns(pl.col("num").is_last().alias("is_last")) + >>> df = pl.DataFrame({"a": [1, 1, 2, 3, 2]}) + >>> df.with_columns(pl.col("a").is_last_distinct().alias("last")) shape: (5, 2) - ┌─────┬─────────┐ - │ num ┆ is_last │ - │ --- ┆ --- │ - │ i64 ┆ bool │ - ╞═════╪═════════╡ - │ 1 ┆ false │ - │ 2 ┆ true │ - │ 3 ┆ true │ - │ 1 ┆ true │ - │ 5 ┆ true │ - └─────┴─────────┘ + ┌─────┬───────┐ + │ a ┆ last │ + │ --- ┆ --- │ + │ i64 ┆ bool │ + ╞═════╪═══════╡ + │ 1 ┆ false │ + │ 1 ┆ true │ + │ 2 ┆ false │ + │ 3 ┆ true │ + │ 2 ┆ true │ + └─────┴───────┘ """ - return self._from_pyexpr(self._pyexpr.is_last()) + return self._from_pyexpr(self._pyexpr.is_last_distinct()) def is_duplicated(self) -> Self: """ - Get mask of duplicated values. + Return a boolean mask indicating duplicated values. + + Returns + ------- + Expr + Expression of data type :class:`Boolean`. Examples -------- @@ -3696,7 +3697,7 @@ def map_batches( represented by an expression using a third-party library. Read more in `the book - `_. + `_. Parameters ---------- @@ -9267,6 +9268,96 @@ def rolling_apply( function, window_size, weights, min_periods, center=center ) + @deprecate_renamed_function("is_first_distinct", version="0.19.3") + def is_first(self) -> Self: + """ + Return a boolean mask indicating the first occurrence of each distinct value. + + .. deprecated:: 0.19.3 + This method has been renamed to :func:`Expr.is_first_distinct`. + + Returns + ------- + Expr + Expression of data type :class:`Boolean`. + + """ + return self.is_first_distinct() + + @deprecate_renamed_function("is_last_distinct", version="0.19.3") + def is_last(self) -> Self: + """ + Return a boolean mask indicating the last occurrence of each distinct value. + + .. deprecated:: 0.19.3 + This method has been renamed to :func:`Expr.is_last_distinct`. + + Returns + ------- + Expr + Expression of data type :class:`Boolean`. + + """ + return self.is_last_distinct() + + def _register_plugin( + self, + lib: str, + symbol: str, + args: list[IntoExpr] | None = None, + *, + is_elementwise: bool = False, + input_wildcard_expansion: bool = False, + auto_explode: bool = False, + cast_to_supertypes: bool = False, + ) -> Self: + """ + Register a shared library as a plugin. + + .. warning:: + This is highly unsafe as this will call the C function + loaded by ``lib::symbol`` + + .. note:: + This functionality is unstable and may change without it + being considered breaking. + + Parameters + ---------- + lib + Library to load. + symbol + Function to load. + args + Arguments (other than self) passed to this function. + is_elementwise + If the function only operates on scalars + this will trigger fast paths. + input_wildcard_expansion + Expand expressions as input of this function. + auto_explode + Explode the results in a group_by. + This is recommended for aggregation functions. + cast_to_supertypes + Cast the input datatypes to their supertype. + + """ + if args is None: + args = [] + else: + args = [parse_as_expression(a) for a in args] + return self._from_pyexpr( + self._pyexpr.register_plugin( + lib, + symbol, + args, + is_elementwise, + input_wildcard_expansion, + auto_explode, + cast_to_supertypes, + ) + ) + @property def bin(self) -> ExprBinaryNameSpace: """ diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 642ca1afcabc..487369284cf3 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -459,7 +459,7 @@ def contains( item = parse_as_expression(item, str_as_lit=True) return wrap_expr(self._pyexpr.list_contains(item)) - def join(self, separator: str) -> Expr: + def join(self, separator: IntoExpr) -> Expr: """ Join all string items in a sublist and place a separator between them. @@ -489,7 +489,21 @@ def join(self, separator: str) -> Expr: │ x y │ └───────┘ + >>> df = pl.DataFrame( + ... {"s": [["a", "b", "c"], ["x", "y"]], "separator": ["*", "_"]} + ... ) + >>> df.select(pl.col("s").list.join(pl.col("separator"))) + shape: (2, 1) + ┌───────┐ + │ s │ + │ --- │ + │ str │ + ╞═══════╡ + │ a*b*c │ + │ x_y │ + └───────┘ """ + separator = parse_as_expression(separator, str_as_lit=True) return wrap_expr(self._pyexpr.list_join(separator)) def arg_min(self) -> Expr: diff --git a/py-polars/polars/expr/meta.py b/py-polars/polars/expr/meta.py index 59285c3f78c5..7bb334c6c81d 100644 --- a/py-polars/polars/expr/meta.py +++ b/py-polars/polars/expr/meta.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal, overload +from polars.exceptions import ComputeError from polars.utils._wrap import wrap_expr from polars.utils.deprecation import deprecate_nonkeyword_arguments from polars.utils.various import normalize_filepath @@ -88,12 +89,21 @@ def is_regex_projection(self) -> bool: """ return self._pyexpr.meta_is_regex_projection() - def output_name(self) -> str: + @overload + def output_name(self, *, raise_if_undetermined: Literal[True] = True) -> str: + ... + + @overload + def output_name(self, *, raise_if_undetermined: Literal[False]) -> str | None: + ... + + def output_name(self, *, raise_if_undetermined: bool = True) -> str | None: """ Get the column name that this expression would produce. - It may not always be possible to determine the output name, as that can depend - on the schema of the context; in that case this will raise ``ComputeError``. + It may not always be possible to determine the output name as that can depend + on the schema of the context; in that case this will raise ``ComputeError`` if + ``raise_if_undetermined`` is True (the default), or ``None`` otherwise. Examples -------- @@ -109,12 +119,16 @@ def output_name(self) -> str: >>> e_sum_slice = pl.sum("foo").slice(pl.count() - 10, pl.col("bar")) >>> e_sum_slice.meta.output_name() 'foo' - >>> e_count = pl.count() - >>> e_count.meta.output_name() + >>> pl.count().meta.output_name() 'count' """ - return self._pyexpr.meta_output_name() + try: + return self._pyexpr.meta_output_name() + except ComputeError: + if not raise_if_undetermined: + return None + raise def pop(self) -> list[Expr]: """ diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index d1ebf6af33c5..1bd7d945eb09 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -19,6 +19,7 @@ from polars import Expr from polars.type_aliases import ( Ambiguous, + IntoExpr, PolarsDataType, PolarsTemporalType, TimeUnit, @@ -952,17 +953,34 @@ def ends_with(self, suffix: str | Expr) -> Expr: │ null ┆ null │ └────────┴────────────┘ + >>> df = pl.DataFrame( + ... {"fruits": ["apple", "mango", "banana"], "suffix": ["le", "go", "nu"]} + ... ) + >>> df.with_columns( + ... pl.col("fruits").str.ends_with(pl.col("suffix")).alias("has_suffix"), + ... ) + shape: (3, 3) + ┌────────┬────────┬────────────┐ + │ fruits ┆ suffix ┆ has_suffix │ + │ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ bool │ + ╞════════╪════════╪════════════╡ + │ apple ┆ le ┆ true │ + │ mango ┆ go ┆ true │ + │ banana ┆ nu ┆ false │ + └────────┴────────┴────────────┘ + Using ``ends_with`` as a filter condition: >>> df.filter(pl.col("fruits").str.ends_with("go")) - shape: (1, 1) - ┌────────┐ - │ fruits │ - │ --- │ - │ str │ - ╞════════╡ - │ mango │ - └────────┘ + shape: (1, 2) + ┌────────┬────────┐ + │ fruits ┆ suffix │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞════════╪════════╡ + │ mango ┆ go │ + └────────┴────────┘ """ suffix = parse_as_expression(suffix, str_as_lit=True) @@ -999,17 +1017,34 @@ def starts_with(self, prefix: str | Expr) -> Expr: │ null ┆ null │ └────────┴────────────┘ + >>> df = pl.DataFrame( + ... {"fruits": ["apple", "mango", "banana"], "prefix": ["app", "na", "ba"]} + ... ) + >>> df.with_columns( + ... pl.col("fruits").str.starts_with(pl.col("prefix")).alias("has_prefix"), + ... ) + shape: (3, 3) + ┌────────┬────────┬────────────┐ + │ fruits ┆ prefix ┆ has_prefix │ + │ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ bool │ + ╞════════╪════════╪════════════╡ + │ apple ┆ app ┆ true │ + │ mango ┆ na ┆ false │ + │ banana ┆ ba ┆ true │ + └────────┴────────┴────────────┘ + Using ``starts_with`` as a filter condition: >>> df.filter(pl.col("fruits").str.starts_with("app")) - shape: (1, 1) - ┌────────┐ - │ fruits │ - │ --- │ - │ str │ - ╞════════╡ - │ apple │ - └────────┘ + shape: (1, 2) + ┌────────┬────────┐ + │ fruits ┆ prefix │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞════════╪════════╡ + │ apple ┆ app │ + └────────┴────────┘ """ prefix = parse_as_expression(prefix, str_as_lit=True) @@ -1477,7 +1512,7 @@ def count_matches(self, pattern: str | Expr, *, literal: bool = False) -> Expr: pattern = parse_as_expression(pattern, str_as_lit=True) return wrap_expr(self._pyexpr.str_count_matches(pattern, literal)) - def split(self, by: str, *, inclusive: bool = False) -> Expr: + def split(self, by: IntoExpr, *, inclusive: bool = False) -> Expr: """ Split the string by a substring. @@ -1490,18 +1525,41 @@ def split(self, by: str, *, inclusive: bool = False) -> Expr: Examples -------- - >>> df = pl.DataFrame({"s": ["foo bar", "foo-bar", "foo bar baz"]}) - >>> df.select(pl.col("s").str.split(by=" ")) - shape: (3, 1) - ┌───────────────────────┐ - │ s │ - │ --- │ - │ list[str] │ - ╞═══════════════════════╡ - │ ["foo", "bar"] │ - │ ["foo-bar"] │ - │ ["foo", "bar", "baz"] │ - └───────────────────────┘ + >>> df = pl.DataFrame({"s": ["foo bar", "foo_bar", "foo_bar_baz"]}) + >>> df.with_columns( + ... pl.col("s").str.split(by="_").alias("split"), + ... pl.col("s").str.split(by="_", inclusive=True).alias("split_inclusive"), + ... ) + shape: (3, 3) + ┌─────────────┬───────────────────────┬─────────────────────────┐ + │ s ┆ split ┆ split_inclusive │ + │ --- ┆ --- ┆ --- │ + │ str ┆ list[str] ┆ list[str] │ + ╞═════════════╪═══════════════════════╪═════════════════════════╡ + │ foo bar ┆ ["foo bar"] ┆ ["foo bar"] │ + │ foo_bar ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │ + │ foo_bar_baz ┆ ["foo", "bar", "baz"] ┆ ["foo_", "bar_", "baz"] │ + └─────────────┴───────────────────────┴─────────────────────────┘ + + >>> df = pl.DataFrame( + ... {"s": ["foo^bar", "foo_bar", "foo*bar*baz"], "by": ["_", "_", "*"]} + ... ) + >>> df.with_columns( + ... pl.col("s").str.split(by=pl.col("by")).alias("split"), + ... pl.col("s") + ... .str.split(by=pl.col("by"), inclusive=True) + ... .alias("split_inclusive"), + ... ) + shape: (3, 4) + ┌─────────────┬─────┬───────────────────────┬─────────────────────────┐ + │ s ┆ by ┆ split ┆ split_inclusive │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ str ┆ str ┆ list[str] ┆ list[str] │ + ╞═════════════╪═════╪═══════════════════════╪═════════════════════════╡ + │ foo^bar ┆ _ ┆ ["foo^bar"] ┆ ["foo^bar"] │ + │ foo_bar ┆ _ ┆ ["foo", "bar"] ┆ ["foo_", "bar"] │ + │ foo*bar*baz ┆ * ┆ ["foo", "bar", "baz"] ┆ ["foo*", "bar*", "baz"] │ + └─────────────┴─────┴───────────────────────┴─────────────────────────┘ Returns ------- @@ -1509,6 +1567,7 @@ def split(self, by: str, *, inclusive: bool = False) -> Expr: Expression of data type :class:`Utf8`. """ + by = parse_as_expression(by, str_as_lit=True) if inclusive: return wrap_expr(self._pyexpr.str_split_inclusive(by)) return wrap_expr(self._pyexpr.str_split(by)) diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 4dd2f1699bcd..f800df84b4bb 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -6,7 +6,7 @@ import polars._reexport as pl import polars.functions as F from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Datetime, Int64 -from polars.utils._async import _AsyncDataFrameResult +from polars.utils._async import _AioDataFrameResult, _GeventDataFrameResult from polars.utils._parse_expr_input import ( parse_as_expression, parse_as_list_of_expressions, @@ -23,8 +23,7 @@ if TYPE_CHECKING: - from queue import Queue - from typing import Collection, Literal + from typing import Awaitable, Collection, Literal from polars import DataFrame, Expr, LazyFrame, Series from polars.type_aliases import ( @@ -1652,10 +1651,46 @@ def collect_all( return result +@overload +def collect_all_async( + lazy_frames: Sequence[LazyFrame], + *, + gevent: Literal[True], + type_coercion: bool = True, + predicate_pushdown: bool = True, + projection_pushdown: bool = True, + simplify_expression: bool = True, + no_optimization: bool = True, + slice_pushdown: bool = True, + comm_subplan_elim: bool = True, + comm_subexpr_elim: bool = True, + streaming: bool = True, +) -> _GeventDataFrameResult[list[DataFrame]]: + ... + + +@overload +def collect_all_async( + lazy_frames: Sequence[LazyFrame], + *, + gevent: Literal[False] = False, + type_coercion: bool = True, + predicate_pushdown: bool = True, + projection_pushdown: bool = True, + simplify_expression: bool = True, + no_optimization: bool = False, + slice_pushdown: bool = True, + comm_subplan_elim: bool = True, + comm_subexpr_elim: bool = True, + streaming: bool = False, +) -> Awaitable[list[DataFrame]]: + ... + + def collect_all_async( lazy_frames: Sequence[LazyFrame], - queue: Queue[list[DataFrame] | Exception], *, + gevent: bool = False, type_coercion: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, @@ -1665,33 +1700,46 @@ def collect_all_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, streaming: bool = False, -) -> _AsyncDataFrameResult[list[DataFrame]]: +) -> Awaitable[list[DataFrame]] | _GeventDataFrameResult[list[DataFrame]]: """ Collect multiple LazyFrames at the same time asynchronously in thread pool. - Collects into a list of DataFrame, like :func:`polars.collect_all` - but instead of returning them directly its collected inside thread pool - and gets put into `queue` with `put_nowait` method, - while this method returns almost instantly. + Collects into a list of DataFrame (like :func:`polars.collect_all`), + but instead of returning them directly, they are scheduled to be collected + inside thread pool, while this method returns almost instantly. May be useful if you use gevent or asyncio and want to release control to other greenlets/tasks while LazyFrames are being collected. - You must use correct queue in that case. - Given `queue` must be thread safe! - - For gevent use - [`gevent.queue.Queue`](https://www.gevent.org/api/gevent.queue.html#gevent.queue.Queue). - For asyncio - [`asyncio.queues.Queue`](https://docs.python.org/3/library/asyncio-queue.html#queue) - can not be used, since it's not thread safe! - For that purpose use [janus](https://github.com/aio-libs/janus) library. + Parameters + ---------- + lazy_frames + A list of LazyFrames to collect. + gevent + Return wrapper to `gevent.event.AsyncResult` instead of Awaitable + type_coercion + Do type coercion optimization. + predicate_pushdown + Do predicate pushdown optimization. + projection_pushdown + Do projection pushdown optimization. + simplify_expression + Run simplify expressions optimization. + no_optimization + Turn off (certain) optimizations. + slice_pushdown + Slice pushdown optimization. + comm_subplan_elim + Will try to cache branching subplans that occur on self-joins or unions. + comm_subexpr_elim + Common subexpressions will be cached and reused. + streaming + Run parts of the query in a streaming fashion (this is in an alpha state) Notes ----- - Results are put in queue exactly once using `put_nowait`. - If error occurred then Exception will be put in the queue instead of result - which is then raised by returned wrapper `get` method. + In case of error `set_exception` is used on + `asyncio.Future`/`gevent.event.AsyncResult` and will be reraised by them. Warnings -------- @@ -1705,8 +1753,10 @@ def collect_all_async( Returns ------- - Wrapper that has `get` method and `queue` attribute with given queue. - `get` accepts kwargs that are passed down to `queue.get`. + If `gevent=False` (default) then returns awaitable. + + If `gevent=True` then returns wrapper that has + `.get(block=True, timeout=None)` method. """ if no_optimization: predicate_pushdown = False @@ -1731,9 +1781,9 @@ def collect_all_async( ) prepared.append(ldf) - result = _AsyncDataFrameResult(queue) - plr.collect_all_with_callback(prepared, result._callback_all) - return result + result = _GeventDataFrameResult() if gevent else _AioDataFrameResult() + plr.collect_all_with_callback(prepared, result._callback_all) # type: ignore[attr-defined] + return result # type: ignore[return-value] def select(*exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr) -> DataFrame: diff --git a/py-polars/polars/io/__init__.py b/py-polars/polars/io/__init__.py index 995bc4552c55..f4a39b5f778a 100644 --- a/py-polars/polars/io/__init__.py +++ b/py-polars/polars/io/__init__.py @@ -4,12 +4,13 @@ from polars.io.csv import read_csv, read_csv_batched, scan_csv from polars.io.database import read_database, read_database_uri from polars.io.delta import read_delta, scan_delta -from polars.io.excel import read_excel +from polars.io.iceberg import scan_iceberg from polars.io.ipc import read_ipc, read_ipc_schema, read_ipc_stream, scan_ipc from polars.io.json import read_json from polars.io.ndjson import read_ndjson, scan_ndjson from polars.io.parquet import read_parquet, read_parquet_schema, scan_parquet from polars.io.pyarrow_dataset import scan_pyarrow_dataset +from polars.io.spreadsheet import read_excel, read_ods __all__ = [ "read_avro", @@ -24,10 +25,12 @@ "read_ipc_schema", "read_json", "read_ndjson", + "read_ods", "read_parquet", "read_parquet_schema", "scan_csv", "scan_delta", + "scan_iceberg", "scan_ipc", "scan_ndjson", "scan_parquet", diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index babb641f56b6..347d137b3b7b 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -273,9 +273,10 @@ def read_database( # noqa: D417 An instantiated connection (or cursor/client object) that the query can be executed against. batch_size - The number of rows to fetch each time as data is collected; if this option is - supported by the backend it will be passed to the underlying query execution - method (if the backend does not have such support it is ignored without error). + Enable batched data fetching and set the number of rows to fetch each time as + data is collected. If supported by the backend, it will be passed to the + underlying query execution method. If the backend does not support changing the + batch size, it is ignored without error. schema_overrides A dictionary mapping column names to dtypes, used to override the schema inferred from the query cursor or given by the incoming Arrow data (depending diff --git a/py-polars/polars/io/excel/__init__.py b/py-polars/polars/io/excel/__init__.py deleted file mode 100644 index f51df1698e1f..000000000000 --- a/py-polars/polars/io/excel/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from polars.io.excel.functions import read_excel - -__all__ = ["read_excel"] diff --git a/py-polars/polars/io/excel/functions.py b/py-polars/polars/io/excel/functions.py deleted file mode 100644 index 2fc750e1ecac..000000000000 --- a/py-polars/polars/io/excel/functions.py +++ /dev/null @@ -1,380 +0,0 @@ -from __future__ import annotations - -import re -from io import StringIO -from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, NoReturn, Sequence, overload - -import polars._reexport as pl -from polars.exceptions import NoDataError -from polars.io.csv.functions import read_csv -from polars.utils.various import normalize_filepath - -if TYPE_CHECKING: - from io import BytesIO - from typing import Literal - - -@overload -def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, - *, - sheet_id: None = ..., - sheet_name: str, - xlsx2csv_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., - engine: Literal["xlsx2csv", "openpyxl"] | None = ..., - raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... - - -@overload -def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, - *, - sheet_id: None = ..., - sheet_name: None = ..., - xlsx2csv_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., - engine: Literal["xlsx2csv", "openpyxl"] | None = ..., - raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... - - -@overload -def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, - *, - sheet_id: int, - sheet_name: str, - xlsx2csv_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., - engine: Literal["xlsx2csv", "openpyxl"] | None = ..., - raise_if_empty: bool = ..., -) -> NoReturn: - ... - - -# note: 'ignore' required as mypy thinks that the return value for -# Literal[0] overlaps with the return value for other integers -@overload # type: ignore[misc] -def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, - *, - sheet_id: Literal[0] | Sequence[int], - sheet_name: None = ..., - xlsx2csv_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., - engine: Literal["xlsx2csv", "openpyxl"] | None = ..., - raise_if_empty: bool = ..., -) -> dict[str, pl.DataFrame]: - ... - - -@overload -def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, - *, - sheet_id: None, - sheet_name: list[str] | tuple[str], - xlsx2csv_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., - engine: Literal["xlsx2csv", "openpyxl"] | None = ..., - raise_if_empty: bool = ..., -) -> dict[str, pl.DataFrame]: - ... - - -@overload -def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, - *, - sheet_id: int, - sheet_name: None = ..., - xlsx2csv_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., - engine: Literal["xlsx2csv", "openpyxl"] | None = ..., - raise_if_empty: bool = ..., -) -> pl.DataFrame: - ... - - -def read_excel( - source: str | BytesIO | Path | BinaryIO | bytes, - *, - sheet_id: int | Sequence[int] | None = None, - sheet_name: str | list[str] | tuple[str] | None = None, - xlsx2csv_options: dict[str, Any] | None = None, - read_csv_options: dict[str, Any] | None = None, - engine: Literal["xlsx2csv", "openpyxl"] | None = None, - raise_if_empty: bool = True, -) -> pl.DataFrame | dict[str, pl.DataFrame]: - """ - Read Excel (XLSX) sheet into a DataFrame. - - If using the ``xlsx2csv`` engine, converts an Excel sheet with - ``xlsx2csv.Xlsx2csv().convert()`` to CSV and parses the CSV output with - :func:`read_csv`. You can pass additional options to ``read_csv_options`` to - influence parsing behaviour. - - When using the ``openpyxl`` engine, reads an Excel sheet with - ``openpyxl.load_workbook(source)``. - - Parameters - ---------- - source - Path to a file or a file-like object (by file-like object, we refer to objects - that have a ``read()`` method, such as a file handler (e.g. via builtin ``open`` - function) or ``BytesIO``). - sheet_id - Sheet number to convert (set ``0`` to load all sheets as DataFrames) and return - a ``{sheetname:frame,}`` dict. (Defaults to `1` if neither this nor `sheet_name` - are specified). Can also take a sequence of sheet numbers. - sheet_name - Sheet name()s to convert; cannot be used in conjunction with `sheet_id`. If more - than one is given then a ``{sheetname:frame,}`` dict is returned. - xlsx2csv_options - Extra options passed to ``xlsx2csv.Xlsx2csv()``, - e.g. ``{"skip_empty_lines": True}`` - read_csv_options - Extra options passed to :func:`read_csv` for parsing the CSV file returned by - ``xlsx2csv.Xlsx2csv().convert()`` - e.g.: ``{"has_header": False, "new_columns": ["a", "b", "c"], - "infer_schema_length": None}`` - engine - Library used to parse Excel, either openpyxl or xlsx2csv (default is xlsx2csv). - Please note that xlsx2csv converts first to csv, making type inference worse - than openpyxl. To remedy that, you can use the extra options defined on - `xlsx2csv_options` and `read_csv_options` - raise_if_empty - When there is no data in the sheet,``NoDataError`` is raised. If this parameter - is set to False, an empty DataFrame (with no columns) is returned instead. - - Returns - ------- - DataFrame, or a sheetname to DataFrame dict when ``sheet_id == 0``. - - Examples - -------- - Read the "data" worksheet from an Excel file into a DataFrame. - - >>> pl.read_excel( - ... source="test.xlsx", - ... sheet_name="data", - ... ) # doctest: +SKIP - - Read sheet 3 from Excel sheet file to a DataFrame while skipping empty lines in the - sheet. As sheet 3 does not have header row, pass the needed settings to - :func:`read_csv`. - - >>> pl.read_excel( - ... source="test.xlsx", - ... sheet_id=3, - ... xlsx2csv_options={"skip_empty_lines": True}, - ... read_csv_options={"has_header": False, "new_columns": ["a", "b", "c"]}, - ... ) # doctest: +SKIP - - If the correct datatypes can't be determined by polars, look at the :func:`read_csv` - documentation to see which options you can pass to fix this issue. For example - ``"infer_schema_length": None`` can be used to read the data twice, once to infer - the correct output types and once to convert the input to the correct types. - When `"infer_schema_length": 1000``, only the first 1000 lines are read twice. - - >>> pl.read_excel( - ... source="test.xlsx", - ... read_csv_options={"infer_schema_length": None}, - ... ) # doctest: +SKIP - - The ``openpyxl`` engine can also be used to provide automatic type inference. - To do so, specify the right engine (`xlsx2csv_options` and `read_csv_options` - will be ignored): - - >>> pl.read_excel( - ... source="test.xlsx", - ... engine="openpyxl", - ... ) # doctest: +SKIP - - If :func:`read_excel` does not work or you need to read other types of - spreadsheet files, you can try pandas ``pd.read_excel()`` - (supports `xls`, `xlsx`, `xlsm`, `xlsb`, `odf`, `ods` and `odt`). - - >>> pl.from_pandas(pd.read_excel("test.xlsx")) # doctest: +SKIP - - """ - if sheet_id is not None and sheet_name is not None: - raise ValueError( - f"cannot specify both `sheet_name` ({sheet_name!r}) and `sheet_id` ({sheet_id!r})" - ) - - if xlsx2csv_options is None: - xlsx2csv_options = {} - if read_csv_options is None: - read_csv_options = {"truncate_ragged_lines": True} - elif "truncate_ragged_lines" not in read_csv_options: - read_csv_options["truncate_ragged_lines"] = True - - # establish the reading function, parser, and available worksheets - reader_fn, parser, worksheets = _initialise_excel_parser( - engine, source, xlsx2csv_options - ) - - # use the parser to read data from one or more sheets - if ( - sheet_id == 0 - or isinstance(sheet_id, Sequence) - or (sheet_name and not isinstance(sheet_name, str)) - ): - # read multiple sheets by id - sheet_ids = sheet_id or () - sheet_names = sheet_name or () - return { - sheet["name"]: reader_fn( - parser=parser, - sheet_id=sheet["index"], - sheet_name=None, - read_csv_options=read_csv_options, - raise_if_empty=raise_if_empty, - ) - for sheet in worksheets - if sheet_id == 0 or sheet["index"] in sheet_ids or sheet["name"] in sheet_names # type: ignore[operator] - } - else: - # read a specific sheet by id or name - if sheet_name is None: - sheet_id = sheet_id or 1 - return reader_fn( - parser=parser, - sheet_id=sheet_id, - sheet_name=sheet_name, - read_csv_options=read_csv_options, - raise_if_empty=raise_if_empty, - ) - - -def _initialise_excel_parser( - engine: str | None, - source: str | BytesIO | Path | BinaryIO | bytes, - xlsx2csv_options: dict[str, Any], -) -> tuple[Any, Any, list[dict[str, Any]]]: - """Instantiate the indicated Excel parser and establish related properties.""" - if isinstance(source, (str, Path)): - source = normalize_filepath(source) - - if engine == "openpyxl": - try: - import openpyxl - except ImportError: - raise ImportError( - "openpyxl is not installed\n\nPlease run `pip install openpyxl`" - ) from None - parser: openpyxl.Workbook = openpyxl.load_workbook(source, data_only=True) - sheets = [ - {"index": i + 1, "name": sheet.title} for i, sheet in enumerate(parser) - ] - return _read_excel_sheet_openpyxl, parser, sheets - - elif engine == "xlsx2csv" or engine is None: # default - try: - import xlsx2csv - except ImportError: - raise ModuleNotFoundError( - "xlsx2csv is not installed\n\nPlease run: `pip install xlsx2csv`" - ) from None - parser: xlsx2csv.Xlsx2csv = xlsx2csv.Xlsx2csv(source, **xlsx2csv_options) # type: ignore[no-redef] - sheets = parser.workbook.sheets - return _read_excel_sheet_xlsx2csv, parser, sheets - - raise NotImplementedError(f"Unrecognised engine: {engine!r}") - - -def _drop_unnamed_null_columns(df: pl.DataFrame) -> pl.DataFrame: - """If DataFrame contains unnamed columns that contain only nulls, drop them.""" - if "" in df.columns: - null_cols = [] - for col_name in df.columns: - # note that if multiple unnamed columns are found then all but - # the first one will be ones will be named as "_duplicated_{n}" - if col_name == "" or re.match(r"_duplicated_\d+$", col_name): - if df[col_name].null_count() == len(df): - null_cols.append(col_name) - if null_cols: - df = df.drop(*null_cols) - return df - - -def _read_excel_sheet_openpyxl( - parser: Any, - sheet_id: int | None, - sheet_name: str | None, - read_csv_options: dict[str, Any] | None, - *, - raise_if_empty: bool, -) -> pl.DataFrame: - """Use the 'openpyxl' library to read data from the given worksheet.""" - # read requested sheet if provided on kwargs, otherwise read active sheet - if sheet_name is not None: - ws = parser[sheet_name] - elif sheet_id is not None: - ws = parser.worksheets[sheet_id - 1] - else: - ws = parser.active - - # prefer detection of actual table objects; otherwise read - # data in the used worksheet range, dropping null columns - header: list[str | None] = [] - if tables := getattr(ws, "tables", None): - table = next(iter(tables.values())) - rows = list(ws[table.ref]) - header.extend(cell.value for cell in rows.pop(0)) - if table.totalsRowCount: - rows = rows[: -table.totalsRowCount] - rows_iter = iter(rows) - else: - rows_iter = ws.iter_rows() - for row in rows_iter: - row_values = [cell.value for cell in row] - if any(v is not None for v in row_values): - header.extend(row_values) - break - - series_data = [ - pl.Series(name, [cell.value for cell in column_data]) - for name, column_data in zip(header, zip(*rows_iter)) - ] - df = pl.DataFrame({s.name: s for s in series_data if s.name}) - if raise_if_empty and len(df) == 0 and len(df.columns) == 0: - raise NoDataError( - "Empty Excel sheet; if you want to read this as " - "an empty DataFrame, set `raise_if_empty=False`" - ) - return _drop_unnamed_null_columns(df) - - -def _read_excel_sheet_xlsx2csv( - parser: Any, - sheet_id: int | None, - sheet_name: str | None, - read_csv_options: dict[str, Any], - *, - raise_if_empty: bool, -) -> pl.DataFrame: - """Use the 'xlsx2csv' library to read data from the given worksheet.""" - # parse sheet data into the given buffer - csv_buffer = StringIO() - parser.convert(outfile=csv_buffer, sheetid=sheet_id, sheetname=sheet_name) - - # handle (completely) empty sheet data - if csv_buffer.tell() == 0: - if raise_if_empty: - raise NoDataError( - "Empty Excel sheet; if you want to read this as " - "an empty DataFrame, set `raise_if_empty=False`" - ) - return pl.DataFrame() - - # otherwise rewind the buffer and parse as csv - csv_buffer.seek(0) - df = read_csv(csv_buffer, **read_csv_options) - return _drop_unnamed_null_columns(df) diff --git a/py-polars/polars/io/iceberg.py b/py-polars/polars/io/iceberg.py new file mode 100644 index 000000000000..2feedaaff255 --- /dev/null +++ b/py-polars/polars/io/iceberg.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +import ast +from _ast import GtE, Lt, LtE +from ast import ( + Attribute, + BinOp, + BitAnd, + BitOr, + Call, + Compare, + Constant, + Eq, + Gt, + Invert, + List, + Name, + UnaryOp, +) +from functools import partial, singledispatch +from typing import TYPE_CHECKING, Any, Callable + +import polars._reexport as pl +from polars.dependencies import pyiceberg +from polars.utils.convert import _to_python_date, _to_python_datetime + +if TYPE_CHECKING: + from datetime import date, datetime + + from pyiceberg.table import Table + + from polars import DataFrame, LazyFrame, Series + +__all__ = ["scan_iceberg"] + +_temporal_conversions: dict[str, Callable[..., datetime | date]] = { + "_to_python_date": _to_python_date, + "_to_python_datetime": _to_python_datetime, +} + + +def scan_iceberg( + source: str | Table, + *, + storage_options: dict[str, Any] | None = None, +) -> LazyFrame: + """ + Lazily read from an Apache Iceberg table. + + Parameters + ---------- + source + A PyIceberg table, or a direct path to the metadata. + + Note: For Local filesystem, absolute and relative paths are supported but + for the supported object storages - GCS, Azure and S3 full URI must be provided. + storage_options + Extra options for the storage backends supported by `pyiceberg`. + For cloud storages, this may include configurations for authentication etc. + + More info is available `here `__. + + Returns + ------- + LazyFrame + + Examples + -------- + Creates a scan for an Iceberg table from local filesystem, or object store. + + >>> table_path = "file:/path/to/iceberg-table/metadata.json" + >>> pl.scan_iceberg(table_path).collect() # doctest: +SKIP + + Creates a scan for an Iceberg table from S3. + See a list of supported storage options for S3 `here + `__. + + >>> table_path = "s3://bucket/path/to/iceberg-table/metadata.json" + >>> storage_options = { + ... "s3.region": "eu-central-1", + ... "s3.access-key-id": "THE_AWS_ACCESS_KEY_ID", + ... "s3.secret-access-key": "THE_AWS_SECRET_ACCESS_KEY", + ... } + >>> pl.scan_iceberg( + ... table_path, storage_options=storage_options + ... ).collect() # doctest: +SKIP + + Creates a scan for an Iceberg table from Azure. + Supported options for Azure are available `here + `__. + + Following type of table paths are supported: + * az:////metadata.json + * adl:////metadata.json + * abfs[s]:////metadata.json + + >>> table_path = "az://container/path/to/iceberg-table/metadata.json" + >>> storage_options = { + ... "adlfs.account-name": "AZURE_STORAGE_ACCOUNT_NAME", + ... "adlfs.account-key": "AZURE_STORAGE_ACCOUNT_KEY", + ... } + >>> pl.scan_iceberg( + ... table_path, storage_options=storage_options + ... ).collect() # doctest: +SKIP + + Creates a scan for an Iceberg table from Google Cloud Storage. + Supported options for GCS are available `here + `__. + + >>> table_path = "s3://bucket/path/to/iceberg-table/metadata.json" + >>> storage_options = { + ... "gcs.project-id": "my-gcp-project", + ... "gcs.oauth.token": "ya29.dr.AfM...", + ... } + >>> pl.scan_iceberg( + ... table_path, storage_options=storage_options + ... ).collect() # doctest: +SKIP + + Creates a scan for an Iceberg table with additional options. + In the below example, `without_files` option is used which loads the table without + file tracking information. + + >>> table_path = "/path/to/iceberg-table/metadata.json" + >>> storage_options = {"py-io-impl": "pyiceberg.io.fsspec.FsspecFileIO"} + >>> pl.scan_iceberg( + ... table_path, storage_options=storage_options + ... ).collect() # doctest: +SKIP + + """ + from pyiceberg.io.pyarrow import schema_to_pyarrow + from pyiceberg.table import StaticTable + + if isinstance(source, str): + source = StaticTable.from_metadata( + metadata_location=source, properties=storage_options or {} + ) + + func = partial(_scan_pyarrow_dataset_impl, source) + arrow_schema = schema_to_pyarrow(source.schema()) + return pl.LazyFrame._scan_python_function(arrow_schema, func, pyarrow=True) + + +def _scan_pyarrow_dataset_impl( + tbl: Table, + with_columns: list[str] | None = None, + predicate: str = "", + n_rows: int | None = None, + **kwargs: Any, +) -> DataFrame | Series: + """ + Take the projected columns and materialize an arrow table. + + Parameters + ---------- + tbl + pyarrow dataset + with_columns + Columns that are projected + predicate + pyarrow expression that can be evaluated with eval + n_rows: + Materialize only n rows from the arrow dataset. + batch_size + The maximum row count for scanned pyarrow record batches. + kwargs: + For backward compatibility + + Returns + ------- + DataFrame + + """ + from polars import from_arrow + + scan = tbl.scan(limit=n_rows) + + if with_columns is not None: + scan = scan.select(*with_columns) + + if predicate is not None: + try: + expr_ast = _to_ast(predicate) + pyiceberg_expr = _convert_predicate(expr_ast) + except ValueError as e: + raise ValueError( + f"Could not convert predicate to PyIceberg: {predicate}" + ) from e + + scan = scan.filter(pyiceberg_expr) + + return from_arrow(scan.to_arrow()) + + +def _to_ast(expr: str) -> ast.expr: + """ + Converts a Python string to an AST. + + This will take the Python Arrow expression (as a string), and it will + be converted into a Python AST that can be traversed to convert it to a PyIceberg + expression. + + The reason to convert it to an AST is because the PyArrow expression + itself doesn't have any methods/properties to traverse the expression. + We need this to convert it into a PyIceberg expression. + + Parameters + ---------- + expr + The string expression + + Returns + ------- + The AST representing the Arrow expression + """ + return ast.parse(expr, mode="eval").body + + +@singledispatch +def _convert_predicate(a: Any) -> Any: + """Walks the AST to convert the PyArrow expression to a PyIceberg expression.""" + raise ValueError(f"Unexpected symbol: {a}") + + +@_convert_predicate.register(Constant) +def _(a: Constant) -> Any: + return a.value + + +@_convert_predicate.register(Name) +def _(a: Name) -> Any: + return a.id + + +@_convert_predicate.register(UnaryOp) +def _(a: UnaryOp) -> Any: + if isinstance(a.op, Invert): + return pyiceberg.expressions.Not(_convert_predicate(a.operand)) + else: + raise TypeError(f"Unexpected UnaryOp: {a}") + + +@_convert_predicate.register(Call) +def _(a: Call) -> Any: + args = [_convert_predicate(arg) for arg in a.args] + f = _convert_predicate(a.func) + if f == "field": + return args + elif f in _temporal_conversions: + # convert from polars-native i64 to ISO8601 string + return _temporal_conversions[f](*args).isoformat() + else: + ref = _convert_predicate(a.func.value)[0] # type: ignore[attr-defined] + if f == "isin": + return pyiceberg.expressions.In(ref, args[0]) + elif f == "is_null": + return pyiceberg.expressions.IsNull(ref) + elif f == "is_nan": + return pyiceberg.expressions.IsNaN(ref) + + raise ValueError(f"Unknown call: {f!r}") + + +@_convert_predicate.register(Attribute) +def _(a: Attribute) -> Any: + return a.attr + + +@_convert_predicate.register(BinOp) +def _(a: BinOp) -> Any: + lhs = _convert_predicate(a.left) + rhs = _convert_predicate(a.right) + + op = a.op + if isinstance(op, BitAnd): + return pyiceberg.expressions.And(lhs, rhs) + if isinstance(op, BitOr): + return pyiceberg.expressions.Or(lhs, rhs) + else: + raise TypeError(f"Unknown: {lhs} {op} {rhs}") + + +@_convert_predicate.register(Compare) +def _(a: Compare) -> Any: + op = a.ops[0] + lhs = _convert_predicate(a.left)[0] + rhs = _convert_predicate(a.comparators[0]) + + if isinstance(op, Gt): + return pyiceberg.expressions.GreaterThan(lhs, rhs) + if isinstance(op, GtE): + return pyiceberg.expressions.GreaterThanOrEqual(lhs, rhs) + if isinstance(op, Eq): + return pyiceberg.expressions.EqualTo(lhs, rhs) + if isinstance(op, Lt): + return pyiceberg.expressions.LessThan(lhs, rhs) + if isinstance(op, LtE): + return pyiceberg.expressions.LessThanOrEqual(lhs, rhs) + else: + raise TypeError(f"Unknown comparison: {op}") + + +@_convert_predicate.register(List) +def _(a: List) -> Any: + return [_convert_predicate(e) for e in a.elts] diff --git a/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py b/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py index 49c465f3f24e..cb81aea73ad0 100644 --- a/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py +++ b/py-polars/polars/io/pyarrow_dataset/anonymous_scan.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING import polars._reexport as pl -from polars.dependencies import pyarrow as pa # noqa: TCH001 +from polars.dependencies import pyarrow as pa if TYPE_CHECKING: from polars import DataFrame, LazyFrame @@ -71,16 +71,29 @@ def _scan_pyarrow_dataset_impl( from polars import from_arrow _filter = None + if predicate: - # imports are used by inline python evaluated by `eval` - from polars.datatypes import Date, Datetime, Duration # noqa: F401 + from polars.datatypes import Date, Datetime, Duration from polars.utils.convert import ( - _to_python_datetime, # noqa: F401 - _to_python_time, # noqa: F401 - _to_python_timedelta, # noqa: F401 + _to_python_date, + _to_python_datetime, + _to_python_time, + _to_python_timedelta, ) - _filter = eval(predicate) + _filter = eval( + predicate, + { + "pa": pa, + "Date": Date, + "Datetime": Datetime, + "Duration": Duration, + "_to_python_date": _to_python_date, + "_to_python_datetime": _to_python_datetime, + "_to_python_time": _to_python_time, + "_to_python_timedelta": _to_python_timedelta, + }, + ) common_params = {"columns": with_columns, "filter": _filter} if batch_size is not None: diff --git a/py-polars/polars/io/spreadsheet/__init__.py b/py-polars/polars/io/spreadsheet/__init__.py new file mode 100644 index 000000000000..ba6af1cd69d6 --- /dev/null +++ b/py-polars/polars/io/spreadsheet/__init__.py @@ -0,0 +1,3 @@ +from polars.io.spreadsheet.functions import read_excel, read_ods + +__all__ = ["read_excel", "read_ods"] diff --git a/py-polars/polars/io/excel/_write_utils.py b/py-polars/polars/io/spreadsheet/_write_utils.py similarity index 100% rename from py-polars/polars/io/excel/_write_utils.py rename to py-polars/polars/io/spreadsheet/_write_utils.py diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py new file mode 100644 index 000000000000..ecbf01a4d9c2 --- /dev/null +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -0,0 +1,687 @@ +from __future__ import annotations + +import re +from io import StringIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO, Callable, NoReturn, Sequence, overload + +import polars._reexport as pl +from polars import functions as F +from polars.datatypes import Date, Datetime +from polars.exceptions import NoDataError, ParameterCollisionError +from polars.io.csv.functions import read_csv +from polars.utils.various import normalize_filepath + +if TYPE_CHECKING: + from io import BytesIO + from typing import Literal + + from polars.type_aliases import SchemaDict + + +@overload +def read_excel( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: None = ..., + sheet_name: str, + engine: Literal["xlsx2csv", "openpyxl"] | None = ..., + xlsx2csv_options: dict[str, Any] | None = ..., + read_csv_options: dict[str, Any] | None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> pl.DataFrame: + ... + + +@overload +def read_excel( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: None = ..., + sheet_name: None = ..., + engine: Literal["xlsx2csv", "openpyxl"] | None = ..., + xlsx2csv_options: dict[str, Any] | None = ..., + read_csv_options: dict[str, Any] | None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> pl.DataFrame: + ... + + +@overload +def read_excel( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: int, + sheet_name: str, + engine: Literal["xlsx2csv", "openpyxl"] | None = ..., + xlsx2csv_options: dict[str, Any] | None = ..., + read_csv_options: dict[str, Any] | None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> NoReturn: + ... + + +# note: 'ignore' required as mypy thinks that the return value for +# Literal[0] overlaps with the return value for other integers +@overload # type: ignore[misc] +def read_excel( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: Literal[0] | Sequence[int], + sheet_name: None = ..., + engine: Literal["xlsx2csv", "openpyxl"] | None = ..., + xlsx2csv_options: dict[str, Any] | None = ..., + read_csv_options: dict[str, Any] | None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> dict[str, pl.DataFrame]: + ... + + +@overload +def read_excel( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: int, + sheet_name: None = ..., + engine: Literal["xlsx2csv", "openpyxl"] | None = ..., + xlsx2csv_options: dict[str, Any] | None = ..., + read_csv_options: dict[str, Any] | None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> pl.DataFrame: + ... + + +@overload +def read_excel( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: None, + sheet_name: list[str] | tuple[str], + engine: Literal["xlsx2csv", "openpyxl"] | None = ..., + xlsx2csv_options: dict[str, Any] | None = ..., + read_csv_options: dict[str, Any] | None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> dict[str, pl.DataFrame]: + ... + + +def read_excel( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: int | Sequence[int] | None = None, + sheet_name: str | list[str] | tuple[str] | None = None, + engine: Literal["xlsx2csv", "openpyxl"] | None = None, + xlsx2csv_options: dict[str, Any] | None = None, + read_csv_options: dict[str, Any] | None = None, + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = True, +) -> pl.DataFrame | dict[str, pl.DataFrame]: + """ + Read Excel (XLSX) spreadsheet data into a DataFrame. + + If using the ``xlsx2csv`` engine, converts an Excel sheet with + ``xlsx2csv.Xlsx2csv().convert()`` to CSV and parses the CSV output with + :func:`read_csv`. You can pass additional options to ``read_csv_options`` to + influence parsing behaviour. + + When using the ``openpyxl`` engine, reads an Excel sheet with + ``openpyxl.load_workbook(source)``. + + Parameters + ---------- + source + Path to a file or a file-like object (by file-like object, we refer to objects + that have a ``read()`` method, such as a file handler (e.g. via builtin ``open`` + function) or ``BytesIO``). + sheet_id + Sheet number(s) to convert (set ``0`` to load all sheets as DataFrames) and + return a ``{sheetname:frame,}`` dict. (Defaults to `1` if neither this nor + `sheet_name` are specified). Can also take a sequence of sheet numbers. + sheet_name + Sheet name(s) to convert; cannot be used in conjunction with `sheet_id`. If more + than one is given then a ``{sheetname:frame,}`` dict is returned. + engine + Library used to parse the spreadsheet file; defaults to "xlsx2csv" if not set. + + * "xlsx2csv": the fastest engine; converts the data to an in-memory CSV first + and then uses the polars ``read_csv`` method to parse the result. You can + pass `xlsx2csv_options` and/or `read_csv_options` to refine the conversion. + * "openpyxl": slower than ``xlsx2csv`` but supports additional automatic type + inference; potentially useful if you are unable to parse your sheet with + the ``xlsx2csv`` engine. + * "odf": this engine is only used for OpenOffice files; it will be used + automatically for files with the ".ods" extension. + xlsx2csv_options + Extra options passed to ``xlsx2csv.Xlsx2csv()``, + e.g. ``{"skip_empty_lines": True}`` + read_csv_options + Extra options passed to :func:`read_csv` for parsing the CSV file returned by + ``xlsx2csv.Xlsx2csv().convert()`` + e.g.: ``{"has_header": False, "new_columns": ["a", "b", "c"], + "infer_schema_length": None}`` + schema_overrides + Support type specification or override of one or more columns. + raise_if_empty + When there is no data in the sheet,``NoDataError`` is raised. If this parameter + is set to False, an empty DataFrame (with no columns) is returned instead. + + Returns + ------- + DataFrame, or a ``{sheetname: DataFrame, ...}`` dict if reading multiple sheets. + + Examples + -------- + Read the "data" worksheet from an Excel file into a DataFrame. + + >>> pl.read_excel( + ... source="test.xlsx", + ... sheet_name="data", + ... ) # doctest: +SKIP + + Read table data from sheet 3 in an Excel workbook as a DataFrame while skipping + empty lines in the sheet. As sheet 3 does not have a header row and the default + engine is ``xlsx2csv`` you can pass the necessary additional settings for this + to the "read_csv_options" parameter; these will be passed to :func:`read_csv`. + + >>> pl.read_excel( + ... source="test.xlsx", + ... sheet_id=3, + ... xlsx2csv_options={"skip_empty_lines": True}, + ... read_csv_options={"has_header": False, "new_columns": ["a", "b", "c"]}, + ... ) # doctest: +SKIP + + If the correct datatypes can't be determined you can use ``schema_overrides`` and/or + some of the :func:`read_csv` documentation to see which options you can pass to fix + this issue. For example ``"infer_schema_length": None`` can be used to read the + data twice, once to infer the correct output types and once more to then read the + data with those types. If the types are known in advance then ``schema_overrides`` + is the more efficient option. + + >>> pl.read_excel( + ... source="test.xlsx", + ... read_csv_options={"infer_schema_length": 1000}, + ... schema_overrides={"dt": pl.Date}, + ... ) # doctest: +SKIP + + The ``openpyxl`` package can also be used to parse Excel data; it has slightly + better default type detection, but is slower than ``xlsx2csv``. If you have a sheet + that is better read using this package you can set the engine as "openpyxl" (if you + use this engine then both `xlsx2csv_options` and `read_csv_options` cannot be set). + + >>> pl.read_excel( + ... source="test.xlsx", + ... engine="openpyxl", + ... schema_overrides={"dt": pl.Datetime, "value": pl.Int32}, + ... ) # doctest: +SKIP + + """ + if xlsx2csv_options is None: + xlsx2csv_options = {} + + if read_csv_options is None: + read_csv_options = {"truncate_ragged_lines": True} + elif "truncate_ragged_lines" not in read_csv_options: + read_csv_options["truncate_ragged_lines"] = True + + return _read_spreadsheet( + sheet_id, + sheet_name, + source=source, + engine=engine, + engine_options=xlsx2csv_options, + read_csv_options=read_csv_options, + schema_overrides=schema_overrides, + raise_if_empty=raise_if_empty, + ) + + +@overload +def read_ods( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: None = ..., + sheet_name: str, + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> pl.DataFrame: + ... + + +@overload +def read_ods( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: None = ..., + sheet_name: None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> pl.DataFrame: + ... + + +@overload +def read_ods( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: int, + sheet_name: str, + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> NoReturn: + ... + + +@overload # type: ignore[misc] +def read_ods( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: Literal[0] | Sequence[int], + sheet_name: None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> dict[str, pl.DataFrame]: + ... + + +@overload +def read_ods( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: int, + sheet_name: None = ..., + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> pl.DataFrame: + ... + + +@overload +def read_ods( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: None, + sheet_name: list[str] | tuple[str], + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = ..., +) -> dict[str, pl.DataFrame]: + ... + + +def read_ods( + source: str | BytesIO | Path | BinaryIO | bytes, + *, + sheet_id: int | Sequence[int] | None = None, + sheet_name: str | list[str] | tuple[str] | None = None, + schema_overrides: SchemaDict | None = None, + raise_if_empty: bool = True, +) -> pl.DataFrame | dict[str, pl.DataFrame]: + """ + Read OpenOffice (ODS) spreadsheet data into a DataFrame. + + Parameters + ---------- + source + Path to a file or a file-like object (by file-like object, we refer to objects + that have a ``read()`` method, such as a file handler (e.g. via builtin ``open`` + function) or ``BytesIO``). + sheet_id + Sheet number(s) to convert (set ``0`` to load all sheets as DataFrames) and + return a ``{sheetname:frame,}`` dict. (Defaults to `1` if neither this nor + `sheet_name` are specified). Can also take a sequence of sheet numbers. + sheet_name + Sheet name(s) to convert; cannot be used in conjunction with `sheet_id`. If more + than one is given then a ``{sheetname:frame,}`` dict is returned. + schema_overrides + Support type specification or override of one or more columns. + raise_if_empty + When there is no data in the sheet,``NoDataError`` is raised. If this parameter + is set to False, an empty DataFrame (with no columns) is returned instead. + + Returns + ------- + DataFrame, or a ``{sheetname: DataFrame, ...}`` dict if reading multiple sheets. + + Examples + -------- + Read the "data" worksheet from an OpenOffice spreadsheet file into a DataFrame. + + >>> pl.read_ods( + ... source="test.ods", + ... sheet_name="data", + ... ) # doctest: +SKIP + + If the correct dtypes can't be determined, use the ``schema_overrides`` parameter + to specify them. + + >>> pl.read_ods( + ... source="test.ods", + ... sheet_id=3, + ... schema_overrides={"dt": pl.Date}, + ... raise_if_empty=False, + ... ) # doctest: +SKIP + + """ + return _read_spreadsheet( + sheet_id, + sheet_name, + source=source, + engine="ods", + engine_options={}, + read_csv_options={}, + schema_overrides=schema_overrides, + raise_if_empty=raise_if_empty, + ) + + +def _read_spreadsheet( + sheet_id: int | Sequence[int] | None, + sheet_name: str | list[str] | tuple[str] | None, + source: str | BytesIO | Path | BinaryIO | bytes, + engine: Literal["xlsx2csv", "openpyxl", "ods"] | None, + engine_options: dict[str, Any] | None = None, + read_csv_options: dict[str, Any] | None = None, + schema_overrides: SchemaDict | None = None, + *, + raise_if_empty: bool = True, +) -> pl.DataFrame | dict[str, pl.DataFrame]: + if sheet_id is not None and sheet_name is not None: + raise ValueError( + f"cannot specify both `sheet_name` ({sheet_name!r}) and `sheet_id` ({sheet_id!r})" + ) + + if engine_options is None: + engine_options = {} + + # establish the reading function, parser, and available worksheets + reader_fn, parser, worksheets = _initialise_spreadsheet_parser( + engine, source, engine_options + ) + + # use the parser to read data from one or more sheets + if ( + sheet_id == 0 + or isinstance(sheet_id, Sequence) + or (sheet_name and not isinstance(sheet_name, str)) + ): + # read multiple sheets by id + sheet_ids = sheet_id or () + sheet_names = sheet_name or () + return { + sheet["name"]: reader_fn( + parser=parser, + sheet_id=sheet["index"], + sheet_name=None, + read_csv_options=read_csv_options, + schema_overrides=schema_overrides, + raise_if_empty=raise_if_empty, + ) + for sheet in worksheets + if sheet_id == 0 or sheet["index"] in sheet_ids or sheet["name"] in sheet_names # type: ignore[operator] + } + else: + # read a specific sheet by id or name + if sheet_name is None: + sheet_id = sheet_id or 1 + + return reader_fn( + parser=parser, + sheet_id=sheet_id, + sheet_name=sheet_name, + read_csv_options=read_csv_options, + schema_overrides=schema_overrides, + raise_if_empty=raise_if_empty, + ) + + +def _initialise_spreadsheet_parser( + engine: Literal["xlsx2csv", "openpyxl", "ods"] | None, + source: str | BytesIO | Path | BinaryIO | bytes, + engine_options: dict[str, Any], +) -> tuple[Callable[..., pl.DataFrame], Any, list[dict[str, Any]]]: + """Instantiate the indicated spreadsheet parser and establish related properties.""" + if isinstance(source, (str, Path)): + source = normalize_filepath(source) + if engine is None and str(source).lower().endswith(".ods"): + engine = "ods" + + if engine == "xlsx2csv" or engine is None: # default + try: + import xlsx2csv + except ImportError: + raise ModuleNotFoundError( + "Required package not installed\n\nPlease run: `pip install xlsx2csv`" + ) from None + parser = xlsx2csv.Xlsx2csv(source, **engine_options) + sheets = parser.workbook.sheets + return _read_spreadsheet_xlsx2csv, parser, sheets + + elif engine == "openpyxl": + try: + import openpyxl + except ImportError: + raise ImportError( + "Required package not installed\n\nPlease run `pip install openpyxl`" + ) from None + parser = openpyxl.load_workbook(source, data_only=True, **engine_options) + sheets = [{"index": i + 1, "name": ws.title} for i, ws in enumerate(parser)] + return _read_spreadsheet_openpyxl, parser, sheets + + elif engine == "ods": + try: + import ezodf + except ImportError: + raise ImportError( + "Required package not installed\n\nPlease run `pip install ezodf lxml`" + ) from None + parser = ezodf.opendoc(source, **engine_options) + sheets = [ + {"index": i + 1, "name": ws.name} for i, ws in enumerate(parser.sheets) + ] + return _read_spreadsheet_ods, parser, sheets + + raise NotImplementedError(f"Unrecognised engine: {engine!r}") + + +def _csv_buffer_to_frame( + csv: StringIO, + separator: str, + read_csv_options: dict[str, Any] | None, + schema_overrides: SchemaDict | None, + *, + raise_if_empty: bool, +) -> pl.DataFrame: + """Translate StringIO buffer containing delimited data as a DataFrame.""" + # handle (completely) empty sheet data + if csv.tell() == 0: + if raise_if_empty: + raise NoDataError( + "Empty Excel sheet; if you want to read this as " + "an empty DataFrame, set `raise_if_empty=False`" + ) + return pl.DataFrame() + + if read_csv_options is None: + read_csv_options = {} + if schema_overrides: + if (csv_dtypes := read_csv_options.get("dtypes", {})) and set( + csv_dtypes + ).intersection(schema_overrides): + raise ParameterCollisionError( + "Cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`" + ) + read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides} + + # otherwise rewind the buffer and parse as csv + csv.seek(0) + df = read_csv( + csv, + separator=separator, + **read_csv_options, + ) + return _drop_unnamed_null_columns(df) + + +def _drop_unnamed_null_columns(df: pl.DataFrame) -> pl.DataFrame: + """If DataFrame contains unnamed columns that contain only nulls, drop them.""" + null_cols = [] + for col_name in df.columns: + # note that if multiple unnamed columns are found then all but + # the first one will be ones will be named as "_duplicated_{n}" + if col_name == "" or re.match(r"_duplicated_\d+$", col_name): + if df[col_name].null_count() == len(df): + null_cols.append(col_name) + if null_cols: + df = df.drop(*null_cols) + return df + + +def _read_spreadsheet_ods( + parser: Any, + sheet_id: int | None, + sheet_name: str | None, + read_csv_options: dict[str, Any] | None, + schema_overrides: SchemaDict | None, + *, + raise_if_empty: bool, +) -> pl.DataFrame: + """Use the 'ezodf' library to read data from the given worksheet.""" + sheets = parser.sheets + if sheet_id is not None: + ws = sheets[sheet_id - 1] + elif sheet_name is not None: + ws = next((s for s in sheets if s.name == sheet_name), None) + if ws is None: + raise ValueError(f"Sheet {sheet_name!r} not found") + else: + ws = sheets[0] + + row_data = [] + found_row_data = False + for row in ws.rows(): + row_values = [c.value for c in row] + if found_row_data or (found_row_data := any(v is not None for v in row_values)): + row_data.append(row_values) + + overrides = {} + strptime_cols = {} + headers: list[str] = [] + + if not row_data: + df = pl.DataFrame() + else: + for idx, name in enumerate(row_data[0]): + headers.append(name or (f"_duplicated_{idx}" if headers else "")) + + trailing_null_row = all(v is None for v in row_data[-1]) + row_data = row_data[1 : -1 if trailing_null_row else None] + + if schema_overrides: + for nm, dtype in schema_overrides.items(): + if dtype in (Datetime, Date): + strptime_cols[nm] = dtype + else: + overrides[nm] = dtype + + df = pl.DataFrame( + row_data, + orient="row", + schema=headers, + schema_overrides=overrides, + ) + if raise_if_empty and len(df) == 0 and len(df.columns) == 0: + raise NoDataError( + "Empty Excel sheet; if you want to read this as " + "an empty DataFrame, set `raise_if_empty=False`" + ) + + if strptime_cols: + df = df.with_columns( + F.col(nm).str.strptime(dtype) # type: ignore[arg-type] + for nm, dtype in strptime_cols.items() + ) + + df.columns = headers + return _drop_unnamed_null_columns(df) + + +def _read_spreadsheet_openpyxl( + parser: Any, + sheet_id: int | None, + sheet_name: str | None, + read_csv_options: dict[str, Any] | None, + schema_overrides: SchemaDict | None, + *, + raise_if_empty: bool, +) -> pl.DataFrame: + """Use the 'openpyxl' library to read data from the given worksheet.""" + # read requested sheet if provided on kwargs, otherwise read active sheet + if sheet_name is not None: + ws = parser[sheet_name] + elif sheet_id is not None: + ws = parser.worksheets[sheet_id - 1] + else: + ws = parser.active + + # prefer detection of actual table objects; otherwise read + # data in the used worksheet range, dropping null columns + header: list[str | None] = [] + if tables := getattr(ws, "tables", None): + table = next(iter(tables.values())) + rows = list(ws[table.ref]) + header.extend(cell.value for cell in rows.pop(0)) + if table.totalsRowCount: + rows = rows[: -table.totalsRowCount] + rows_iter = iter(rows) + else: + rows_iter = ws.iter_rows() + for row in rows_iter: + row_values = [cell.value for cell in row] + if any(v is not None for v in row_values): + header.extend(row_values) + break + + series_data = [ + pl.Series(name, [cell.value for cell in column_data]) + for name, column_data in zip(header, zip(*rows_iter)) + ] + df = pl.DataFrame( + {s.name: s for s in series_data if s.name}, + schema_overrides=schema_overrides, + ) + if raise_if_empty and len(df) == 0 and len(df.columns) == 0: + raise NoDataError( + "Empty Excel sheet; if you want to read this as " + "an empty DataFrame, set `raise_if_empty=False`" + ) + return _drop_unnamed_null_columns(df) + + +def _read_spreadsheet_xlsx2csv( + parser: Any, + sheet_id: int | None, + sheet_name: str | None, + read_csv_options: dict[str, Any] | None, + schema_overrides: SchemaDict | None, + *, + raise_if_empty: bool, +) -> pl.DataFrame: + """Use the 'xlsx2csv' library to read data from the given worksheet.""" + csv_buffer = StringIO() + parser.convert( + outfile=csv_buffer, + sheetid=sheet_id, + sheetname=sheet_name, + ) + return _csv_buffer_to_frame( + csv_buffer, + separator=",", + read_csv_options=read_csv_options, + schema_overrides=schema_overrides, + raise_if_empty=raise_if_empty, + ) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 7f5315eba0e3..a243f76ebe8b 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -50,13 +50,13 @@ from polars.lazyframe.group_by import LazyGroupBy from polars.selectors import _expand_selectors, expand_selector from polars.slice import LazyPolarsSlice -from polars.utils._async import _AsyncDataFrameResult +from polars.utils._async import _AioDataFrameResult, _GeventDataFrameResult from polars.utils._parse_expr_input import ( parse_as_expression, parse_as_list_of_expressions, ) from polars.utils._wrap import wrap_df, wrap_expr -from polars.utils.convert import _timedelta_to_pl_duration +from polars.utils.convert import _negate_duration, _timedelta_to_pl_duration from polars.utils.deprecation import ( deprecate_function, deprecate_renamed_function, @@ -75,8 +75,7 @@ if TYPE_CHECKING: import sys from io import IOBase - from queue import Queue - from typing import Literal + from typing import Awaitable, Literal import pyarrow as pa @@ -1703,10 +1702,44 @@ def collect( ) return wrap_df(ldf.collect()) + @overload def collect_async( self, - queue: Queue[DataFrame | Exception], *, + gevent: Literal[True], + type_coercion: bool = True, + predicate_pushdown: bool = True, + projection_pushdown: bool = True, + simplify_expression: bool = True, + no_optimization: bool = True, + slice_pushdown: bool = True, + comm_subplan_elim: bool = True, + comm_subexpr_elim: bool = True, + streaming: bool = True, + ) -> _GeventDataFrameResult[DataFrame]: + ... + + @overload + def collect_async( + self, + *, + gevent: Literal[False] = False, + type_coercion: bool = True, + predicate_pushdown: bool = True, + projection_pushdown: bool = True, + simplify_expression: bool = True, + no_optimization: bool = True, + slice_pushdown: bool = True, + comm_subplan_elim: bool = True, + comm_subexpr_elim: bool = True, + streaming: bool = True, + ) -> Awaitable[DataFrame]: + ... + + def collect_async( + self, + *, + gevent: bool = False, type_coercion: bool = True, predicate_pushdown: bool = True, projection_pushdown: bool = True, @@ -1716,33 +1749,44 @@ def collect_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, streaming: bool = False, - ) -> _AsyncDataFrameResult[DataFrame]: + ) -> Awaitable[DataFrame] | _GeventDataFrameResult[DataFrame]: """ Collect DataFrame asynchronously in thread pool. - Collects into a DataFrame, like :func:`collect` - but instead of returning DataFrame directly its collected inside thread pool - and gets put into `queue` with `put_nowait` method, + Collects into a DataFrame (like :func:`collect`), but instead of returning + dataframe directly, they are scheduled to be collected inside thread pool, while this method returns almost instantly. May be useful if you use gevent or asyncio and want to release control to other greenlets/tasks while LazyFrames are being collected. - You must use correct queue in that case. - Given `queue` must be thread safe! - - For gevent use - [`gevent.queue.Queue`](https://www.gevent.org/api/gevent.queue.html#gevent.queue.Queue). - For asyncio - [`asyncio.queues.Queue`](https://docs.python.org/3/library/asyncio-queue.html#queue) - can not be used, since it's not thread safe! - For that purpose use [janus](https://github.com/aio-libs/janus) library. + Parameters + ---------- + gevent + Return wrapper to `gevent.event.AsyncResult` instead of Awaitable + type_coercion + Do type coercion optimization. + predicate_pushdown + Do predicate pushdown optimization. + projection_pushdown + Do projection pushdown optimization. + simplify_expression + Run simplify expressions optimization. + no_optimization + Turn off (certain) optimizations. + slice_pushdown + Slice pushdown optimization. + comm_subplan_elim + Will try to cache branching subplans that occur on self-joins or unions. + comm_subexpr_elim + Common subexpressions will be cached and reused. + streaming + Run parts of the query in a streaming fashion (this is in an alpha state) Notes ----- - Results are put in queue exactly once using `put_nowait`. - If error occurred then Exception will be put in the queue instead of result - which is then raised by returned wrapper `get` method. + In case of error `set_exception` is used on + `asyncio.Future`/`gevent.event.AsyncResult` and will be reraised by them. Warnings -------- @@ -1756,12 +1800,14 @@ def collect_async( Returns ------- - Wrapper that has `get` method and `queue` attribute with given queue. - `get` accepts kwargs that are passed down to `queue.get`. + If `gevent=False` (default) then returns awaitable. + + If `gevent=True` then returns wrapper that has + `.get(block=True, timeout=None)` method. Examples -------- - >>> import queue + >>> import asyncio >>> lf = pl.LazyFrame( ... { ... "a": ["a", "b", "a", "b", "b", "c"], @@ -1769,12 +1815,14 @@ def collect_async( ... "c": [6, 5, 4, 3, 2, 1], ... } ... ) - >>> a = ( - ... lf.group_by("a", maintain_order=True) - ... .agg(pl.all().sum()) - ... .collect_async(queue.Queue()) - ... ) - >>> a.get() + >>> async def main(): + ... return await ( + ... lf.group_by("a", maintain_order=True) + ... .agg(pl.all().sum()) + ... .collect_async() + ... ) + ... + >>> asyncio.run(main()) shape: (3, 3) ┌─────┬─────┬─────┐ │ a ┆ b ┆ c │ @@ -1785,7 +1833,6 @@ def collect_async( │ b ┆ 11 ┆ 10 │ │ c ┆ 6 ┆ 1 │ └─────┴─────┴─────┘ - """ if no_optimization: predicate_pushdown = False @@ -1809,9 +1856,9 @@ def collect_async( eager=False, ) - result = _AsyncDataFrameResult(queue) - ldf.collect_with_callback(result._callback) - return result + result = _GeventDataFrameResult() if gevent else _AioDataFrameResult() + ldf.collect_with_callback(result._callback) # type: ignore[attr-defined] + return result # type: ignore[return-value] def sink_parquet( self, @@ -2033,7 +2080,7 @@ def sink_csv( ``Float64`` datatypes. null_value A string representing null values (defaulting to the empty string). - quote_style : {'necessary', 'always', 'non_numeric'} + quote_style : {'necessary', 'always', 'non_numeric', 'never'} Determines the quoting strategy used. - necessary (default): This puts quotes around fields only when necessary. They are necessary when fields contain a quote, @@ -2042,6 +2089,8 @@ def sink_csv( (which is indistinguishable from a record with one empty field). This is the default. - always: This puts quotes around every field. Always. + - never: This never puts quotes around fields, even if that results in + invalid CSV data (e.g.: by not quoting strings containing the separator). - non_numeric: This puts quotes around all fields that are non-numeric. Namely, when writing a field that does not parse as a valid float or integer, then quotes will be used even if they aren`t strictly @@ -2151,14 +2200,6 @@ def fetch( """ Collect a small number of rows for debugging purposes. - Fetch is like a :func:`collect` operation, but it overwrites the number of rows - read by every scan operation. This is a utility that helps debug a query on a - smaller number of rows. - - Note that the fetch does not guarantee the final number of rows in the - DataFrame. Filter, join operations and a lower number of rows available in the - scanned file influence the final number of rows. - Parameters ---------- n_rows @@ -2182,6 +2223,20 @@ def fetch( streaming Run parts of the query in a streaming fashion (this is in an alpha state) + Notes + ----- + This is similar to a :func:`collect` operation, but it overwrites the number of + rows read by *every* scan operation. Be aware that ``fetch`` does not guarantee + the final number of rows in the DataFrame. Filters, join operations and fewer + rows being available in the scanned data will all influence the final number + of rows (joins are especially susceptible to this, and may return no data + at all if ``n_rows`` is too small as the join keys may not be present). + + Warnings + -------- + This is strictly a utility function that can help to debug queries using a + smaller number of rows, and should *not* be used in production code. + Returns ------- DataFrame @@ -2853,7 +2908,7 @@ def group_by_rolling( """ index_column = parse_as_expression(index_column) if offset is None: - offset = f"-{_timedelta_to_pl_duration(period)}" + offset = _negate_duration(_timedelta_to_pl_duration(period)) pyexprs_by = parse_as_list_of_expressions(by) if by is not None else [] period = _timedelta_to_pl_duration(period) @@ -2882,47 +2937,16 @@ def group_by_dynamic( Group based on a time value (or index value of type Int32, Int64). Time windows are calculated and rows are assigned to windows. Different from a - normal group by is that a row can be member of multiple groups. The time/index - window could be seen as a rolling window, with a window size determined by - dates/times/values instead of slots in the DataFrame. - - A window is defined by: - - - every: interval of the window - - period: length of the window - - offset: offset of the window + normal group by is that a row can be member of multiple groups. + By default, the windows look like: - The `every`, `period` and `offset` arguments are created with - the following string language: + - [start, start + period) + - [start + every, start + every + period) + - [start + 2*every, start + 2*every + period) + - ... - - 1ns (1 nanosecond) - - 1us (1 microsecond) - - 1ms (1 millisecond) - - 1s (1 second) - - 1m (1 minute) - - 1h (1 hour) - - 1d (1 calendar day) - - 1w (1 calendar week) - - 1mo (1 calendar month) - - 1q (1 calendar quarter) - - 1y (1 calendar year) - - 1i (1 index count) - - Or combine them: - "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds - - Suffix with `"_saturating"` to indicate that dates too large for - their month should saturate at the largest date (e.g. 2022-02-29 -> 2022-02-28) - instead of erroring. - - By "calendar day", we mean the corresponding time on the next day (which may - not be 24 hours, due to daylight savings). Similarly for "calendar week", - "calendar month", "calendar quarter", and "calendar year". - - In case of a group_by_dynamic on an integer column, the windows are defined by: - - - "1i" # length 1 - - "10i" # length 10 + where `start` is determined by `start_by`, `offset`, and `every` (see parameter + descriptions below). .. warning:: The index column must be sorted in ascending order. If `by` is passed, then @@ -2942,10 +2966,10 @@ def group_by_dynamic( every interval of the window period - length of the window, if None it is equal to 'every' + length of the window, if None it will equal 'every' offset - offset of the window if None and period is None it will be equal to negative - `every` + offset of the window, only takes effect if `start_by` is `'window'`. + Defaults to negative `every`. truncate truncate the time value to the window lower bound include_boundaries @@ -2959,7 +2983,8 @@ def group_by_dynamic( start_by : {'window', 'datapoint', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'} The strategy to determine the start of the first window by. - * 'window': Truncate the start of the window with the 'every' argument. + * 'window': Start by taking the earliest timestamp, truncating it with + `every`, and then adding `offset`. Note that weekly windows start on Monday. * 'datapoint': Start from the first encountered data point. * a day of the week (only takes effect if `every` contains ``'w'``): @@ -2988,23 +3013,55 @@ def group_by_dynamic( Notes ----- - If you're coming from pandas, then + 1) If you're coming from pandas, then + + .. code-block:: python + + # polars + df.group_by_dynamic("ts", every="1d").agg(pl.col("value").sum()) + + is equivalent to - .. code-block:: python + .. code-block:: python - # polars - df.group_by_dynamic("ts", every="1d").agg(pl.col("value").sum()) + # pandas + df.set_index("ts").resample("D")["value"].sum().reset_index() - is equivalent to + though note that, unlike pandas, polars doesn't add extra rows for empty + windows. If you need `index_column` to be evenly spaced, then please combine + with :func:`DataFrame.upsample`. - .. code-block:: python + 2) The `every`, `period` and `offset` arguments are created with + the following string language: - # pandas - df.set_index("ts").resample("D")["value"].sum().reset_index() + - 1ns (1 nanosecond) + - 1us (1 microsecond) + - 1ms (1 millisecond) + - 1s (1 second) + - 1m (1 minute) + - 1h (1 hour) + - 1d (1 calendar day) + - 1w (1 calendar week) + - 1mo (1 calendar month) + - 1q (1 calendar quarter) + - 1y (1 calendar year) + - 1i (1 index count) - though note that, unlike pandas, polars doesn't add extra rows for empty - windows. If you need `index_column` to be evenly spaced, then please combine - with :func:`DataFrame.upsample`. + Or combine them: + "3d12h4m25s" # 3 days, 12 hours, 4 minutes, and 25 seconds + + Suffix with `"_saturating"` to indicate that dates too large for + their month should saturate at the largest date (e.g. 2022-02-29 -> 2022-02-28) + instead of erroring. + + By "calendar day", we mean the corresponding time on the next day (which may + not be 24 hours, due to daylight savings). Similarly for "calendar week", + "calendar month", "calendar quarter", and "calendar year". + + In case of a group_by_dynamic on an integer column, the windows are defined by: + + - "1i" # length 1 + - "10i" # length 10 Examples -------- @@ -3180,12 +3237,13 @@ def group_by_dynamic( ... include_boundaries=True, ... closed="right", ... ).agg(pl.col("A").alias("A_agg_list")).collect() - shape: (3, 4) + shape: (4, 4) ┌─────────────────┬─────────────────┬─────┬─────────────────┐ │ _lower_boundary ┆ _upper_boundary ┆ idx ┆ A_agg_list │ │ --- ┆ --- ┆ --- ┆ --- │ │ i64 ┆ i64 ┆ i64 ┆ list[str] │ ╞═════════════════╪═════════════════╪═════╪═════════════════╡ + │ -2 ┆ 1 ┆ -2 ┆ ["A", "A"] │ │ 0 ┆ 3 ┆ 0 ┆ ["A", "B", "B"] │ │ 2 ┆ 5 ┆ 2 ┆ ["B", "B", "C"] │ │ 4 ┆ 7 ┆ 4 ┆ ["C"] │ @@ -3194,7 +3252,7 @@ def group_by_dynamic( """ # noqa: W505 index_column = parse_as_expression(index_column) if offset is None: - offset = f"-{_timedelta_to_pl_duration(every)}" if period is None else "0ns" + offset = _negate_duration(_timedelta_to_pl_duration(every)) if period is None: period = every @@ -5660,10 +5718,10 @@ def groupby_dynamic( every interval of the window period - length of the window, if None it is equal to 'every' + length of the window, if None it will equal 'every' offset - offset of the window if None and period is None it will be equal to negative - `every` + offset of the window, only takes effect if `start_by` is ``'window'``. + Defaults to negative `every`. truncate truncate the time value to the window lower bound include_boundaries @@ -5677,7 +5735,8 @@ def groupby_dynamic( start_by : {'window', 'datapoint', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'} The strategy to determine the start of the first window by. - * 'window': Truncate the start of the window with the 'every' argument. + * 'window': Start by taking the earliest timestamp, truncating it with + `every`, and then adding `offset`. Note that weekly windows start on Monday. * 'datapoint': Start from the first encountered data point. * a day of the week (only takes effect if `every` contains ``'w'``): diff --git a/py-polars/polars/series/binary.py b/py-polars/polars/series/binary.py index 9511f7e85067..69a273f7c2f6 100644 --- a/py-polars/polars/series/binary.py +++ b/py-polars/polars/series/binary.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from polars import Series from polars.polars import PySeries - from polars.type_aliases import TransferEncoding + from polars.type_aliases import IntoExpr, TransferEncoding @expr_dispatch @@ -19,7 +19,7 @@ class BinaryNameSpace: def __init__(self, series: Series): self._s: PySeries = series._s - def contains(self, literal: bytes) -> Series: + def contains(self, literal: IntoExpr) -> Series: """ Check if binaries in Series contain a binary substring. @@ -35,7 +35,7 @@ def contains(self, literal: bytes) -> Series: """ - def ends_with(self, suffix: bytes) -> Series: + def ends_with(self, suffix: IntoExpr) -> Series: """ Check if string values end with a binary substring. @@ -46,7 +46,7 @@ def ends_with(self, suffix: bytes) -> Series: """ - def starts_with(self, prefix: bytes) -> Series: + def starts_with(self, prefix: IntoExpr) -> Series: """ Check if values start with a binary substring. diff --git a/py-polars/polars/series/datetime.py b/py-polars/polars/series/datetime.py index 8774bf639149..a3cd30aef055 100644 --- a/py-polars/polars/series/datetime.py +++ b/py-polars/polars/series/datetime.py @@ -1435,7 +1435,7 @@ def offset_by(self, by: str | Expr) -> Series: def truncate( self, - every: str | dt.timedelta, + every: str | dt.timedelta | Expr, offset: str | dt.timedelta | None = None, *, use_earliest: bool | None = None, diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 403958404c1c..15a4f41354ac 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -12,7 +12,7 @@ from polars import Expr, Series from polars.polars import PySeries - from polars.type_aliases import NullBehavior, ToStructStrategy + from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy @expr_dispatch @@ -198,7 +198,7 @@ def take( def __getitem__(self, item: int) -> Series: return self.get(item) - def join(self, separator: str) -> Series: + def join(self, separator: IntoExpr) -> Series: """ Join all string items in a sublist and place a separator between them. diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index b581bd4807ce..82ae481eb9b4 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -933,11 +933,11 @@ def __contains__(self, item: Any) -> bool: def __iter__(self) -> Generator[Any, None, None]: if self.dtype == List: # TODO: either make a change and return py-native list data here, or find - # a faster way to return nested/List series; sequential 'get_idx' calls + # a faster way to return nested/List series; sequential 'get_index' calls # make this path a lot slower (~10x) than it needs to be. - get_idx = self._s.get_idx + get_index = self._s.get_index for idx in range(self.len()): - yield get_idx(idx) + yield get_index(idx) else: buffer_size = 25_000 for offset in range(0, self.len(), buffer_size): @@ -1019,17 +1019,15 @@ def __getitem__( elif _check_for_numpy(item) and isinstance(item, np.ndarray): return self._take_with_series(numpy_to_idxs(item, self.len())) - # Integer. + # Integer elif isinstance(item, int): - if item < 0: - item = self.len() + item - return self._s.get_idx(item) + return self._s.get_index_signed(item) - # Slice. + # Slice elif isinstance(item, slice): return PolarsSlice(self).apply(item) - # Range. + # Range elif isinstance(item, range): return self[range_to_slice(item)] @@ -1191,12 +1189,13 @@ def _repr_html_(self) -> str: """Format output data in HTML for display in Jupyter Notebooks.""" return self.to_frame()._repr_html_(from_series=True) - def item(self, row: int | None = None) -> Any: + @deprecate_renamed_parameter("row", "index", version="0.19.3") + def item(self, index: int | None = None) -> Any: """ - Return the series as a scalar, or return the element at the given row index. + Return the series as a scalar, or return the element at the given index. - If no row index is provided, this is equivalent to ``s[0]``, with a check - that the shape is (1,). With a row index, this is equivalent to ``s[row]``. + If no index is provided, this is equivalent to ``s[0]``, with a check + that the shape is (1,). With an index, this is equivalent to ``s[index]``. Examples -------- @@ -1208,12 +1207,15 @@ def item(self, row: int | None = None) -> Any: 24 """ - if row is None and len(self) != 1: - raise ValueError( - f"can only call '.item()' if the series is of length 1, or an" - f" explicit row index is provided (series is of length {len(self)})" - ) - return self[row or 0] + if index is None: + if len(self) != 1: + raise ValueError( + "can only call '.item()' if the series is of length 1," + f" or an explicit index is provided (series is of length {len(self)})" + ) + return self._s.get_index(0) + + return self._s.get_index_signed(index) def estimated_size(self, unit: SizeUnit = "b") -> int | float: """ @@ -3581,24 +3583,53 @@ def is_unique(self) -> Series: """ - def is_first(self) -> Series: + def is_first_distinct(self) -> Series: """ - Get a mask of the first unique value. + Return a boolean mask indicating the first occurrence of each distinct value. Returns ------- Series Series of data type :class:`Boolean`. + Examples + -------- + >>> s = pl.Series([1, 1, 2, 3, 2]) + >>> s.is_first_distinct() + shape: (5,) + Series: '' [bool] + [ + true + false + true + true + false + ] + """ - def is_last(self) -> Series: + def is_last_distinct(self) -> Series: """ - Get a mask of the last unique value. + Return a boolean mask indicating the last occurrence of each distinct value. Returns ------- - Boolean Series + Series + Series of data type :class:`Boolean`. + + Examples + -------- + >>> s = pl.Series([1, 1, 2, 3, 2]) + >>> s.is_last_distinct() + shape: (5,) + Series: '' [bool] + [ + false + true + false + true + true + ] """ @@ -5672,8 +5703,8 @@ def sample( shuffle Shuffle the order of sampled data points. seed - Seed for the random number generator. If set to None (default), a random - seed is generated using the ``random`` module. + Seed for the random number generator. If set to None (default), a + random seed is generated for each sample operation. Examples -------- @@ -6311,8 +6342,8 @@ def shuffle(self, seed: int | None = None) -> Series: Parameters ---------- seed - Seed for the random number generator. If set to None (default), a random - seed is generated using the ``random`` module. + Seed for the random number generator. If set to None (default), a + random seed is generated each time the shuffle is called. Examples -------- @@ -6701,6 +6732,36 @@ def rolling_apply( """ + @deprecate_renamed_function("is_first_distinct", version="0.19.3") + def is_first(self) -> Series: + """ + Return a boolean mask indicating the first occurrence of each distinct value. + + .. deprecated:: 0.19.3 + This method has been renamed to :func:`Series.is_first_distinct`. + + Returns + ------- + Series + Series of data type :class:`Boolean`. + + """ + + @deprecate_renamed_function("is_last_distinct", version="0.19.3") + def is_last(self) -> Series: + """ + Return a boolean mask indicating the last occurrence of each distinct value. + + .. deprecated:: 0.19.3 + This method has been renamed to :func:`Series.is_last_distinct`. + + Returns + ------- + Series + Series of data type :class:`Boolean`. + + """ + # Keep the `list` and `str` properties below at the end of the definition of Series, # as to not confuse mypy with the type annotation `str` and `list` diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index d22c9478356e..0a00f53c4e23 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -10,6 +10,7 @@ from polars.polars import PySeries from polars.type_aliases import ( Ambiguous, + IntoExpr, PolarsDataType, PolarsTemporalType, TimeUnit, @@ -872,7 +873,7 @@ def count_matches(self, pattern: str | Series, *, literal: bool = False) -> Seri """ - def split(self, by: str, *, inclusive: bool = False) -> Series: + def split(self, by: IntoExpr, *, inclusive: bool = False) -> Series: """ Split the string by a substring. diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index 81503da7cf3c..610d78bc19d9 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -55,7 +55,6 @@ ] SchemaDefinition: TypeAlias = Union[ - Sequence[str], Mapping[str, Union[PolarsDataType, PythonDataType]], Sequence[Union[str, Tuple[str, Union[PolarsDataType, PythonDataType, None]]]], ] @@ -76,7 +75,7 @@ # User-facing string literal types # The following all have an equivalent Rust enum with the same name AvroCompression: TypeAlias = Literal["uncompressed", "snappy", "deflate"] -CsvQuoteStyle: TypeAlias = Literal["necessary", "always", "non_numeric"] +CsvQuoteStyle: TypeAlias = Literal["necessary", "always", "non_numeric", "never"] CategoricalOrdering: TypeAlias = Literal["physical", "lexical"] CsvEncoding: TypeAlias = Literal["utf8", "utf8-lossy"] FillNullStrategy: TypeAlias = Literal[ diff --git a/py-polars/polars/utils/_async.py b/py-polars/polars/utils/_async.py index d35956156b8c..42ddfe85c313 100644 --- a/py-polars/polars/utils/_async.py +++ b/py-polars/polars/utils/_async.py @@ -1,11 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, TypeVar +from polars.dependencies import _GEVENT_AVAILABLE from polars.utils._wrap import wrap_df if TYPE_CHECKING: - from queue import Queue + from asyncio.futures import Future from polars.polars import PyDataFrame @@ -13,33 +14,81 @@ T = TypeVar("T") -class _AsyncDataFrameResult(Generic[T]): - queue: Queue[Exception | T] - _result: Exception | T | None +class _GeventDataFrameResult(Generic[T]): + __slots__ = ("_watcher", "_value", "_result") - __slots__ = ("queue", "_result") + def __init__(self) -> None: + if not _GEVENT_AVAILABLE: + raise ImportError( + "gevent is required for using LazyFrame.collect_async(gevent=True) or" + "polars.collect_all_async(gevent=True)" + ) - def __init__(self, queue: Queue[Exception | T]) -> None: - self.queue = queue - self._result = None + from gevent.event import AsyncResult # type: ignore[import] + from gevent.hub import get_hub # type: ignore[import] - def get(self, **kwargs: Any) -> T: - if self._result is not None: - if isinstance(self._result, Exception): - raise self._result - return self._result + self._value: None | Exception | PyDataFrame | list[PyDataFrame] = None + self._result = AsyncResult() - self._result = self.queue.get(**kwargs) - if isinstance(self._result, Exception): - raise self._result + self._watcher = get_hub().loop.async_() + self._watcher.start(self._watcher_callback) + + def get( + self, block: bool = True, timeout: float | int | None = None # noqa: FBT001 + ) -> T: + return self.result.get(block=block, timeout=timeout) + + @property + def result(self) -> Any: + # required if we did not made any switches and just want results later + # with block=False and possibly without timeout + if self._value is not None and not self._result.ready(): + self._watcher_callback() return self._result + def _watcher_callback(self) -> None: + if isinstance(self._value, Exception): + self._result.set_exception(self._value) + else: + self._result.set(self._value) + self._watcher.close() + def _callback(self, obj: PyDataFrame | Exception) -> None: if not isinstance(obj, Exception): obj = wrap_df(obj) - self.queue.put_nowait(obj) + self._value = obj + self._watcher.send() def _callback_all(self, obj: list[PyDataFrame] | Exception) -> None: if not isinstance(obj, Exception): obj = [wrap_df(pydf) for pydf in obj] - self.queue.put_nowait(obj) # type: ignore[arg-type] + self._value = obj + self._watcher.send() + + +class _AioDataFrameResult(Awaitable[T], Generic[T]): + __slots__ = ("loop", "result") + + def __init__(self) -> None: + from asyncio import get_event_loop + + self.loop = get_event_loop() + self.result: Future[T] = self.loop.create_future() + + def __await__(self) -> Generator[Any, None, T]: + return self.result.__await__() + + def _callback(self, obj: PyDataFrame | Exception) -> None: + if isinstance(obj, Exception): + self.loop.call_soon_threadsafe(self.result.set_exception, obj) + else: + self.loop.call_soon_threadsafe(self.result.set_result, wrap_df(obj)) + + def _callback_all(self, obj: list[PyDataFrame] | Exception) -> None: + if isinstance(obj, Exception): + self.loop.call_soon_threadsafe(self.result.set_exception, obj) + else: + self.loop.call_soon_threadsafe( + self.result.set_result, + [wrap_df(pydf) for pydf in obj], + ) diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index 95fee7c63b1b..43069ed2977c 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -644,14 +644,13 @@ def _post_apply_columns( column_casts = [] for i, col in enumerate(columns): dtype = dtypes.get(col) - if dtype == Categorical != pydf_dtypes[i]: + pydf_dtype = pydf_dtypes[i] + if dtype == Categorical != pydf_dtype: column_casts.append(F.col(col).cast(Categorical)._pyexpr) - elif structs and col in structs and structs[col] != pydf_dtypes[i]: - column_casts.append(F.col(col).cast(structs[col])._pyexpr) - elif dtype not in (None, Unknown) and dtype != pydf_dtypes[i]: - column_casts.append( - F.col(col).cast(dtype)._pyexpr, # type: ignore[arg-type] - ) + elif structs and (struct := structs.get(col)) and struct != pydf_dtype: + column_casts.append(F.col(col).cast(struct)._pyexpr) + elif dtype is not None and dtype != Unknown and dtype != pydf_dtype: + column_casts.append(F.col(col).cast(dtype)._pyexpr) if column_casts or column_subset: pydf = pydf.lazy() @@ -678,42 +677,59 @@ def _unpack_schema( Works for any (name, dtype) pairs or schema dict input, overriding any inferred dtypes with explicit dtypes if supplied. """ + # coerce schema_overrides to dict[str, PolarsDataType] + if schema_overrides: + schema_overrides = { + name: dtype + if is_polars_dtype(dtype, include_unknown=True) + else py_type_to_dtype(dtype) + for name, dtype in schema_overrides.items() + } + else: + schema_overrides = {} + + # fastpath for empty schema + if not schema: + return ( + [f"column_{i}" for i in range(n_expected)] if n_expected else [], + schema_overrides, + ) + + # determine column names from schema if isinstance(schema, dict): + column_names: list[str] = list(schema) + # coerce schema to list[str | tuple[str, PolarsDataType | PythonDataType | None] schema = list(schema.items()) - column_names = [ - (col or f"column_{i}") if isinstance(col, str) else col[0] - for i, col in enumerate(schema or []) - ] - if not column_names and n_expected: - column_names = [f"column_{i}" for i in range(n_expected)] - lookup = ( - { - col: name - for col, name in zip_longest(column_names, lookup_names or []) - if name - } + else: + column_names = [ + (col or f"column_{i}") if isinstance(col, str) else col[0] + for i, col in enumerate(schema) + ] + + # determine column dtypes from schema and lookup_names + lookup: dict[str, str] | None = ( + {col: name for col, name in zip_longest(column_names, lookup_names) if name} if lookup_names else None ) - column_dtypes = { - lookup.get(col[0], col[0]) if lookup else col[0]: col[1] - for col in (schema or []) - if not isinstance(col, str) and col[1] is not None + column_dtypes: dict[str, PolarsDataType] = { + lookup.get((name := col[0]), name) + if lookup + else col[0]: dtype # type: ignore[misc] + if is_polars_dtype(dtype, include_unknown=True) + else py_type_to_dtype(dtype) + for col in schema + if isinstance(col, tuple) and (dtype := col[1]) is not None } + + # apply schema overrides if schema_overrides: column_dtypes.update(schema_overrides) - if schema and include_overrides_in_columns: - column_names = column_names + [ - col for col in column_dtypes if col not in column_names - ] - for col, dtype in column_dtypes.items(): - if not is_polars_dtype(dtype, include_unknown=True) and dtype is not None: - column_dtypes[col] = py_type_to_dtype(dtype) - return ( - column_names, # type: ignore[return-value] - column_dtypes, - ) + if include_overrides_in_columns: + column_names.extend(col for col in column_dtypes if col not in column_names) + + return column_names, column_dtypes def _expand_dict_data( @@ -1090,14 +1106,14 @@ def _sequence_of_dict_to_pydf( ) dicts_schema = ( include_unknowns(schema_overrides, column_names or list(schema_overrides)) - if schema_overrides and column_names + if column_names else None ) pydf = PyDataFrame.read_dicts(data, infer_schema_length, dicts_schema) - if not schema_overrides and set(pydf.columns()) == set(column_names): - pass - elif column_names or schema_overrides: + # TODO: we can remove this `schema_overrides` block completely + # once https://github.com/pola-rs/polars/issues/11044 is fixed + if schema_overrides: pydf = _post_apply_columns( pydf, columns=column_names, diff --git a/py-polars/polars/utils/convert.py b/py-polars/polars/utils/convert.py index 2c6431837183..241e8bd1f125 100644 --- a/py-polars/polars/utils/convert.py +++ b/py-polars/polars/utils/convert.py @@ -91,6 +91,12 @@ def _timedelta_to_pl_duration(td: timedelta | str | None) -> str | None: return f"{d}{s}{us}" +def _negate_duration(duration: str) -> str: + if duration.startswith("-"): + return duration[1:] + return f"-{duration}" + + def _datetime_to_pl_timestamp(dt: datetime, time_unit: TimeUnit | None) -> int: """Convert a python datetime to a timestamp in nanoseconds.""" dt = dt.replace(tzinfo=timezone.utc) if dt.tzinfo != timezone.utc else dt diff --git a/py-polars/polars/utils/show_versions.py b/py-polars/polars/utils/show_versions.py index c92b46f2e875..f20185c9ca5d 100644 --- a/py-polars/polars/utils/show_versions.py +++ b/py-polars/polars/utils/show_versions.py @@ -63,11 +63,13 @@ def _get_dependency_info() -> dict[str, str]: "connectorx", "deltalake", "fsspec", + "gevent", "matplotlib", "numpy", "pandas", "pyarrow", "pydantic", + "pyiceberg", "sqlalchemy", "xlsx2csv", "xlsxwriter", diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 34100d91f301..78ecba839e6c 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -868,7 +868,17 @@ def warn_on_inefficient_map( ) -__all__ = [ - "BytecodeParser", - "warn_on_inefficient_map", -] +def is_shared_lib(file: str) -> bool: + return file.endswith((".so", ".dll")) + + +def _get_shared_lib_location(main_file: Any) -> str: + import os + + directory = os.path.dirname(main_file) # noqa: PTH120 + return os.path.join( # noqa: PTH118 + directory, next(filter(is_shared_lib, os.listdir(directory))) + ) + + +__all__ = ["BytecodeParser", "warn_on_inefficient_map", "_get_shared_lib_location"] diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 50628492165a..7b75e280c38c 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -49,12 +49,14 @@ deltalake = ["deltalake >= 0.10.0"] timezone = ["backports.zoneinfo; python_version < '3.9'", "tzdata; platform_system == 'Windows'"] matplotlib = ["matplotlib"] pydantic = ["pydantic"] +pyiceberg = ["pyiceberg >= 0.5.0"] sqlalchemy = ["sqlalchemy", "pandas"] xlsxwriter = ["xlsxwriter"] adbc = ["adbc_driver_sqlite"] cloudpickle = ["cloudpickle"] +gevent = ["gevent"] all = [ - "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,deltalake,timezone,matplotlib,pydantic,sqlalchemy,xlsxwriter,adbc,cloudpickle]", + "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,deltalake,timezone,matplotlib,pydantic,pyiceberg,sqlalchemy,xlsxwriter,adbc,cloudpickle,gevent]", ] [tool.mypy] @@ -77,22 +79,27 @@ module = [ "backports", "connectorx", "deltalake.*", + "ezodf.*", "fsspec.*", + "gevent", "matplotlib.*", "moto.server", + "openpyxl", "polars.polars", "pyarrow.*", "pydantic", "sqlalchemy.*", "xlsx2csv", "xlsxwriter.*", - "openpyxl", "zoneinfo", ] ignore_missing_imports = true [[tool.mypy.overrides]] -module = ["IPython.*"] +module = [ + "IPython.*", + "matplotlib.*", +] follow_imports = "skip" [[tool.mypy.overrides]] @@ -184,12 +191,13 @@ addopts = [ "--strict-markers", "--import-mode=importlib", # Default to running fast tests only. To run ALL tests, run: pytest -m "" - "-m not slow and not hypothesis and not benchmark and not write_disk", + "-m not slow and not hypothesis and not benchmark and not write_disk and not docs", ] markers = [ "write_disk: Tests that write to disk", "slow: Tests with a longer than average runtime.", "benchmark: Tests that should be run on a Polars release build.", + "docs: Documentation code snippets", ] filterwarnings = [ # Fail on warnings... @@ -197,6 +205,7 @@ filterwarnings = [ # ...except where it prevents test debugging in an IPython console "ignore:.*unrecognized arguments.*PyDevIPCompleter:DeprecationWarning", "ignore:.*is_sparse is deprecated.*:FutureWarning", + "ignore:FigureCanvasAgg is non-interactive:UserWarning", ] xfail_strict = true diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 9765dfb15f15..0a45cad3e538 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -4,28 +4,47 @@ --prefer-binary -# Dependencies -dataframe-api-compat >= 0.1.6 -# pin deltalake until issues with pyarrow 13 resolved -# https://github.com/delta-io/delta-rs/pull/1602 -deltalake == 0.10.1 +# ------------ +# DEPENDENCIES +# ------------ + +# Interoperability numpy pandas pyarrow pydantic >= 2.0.0 +# Datetime / time zones backports.zoneinfo; python_version < '3.9' tzdata; platform_system == 'Windows' +# Database SQLAlchemy -xlsx2csv -openpyxl -XlsxWriter adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows' connectorx +# Cloud cloudpickle fsspec s3fs[boto3] +# Spreadsheet +ezodf +lxml +openpyxl +xlsx2csv +XlsxWriter +# Deltalake +# note: pinning deltalake until issues with pyarrow 13 resolved +# https://github.com/delta-io/delta-rs/pull/1602 +deltalake == 0.10.1 +# Dataframe interchange protocol +dataframe-api-compat >= 0.1.6 +pyiceberg >= 0.5.0 +# Other +matplotlib +gevent + +# ------- +# TOOLING +# ------- -# Tooling hypothesis==6.82.6 maturin==1.2.3 patchelf; platform_system == 'Linux' # Extra dependency for maturin, only for Linux diff --git a/py-polars/src/conversion.rs b/py-polars/src/conversion.rs index 557e4c5916c7..e040fd4f3d54 100644 --- a/py-polars/src/conversion.rs +++ b/py-polars/src/conversion.rs @@ -1068,7 +1068,7 @@ impl FromPyObject<'_> for Wrap { "nearest" => AsofStrategy::Nearest, v => { return Err(PyValueError::new_err(format!( - "strategy must be one of {{'backward', 'forward', 'nearest'}}, got {v}", + "asof `strategy` must be one of {{'backward', 'forward', 'nearest'}}, got {v}", ))) }, }; @@ -1083,7 +1083,7 @@ impl FromPyObject<'_> for Wrap { "nearest" => InterpolationMethod::Nearest, v => { return Err(PyValueError::new_err(format!( - "method must be one of {{'linear', 'nearest'}}, got {v}", + "interpolation `method` must be one of {{'linear', 'nearest'}}, got {v}", ))) }, }; @@ -1100,8 +1100,8 @@ impl FromPyObject<'_> for Wrap> { "deflate" => Some(AvroCompression::Deflate), v => { return Err(PyValueError::new_err(format!( - "compression must be one of {{'uncompressed', 'snappy', 'deflate'}}, got {v}", - ))) + "avro `compression` must be one of {{'uncompressed', 'snappy', 'deflate'}}, got {v}", + ))) }, }; Ok(Wrap(parsed)) @@ -1115,7 +1115,7 @@ impl FromPyObject<'_> for Wrap { "lexical" => CategoricalOrdering::Lexical, v => { return Err(PyValueError::new_err(format!( - "ordering must be one of {{'physical', 'lexical'}}, got {v}", + "categorical `ordering` must be one of {{'physical', 'lexical'}}, got {v}", ))) }, }; @@ -1137,7 +1137,7 @@ impl FromPyObject<'_> for Wrap { "sunday" => StartBy::Sunday, v => { return Err(PyValueError::new_err(format!( - "closed must be one of {{'window', 'datapoint', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'}}, got {v}", + "`start_by` must be one of {{'window', 'datapoint', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday'}}, got {v}", ))) } }; @@ -1154,7 +1154,7 @@ impl FromPyObject<'_> for Wrap { "none" => ClosedWindow::None, v => { return Err(PyValueError::new_err(format!( - "closed must be one of {{'left', 'right', 'both', 'none'}}, got {v}", + "`closed` must be one of {{'left', 'right', 'both', 'none'}}, got {v}", ))) }, }; @@ -1169,7 +1169,7 @@ impl FromPyObject<'_> for Wrap { "utf8-lossy" => CsvEncoding::LossyUtf8, v => { return Err(PyValueError::new_err(format!( - "encoding must be one of {{'utf8', 'utf8-lossy'}}, got {v}", + "csv `encoding` must be one of {{'utf8', 'utf8-lossy'}}, got {v}", ))) }, }; @@ -1186,7 +1186,7 @@ impl FromPyObject<'_> for Wrap> { "zstd" => Some(IpcCompression::ZSTD), v => { return Err(PyValueError::new_err(format!( - "compression must be one of {{'uncompressed', 'lz4', 'zstd'}}, got {v}", + "ipc `compression` must be one of {{'uncompressed', 'lz4', 'zstd'}}, got {v}", ))) }, }; @@ -1206,7 +1206,7 @@ impl FromPyObject<'_> for Wrap { "cross" => JoinType::Cross, v => { return Err(PyValueError::new_err(format!( - "how must be one of {{'inner', 'left', 'outer', 'semi', 'anti', 'cross'}}, got {v}", + "`how` must be one of {{'inner', 'left', 'outer', 'semi', 'anti', 'cross'}}, got {v}", ))) }, }; @@ -1221,7 +1221,7 @@ impl FromPyObject<'_> for Wrap { "max_width" => ListToStructWidthStrategy::MaxWidth, v => { return Err(PyValueError::new_err(format!( - "n_field_strategy must be one of {{'first_non_null', 'max_width'}}, got {v}", + "`n_field_strategy` must be one of {{'first_non_null', 'max_width'}}, got {v}", ))) }, }; @@ -1236,7 +1236,7 @@ impl FromPyObject<'_> for Wrap { "ignore" => NullBehavior::Ignore, v => { return Err(PyValueError::new_err(format!( - "null behavior must be one of {{'drop', 'ignore'}}, got {v}", + "`null_behavior` must be one of {{'drop', 'ignore'}}, got {v}", ))) }, }; @@ -1251,7 +1251,7 @@ impl FromPyObject<'_> for Wrap { "propagate" => NullStrategy::Propagate, v => { return Err(PyValueError::new_err(format!( - "null strategy must be one of {{'ignore', 'propagate'}}, got {v}", + "`null_strategy` must be one of {{'ignore', 'propagate'}}, got {v}", ))) }, }; @@ -1269,8 +1269,8 @@ impl FromPyObject<'_> for Wrap { "none" => ParallelStrategy::None, v => { return Err(PyValueError::new_err(format!( - "parallel must be one of {{'auto', 'columns', 'row_groups', 'none'}}, got {v}", - ))) + "`parallel` must be one of {{'auto', 'columns', 'row_groups', 'none'}}, got {v}", + ))) }, }; Ok(Wrap(parsed)) @@ -1284,7 +1284,7 @@ impl FromPyObject<'_> for Wrap { "c" => IndexOrder::C, v => { return Err(PyValueError::new_err(format!( - "order must be one of {{'fortran', 'c'}}, got {v}", + "`order` must be one of {{'fortran', 'c'}}, got {v}", ))) }, }; @@ -1302,7 +1302,7 @@ impl FromPyObject<'_> for Wrap { "midpoint" => QuantileInterpolOptions::Midpoint, v => { return Err(PyValueError::new_err(format!( - "interpolation must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint'}}, got {v}", + "`interpolation` must be one of {{'lower', 'higher', 'nearest', 'linear', 'midpoint'}}, got {v}", ))) } }; @@ -1321,7 +1321,7 @@ impl FromPyObject<'_> for Wrap { "random" => RankMethod::Random, v => { return Err(PyValueError::new_err(format!( - "method must be one of {{'min', 'max', 'average', 'dense', 'ordinal', 'random'}}, got {v}", + "rank `method` must be one of {{'min', 'max', 'average', 'dense', 'ordinal', 'random'}}, got {v}", ))) } }; @@ -1337,7 +1337,7 @@ impl FromPyObject<'_> for Wrap { "ms" => TimeUnit::Milliseconds, v => { return Err(PyValueError::new_err(format!( - "time unit must be one of {{'ns', 'us', 'ms'}}, got {v}", + "`time_unit` must be one of {{'ns', 'us', 'ms'}}, got {v}", ))) }, }; @@ -1354,7 +1354,7 @@ impl FromPyObject<'_> for Wrap { "any" => UniqueKeepStrategy::Any, v => { return Err(PyValueError::new_err(format!( - "keep must be one of {{'first', 'last', 'any', 'none'}}, got {v}", + "`keep` must be one of {{'first', 'last', 'any', 'none'}}, got {v}", ))) }, }; @@ -1370,7 +1370,7 @@ impl FromPyObject<'_> for Wrap { "lz4" => IpcCompression::LZ4, v => { return Err(PyValueError::new_err(format!( - "compression must be one of {{'zstd', 'lz4'}}, got {v}", + "ipc `compression` must be one of {{'zstd', 'lz4'}}, got {v}", ))) }, }; @@ -1386,7 +1386,7 @@ impl FromPyObject<'_> for Wrap { "right" => SearchSortedSide::Right, v => { return Err(PyValueError::new_err(format!( - "side must be one of {{'any', 'left', 'right'}}, got {v}", + "sorted `side` must be one of {{'any', 'left', 'right'}}, got {v}", ))) }, }; @@ -1402,8 +1402,8 @@ impl FromPyObject<'_> for Wrap { "explode" => WindowMapping::Explode, v => { return Err(PyValueError::new_err(format!( - "side must be one of {{'group_to_rows', 'join', 'explode'}}, got {v}", - ))) + "`mapping_strategy` must be one of {{'group_to_rows', 'join', 'explode'}}, got {v}", + ))) }, }; Ok(Wrap(parsed)) @@ -1419,7 +1419,7 @@ impl FromPyObject<'_> for Wrap { "m:1" => JoinValidation::ManyToOne, v => { return Err(PyValueError::new_err(format!( - "validate must be one of {{'m:m', 'm:1', '1:m', '1:1'}}, got {v}", + "`validate` must be one of {{'m:m', 'm:1', '1:m', '1:1'}}, got {v}", ))) }, }; @@ -1433,9 +1433,10 @@ impl FromPyObject<'_> for Wrap { "always" => QuoteStyle::Always, "necessary" => QuoteStyle::Necessary, "non_numeric" => QuoteStyle::NonNumeric, + "never" => QuoteStyle::Never, v => { return Err(PyValueError::new_err(format!( - "validate must be one of {{'always', 'necessary', 'non_numeric'}}, got {v}", + "`quote_style` must be one of {{'always', 'necessary', 'non_numeric', 'never'}}, got {v}", ))) }, }; @@ -1453,7 +1454,7 @@ impl FromPyObject<'_> for Wrap { "symmetric_difference" => SetOperation::SymmetricDifference, v => { return Err(PyValueError::new_err(format!( - "validate must be one of {{'union', 'difference', 'intersection', 'symmetric_difference'}}, got {v}", + "set operation must be one of {{'union', 'difference', 'intersection', 'symmetric_difference'}}, got {v}", ))) } }; @@ -1475,7 +1476,7 @@ pub(crate) fn parse_fill_null_strategy( "one" => FillNullStrategy::One, e => { return Err(PyValueError::new_err(format!( - "strategy must be one of {{'forward', 'backward', 'min', 'max', 'mean', 'zero', 'one'}}, got {e}", + "`strategy` must be one of {{'forward', 'backward', 'min', 'max', 'mean', 'zero', 'one'}}, got {e}", ))) } }; @@ -1518,7 +1519,7 @@ pub(crate) fn parse_parquet_compression( ), e => { return Err(PyValueError::new_err(format!( - "compression must be one of {{'uncompressed', 'snappy', 'gzip', 'lzo', 'brotli', 'lz4', 'zstd'}}, got {e}", + "parquet `compression` must be one of {{'uncompressed', 'snappy', 'gzip', 'lzo', 'brotli', 'lz4', 'zstd'}}, got {e}", ))) } }; diff --git a/py-polars/src/error.rs b/py-polars/src/error.rs index 777b8f5ed33b..6e017bf566b7 100644 --- a/py-polars/src/error.rs +++ b/py-polars/src/error.rs @@ -40,6 +40,7 @@ impl std::convert::From for PyErr { }, PolarsError::Io(err) => PyIOError::new_err(err.to_string()), PolarsError::NoData(err) => NoDataError::new_err(err.to_string()), + PolarsError::OutOfBounds(err) => OutOfBoundsError::new_err(err.to_string()), PolarsError::SchemaFieldNotFound(name) => { SchemaFieldNotFoundError::new_err(name.to_string()) }, @@ -75,6 +76,7 @@ create_exception!(exceptions, ComputeError, PyException); create_exception!(exceptions, DuplicateError, PyException); create_exception!(exceptions, InvalidOperationError, PyException); create_exception!(exceptions, NoDataError, PyException); +create_exception!(exceptions, OutOfBoundsError, PyException); create_exception!(exceptions, SchemaError, PyException); create_exception!(exceptions, SchemaFieldNotFoundError, PyException); create_exception!(exceptions, ShapeError, PyException); diff --git a/py-polars/src/expr/binary.rs b/py-polars/src/expr/binary.rs index ba97435f9bca..5db2a9fcb4bf 100644 --- a/py-polars/src/expr/binary.rs +++ b/py-polars/src/expr/binary.rs @@ -5,16 +5,20 @@ use crate::PyExpr; #[pymethods] impl PyExpr { - fn bin_contains(&self, lit: Vec) -> Self { - self.inner.clone().binary().contains_literal(lit).into() + fn bin_contains(&self, lit: PyExpr) -> Self { + self.inner + .clone() + .binary() + .contains_literal(lit.inner) + .into() } - fn bin_ends_with(&self, sub: Vec) -> Self { - self.inner.clone().binary().ends_with(sub).into() + fn bin_ends_with(&self, sub: PyExpr) -> Self { + self.inner.clone().binary().ends_with(sub.inner).into() } - fn bin_starts_with(&self, sub: Vec) -> Self { - self.inner.clone().binary().starts_with(sub).into() + fn bin_starts_with(&self, sub: PyExpr) -> Self { + self.inner.clone().binary().starts_with(sub.inner).into() } #[cfg(feature = "binary_encoding")] diff --git a/py-polars/src/expr/datetime.rs b/py-polars/src/expr/datetime.rs index 0ffab6019a10..66bded2b990c 100644 --- a/py-polars/src/expr/datetime.rs +++ b/py-polars/src/expr/datetime.rs @@ -50,11 +50,11 @@ impl PyExpr { .into() } - fn dt_truncate(&self, every: String, offset: String, ambiguous: Self) -> Self { + fn dt_truncate(&self, every: Self, offset: String, ambiguous: Self) -> Self { self.inner .clone() .dt() - .truncate(TruncateOptions { every, offset }, ambiguous.inner) + .truncate(every.inner, offset, ambiguous.inner) .into() } diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index c1f24f1ca5cf..102241a7b2db 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -387,12 +387,12 @@ impl PyExpr { self.clone().inner.approx_n_unique().into() } - fn is_first(&self) -> Self { - self.clone().inner.is_first().into() + fn is_first_distinct(&self) -> Self { + self.clone().inner.is_first_distinct().into() } - fn is_last(&self) -> Self { - self.clone().inner.is_last().into() + fn is_last_distinct(&self) -> Self { + self.clone().inner.is_last_distinct().into() } fn explode(&self) -> Self { @@ -854,4 +854,47 @@ impl PyExpr { }; self.inner.clone().set_sorted_flag(is_sorted).into() } + + #[cfg(feature = "ffi_plugin")] + #[allow(clippy::too_many_arguments)] + fn register_plugin( + &self, + lib: &str, + symbol: &str, + args: Vec, + is_elementwise: bool, + input_wildcard_expansion: bool, + auto_explode: bool, + cast_to_supertypes: bool, + ) -> Self { + use polars_plan::prelude::*; + let inner = self.inner.clone(); + + let collect_groups = if is_elementwise { + ApplyOptions::ApplyFlat + } else { + ApplyOptions::ApplyGroups + }; + let mut input = Vec::with_capacity(args.len() + 1); + input.push(inner); + for a in args { + input.push(a.inner) + } + + Expr::Function { + input, + function: FunctionExpr::FfiPlugin { + lib: Arc::from(lib), + symbol: Arc::from(symbol), + }, + options: FunctionOptions { + collect_groups, + input_wildcard_expansion, + auto_explode, + cast_to_supertypes, + ..Default::default() + }, + } + .into() + } } diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 3495718d817d..261a96fa6259 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -49,8 +49,8 @@ impl PyExpr { self.inner.clone().list().get(index.inner).into() } - fn list_join(&self, separator: &str) -> Self { - self.inner.clone().list().join(separator).into() + fn list_join(&self, separator: PyExpr) -> Self { + self.inner.clone().list().join(separator.inner).into() } fn list_lengths(&self) -> Self { diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index ed542b6c0fb0..e466e5256254 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -275,12 +275,12 @@ impl PyExpr { .into() } - fn str_split(&self, by: &str) -> Self { - self.inner.clone().str().split(by).into() + fn str_split(&self, by: Self) -> Self { + self.inner.clone().str().split(by.inner).into() } - fn str_split_inclusive(&self, by: &str) -> Self { - self.inner.clone().str().split_inclusive(by).into() + fn str_split_inclusive(&self, by: Self) -> Self { + self.inner.clone().str().split_inclusive(by.inner).into() } fn str_split_exact(&self, by: &str, n: usize) -> Self { diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 0fb9243e0896..a5ea89da0c80 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -71,7 +71,7 @@ pub fn arg_where(condition: PyExpr) -> PyExpr { #[pyfunction] pub fn as_struct(exprs: Vec) -> PyExpr { let exprs = exprs.to_exprs(); - dsl::as_struct(&exprs).into() + dsl::as_struct(exprs).into() } #[pyfunction] diff --git a/py-polars/src/lazyframe.rs b/py-polars/src/lazyframe.rs index fec73fa40c3e..215435fcb2fc 100644 --- a/py-polars/src/lazyframe.rs +++ b/py-polars/src/lazyframe.rs @@ -11,9 +11,8 @@ use polars::lazy::frame::LazyCsvReader; use polars::lazy::frame::LazyJsonLineReader; use polars::lazy::frame::{AllowedOptimizations, LazyFrame}; use polars::lazy::prelude::col; -use polars::prelude::{ClosedWindow, CsvEncoding, Field, JoinType, Schema}; +use polars::prelude::{cloud, ClosedWindow, CsvEncoding, Field, JoinType, Schema}; use polars::time::*; -use polars_core::cloud; use polars_core::frame::explode::MeltArgs; use polars_core::frame::hash_join::JoinValidation; use polars_core::frame::UniqueKeepStrategy; diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index 0f4859887a58..3a9ba4f1e2a8 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -51,7 +51,8 @@ use crate::conversion::Wrap; use crate::dataframe::PyDataFrame; use crate::error::{ ArrowErrorException, ColumnNotFoundError, ComputeError, DuplicateError, InvalidOperationError, - NoDataError, PyPolarsErr, SchemaError, SchemaFieldNotFoundError, StructFieldNotFoundError, + NoDataError, OutOfBoundsError, PyPolarsErr, SchemaError, SchemaFieldNotFoundError, + StructFieldNotFoundError, }; use crate::expr::PyExpr; use crate::lazyframe::PyLazyFrame; @@ -243,6 +244,8 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { ) .unwrap(); m.add("NoDataError", py.get_type::()).unwrap(); + m.add("OutOfBoundsError", py.get_type::()) + .unwrap(); m.add("PolarsPanicError", py.get_type::()) .unwrap(); m.add("SchemaError", py.get_type::()).unwrap(); diff --git a/py-polars/src/map/lazy.rs b/py-polars/src/map/lazy.rs index 7138e8c00309..817840f08bef 100644 --- a/py-polars/src/map/lazy.rs +++ b/py-polars/src/map/lazy.rs @@ -157,7 +157,7 @@ pub(crate) fn call_lambda_with_series_slice( // call the lambda and get a python side Series wrapper match lambda.call1(py, (wrapped_s,)) { Ok(pyobj) => pyobj, - Err(e) => panic!("python apply failed: {}", e.value(py)), + Err(e) => panic!("python function failed: {}", e.value(py)), } } diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index 6917f1d0203e..d523576e2ffa 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -11,7 +11,7 @@ use polars_algo::hist; use polars_core::series::IsSorted; use polars_core::utils::flatten::flatten_series; use polars_core::with_match_physical_numeric_polars_type; -use pyo3::exceptions::{PyRuntimeError, PyValueError}; +use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3::Python; @@ -154,8 +154,15 @@ impl PySeries { } } - fn get_idx(&self, py: Python, idx: usize) -> PyResult { - let av = self.series.get(idx).map_err(PyPolarsErr::from)?; + fn get_index(&self, py: Python, index: usize) -> PyResult { + let av = match self.series.get(index) { + Ok(v) => v, + Err(PolarsError::OutOfBounds(err)) => { + return Err(PyIndexError::new_err(err.to_string())) + }, + Err(e) => return Err(PyPolarsErr::from(e).into()), + }; + if let AnyValue::List(s) = av { let pyseries = PySeries::new(s); let out = POLARS @@ -163,11 +170,27 @@ impl PySeries { .unwrap() .call1(py, (pyseries,)) .unwrap(); + return Ok(out.into_py(py)); + } - Ok(out.into_py(py)) + Ok(Wrap(av).into_py(py)) + } + + /// Get index but allow negative indices + fn get_index_signed(&self, py: Python, index: i64) -> PyResult { + let index = if index < 0 { + match self.len().checked_sub(index.unsigned_abs() as usize) { + Some(v) => v, + None => { + return Err(PyIndexError::new_err( + polars_err!(oob = index, self.len()).to_string(), + )); + }, + } } else { - Ok(Wrap(self.series.get(idx).map_err(PyPolarsErr::from)?).into_py(py)) - } + index as usize + }; + self.get_index(py, index) } fn bitand(&self, other: &PySeries) -> PyResult { @@ -552,12 +575,8 @@ impl PySeries { } fn get_list(&self, index: usize) -> Option { - if let Ok(ca) = &self.series.list() { - let s = ca.get(index); - s.map(|s| s.into()) - } else { - None - } + let ca = self.series.list().ok()?; + Some(ca.get_as_series(index)?.into()) } fn peak_max(&self) -> Self { diff --git a/py-polars/tests/docs/test_user_guide.py b/py-polars/tests/docs/test_user_guide.py new file mode 100644 index 000000000000..032961dd936a --- /dev/null +++ b/py-polars/tests/docs/test_user_guide.py @@ -0,0 +1,32 @@ +"""Run all Python code snippets.""" +import os +import runpy +from pathlib import Path +from typing import Iterator + +import matplotlib +import pytest + +# Do not show plots +matplotlib.use("Agg") + +# Get paths to Python code snippets +repo_root = Path(__file__).parent.parent.parent.parent +python_snippets_dir = repo_root / "docs" / "src" / "python" +snippet_paths = list(python_snippets_dir.rglob("*.py")) + + +@pytest.fixture(scope="module") +def _change_test_dir() -> Iterator[None]: + """Change path to repo root to accommodate data paths in code snippets.""" + current_path = Path() + os.chdir(repo_root) + yield + os.chdir(current_path) + + +@pytest.mark.docs() +@pytest.mark.parametrize("path", snippet_paths) +@pytest.mark.usefixtures("_change_test_dir") +def test_run_python_snippets(path: Path) -> None: + runpy.run_path(str(path)) diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index b718c2b3eee3..9e074dd6b72a 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -3402,31 +3402,6 @@ def test_glimpse(capsys: Any) -> None: assert result == expected -def test_item() -> None: - df = pl.DataFrame({"a": [1]}) - assert df.item() == 1 - - df = pl.DataFrame({"a": [1, 2]}) - with pytest.raises(ValueError, match=r".* frame has shape \(2, 1\)"): - df.item() - - assert df.item(0, 0) == 1 - assert df.item(1, "a") == 2 - - df = pl.DataFrame({"a": [1], "b": [2]}) - with pytest.raises(ValueError, match=r".* frame has shape \(1, 2\)"): - df.item() - - assert df.item(0, "a") == 1 - assert df.item(0, "b") == 2 - - df = pl.DataFrame({}) - with pytest.raises(ValueError, match=r".* frame has shape \(0, 0\)"): - df.item() - with pytest.raises(IndexError, match="column index 10 is out of bounds"): - df.item(0, 10) - - @pytest.mark.parametrize( ("subset", "keep", "expected_mask"), [ diff --git a/py-polars/tests/unit/dataframe/test_item.py b/py-polars/tests/unit/dataframe/test_item.py new file mode 100644 index 000000000000..12f9d87c913f --- /dev/null +++ b/py-polars/tests/unit/dataframe/test_item.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_df_item() -> None: + df = pl.DataFrame({"a": [1]}) + assert df.item() == 1 + + +def test_df_item_empty() -> None: + df = pl.DataFrame() + with pytest.raises(ValueError, match=r".* frame has shape \(0, 0\)"): + df.item() + + +def test_df_item_incorrect_shape_rows() -> None: + df = pl.DataFrame({"a": [1, 2]}) + with pytest.raises(ValueError, match=r".* frame has shape \(2, 1\)"): + df.item() + + +def test_df_item_incorrect_shape_columns() -> None: + df = pl.DataFrame({"a": [1], "b": [2]}) + with pytest.raises(ValueError, match=r".* frame has shape \(1, 2\)"): + df.item() + + +@pytest.fixture(scope="module") +def df() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@pytest.mark.parametrize( + ("row", "col", "expected"), + [ + (0, 0, 1), + (1, "a", 2), + (-1, 1, 6), + (-2, "b", 5), + ], +) +def test_df_item_with_indices( + row: int, col: int | str, expected: int, df: pl.DataFrame +) -> None: + assert df.item(row, col) == expected + + +def test_df_item_with_single_index(df: pl.DataFrame) -> None: + with pytest.raises(ValueError): + df.item(0) + with pytest.raises(ValueError): + df.item(column="b") + with pytest.raises(ValueError): + df.item(None, 0) + + +@pytest.mark.parametrize( + ("row", "col"), [(0, 10), (10, 0), (10, 10), (-10, 0), (-10, 10)] +) +def test_df_item_out_of_bounds(row: int, col: int, df: pl.DataFrame) -> None: + with pytest.raises(IndexError, match="out of bounds"): + df.item(row, col) diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index 06d7a6f266b4..b17ce6304ad0 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -590,8 +590,13 @@ def test_truncate_negative_offset(tzinfo: ZoneInfo | None) -> None: "idx", every="2i", period="3i", include_boundaries=True ).agg(pl.col("A")) - assert out.shape == (3, 4) - assert out["A"].to_list() == [["A", "A", "B"], ["B", "B", "B"], ["B", "C"]] + assert out.shape == (4, 4) + assert out["A"].to_list() == [ + ["A"], + ["A", "A", "B"], + ["B", "B", "B"], + ["B", "C"], + ] def test_to_arrow() -> None: @@ -658,13 +663,14 @@ def test_groupy_by_dynamic_median_10695() -> None: pl.col("foo").median() ).to_dict(False) == { "timestamp": [ + datetime(2023, 8, 22, 15, 43), datetime(2023, 8, 22, 15, 44), datetime(2023, 8, 22, 15, 45), datetime(2023, 8, 22, 15, 46), datetime(2023, 8, 22, 15, 47), datetime(2023, 8, 22, 15, 48), ], - "foo": [1.0, 1.0, 1.0, 1.0, 1.0], + "foo": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], } @@ -2570,6 +2576,125 @@ def test_asof_join_by_forward() -> None: } +def test_truncate_expr() -> None: + df = pl.DataFrame( + { + "date": [ + datetime(2022, 11, 14), + datetime(2023, 10, 11), + datetime(2022, 3, 20, 5, 7, 18), + datetime(2022, 4, 3, 13, 30, 32), + ], + "every": ["1y", "1mo", "1m", "1s"], + "ambiguous": ["earliest", "latest", "latest", "raise"], + } + ) + + every_expr = df.select( + pl.col("date").dt.truncate(every=pl.col("every"), ambiguous=pl.lit("raise")) + ) + assert every_expr.to_dict(False) == { + "date": [ + datetime(2022, 1, 1), + datetime(2023, 10, 1), + datetime(2022, 3, 20, 5, 7), + datetime(2022, 4, 3, 13, 30, 32), + ] + } + + all_lit = df.select( + pl.col("date").dt.truncate(every=pl.lit("1mo"), ambiguous=pl.lit("raise")) + ) + assert all_lit.to_dict(False) == { + "date": [ + datetime(2022, 11, 1), + datetime(2023, 10, 1), + datetime(2022, 3, 1), + datetime(2022, 4, 1), + ] + } + + df = pl.DataFrame( + { + "date": pl.datetime_range( + date(2020, 10, 25), + datetime(2020, 10, 25, 2), + "30m", + eager=True, + time_zone="Europe/London", + ).dt.offset_by("15m"), + "every": ["30m", "15m", "30m", "15m", "30m", "15m", "30m"], + "ambiguous": [ + "raise", + "earliest", + "earliest", + "latest", + "latest", + "latest", + "raise", + ], + } + ) + + ambiguous_expr = df.select( + pl.col("date").dt.truncate(every=pl.lit("30m"), ambiguous=pl.col("ambiguous")) + ) + assert ambiguous_expr.to_dict(False) == { + "date": [ + datetime(2020, 10, 25, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 0, 30, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 30, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo(key="Europe/London")), + ] + } + + all_expr = df.select( + pl.col("date").dt.truncate(every=pl.col("every"), ambiguous=pl.col("ambiguous")) + ) + assert all_expr.to_dict(False) == { + "date": [ + datetime(2020, 10, 25, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 0, 45, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 0, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 1, 45, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 25, 2, 0, tzinfo=ZoneInfo(key="Europe/London")), + ] + } + + +def test_truncate_propagate_null() -> None: + df = pl.DataFrame( + { + "date": [ + None, + datetime(2022, 11, 14), + datetime(2022, 3, 20, 5, 7, 18), + ], + "every": ["1y", None, "1m"], + "ambiguous": ["earliest", "latest", None], + } + ) + assert df.select( + pl.col("date").dt.truncate(every=pl.col("every"), ambiguous="raise") + ).to_dict(False) == {"date": [None, None, datetime(2022, 3, 20, 5, 7, 0)]} + assert df.select( + pl.col("date").dt.truncate(every="1mo", ambiguous=pl.col("ambiguous")) + ).to_dict(False) == {"date": [None, datetime(2022, 11, 1), None]} + assert df.select( + pl.col("date").dt.truncate(every=pl.col("every"), ambiguous=pl.col("ambiguous")) + ).to_dict(False) == {"date": [None, None, None]} + assert df.select( + pl.col("date").dt.truncate( + every=pl.lit(None, dtype=pl.Utf8), ambiguous=pl.lit(None, dtype=pl.Utf8) + ) + ).to_dict(False) == {"date": [None, None, None]} + + def test_truncate_by_calendar_weeks() -> None: # 5557 start = datetime(2022, 11, 14, 0, 0, 0) diff --git a/py-polars/tests/unit/functions/test_as_datatype.py b/py-polars/tests/unit/functions/test_as_datatype.py index a039a14e4a19..6a92f0effd4f 100644 --- a/py-polars/tests/unit/functions/test_as_datatype.py +++ b/py-polars/tests/unit/functions/test_as_datatype.py @@ -514,6 +514,16 @@ def test_concat_str_wildcard_expansion() -> None: ).to_series().to_list() == ["xs", "yo", "zs"] +def test_concat_str_with_non_utf8_col() -> None: + out = ( + pl.LazyFrame({"a": [0], "b": ["x"]}) + .select(pl.concat_str(["a", "b"], separator="-").fill_null(pl.col("a"))) + .collect() + ) + expected = pl.Series("a", ["0-x"], dtype=pl.Utf8) + assert_series_equal(out.to_series(), expected) + + def test_format() -> None: df = pl.DataFrame({"a": ["a", "b", "c"], "b": [1, 2, 3]}) diff --git a/py-polars/tests/unit/functions/test_repeat.py b/py-polars/tests/unit/functions/test_repeat.py index d31131548a6f..e2e39669240e 100644 --- a/py-polars/tests/unit/functions/test_repeat.py +++ b/py-polars/tests/unit/functions/test_repeat.py @@ -74,7 +74,7 @@ def test_repeat_n_non_integer(n: Any) -> None: def test_repeat_n_empty() -> None: df = pl.DataFrame(schema={"a": pl.Int32}) - with pytest.raises(pl.ComputeError, match="index 0 is out of bounds"): + with pytest.raises(pl.OutOfBoundsError, match="index 0 is out of bounds"): df.select(pl.repeat(1, n=pl.col("a"))) diff --git a/py-polars/tests/unit/io/files/empty.ods b/py-polars/tests/unit/io/files/empty.ods new file mode 100644 index 000000000000..80e9cafdc19d Binary files /dev/null and b/py-polars/tests/unit/io/files/empty.ods differ diff --git a/py-polars/tests/unit/io/files/example.ods b/py-polars/tests/unit/io/files/example.ods new file mode 100644 index 000000000000..7f217f1d86dc Binary files /dev/null and b/py-polars/tests/unit/io/files/example.ods differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet.crc b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet.crc new file mode 100644 index 000000000000..a9285317b405 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet.crc differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet new file mode 100644 index 000000000000..0bbb8ba707a5 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-01/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00001.parquet differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet.crc b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet.crc new file mode 100644 index 000000000000..258a5cd76cc1 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/.00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet.crc differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet new file mode 100644 index 000000000000..8da5caa43195 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/data/ts_day=2023-03-02/00000-1-6bc54766-6e8a-4fd5-8c00-c6bacbdcaeeb-00002.parquet differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/metadata/aef5d952-7e24-4764-9b30-3483be37240f-m0.avro b/py-polars/tests/unit/io/files/iceberg-table/metadata/aef5d952-7e24-4764-9b30-3483be37240f-m0.avro new file mode 100644 index 000000000000..a5e28a04bcff Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/metadata/aef5d952-7e24-4764-9b30-3483be37240f-m0.avro differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro b/py-polars/tests/unit/io/files/iceberg-table/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro new file mode 100644 index 000000000000..21f503266051 Binary files /dev/null and b/py-polars/tests/unit/io/files/iceberg-table/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro differ diff --git a/py-polars/tests/unit/io/files/iceberg-table/metadata/v2.metadata.json b/py-polars/tests/unit/io/files/iceberg-table/metadata/v2.metadata.json new file mode 100644 index 000000000000..04e078f8a63b --- /dev/null +++ b/py-polars/tests/unit/io/files/iceberg-table/metadata/v2.metadata.json @@ -0,0 +1,109 @@ +{ + "format-version" : 1, + "table-uuid" : "c470045a-5d75-48aa-9d4b-86a86a9b1fce", + "location" : "/tmp/iceberg/t1", + "last-updated-ms" : 1694547405299, + "last-column-id" : 3, + "schema" : { + "type" : "struct", + "schema-id" : 0, + "fields" : [ { + "id" : 1, + "name" : "id", + "required" : false, + "type" : "int" + }, { + "id" : 2, + "name" : "str", + "required" : false, + "type" : "string" + }, { + "id" : 3, + "name" : "ts", + "required" : false, + "type" : "timestamp" + } ] + }, + "current-schema-id" : 0, + "schemas" : [ { + "type" : "struct", + "schema-id" : 0, + "fields" : [ { + "id" : 1, + "name" : "id", + "required" : false, + "type" : "int" + }, { + "id" : 2, + "name" : "str", + "required" : false, + "type" : "string" + }, { + "id" : 3, + "name" : "ts", + "required" : false, + "type" : "timestamp" + } ] + } ], + "partition-spec" : [ { + "name" : "ts_day", + "transform" : "day", + "source-id" : 3, + "field-id" : 1000 + } ], + "default-spec-id" : 0, + "partition-specs" : [ { + "spec-id" : 0, + "fields" : [ { + "name" : "ts_day", + "transform" : "day", + "source-id" : 3, + "field-id" : 1000 + } ] + } ], + "last-partition-id" : 1000, + "default-sort-order-id" : 0, + "sort-orders" : [ { + "order-id" : 0, + "fields" : [ ] + } ], + "properties" : { + "owner" : "fokkodriesprong" + }, + "current-snapshot-id" : 7051579356916758811, + "refs" : { + "main" : { + "snapshot-id" : 7051579356916758811, + "type" : "branch" + } + }, + "snapshots" : [ { + "snapshot-id" : 7051579356916758811, + "timestamp-ms" : 1694547405299, + "summary" : { + "operation" : "append", + "spark.app.id" : "local-1694547283063", + "added-data-files" : "2", + "added-records" : "3", + "added-files-size" : "1788", + "changed-partition-count" : "2", + "total-records" : "3", + "total-files-size" : "1788", + "total-data-files" : "2", + "total-delete-files" : "0", + "total-position-deletes" : "0", + "total-equality-deletes" : "0" + }, + "manifest-list" : "/tmp/iceberg/t1/metadata/snap-7051579356916758811-1-aef5d952-7e24-4764-9b30-3483be37240f.avro", + "schema-id" : 0 + } ], + "statistics" : [ ], + "snapshot-log" : [ { + "timestamp-ms" : 1694547405299, + "snapshot-id" : 7051579356916758811 + } ], + "metadata-log" : [ { + "timestamp-ms" : 1694547211303, + "metadata-file" : "/tmp/iceberg/t1/metadata/v1.metadata.json" + } ] +} \ No newline at end of file diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 2c5be46ca422..d3eed2ca7a9c 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -1463,23 +1463,27 @@ def test_csv_quote_styles() -> None: df = pl.DataFrame( { "float": [1.0, 2.0, None], - "string": ["a", "abc", '"hello'], + "string": ["a", "a,bc", '"hello'], "int": [1, 2, 3], "bool": [True, False, None], } ) + assert ( + df.write_csv(quote_style="always") + == '"float","string","int","bool"\n"1.0","a","1","true"\n"2.0","a,bc","2","false"\n"","""hello","3",""\n' + ) assert ( df.write_csv(quote_style="necessary") - == 'float,string,int,bool\n1.0,a,1,true\n2.0,abc,2,false\n,"""hello",3,\n' + == 'float,string,int,bool\n1.0,a,1,true\n2.0,"a,bc",2,false\n,"""hello",3,\n' ) assert ( - df.write_csv(quote_style="always") - == '"float","string","int","bool"\n"1.0","a","1","true"\n"2.0","abc","2","false"\n"","""hello","3",""\n' + df.write_csv(quote_style="never") + == 'float,string,int,bool\n1.0,a,1,true\n2.0,a,bc,2,false\n,"hello,3,\n' ) assert ( df.write_csv(quote_style="non_numeric", quote="8") - == '8float8,8string8,8int8,8bool8\n1.0,8a8,1,8true8\n2.0,8abc8,2,8false8\n,8"hello8,3,\n' + == '8float8,8string8,8int8,8bool8\n1.0,8a8,1,8true8\n2.0,8a,bc8,2,8false8\n,8"hello8,3,\n' ) diff --git a/py-polars/tests/unit/io/test_iceberg.py b/py-polars/tests/unit/io/test_iceberg.py new file mode 100644 index 000000000000..692bf790a3b7 --- /dev/null +++ b/py-polars/tests/unit/io/test_iceberg.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import contextlib +import os +from datetime import datetime +from pathlib import Path + +import pytest + +import polars as pl +from polars.io.iceberg import _convert_predicate, _to_ast + + +@pytest.fixture() +def iceberg_path(io_files_path: Path) -> str: + # Iceberg requires absolute paths, so we'll symlink + # the test table into /tmp/iceberg/t1/ + Path("/tmp/iceberg").mkdir(parents=True, exist_ok=True) + current_path = Path(__file__).parent.resolve() + + with contextlib.suppress(FileExistsError): + os.symlink(f"{current_path}/files/iceberg-table", "/tmp/iceberg/t1") + + iceberg_path = io_files_path / "iceberg-table" / "metadata" / "v2.metadata.json" + return f"file://{iceberg_path.resolve()}" + + +@pytest.mark.filterwarnings( + "ignore:No preferred file implementation for scheme*:UserWarning" +) +def test_scan_iceberg_plain(iceberg_path: str) -> None: + df = pl.scan_iceberg(iceberg_path) + assert len(df.collect()) == 3 + assert df.schema == { + "id": pl.Int32, + "str": pl.Utf8, + "ts": pl.Datetime(time_unit="us", time_zone=None), + } + + +@pytest.mark.filterwarnings( + "ignore:No preferred file implementation for scheme*:UserWarning" +) +def test_scan_iceberg_filter_on_partition(iceberg_path: str) -> None: + ts1 = datetime(2023, 3, 1, 18, 15) + ts2 = datetime(2023, 3, 1, 19, 25) + ts3 = datetime(2023, 3, 2, 22, 0) + + lf = pl.scan_iceberg(iceberg_path) + + res = lf.filter(pl.col("ts") >= ts2) + assert len(res.collect()) == 2 + + res = lf.filter(pl.col("ts") > ts2).select(pl.col("id")) + assert res.collect().rows() == [(3,)] + + res = lf.filter(pl.col("ts") <= ts2).select("id", "ts") + assert res.collect().rows(named=True) == [ + {"id": 1, "ts": ts1}, + {"id": 2, "ts": ts2}, + ] + + res = lf.filter(pl.col("ts") > ts3) + assert len(res.collect()) == 0 + + for constraint in ( + (pl.col("ts") == ts1) | (pl.col("ts") == ts3), + pl.col("ts").is_in([ts1, ts3]), + ): + res = lf.filter(constraint).select("id") + assert res.collect().rows() == [(1,), (3,)] + + +@pytest.mark.filterwarnings( + "ignore:No preferred file implementation for scheme*:UserWarning" +) +def test_scan_iceberg_filter_on_column(iceberg_path: str) -> None: + lf = pl.scan_iceberg(iceberg_path) + res = lf.filter(pl.col("id") < 2) + assert res.collect().rows() == [(1, "1", datetime(2023, 3, 1, 18, 15))] + + res = lf.filter(pl.col("id") == 2) + assert res.collect().rows() == [(2, "2", datetime(2023, 3, 1, 19, 25))] + + res = lf.filter(pl.col("id").is_in([1, 3])) + assert res.collect().rows() == [ + (1, "1", datetime(2023, 3, 1, 18, 15)), + (3, "3", datetime(2023, 3, 2, 22, 0)), + ] + + +def test_is_null_expression() -> None: + from pyiceberg.expressions import IsNull + + expr = _to_ast("(pa.compute.field('id')).is_null()") + assert _convert_predicate(expr) == IsNull("id") + + +def test_is_not_null_expression() -> None: + from pyiceberg.expressions import IsNull, Not + + expr = _to_ast("~(pa.compute.field('id')).is_null()") + assert _convert_predicate(expr) == Not(IsNull("id")) + + +def test_isin_expression() -> None: + from pyiceberg.expressions import In, literal # type: ignore[attr-defined] + + expr = _to_ast("(pa.compute.field('id')).isin([1,2,3])") + assert _convert_predicate(expr) == In("id", {literal(1), literal(2), literal(3)}) + + +def test_parse_combined_expression() -> None: + from pyiceberg.expressions import ( # type: ignore[attr-defined] + And, + EqualTo, + GreaterThan, + In, + Or, + Reference, + literal, + ) + + expr = _to_ast( + "(((pa.compute.field('str') == '2') & (pa.compute.field('id') > 10)) | (pa.compute.field('id')).isin([1,2,3]))" + ) + assert _convert_predicate(expr) == Or( + left=And( + left=EqualTo(term=Reference(name="str"), literal=literal("2")), + right=GreaterThan(term="id", literal=literal(10)), + ), + right=In("id", {literal(1), literal(2), literal(3)}), + ) + + +def test_parse_gt() -> None: + from pyiceberg.expressions import GreaterThan + + expr = _to_ast("(pa.compute.field('ts') > '2023-08-08')") + assert _convert_predicate(expr) == GreaterThan("ts", "2023-08-08") + + +def test_parse_gteq() -> None: + from pyiceberg.expressions import GreaterThanOrEqual + + expr = _to_ast("(pa.compute.field('ts') >= '2023-08-08')") + assert _convert_predicate(expr) == GreaterThanOrEqual("ts", "2023-08-08") + + +def test_parse_eq() -> None: + from pyiceberg.expressions import EqualTo + + expr = _to_ast("(pa.compute.field('ts') == '2023-08-08')") + assert _convert_predicate(expr) == EqualTo("ts", "2023-08-08") + + +def test_parse_lt() -> None: + from pyiceberg.expressions import LessThan + + expr = _to_ast("(pa.compute.field('ts') < '2023-08-08')") + assert _convert_predicate(expr) == LessThan("ts", "2023-08-08") + + +def test_parse_lteq() -> None: + from pyiceberg.expressions import LessThanOrEqual + + expr = _to_ast("(pa.compute.field('ts') <= '2023-08-08')") + assert _convert_predicate(expr) == LessThanOrEqual("ts", "2023-08-08") diff --git a/py-polars/tests/unit/io/test_lazy_parquet.py b/py-polars/tests/unit/io/test_lazy_parquet.py index 02e51bc93765..77b829ae23a0 100644 --- a/py-polars/tests/unit/io/test_lazy_parquet.py +++ b/py-polars/tests/unit/io/test_lazy_parquet.py @@ -389,3 +389,13 @@ def test_parquet_statistics_filter_9925(tmp_path: Path) -> None: (pl.col("code").floordiv(100_000)).is_in([0, 3]) ) assert q.collect().to_dict(False) == {"code": [300964, 300972, 26]} + + +@pytest.mark.write_disk() +def test_parquet_statistics_filter_11069(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + file_path = tmp_path / "foo.parquet" + pl.DataFrame({"x": [1, None]}).write_parquet(file_path, statistics=False) + assert pl.scan_parquet(file_path).filter(pl.col("x").is_null()).collect().to_dict( + False + ) == {"x": [None]} diff --git a/py-polars/tests/unit/io/test_parquet.py b/py-polars/tests/unit/io/test_parquet.py index 2b4f4fa6ddd4..f3de8823b6a4 100644 --- a/py-polars/tests/unit/io/test_parquet.py +++ b/py-polars/tests/unit/io/test_parquet.py @@ -2,7 +2,6 @@ import io from datetime import datetime, timezone -from pathlib import Path from typing import TYPE_CHECKING import numpy as np @@ -15,6 +14,8 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: + from pathlib import Path + from polars.type_aliases import ParquetCompression @@ -486,10 +487,8 @@ def test_parquet_string_cache() -> None: assert_series_equal(pl.read_parquet(f)["a"].cast(str), df["a"].cast(str)) -def test_tz_aware_parquet_9586() -> None: - result = pl.read_parquet( - Path("tests") / "unit" / "io" / "files" / "tz_aware.parquet" - ) +def test_tz_aware_parquet_9586(io_files_path: Path) -> None: + result = pl.read_parquet(io_files_path / "tz_aware.parquet") expected = pl.DataFrame( {"UTC_DATETIME_ID": [datetime(2023, 6, 26, 14, 15, 0, tzinfo=timezone.utc)]} ).select(pl.col("*").cast(pl.Datetime("ns", "UTC"))) diff --git a/py-polars/tests/unit/io/test_pyarrow_dataset.py b/py-polars/tests/unit/io/test_pyarrow_dataset.py index 8a5c85e2ef1c..4f2b1d2ed90e 100644 --- a/py-polars/tests/unit/io/test_pyarrow_dataset.py +++ b/py-polars/tests/unit/io/test_pyarrow_dataset.py @@ -17,6 +17,7 @@ def helper_dataset_test( file_path: Path, query: Callable[[pl.LazyFrame], pl.DataFrame], batch_size: int | None = None, + n_expected: int | None = None, ) -> None: dset = ds.dataset(file_path, format="ipc") expected = query(pl.scan_ipc(file_path)) @@ -24,6 +25,8 @@ def helper_dataset_test( pl.scan_pyarrow_dataset(dset, batch_size=batch_size), ) assert_frame_equal(out, expected) + if n_expected is not None: + assert len(out) == n_expected @pytest.mark.write_disk() @@ -34,30 +37,35 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: helper_dataset_test( file_path, lambda lf: lf.filter("bools").select(["bools", "floats", "date"]).collect(), + n_expected=1, ) helper_dataset_test( file_path, lambda lf: lf.filter(~pl.col("bools")) .select(["bools", "floats", "date"]) .collect(), + n_expected=2, ) helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("int_nulls").is_null()) .select(["bools", "floats", "date"]) .collect(), + n_expected=1, ) helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("int_nulls").is_not_null()) .select(["bools", "floats", "date"]) .collect(), + n_expected=2, ) helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("int_nulls").is_not_null() == pl.col("bools")) .select(["bools", "floats", "date"]) .collect(), + n_expected=0, ) # this equality on a column with nulls fails as pyarrow has different # handling kleene logic. We leave it for now and document it in the function. @@ -66,12 +74,14 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: lambda lf: lf.filter(pl.col("int") == 10) .select(["bools", "floats", "int_nulls"]) .collect(), + n_expected=0, ) helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("int") != 10) .select(["bools", "floats", "int_nulls"]) .collect(), + n_expected=3, ) # this predicate is not supported by pyarrow # check if we still do it on our side @@ -80,20 +90,22 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: lambda lf: lf.filter(pl.col("floats").sum().over("date") == 10) .select(["bools", "floats", "date"]) .collect(), + n_expected=0, ) - # temporal types helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("date") < date(1972, 1, 1)) .select(["bools", "floats", "date"]) .collect(), + n_expected=1, ) helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("datetime") > datetime(1970, 1, 1, second=13)) .select(["bools", "floats", "date"]) .collect(), + n_expected=1, ) # not yet supported in pyarrow helper_dataset_test( @@ -101,20 +113,45 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: lambda lf: lf.filter(pl.col("time") >= time(microsecond=100)) .select(["bools", "time", "date"]) .collect(), + n_expected=3, ) - # pushdown is_in helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("int").is_in([1, 3, 20])) .select(["bools", "floats", "date"]) .collect(), + n_expected=2, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter( + pl.col("date").is_in([date(1973, 8, 17), date(1973, 5, 19)]) + ) + .select(["bools", "floats", "date"]) + .collect(), + n_expected=2, + ) + helper_dataset_test( + file_path, + lambda lf: lf.filter( + pl.col("datetime").is_in( + [ + datetime(1970, 1, 1, 0, 0, 12, 341234), + datetime(1970, 1, 1, 0, 0, 13, 241324), + ] + ) + ) + .select(["bools", "floats", "date"]) + .collect(), + n_expected=2, ) helper_dataset_test( file_path, lambda lf: lf.filter(pl.col("int").is_in(list(range(120)))) .select(["bools", "floats", "date"]) .collect(), + n_expected=3, ) # TODO: remove string cache with pl.StringCache(): @@ -123,11 +160,13 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: lambda lf: lf.filter(pl.col("cat").is_in([])) .select(["bools", "floats", "date"]) .collect(), + n_expected=0, ) helper_dataset_test( file_path, lambda lf: lf.collect(), batch_size=2, + n_expected=3, ) # direct filter @@ -136,6 +175,7 @@ def test_dataset_foo(df: pl.DataFrame, tmp_path: Path) -> None: lambda lf: lf.filter(pl.Series([True, False, True])) .select(["bools", "floats", "date"]) .collect(), + n_expected=2, ) diff --git a/py-polars/tests/unit/io/test_excel.py b/py-polars/tests/unit/io/test_spreadsheet.py similarity index 67% rename from py-polars/tests/unit/io/test_excel.py rename to py-polars/tests/unit/io/test_spreadsheet.py index be6e26399cc3..9f872f678f73 100644 --- a/py-polars/tests/unit/io/test_excel.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -1,20 +1,21 @@ from __future__ import annotations +import warnings from datetime import date, datetime from io import BytesIO -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal import pytest import polars as pl import polars.selectors as cs -from polars.exceptions import NoDataError +from polars.exceptions import NoDataError, ParameterCollisionError from polars.testing import assert_frame_equal if TYPE_CHECKING: from pathlib import Path - from polars.type_aliases import SelectorType + from polars.type_aliases import SchemaDict, SelectorType @pytest.fixture() @@ -27,25 +28,66 @@ def empty_excel_file_path(io_files_path: Path) -> Path: return io_files_path / "empty.xlsx" -@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) -def test_read_excel( - excel_file_path: Path, engine: Literal["xlsx2csv", "openpyxl"] -) -> None: - df = pl.read_excel(excel_file_path, sheet_name="test1", sheet_id=None) +@pytest.fixture() +def openoffice_file_path(io_files_path: Path) -> Path: + return io_files_path / "example.ods" + +@pytest.fixture() +def empty_openoffice_file_path(io_files_path: Path) -> Path: + return io_files_path / "empty.ods" + + +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + (pl.read_excel, "excel_file_path", {"engine": "xlsx2csv"}), + (pl.read_excel, "excel_file_path", {"engine": "openpyxl"}), + (pl.read_ods, "openoffice_file_path", {}), + ], +) +def test_read_spreadsheet( + read_spreadsheet: Callable[..., pl.DataFrame], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + df = read_spreadsheet( + source=request.getfixturevalue(source), + sheet_name="test1", + sheet_id=None, + **params, + ) expected = pl.DataFrame({"hello": ["Row 1", "Row 2"]}) assert_frame_equal(df, expected) -@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + (pl.read_excel, "excel_file_path", {"engine": "xlsx2csv"}), + (pl.read_excel, "excel_file_path", {"engine": "openpyxl"}), + (pl.read_ods, "openoffice_file_path", {}), + ], +) def test_read_excel_multi_sheets( - excel_file_path: Path, engine: Literal["xlsx2csv", "openpyxl"] + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, ) -> None: - frames_by_id = pl.read_excel( - excel_file_path, sheet_id=[1, 2], sheet_name=None, engine=engine + spreadsheet_path = request.getfixturevalue(source) + frames_by_id = read_spreadsheet( + spreadsheet_path, + sheet_id=[1, 2], + sheet_name=None, + **params, ) - frames_by_name = pl.read_excel( - excel_file_path, sheet_id=None, sheet_name=["test1", "test2"], engine=engine + frames_by_name = read_spreadsheet( + spreadsheet_path, + sheet_id=None, + sheet_name=["test1", "test2"], + **params, ) for frames in (frames_by_id, frames_by_name): assert len(frames) == 2 @@ -57,9 +99,27 @@ def test_read_excel_multi_sheets( assert_frame_equal(frames["test2"], expected2) -def test_read_excel_all_sheets_openpyxl(excel_file_path: Path) -> None: - frames = pl.read_excel(excel_file_path, sheet_id=0, engine="openpyxl") - assert len(frames) == 4 +@pytest.mark.parametrize( + ("read_spreadsheet", "source", "params"), + [ + (pl.read_excel, "excel_file_path", {"engine": "xlsx2csv"}), + (pl.read_excel, "excel_file_path", {"engine": "openpyxl"}), + (pl.read_ods, "openoffice_file_path", {}), + ], +) +def test_read_excel_all_sheets( + read_spreadsheet: Callable[..., dict[str, pl.DataFrame]], + source: str, + params: dict[str, str], + request: pytest.FixtureRequest, +) -> None: + spreadsheet_path = request.getfixturevalue(source) + frames = read_spreadsheet( + spreadsheet_path, + sheet_id=0, + **params, + ) + assert len(frames) == (3 if str(spreadsheet_path).endswith("ods") else 4) expected1 = pl.DataFrame({"hello": ["Row 1", "Row 2"]}) expected2 = pl.DataFrame({"world": ["Row 3", "Row 4"]}) @@ -72,13 +132,22 @@ def test_read_excel_all_sheets_openpyxl(excel_file_path: Path) -> None: ) assert_frame_equal(frames["test1"], expected1) assert_frame_equal(frames["test2"], expected2) - assert_frame_equal(frames["test3"], expected3) + if params.get("engine") == "openpyxl": + # TODO: flag that trims trailing all-null rows? + assert_frame_equal(frames["test3"], expected3) + assert_frame_equal(frames["test4"].drop_nulls(), expected3) - # TODO: trim trailing all-null rows? - assert_frame_equal(frames["test4"].drop_nulls(), expected3) - -def test_basic_datatypes_openpyxl_read_excel() -> None: +@pytest.mark.parametrize( + ("engine", "schema_overrides"), + [ + ("xlsx2csv", {"datetime": pl.Datetime}), + ("openpyxl", None), + ], +) +def test_basic_datatypes_read_excel( + engine: Literal["xlsx2csv", "openpyxl"], schema_overrides: SchemaDict | None +) -> None: df = pl.DataFrame( { "A": [1, 2, 3, 4, 5], @@ -92,39 +161,97 @@ def test_basic_datatypes_openpyxl_read_excel() -> None: df.write_excel(xls, position="C5") # check if can be read as it was written - # we use openpyxl because type inference is better - df_by_default = pl.read_excel(xls, engine="openpyxl") - df_by_sheet_id = pl.read_excel(xls, sheet_id=1, engine="openpyxl") - df_by_sheet_name = pl.read_excel(xls, sheet_name="Sheet1", engine="openpyxl") + for sheet_id, sheet_name in ((None, None), (1, None), (None, "Sheet1")): + df = pl.read_excel( + xls, + sheet_id=sheet_id, + sheet_name=sheet_name, + engine=engine, + schema_overrides=schema_overrides, + ) + assert_frame_equal(df, df) - assert_frame_equal(df, df_by_default) - assert_frame_equal(df, df_by_sheet_id) - assert_frame_equal(df, df_by_sheet_name) +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) +def test_write_excel_bytes(engine: Literal["xlsx2csv", "openpyxl"]) -> None: + df = pl.DataFrame({"A": [1, 2, 3, 4, 5]}) -def test_write_excel_bytes() -> None: - df = pl.DataFrame( - { - "A": [1, 2, 3, 4, 5], - } - ) excel_bytes = BytesIO() df.write_excel(excel_bytes) - df_read = pl.read_excel(excel_bytes) + + df_read = pl.read_excel(excel_bytes, engine=engine) assert_frame_equal(df, df_read) +def test_schema_overrides_11161(excel_file_path: Path) -> None: + df1 = pl.read_excel( + excel_file_path, + sheet_name="test4", + schema_overrides={"cardinality": pl.UInt16}, + ).drop_nulls() + assert df1.schema == { + "cardinality": pl.UInt16, + "rows_by_key": pl.Float64, + "iter_groups": pl.Float64, + } + + df2 = pl.read_excel( + excel_file_path, + sheet_name="test4", + read_csv_options={"dtypes": {"cardinality": pl.UInt16}}, + ).drop_nulls() + assert df2.schema == { + "cardinality": pl.UInt16, + "rows_by_key": pl.Float64, + "iter_groups": pl.Float64, + } + + df3 = pl.read_excel( + excel_file_path, + sheet_name="test4", + schema_overrides={"cardinality": pl.UInt16}, + read_csv_options={ + "dtypes": { + "rows_by_key": pl.Float32, + "iter_groups": pl.Float32, + }, + }, + ).drop_nulls() + assert df3.schema == { + "cardinality": pl.UInt16, + "rows_by_key": pl.Float32, + "iter_groups": pl.Float32, + } + + with pytest.raises(ParameterCollisionError): + # cannot specify 'cardinality' in both schema_overrides and read_csv_options + pl.read_excel( + excel_file_path, + sheet_name="test4", + schema_overrides={"cardinality": pl.UInt16}, + read_csv_options={"dtypes": {"cardinality": pl.Int32}}, + ) + + def test_unsupported_engine() -> None: with pytest.raises(NotImplementedError): pl.read_excel(None, engine="foo") # type: ignore[call-overload] -def test_read_excel_all_sheets_with_sheet_name(excel_file_path: Path) -> None: +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) +def test_read_excel_all_sheets_with_sheet_name( + excel_file_path: Path, engine: str +) -> None: with pytest.raises( ValueError, match=r"cannot specify both `sheet_name` \('Sheet1'\) and `sheet_id` \(1\)", ): - pl.read_excel(excel_file_path, sheet_id=1, sheet_name="Sheet1") + pl.read_excel( # type: ignore[call-overload] + excel_file_path, + sheet_id=1, + sheet_name="Sheet1", + engine=engine, + ) # the parameters don't change the data, only the formatting, so we expect @@ -258,7 +385,8 @@ def test_excel_round_trip(write_params: dict[str, Any]) -> None: assert_frame_equal(df, xldf) -def test_excel_compound_types() -> None: +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) +def test_excel_compound_types(engine: Literal["xlsx2csv", "openpyxl"]) -> None: df = pl.DataFrame( {"x": [[1, 2], [3, 4], [5, 6]], "y": ["a", "b", "c"], "z": [9, 8, 7]} ).select("x", pl.struct(["y", "z"])) @@ -266,7 +394,7 @@ def test_excel_compound_types() -> None: xls = BytesIO() df.write_excel(xls, worksheet="data") - xldf = pl.read_excel(xls, sheet_name="data") + xldf = pl.read_excel(xls, sheet_name="data", engine=engine) assert xldf.rows() == [ ("[1, 2]", "{'y': 'a', 'z': 9}"), ("[3, 4]", "{'y': 'b', 'z': 8}"), @@ -274,7 +402,8 @@ def test_excel_compound_types() -> None: ] -def test_excel_sparklines() -> None: +@pytest.mark.parametrize("engine", ["xlsx2csv", "openpyxl"]) +def test_excel_sparklines(engine: Literal["xlsx2csv", "openpyxl"]) -> None: from xlsxwriter import Workbook # note that we don't (quite) expect sparkline export to round-trip as we @@ -326,7 +455,10 @@ def test_excel_sparklines() -> None: tables = {tbl["name"] for tbl in wb.get_worksheet_by_name("frame_data").tables} assert "Frame0" in tables - xldf = pl.read_excel(xls, sheet_name="frame_data") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + xldf = pl.read_excel(xls, sheet_name="frame_data", engine=engine) + # ┌─────┬──────┬─────┬─────┬─────┬─────┬───────┬─────┬─────┐ # │ id ┆ +/- ┆ q1 ┆ q2 ┆ q3 ┆ q4 ┆ trend ┆ h1 ┆ h2 │ # │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ @@ -410,18 +542,43 @@ def test_excel_freeze_panes() -> None: assert pl.read_excel(xls, sheet_name="sheet3").rows() == [] -def test_excel_empty_sheet(empty_excel_file_path: Path) -> None: +@pytest.mark.parametrize( + ("read_spreadsheet", "source"), + [ + (pl.read_excel, "empty_excel_file_path"), + (pl.read_ods, "empty_openoffice_file_path"), + ], +) +def test_excel_empty_sheet( + read_spreadsheet: Callable[..., pl.DataFrame], + source: str, + request: pytest.FixtureRequest, +) -> None: + empty_spreadsheet_path = request.getfixturevalue(source) with pytest.raises(NoDataError, match="Empty Excel sheet"): - pl.read_excel(empty_excel_file_path) + pl.read_excel(empty_spreadsheet_path) - df = pl.read_excel(empty_excel_file_path, raise_if_empty=False) + df = pl.read_excel(empty_spreadsheet_path, raise_if_empty=False) assert_frame_equal(df, pl.DataFrame()) -@pytest.mark.parametrize("hidden_columns", [["a"], ["a", "b"], cs.numeric(), cs.last()]) -def test_excel_hidden_columns(hidden_columns: list[str] | SelectorType) -> None: +@pytest.mark.parametrize( + ("engine", "hidden_columns"), + [ + ("xlsx2csv", ["a"]), + ("openpyxl", ["a", "b"]), + ("xlsx2csv", cs.numeric()), + ("openpyxl", cs.last()), + ], +) +def test_excel_hidden_columns( + hidden_columns: list[str] | SelectorType, + engine: Literal["xlsx2csv", "openpyxl"], +) -> None: df = pl.DataFrame({"a": [1, 2], "b": ["x", "y"]}) + xls = BytesIO() df.write_excel(xls, hidden_columns=hidden_columns) + read_df = pl.read_excel(xls) assert_frame_equal(df, read_df) diff --git a/py-polars/tests/unit/namespaces/test_binary.py b/py-polars/tests/unit/namespaces/test_binary.py index a09e33cf5ab2..1519ca10bcba 100644 --- a/py-polars/tests/unit/namespaces/test_binary.py +++ b/py-polars/tests/unit/namespaces/test_binary.py @@ -24,15 +24,16 @@ def test_contains() -> None: (1, b"some * * text"), (2, b"(with) special\n * chars"), (3, b"**etc...?$"), + (4, None), ], schema=["idx", "bin"], ) for pattern, expected in ( - (b"e * ", [True, False, False]), - (b"text", [True, False, False]), - (b"special", [False, True, False]), - (b"", [True, True, True]), - (b"qwe", [False, False, False]), + (b"e * ", [True, False, False, None]), + (b"text", [True, False, False, None]), + (b"special", [False, True, False, None]), + (b"", [True, True, True, None]), + (b"qwe", [False, False, False, None]), ): # series assert expected == df["bin"].bin.contains(pattern).to_list() @@ -41,16 +42,59 @@ def test_contains() -> None: expected == df.select(pl.col("bin").bin.contains(pattern))["bin"].to_list() ) # frame filter - assert sum(expected) == len(df.filter(pl.col("bin").bin.contains(pattern))) + assert sum([e for e in expected if e is True]) == len( + df.filter(pl.col("bin").bin.contains(pattern)) + ) + + +def test_contains_with_expr() -> None: + df = pl.DataFrame( + { + "bin": [b"some * * text", b"(with) special\n * chars", b"**etc...?$", None], + "lit1": [b"e * ", b"", b"qwe", b"None"], + "lit2": [None, b"special\n", b"?!", None], + } + ) + + assert df.select( + [ + pl.col("bin").bin.contains(pl.col("lit1")).alias("contains_1"), + pl.col("bin").bin.contains(pl.col("lit2")).alias("contains_2"), + pl.col("bin").bin.contains(pl.lit(None)).alias("contains_3"), + ] + ).to_dict(False) == { + "contains_1": [True, True, False, None], + "contains_2": [None, True, False, None], + "contains_3": [None, None, None, None], + } def test_starts_ends_with() -> None: - assert pl.DataFrame({"a": [b"hamburger", b"nuts", b"lollypop"]}).select( + assert pl.DataFrame( + { + "a": [b"hamburger", b"nuts", b"lollypop", None], + "end": [b"ger", b"tg", None, b"anything"], + "start": [b"ha", b"nga", None, b"anything"], + } + ).select( [ - pl.col("a").bin.ends_with(b"pop").alias("pop"), - pl.col("a").bin.starts_with(b"ham").alias("ham"), + pl.col("a").bin.ends_with(b"pop").alias("end_lit"), + pl.col("a").bin.ends_with(pl.lit(None)).alias("end_none"), + pl.col("a").bin.ends_with(pl.col("end")).alias("end_expr"), + pl.col("a").bin.starts_with(b"ham").alias("start_lit"), + pl.col("a").bin.ends_with(pl.lit(None)).alias("start_none"), + pl.col("a").bin.starts_with(pl.col("start")).alias("start_expr"), ] - ).to_dict(False) == {"pop": [False, False, True], "ham": [True, False, False]} + ).to_dict( + False + ) == { + "end_lit": [False, False, True, None], + "end_none": [None, None, None, None], + "end_expr": [True, False, None, None], + "start_lit": [True, False, False, None], + "start_none": [None, None, None, None], + "start_expr": [True, False, None, None], + } def test_base64_encode() -> None: diff --git a/py-polars/tests/unit/namespaces/test_datetime.py b/py-polars/tests/unit/namespaces/test_datetime.py index a31f6a65379c..7fa36bf12f07 100644 --- a/py-polars/tests/unit/namespaces/test_datetime.py +++ b/py-polars/tests/unit/namespaces/test_datetime.py @@ -690,6 +690,80 @@ def test_offset_by_truncate_sorted_flag() -> None: assert s2.flags["SORTED_ASC"] +def test_offset_by_broadcasting() -> None: + # test broadcast lhs + df = pl.DataFrame( + { + "offset": ["1d", "10d", "3d", None], + } + ) + result = df.select( + d1=pl.lit(datetime(2020, 10, 25)).dt.offset_by(pl.col("offset")), + d2=pl.lit(datetime(2020, 10, 25)) + .dt.cast_time_unit("ms") + .dt.offset_by(pl.col("offset")), + d3=pl.lit(datetime(2020, 10, 25)) + .dt.replace_time_zone("Europe/London") + .dt.offset_by(pl.col("offset")), + d4=pl.lit(datetime(2020, 10, 25)).dt.date().dt.offset_by(pl.col("offset")), + d5=pl.lit(None, dtype=pl.Datetime).dt.offset_by(pl.col("offset")), + ) + expected_dict = { + "d1": [ + datetime(2020, 10, 26), + datetime(2020, 11, 4), + datetime(2020, 10, 28), + None, + ], + "d2": [ + datetime(2020, 10, 26), + datetime(2020, 11, 4), + datetime(2020, 10, 28), + None, + ], + "d3": [ + datetime(2020, 10, 26, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 11, 4, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2020, 10, 28, tzinfo=ZoneInfo(key="Europe/London")), + None, + ], + "d4": [ + datetime(2020, 10, 26).date(), + datetime(2020, 11, 4).date(), + datetime(2020, 10, 28).date(), + None, + ], + "d5": [None, None, None, None], + } + assert result.to_dict(False) == expected_dict + + # test broadcast rhs + df = pl.DataFrame({"dt": [datetime(2020, 10, 25), datetime(2021, 1, 2), None]}) + result = df.select( + d1=pl.col("dt").dt.offset_by(pl.lit("1mo3d")), + d2=pl.col("dt").dt.cast_time_unit("ms").dt.offset_by(pl.lit("1y1mo")), + d3=pl.col("dt") + .dt.replace_time_zone("Europe/London") + .dt.offset_by(pl.lit("3d")), + d4=pl.col("dt").dt.date().dt.offset_by(pl.lit("1y1mo1d")), + ) + expected_dict = { + "d1": [datetime(2020, 11, 28), datetime(2021, 2, 5), None], + "d2": [datetime(2021, 11, 25), datetime(2022, 2, 2), None], + "d3": [ + datetime(2020, 10, 28, tzinfo=ZoneInfo(key="Europe/London")), + datetime(2021, 1, 5, tzinfo=ZoneInfo(key="Europe/London")), + None, + ], + "d4": [datetime(2021, 11, 26).date(), datetime(2022, 2, 3).date(), None], + } + assert result.to_dict(False) == expected_dict + + # test all literal + result = df.select(d=pl.lit(datetime(2021, 11, 26)).dt.offset_by("1mo1d")) + assert result.to_dict(False) == {"d": [datetime(2021, 12, 27)]} + + def test_offset_by_expressions() -> None: df = pl.DataFrame( { diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index d84184fdf6d1..100581c80b49 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -94,6 +94,19 @@ def test_list_concat() -> None: assert out_s[0].to_list() == [1, 2, 4, 1] +def test_list_join() -> None: + df = pl.DataFrame( + { + "a": [["ab", "c", "d"], ["e", "f"], ["g"], [], None], + "separator": ["&", None, "*", "_", "*"], + } + ) + out = df.select(pl.col("a").list.join("-")) + assert out.to_dict(False) == {"a": ["ab-c-d", "e-f", "g", "", None]} + out = df.select(pl.col("a").list.join(pl.col("separator"))) + assert out.to_dict(False) == {"a": ["ab&c&d", None, "g", "", None]} + + def test_list_arr_empty() -> None: df = pl.DataFrame({"cars": [[1, 2, 3], [2, 3], [4], []]}) diff --git a/py-polars/tests/unit/namespaces/test_meta.py b/py-polars/tests/unit/namespaces/test_meta.py index 850a867ab6ed..c3c109e0d21a 100644 --- a/py-polars/tests/unit/namespaces/test_meta.py +++ b/py-polars/tests/unit/namespaces/test_meta.py @@ -43,6 +43,8 @@ def test_root_and_output_names() -> None: ): pl.all().suffix("_").meta.output_name() + assert pl.all().suffix("_").meta.output_name(raise_if_undetermined=False) is None + def test_undo_aliases() -> None: e = pl.col("foo").alias("bar") diff --git a/py-polars/tests/unit/namespaces/test_string.py b/py-polars/tests/unit/namespaces/test_string.py index 6529a6afebbb..c8777d562dc5 100644 --- a/py-polars/tests/unit/namespaces/test_string.py +++ b/py-polars/tests/unit/namespaces/test_string.py @@ -325,6 +325,29 @@ def test_json_extract_lazy_expr() -> None: assert_frame_equal(ldf, expected) +def test_json_extract_primitive_to_list_11053() -> None: + df = pl.DataFrame( + { + "json": [ + '{"col1": ["123"], "col2": "123"}', + '{"col1": ["xyz"], "col2": null}', + ] + } + ) + schema = pl.Struct( + { + "col1": pl.List(pl.Utf8), + "col2": pl.List(pl.Utf8), + } + ) + + output = df.select( + pl.col("json").str.json_extract(schema).alias("casted_json") + ).unnest("casted_json") + expected = pl.DataFrame({"col1": [["123"], ["xyz"]], "col2": [["123"], None]}) + assert_frame_equal(output, expected) + + def test_jsonpath_single() -> None: s = pl.Series(['{"a":"1"}', None, '{"a":2}', '{"a":2.1}', '{"a":true}']) expected = pl.Series(["1", None, "2", "2.1", "true"]) @@ -434,8 +457,8 @@ def test_contains_expr() -> None: .alias("contains_lit"), ] ).to_dict(False) == { - "contains": [True, True, False, False, False, None], - "contains_lit": [False, True, False, False, False, False], + "contains": [True, True, False, None, None, None], + "contains_lit": [False, True, False, None, None, False], } with pytest.raises(pl.ComputeError): @@ -739,7 +762,10 @@ def test_ljust_and_rjust() -> None: def test_starts_ends_with() -> None: df = pl.DataFrame( - {"a": ["hamburger", "nuts", "lollypop"], "sub": ["ham", "ts", None]} + { + "a": ["hamburger", "nuts", "lollypop", None], + "sub": ["ham", "ts", None, "anything"], + } ) assert df.select( @@ -752,12 +778,12 @@ def test_starts_ends_with() -> None: pl.col("a").str.starts_with(pl.col("sub")).alias("starts_sub"), ] ).to_dict(False) == { - "ends_pop": [False, False, True], - "ends_None": [False, False, False], - "ends_sub": [False, True, False], - "starts_ham": [True, False, False], - "starts_None": [False, False, False], - "starts_sub": [True, False, False], + "ends_pop": [False, False, True, None], + "ends_None": [None, None, None, None], + "ends_sub": [False, True, None, None], + "starts_ham": [True, False, False, None], + "starts_None": [None, None, None, None], + "starts_sub": [True, False, None, None], } @@ -820,6 +846,31 @@ def test_split() -> None: assert_frame_equal(df["x"].str.split("_", inclusive=True).to_frame(), expected) +def test_split_expr() -> None: + df = pl.DataFrame({"x": ["a_a", None, "b", "c*c*c"], "by": ["_", "#", "^", "*"]}) + out = df.select([pl.col("x").str.split(pl.col("by"))]) + expected = pl.DataFrame( + [ + {"x": ["a", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c", "c", "c"]}, + ] + ) + assert_frame_equal(out, expected) + + out = df.select([pl.col("x").str.split(pl.col("by"), inclusive=True)]) + expected = pl.DataFrame( + [ + {"x": ["a_", "a"]}, + {"x": None}, + {"x": ["b"]}, + {"x": ["c*", "c*", "c"]}, + ] + ) + assert_frame_equal(out, expected) + + def test_split_exact() -> None: df = pl.DataFrame({"x": ["a_a", None, "b", "c_c"]}) out = df.select([pl.col("x").str.split_exact("_", 2, inclusive=False)]).unnest("x") diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 6eef6580017d..4ed0356dbfb0 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -48,7 +48,7 @@ def example_df() -> pl.DataFrame: def test_rolling_kernels_and_group_by_rolling( example_df: pl.DataFrame, period: str | timedelta, closed: ClosedInterval ) -> None: - out1 = example_df.select( + out1 = example_df.set_sorted("dt").select( [ pl.col("dt"), # this differs from group_by aggregation because the empty window is @@ -761,11 +761,11 @@ def test_rolling_kernels_group_by_dynamic_7548() -> None: ).to_dict( False ) == { - "time": [0, 1, 2, 3], - "value": [[0, 1, 2], [1, 2, 3], [2, 3], [3]], - "min_value": [0, 1, 2, 3], - "max_value": [2, 3, 3, 3], - "sum_value": [3, 6, 5, 3], + "time": [-1, 0, 1, 2, 3], + "value": [[0, 1], [0, 1, 2], [1, 2, 3], [2, 3], [3]], + "min_value": [0, 0, 1, 2, 3], + "max_value": [1, 2, 3, 3, 3], + "sum_value": [1, 3, 6, 5, 3], } @@ -790,7 +790,7 @@ def test_rolling_empty_window_9406(time_unit: TimeUnit) -> None: "d", [datetime(2019, 1, x) for x in [16, 17, 18, 22, 23]], dtype=pl.Datetime(time_unit=time_unit, time_zone=None), - ) + ).set_sorted() rawdata = pl.Series("x", [1.1, 1.2, 1.3, 1.15, 1.25], dtype=pl.Float64) rmin = pl.Series("x", [None, 1.1, 1.1, None, 1.15], dtype=pl.Float64) rmax = pl.Series("x", [None, 1.1, 1.2, None, 1.15], dtype=pl.Float64) @@ -835,6 +835,20 @@ def test_rolling_weighted_quantile_10031() -> None: ) +def test_rolling_aggregations_unsorted_raise_10991() -> None: + df = pl.DataFrame( + { + "dt": [datetime(2020, 1, 3), datetime(2020, 1, 1), datetime(2020, 1, 2)], + "val": [1, 2, 3], + } + ) + with pytest.raises( + pl.InvalidOperationError, + match="argument in operation 'rolling_sum' is not explicitly sorted", + ): + df.with_columns(roll=pl.col("val").rolling_sum("2d", by="dt", closed="right")) + + def test_rolling() -> None: a = pl.Series("a", [1, 2, 3, 2, 1]) assert_series_equal(a.rolling_min(2), pl.Series("a", [None, 1, 2, 2, 1])) @@ -914,7 +928,7 @@ def test_rolling_nanoseconds_11003() -> None: "val": [1, 2, 3], } ) - df = df.with_columns(pl.col("dt").str.to_datetime(time_unit="ns")) + df = df.with_columns(pl.col("dt").str.to_datetime(time_unit="ns")).set_sorted("dt") result = df.with_columns( pl.col("val").rolling_sum("500ns", by="dt", closed="right") ) diff --git a/py-polars/tests/unit/operations/test_filter.py b/py-polars/tests/unit/operations/test_filter.py index 72ca0a0dfe93..6f1c48b05c4a 100644 --- a/py-polars/tests/unit/operations/test_filter.py +++ b/py-polars/tests/unit/operations/test_filter.py @@ -1,4 +1,7 @@ +import pytest + import polars as pl +from polars import PolarsDataType from polars.testing import assert_frame_equal @@ -12,6 +15,11 @@ def test_simplify_expression_lit_true_4376() -> None: ).rows() == [(1, 2, 3), (4, 5, 6), (7, 8, 9)] +def test_filter_contains_nth_11205() -> None: + df = pl.DataFrame({"x": [False]}) + assert df.filter(pl.first()).is_empty() + + def test_melt_values_predicate_pushdown() -> None: lf = pl.DataFrame( { @@ -50,6 +58,15 @@ def test_filter_is_in_4572() -> None: assert_frame_equal(result, expected) +@pytest.mark.parametrize( + "dtype", [pl.Int32, pl.Boolean, pl.Utf8, pl.Binary, pl.List(pl.Int64), pl.Object] +) +def test_filter_on_empty(dtype: PolarsDataType) -> None: + df = pl.DataFrame({"a": []}, schema={"a": dtype}) + out = df.filter(pl.col("a").is_null()) + assert out.is_empty() + + def test_filter_aggregation_any() -> None: df = pl.DataFrame( { diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index cf9e4efb91c8..956e6cc5d994 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -348,10 +348,10 @@ def test_group_by_dynamic_flat_agg_4814() -> None: (pl.col("b") / pl.col("a")).last().alias("last_ratio_2"), ] ).to_dict(False) == { - "a": [1, 2], - "sum_ratio_1": [4.2, 5.0], - "last_ratio_1": [6.0, 6.0], - "last_ratio_2": [6.0, 6.0], + "a": [0, 1, 2], + "sum_ratio_1": [1.0, 4.2, 5.0], + "last_ratio_1": [1.0, 6.0, 6.0], + "last_ratio_2": [1.0, 6.0, 6.0], } @@ -388,7 +388,7 @@ def test_group_by_dynamic_overlapping_groups_flat_apply_multiple_5038( .to_dict(False) ) - assert res["corr"] == pytest.approx([6.988674024215477]) + assert res["corr"] == pytest.approx([9.148920923684765]) assert res["a"] == [None] @@ -576,10 +576,11 @@ def test_group_by_dynamic_elementwise_following_mean_agg_6904( pl.DataFrame( { "a": [ + datetime(2020, 12, 31, 23, 59, 50), datetime(2021, 1, 1, 0, 0), datetime(2021, 1, 1, 0, 0, 10), ], - "c": [0.9092974268256817, -0.7568024953079282], + "c": [0.9092974268256817, 0.9092974268256817, -0.7568024953079282], } ).with_columns(pl.col("a").dt.replace_time_zone(time_zone)), ) diff --git a/py-polars/tests/unit/operations/test_is_first_last_distinct.py b/py-polars/tests/unit/operations/test_is_first_last_distinct.py new file mode 100644 index 000000000000..997778572059 --- /dev/null +++ b/py-polars/tests/unit/operations/test_is_first_last_distinct.py @@ -0,0 +1,105 @@ +import pytest + +import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_is_first_distinct() -> None: + lf = pl.LazyFrame({"a": [4, 1, 4]}) + result = lf.select(pl.col("a").is_first_distinct()).collect()["a"] + expected = pl.Series("a", [True, True, False]) + assert_series_equal(result, expected) + + +def test_is_first_distinct_struct() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3, 2, None, 2, 1], "b": [0, 2, 3, 2, None, 2, 0]}) + result = lf.select(pl.struct("a", "b").is_first_distinct()) + expected = pl.LazyFrame({"a": [True, True, True, False, True, False, False]}) + assert_frame_equal(result, expected) + + +def test_is_first_distinct_list() -> None: + lf = pl.LazyFrame({"a": [[1, 2], [3], [1, 2], [4, 5], [4, 5]]}) + result = lf.select(pl.col("a").is_first_distinct()) + expected = pl.LazyFrame({"a": [True, True, False, True, False]}) + assert_frame_equal(result, expected) + + +def test_is_first_distinct_various() -> None: + # numeric + s = pl.Series([1, 1, None, 2, None, 3, 3]) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + # str + s = pl.Series(["x", "x", None, "y", None, "z", "z"]) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + # boolean + s = pl.Series([True, True, None, False, None, False, False]) + expected = [True, False, True, True, False, False, False] + assert s.is_first_distinct().to_list() == expected + # struct + s = pl.Series( + [ + {"x": 1, "y": 2}, + {"x": 1, "y": 2}, + None, + {"x": 2, "y": 1}, + None, + {"x": 3, "y": 2}, + {"x": 3, "y": 2}, + ] + ) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + # list + s = pl.Series([[1, 2], [1, 2], None, [2, 3], None, [3, 4], [3, 4]]) + expected = [True, False, True, True, False, True, False] + assert s.is_first_distinct().to_list() == expected + + +def test_is_last_distinct() -> None: + # numeric + s = pl.Series([1, 1, None, 2, None, 3, 3]) + expected = [False, True, False, True, True, False, True] + assert s.is_last_distinct().to_list() == expected + # str + s = pl.Series(["x", "x", None, "y", None, "z", "z"]) + expected = [False, True, False, True, True, False, True] + assert s.is_last_distinct().to_list() == expected + # boolean + s = pl.Series([True, True, None, False, None, False, False]) + expected = [False, True, False, False, True, False, True] + assert s.is_last_distinct().to_list() == expected + # struct + s = pl.Series( + [ + {"x": 1, "y": 2}, + {"x": 1, "y": 2}, + None, + {"x": 2, "y": 1}, + None, + {"x": 3, "y": 2}, + {"x": 3, "y": 2}, + ] + ) + expected = [False, True, False, True, True, False, True] + assert s.is_last_distinct().to_list() == expected + # list + s = pl.Series([[1, 2], [1, 2], None, [2, 3], None, [3, 4], [3, 4]]) + expected = [False, True, False, True, True, False, True] + assert s.is_last_distinct().to_list() == expected + + +@pytest.mark.parametrize("dtypes", [pl.Int32, pl.Utf8, pl.Boolean, pl.List(pl.Int32)]) +def test_is_first_last_distinct_all_null(dtypes: pl.PolarsDataType) -> None: + s = pl.Series([None, None, None], dtype=dtypes) + assert s.is_first_distinct().to_list() == [True, False, False] + assert s.is_last_distinct().to_list() == [False, False, True] + + +def test_is_first_last_deprecated() -> None: + with pytest.deprecated_call(): + pl.col("a").is_first() + with pytest.deprecated_call(): + pl.col("a").is_last() diff --git a/py-polars/tests/unit/series/test_item.py b/py-polars/tests/unit/series/test_item.py new file mode 100644 index 000000000000..8dd715ca43ab --- /dev/null +++ b/py-polars/tests/unit/series/test_item.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import pytest + +import polars as pl + + +def test_series_item() -> None: + s = pl.Series("a", [1]) + assert s.item() == 1 + + +def test_series_item_empty() -> None: + s = pl.Series("a", []) + with pytest.raises(ValueError): + s.item() + + +def test_series_item_incorrect_shape() -> None: + s = pl.Series("a", [1, 2]) + with pytest.raises(ValueError): + s.item() + + +@pytest.fixture(scope="module") +def s() -> pl.Series: + return pl.Series("a", [1, 2]) + + +@pytest.mark.parametrize(("index", "expected"), [(0, 1), (1, 2), (-1, 2), (-2, 1)]) +def test_series_item_with_index(index: int, expected: int, s: pl.Series) -> None: + assert s.item(index) == expected + + +@pytest.mark.parametrize("index", [-10, 10]) +def test_df_item_out_of_bounds(index: int, s: pl.Series) -> None: + with pytest.raises(IndexError, match="out of bounds"): + s.item(index) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 219a589050d1..7f20a3e9da76 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1220,69 +1220,6 @@ def test_apply_list_out() -> None: assert out[2].to_list() == [2, 2] -def test_is_first() -> None: - # numeric - s = pl.Series([1, 1, None, 2, None, 3, 3]) - assert s.is_first().to_list() == [True, False, True, True, False, True, False] - # str - s = pl.Series(["x", "x", None, "y", None, "z", "z"]) - assert s.is_first().to_list() == [True, False, True, True, False, True, False] - # boolean - s = pl.Series([True, True, None, False, None, False, False]) - assert s.is_first().to_list() == [True, False, True, True, False, False, False] - # struct - s = pl.Series( - [ - {"x": 1, "y": 2}, - {"x": 1, "y": 2}, - None, - {"x": 2, "y": 1}, - None, - {"x": 3, "y": 2}, - {"x": 3, "y": 2}, - ] - ) - assert s.is_first().to_list() == [True, False, True, True, False, True, False] - # list - s = pl.Series([[1, 2], [1, 2], None, [2, 3], None, [3, 4], [3, 4]]) - assert s.is_first().to_list() == [True, False, True, True, False, True, False] - - -def test_is_last() -> None: - # numeric - s = pl.Series([1, 1, None, 2, None, 3, 3]) - assert s.is_last().to_list() == [False, True, False, True, True, False, True] - # str - s = pl.Series(["x", "x", None, "y", None, "z", "z"]) - assert s.is_last().to_list() == [False, True, False, True, True, False, True] - # boolean - s = pl.Series([True, True, None, False, None, False, False]) - assert s.is_last().to_list() == [False, True, False, False, True, False, True] - # struct - s = pl.Series( - [ - {"x": 1, "y": 2}, - {"x": 1, "y": 2}, - None, - {"x": 2, "y": 1}, - None, - {"x": 3, "y": 2}, - {"x": 3, "y": 2}, - ] - ) - assert s.is_last().to_list() == [False, True, False, True, True, False, True] - # list - s = pl.Series([[1, 2], [1, 2], None, [2, 3], None, [3, 4], [3, 4]]) - assert s.is_last().to_list() == [False, True, False, True, True, False, True] - - -@pytest.mark.parametrize("dtypes", [pl.Int32, pl.Utf8, pl.Boolean, pl.List(pl.Int32)]) -def test_is_first_last_all_null(dtypes: pl.PolarsDataType) -> None: - s = pl.Series([None, None, None], dtype=dtypes) - assert s.is_first().to_list() == [True, False, False] - assert s.is_last().to_list() == [False, False, True] - - def test_reinterpret() -> None: s = pl.Series("a", [1, 1, 2], dtype=pl.UInt64) assert s.reinterpret(signed=True).dtype == pl.Int64 @@ -2471,7 +2408,7 @@ def test_set_at_idx() -> None: a[-5] = None assert a.to_list() == [None, 1, 2, None, 4] - with pytest.raises(pl.ComputeError): + with pytest.raises(pl.OutOfBoundsError): a[-100] = None @@ -2550,22 +2487,6 @@ def test_get_chunks() -> None: assert_series_equal(chunks[1], b) -def test_item() -> None: - s = pl.Series("a", [1]) - assert s.item() == 1 - - s = pl.Series("a", [1, 2]) - with pytest.raises(ValueError): - s.item() - - assert s.item(0) == 1 - assert s.item(-1) == 2 - - s = pl.Series("a", []) - with pytest.raises(ValueError): - s.item() - - def test_ptr() -> None: # not much to test on the ptr value itself. s = pl.Series([1, None, 3]) @@ -2811,3 +2732,19 @@ def test_symmetry_for_max_in_names() -> None: # TODO: time arithmetic support? # a = pl.Series("a", [1], dtype=pl.Time) # assert (a - a.max()).name == (a.max() - a).name == a.name + + +def test_series_getitem_out_of_bounds_positive() -> None: + s = pl.Series([1, 2]) + with pytest.raises( + IndexError, match="index 10 is out of bounds for sequence of length 2" + ): + s[10] + + +def test_series_getitem_out_of_bounds_negative() -> None: + s = pl.Series([1, 2]) + with pytest.raises( + IndexError, match="index -10 is out of bounds for sequence of length 2" + ): + s[-10] diff --git a/py-polars/tests/unit/test_async.py b/py-polars/tests/unit/test_async.py new file mode 100644 index 000000000000..85ac2e7ced3c --- /dev/null +++ b/py-polars/tests/unit/test_async.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import asyncio +import time +from functools import partial +from typing import Any, Callable + +import pytest + +import polars as pl +from polars.dependencies import gevent + + +async def _aio_collect_async(raises: bool = False) -> pl.DataFrame: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return await lf.collect_async() + + +async def _aio_collect_all_async(raises: bool = False) -> list[pl.DataFrame]: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + + lf2 = pl.LazyFrame({"a": [1, 2], "b": [1, 2]}).group_by("a").sum() + + return await pl.collect_all_async([lf, lf2]) + + +_aio_collect = pytest.mark.parametrize( + ("collect", "raises"), + [ + (_aio_collect_async, None), + (_aio_collect_all_async, None), + (partial(_aio_collect_async, True), pl.ColumnNotFoundError), + (partial(_aio_collect_all_async, True), pl.ColumnNotFoundError), + ], +) + + +def _aio_run(coroutine: Any, raises: Exception | None = None) -> None: + if raises is not None: + with pytest.raises(raises): # type: ignore[call-overload] + asyncio.run(coroutine) + else: + assert len(asyncio.run(coroutine)) > 0 + + +@_aio_collect +def test_collect_async_switch( + collect: Callable[[], Any], + raises: Exception | None, +) -> None: + async def main() -> Any: + df = collect() + await asyncio.sleep(0.3) + return await df + + _aio_run(main(), raises) + + +@_aio_collect +def test_collect_async_task( + collect: Callable[[], Any], raises: Exception | None +) -> None: + async def main() -> Any: + df = asyncio.create_task(collect()) + await asyncio.sleep(0.3) + return await df + + _aio_run(main(), raises) + + +def _gevent_collect_async(raises: bool = False) -> Any: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return lf.collect_async(gevent=True) + + +def _gevent_collect_all_async(raises: bool = False) -> Any: + lf = ( + pl.LazyFrame( + { + "a": ["a", "b", "a", "b", "b", "c"], + "b": [1, 2, 3, 4, 5, 6], + "c": [6, 5, 4, 3, 2, 1], + } + ) + .group_by("a", maintain_order=True) + .agg(pl.all().sum()) + ) + if raises: + lf = lf.select(pl.col("foo_bar")) + return pl.collect_all_async([lf], gevent=True) + + +_gevent_collect = pytest.mark.parametrize( + ("get_result", "raises"), + [ + (_gevent_collect_async, None), + (_gevent_collect_all_async, None), + (partial(_gevent_collect_async, True), pl.ColumnNotFoundError), + (partial(_gevent_collect_all_async, True), pl.ColumnNotFoundError), + ], +) + + +def _gevent_run(callback: Callable[[], Any], raises: Exception | None = None) -> None: + if raises is not None: + with pytest.raises(raises): # type: ignore[call-overload] + callback() + else: + assert len(callback()) > 0 + + +@_gevent_collect +def test_gevent_collect_async_without_hub( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + return get_result().get() + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_with_hub( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + _hub = gevent.get_hub() + + def main() -> Any: + return get_result().get() + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_switch( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result = get_result() + gevent.sleep(0.1) + return result.get(block=False, timeout=3) + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_no_switch( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result = get_result() + time.sleep(1) + return result.get(block=False, timeout=None) + + _gevent_run(main, raises) + + +@_gevent_collect +def test_gevent_collect_async_spawn( + get_result: Callable[[], Any], raises: Exception | None +) -> None: + def main() -> Any: + result_greenlet = gevent.spawn(get_result) + gevent.spawn(gevent.sleep, 0.1) + return result_greenlet.get().get() + + _gevent_run(main, raises) diff --git a/py-polars/tests/unit/test_constructors.py b/py-polars/tests/unit/test_constructors.py index ac03d211b28a..2832d5809e0e 100644 --- a/py-polars/tests/unit/test_constructors.py +++ b/py-polars/tests/unit/test_constructors.py @@ -14,7 +14,7 @@ import polars as pl from polars.dependencies import _ZONEINFO_AVAILABLE, dataclasses, pydantic -from polars.exceptions import ShapeError, TimeZoneAwareConstructorWarning +from polars.exceptions import TimeZoneAwareConstructorWarning from polars.testing import assert_frame_equal, assert_series_equal from polars.utils._construction import type_hints @@ -821,9 +821,29 @@ def test_init_records() -> None: assert_frame_equal(df, expected) assert df.to_dicts() == dicts - df_cd = pl.DataFrame(dicts, schema=["c", "d"]) - expected = pl.DataFrame({"c": [1, 2, 1], "d": [2, 1, 2]}) - assert_frame_equal(df_cd, expected) + df_cd = pl.DataFrame(dicts, schema=["a", "c", "d"]) + expected_values = { + "a": [1, 2, 1], + "c": [None, None, None], + "d": [None, None, None], + } + assert df_cd.to_dict(False) == expected_values + + data = {"a": 1, "b": 2, "c": 3} + + df1 = pl.from_dicts([data]) + assert df1.columns == ["a", "b", "c"] + + df1.columns = ["x", "y", "z"] + assert df1.columns == ["x", "y", "z"] + + df2 = pl.from_dicts([data], schema=["c", "b", "a"]) + assert df2.columns == ["c", "b", "a"] + + for colname in ("c", "b", "a"): + assert pl.from_dicts([data], schema=[colname]).to_dict(False) == { + colname: [data[colname]] + } def test_init_records_schema_order() -> None: @@ -956,13 +976,10 @@ def test_from_dicts_missing_columns() -> None: ] assert pl.from_dicts(data).to_dict(False) == {"a": [1, None], "b": [None, 2]} - # missing columns in the schema; only load the declared keys + # partial schema with some columns missing; only load the declared keys data = [{"a": 1, "b": 2}] assert pl.from_dicts(data, schema=["a"]).to_dict(False) == {"a": [1]} - - # invalid - with pytest.raises(ShapeError): - pl.from_dicts([{"a": 1, "b": 2}], schema=["xyz"]) + assert pl.from_dicts(data, schema=["x"]).to_dict(False) == {"x": [None]} def test_from_rows_dtype() -> None: diff --git a/py-polars/tests/unit/test_cse.py b/py-polars/tests/unit/test_cse.py index 0b9b351c76f8..e1efcdecdf4b 100644 --- a/py-polars/tests/unit/test_cse.py +++ b/py-polars/tests/unit/test_cse.py @@ -1,5 +1,5 @@ import re -from datetime import date +from datetime import date, datetime from tempfile import NamedTemporaryFile from typing import Any @@ -40,6 +40,25 @@ def test_union_duplicates() -> None: ) +def test_cse_with_struct_expr_11116() -> None: + df = pl.DataFrame([{"s": {"a": 1, "b": 4}, "c": 3}]).lazy() + out = df.with_columns( + pl.col("s").struct.field("a").alias("s_a"), + pl.col("s").struct.field("b").alias("s_b"), + ( + (pl.col("s").struct.field("a") <= pl.col("c")) + & (pl.col("s").struct.field("b") > pl.col("c")) + ).alias("c_between_a_and_b"), + ).collect(comm_subexpr_elim=True) + assert out.to_dict(False) == { + "s": [{"a": 1, "b": 4}], + "c": [3], + "s_a": [1], + "s_b": [4], + "c_between_a_and_b": [True], + } + + def test_cse_schema_6081() -> None: df = pl.DataFrame( data=[ @@ -419,3 +438,46 @@ def test_cse_count_in_group_by() -> None: "b": [[1], []], "c": [[40], []], } + + +def test_no_cse_in_with_context() -> None: + df1 = pl.DataFrame( + { + "timestamp": [ + datetime(2023, 1, 1, 0, 0), + datetime(2023, 5, 1, 0, 0), + datetime(2023, 10, 1, 0, 0), + ], + "value": [2, 5, 9], + } + ) + df2 = pl.DataFrame( + { + "date_start": [ + datetime(2022, 12, 31, 0, 0), + datetime(2023, 1, 2, 0, 0), + ], + "date_end": [ + datetime(2023, 4, 30, 0, 0), + datetime(2023, 5, 5, 0, 0), + ], + "label": [0, 1], + } + ) + + assert ( + df1.lazy() + .with_context(df2.lazy()) + .select( + pl.col("date_start", "label").take( + pl.col("date_start").search_sorted("timestamp") - 1 + ), + ) + ).collect().to_dict(False) == { + "date_start": [ + datetime(2022, 12, 31, 0, 0), + datetime(2023, 1, 2, 0, 0), + datetime(2023, 1, 2, 0, 0), + ], + "label": [0, 1, 1], + } diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index 8f10655a20a1..9bc7ceeaf0e6 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -85,12 +85,6 @@ def test_filter_where() -> None: assert_frame_equal(result_filter, expected) -def test_list_join_strings() -> None: - s = pl.Series("a", [["ab", "c", "d"], ["e", "f"], ["g"], []]) - expected = pl.Series("a", ["ab-c-d", "e-f", "g", ""]) - assert_series_equal(s.list.join("-"), expected) - - def test_count_expr() -> None: df = pl.DataFrame({"a": [1, 2, 3, 3, 3], "b": ["a", "a", "b", "a", "a"]}) @@ -265,6 +259,22 @@ def test_null_count_expr() -> None: assert df.select([pl.all().null_count()]).to_dict(False) == {"key": [0], "val": [1]} +def test_pos_neg() -> None: + df = pl.DataFrame( + { + "x": [3, 2, 1], + "y": [6, 7, 8], + } + ).with_columns(-pl.col("x"), +pl.col("y"), -pl.lit(1)) + + # #11149: ensure that we preserve the output name (where available) + assert df.to_dict(False) == { + "x": [-3, -2, -1], + "y": [6, 7, 8], + "literal": [-1, -1, -1], + } + + def test_power_by_expression() -> None: out = pl.DataFrame( {"a": [1, None, None, 4, 5, 6], "b": [1, 2, None, 4, None, 6]} diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 1f9ca2c78077..3aac664a8f88 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -287,24 +287,6 @@ def test_is_unique() -> None: assert_series_equal(result, pl.Series("a", [False, True, False])) -def test_is_first() -> None: - ldf = pl.LazyFrame({"a": [4, 1, 4]}) - result = ldf.select(pl.col("a").is_first()).collect()["a"] - assert_series_equal(result, pl.Series("a", [True, True, False])) - - # struct - ldf = pl.LazyFrame({"a": [1, 2, 3, 2, None, 2, 1], "b": [0, 2, 3, 2, None, 2, 0]}) - - assert ldf.select(pl.struct(["a", "b"]).is_first()).collect().to_dict(False) == { - "a": [True, True, True, False, True, False, False] - } - - ldf = pl.LazyFrame({"a": [[1, 2], [3], [1, 2], [4, 5], [4, 5]]}) - assert ldf.select(pl.col("a").is_first()).collect().to_dict(False) == { - "a": [True, True, False, True, False] - } - - def test_is_duplicated() -> None: ldf = pl.LazyFrame({"a": [4, 1, 4]}).select(pl.col("a").is_duplicated()) assert_series_equal(ldf.collect()["a"], pl.Series("a", [True, False, True])) diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 6c2a361c3268..6abc676027c0 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -270,7 +270,8 @@ def test_schema_owned_arithmetic_5669() -> None: .with_columns(-pl.col("A").alias("B")) .collect() ) - assert df.columns == ["A", "literal"], df.columns + assert df.columns == ["A", "B"] + assert df.rows() == [(3, -3)] def test_fill_null_f32_with_lit() -> None: diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 820fafaf1abd..22293206fb91 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -17,6 +17,12 @@ def test_pickling_simple_expression() -> None: assert str(pickle.loads(buf)) == str(e) +def test_pickling_as_struct_11100() -> None: + e = pl.struct("a") + buf = pickle.dumps(e) + assert str(pickle.loads(buf)) == str(e) + + def test_lazyframe_serde() -> None: lf = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).lazy().select(pl.col("a"))