From 2f85df5743ee2082ac308509bf9352a79bb3d12a Mon Sep 17 00:00:00 2001 From: Tao He Date: Mon, 7 Aug 2023 14:04:54 +0800 Subject: [PATCH] Arrow array resolvers and builders. Signed-off-by: Tao He --- .github/workflows/rust-ci.yml | 11 +- rust/Cargo.toml | 35 +- rust/README.md | 171 +++ rust/vineyard-integration-testing/Cargo.toml | 29 + .../src/ds/arrow_test.rs | 152 +++ .../src/ds/mod.rs | 18 + .../src/ds/numpy_test.rs | 91 ++ .../src/ds/pandas_test.rs | 75 ++ .../src/ds/polars_test.rs | 69 + rust/vineyard-integration-testing/src/lib.rs | 17 + rust/vineyard-polars/Cargo.toml | 32 + rust/vineyard-polars/src/ds/dataframe.rs | 347 +++++ rust/vineyard-polars/src/ds/dataframe_test.rs | 16 + rust/vineyard-polars/src/ds/mod.rs | 16 + rust/vineyard-polars/src/lib.rs | 17 + rust/vineyard/Cargo.toml | 34 +- rust/vineyard/src/client/client.rs | 85 +- rust/vineyard/src/client/ds/blob.rs | 192 ++- rust/vineyard/src/client/ds/blob_test.rs | 99 +- rust/vineyard/src/client/ds/object.rs | 79 +- rust/vineyard/src/client/ds/object_factory.rs | 15 +- rust/vineyard/src/client/ds/object_meta.rs | 71 +- rust/vineyard/src/client/ipc_client.rs | 100 +- rust/vineyard/src/client/ipc_client_test.rs | 9 +- rust/vineyard/src/client/rpc_client.rs | 42 +- rust/vineyard/src/client/rpc_client_test.rs | 9 +- rust/vineyard/src/common/util/arrow.rs | 24 +- rust/vineyard/src/common/util/json.rs | 2 +- rust/vineyard/src/common/util/protocol.rs | 89 +- rust/vineyard/src/common/util/status.rs | 441 +++---- rust/vineyard/src/common/util/typename.rs | 18 +- rust/vineyard/src/common/util/uuid.rs | 5 +- rust/vineyard/src/ds/array.rs | 81 +- rust/vineyard/src/ds/array_test.rs | 55 +- rust/vineyard/src/ds/arrow.rs | 1160 ++++++++++++++++- rust/vineyard/src/ds/arrow_test.rs | 248 +++- rust/vineyard/src/ds/arrow_utils.rs | 81 ++ rust/vineyard/src/ds/dataframe.rs | 148 +++ rust/vineyard/src/ds/dataframe_test.rs | 16 + rust/vineyard/src/ds/hashmap.rs | 1 + rust/vineyard/src/ds/hashmap_test.rs | 3 +- rust/vineyard/src/ds/mod.rs | 5 + rust/vineyard/src/ds/tensor.rs | 536 ++++++++ rust/vineyard/src/ds/tensor_test.rs | 16 + rust/vineyard/src/lib.rs | 11 + 45 files changed, 4145 insertions(+), 626 deletions(-) create mode 100644 rust/vineyard-integration-testing/Cargo.toml create mode 100644 rust/vineyard-integration-testing/src/ds/arrow_test.rs create mode 100644 rust/vineyard-integration-testing/src/ds/mod.rs create mode 100644 rust/vineyard-integration-testing/src/ds/numpy_test.rs create mode 100644 rust/vineyard-integration-testing/src/ds/pandas_test.rs create mode 100644 rust/vineyard-integration-testing/src/ds/polars_test.rs create mode 100644 rust/vineyard-integration-testing/src/lib.rs create mode 100644 rust/vineyard-polars/Cargo.toml create mode 100644 rust/vineyard-polars/src/ds/dataframe.rs create mode 100644 rust/vineyard-polars/src/ds/dataframe_test.rs create mode 100644 rust/vineyard-polars/src/ds/mod.rs create mode 100644 rust/vineyard-polars/src/lib.rs create mode 100644 rust/vineyard/src/ds/arrow_utils.rs create mode 100644 rust/vineyard/src/ds/dataframe.rs create mode 100644 rust/vineyard/src/ds/dataframe_test.rs create mode 100644 rust/vineyard/src/ds/tensor.rs create mode 100644 rust/vineyard/src/ds/tensor_test.rs diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 58d226ff35..25bcfe877a 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -45,10 +45,15 @@ jobs: override: true components: rustfmt, clippy - - name: Test + - name: Check run: | cd rust cargo fmt --all -- --check - # cargo clippy -- -D warnings + cargo clippy -- -D warnings cargo check - # cargo test --lib + + - name: Unittest + if: false + run: | + cd rust + cargo test --lib diff --git a/rust/Cargo.toml b/rust/Cargo.toml index d2182234ea..efc76a1dbf 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -1,4 +1,35 @@ +[workspace] +members = [ + "vineyard", + "vineyard-polars", + "vineyard-integration-testing", +] +resolver = "2" + +[workspace.package] +version = "0.16.5" +homepage = "https://v6d.io" +repository = "https://github.com/v6d-io/v6d.git" +authors = ["Vineyard "] +license = "Apache-2.0" +keywords = ["vineyard"] +include = [ + "src/**/*.rs", + "Cargo.toml", +] edition = "2021" +readme = "README.md" -[workspace] -members = ["vineyard"] +[workspace.dependencies] +arrow-array = ">=40, <44" +arrow-buffer = ">=40, <44" +arrow-ipc = ">=40, <44" +arrow-schema = ">=40, <44" +arrow2 = { version = "0.17", features = ["arrow"] } +inline-python = "0.12" +polars-core = "0.32" +spectral = "0.6" + +vineyard = { version = "0.16.5", path = "./vineyard" } +vineyard-polars = { version = "0.16.5", path = "./vineyard-polars" } +vineyard-integration-testing = { version = "0.16.5", path = "./vineyard-integration-testing" } diff --git a/rust/README.md b/rust/README.md index 2a2986b691..ef63ff2499 100644 --- a/rust/README.md +++ b/rust/README.md @@ -1 +1,172 @@ # Vineyard Rust SDK + +> [!NOTE] +> Rust nightly is required. The vineyard Rust SDK is still under development. +> The API may change in the future. + +[![crates.io](https://img.shields.io/crates/v/vineyard.svg)](https://crates.io/crates/vineyard) +[![Downloads](https://img.shields.io/crates/d/vineyard)](https://crates.io/crates/vineyard) +[![Docs.rs](https://img.shields.io/docsrs/vineyard/latest)](https://docs.rs/vineyard/latest/vineyard/) + +Connecting to Vineyard +---------------------- + +- Resolve the UNIX-domain socket from the environment variable `VINEYARD_IPC_SOCKET`: + + ```rust + use vineyard::client::*; + + let mut client = vineyard::default().unwrap(); + ``` + +- Or, using explicit parameter: + + ```rust + use vineyard::client::*; + + let mut client = vineyard::connect("/var/run/vineyard.sock").unwrap(); + ``` + +Interact with Vineyard +---------------------- + +- Creating blob: + + ```rust + let mut blob_writer = client.create_blob(N)?; + ``` + +- Get object: + + ```rust + let mut meta_writer = client.get::(object_id)?; + ``` + +Inter-op with Python: `numpy.ndarray` +------------------------------------- + +- Python: + + ```python + import numpy as np + import vineyard + + client = vineyard.connect() + + np_array = np.random.rand(10, 20).astype(np.int32) + object_id = int(client.put(np_array)) + ``` + +- Rust: + + ```rust + let mut client = IPCClient::default()?; + let tensor = client.get::(object_id)?; + assert_that!(tensor.shape().to_vec()).is_equal_to(vec![10, 20]); + ``` + +Inter-op with Python: `pandas.DataFrame` +---------------------------------------- + +- Python + + ```python + import pandas as pd + import vineyard + + client = vineyard.connect() + + df = pd.DataFrame({'a': ["1", "2", "3", "4"], 'b': ["5", "6", "7", "8"]}) + object_id = int(client.put(df)) + ``` + +- Rust + + ```rust + let mut client = IPCClient::default()?; + let dataframe = client.get::(object_id)?; + assert_that!(dataframe.num_columns()).is_equal_to(2); + assert_that!(dataframe.names().to_vec()).is_equal_to(vec!["a".into(), "b".into()]); + for index in 0..dataframe.num_columns() { + let column = dataframe.column(index); + assert_that!(column.len()).is_equal_to(4); + } + ``` + +Inter-op with Python: `pyarrow.RecordBatch` +------------------------------------- + +- Python + + ```python + import pandas as pd + import pyarrow as pa + import vineyard + + client = vineyard.connect() + + arrays = [ + pa.array([1, 2, 3, 4]), + pa.array(["foo", "bar", "baz", "qux"]), + pa.array([3.0, 5.0, 7.0, 9.0]), + ] + batch = pa.RecordBatch.from_arrays(arrays, ["f0", "f1", "f2"]) + object_id = int(client.put(batch)) + ``` + +- Rust + + ```rust + let batch = client.get::(object_id)?; + assert_that!(batch.num_columns()).is_equal_to(3); + assert_that!(batch.num_rows()).is_equal_to(4); + let schema = batch.schema(); + let names = ["f0", "f1", "f2"]; + let recordbatch = batch.as_ref().as_ref(); + ``` + +Inter-op with Python: `pyarrow.Table` +------------------------------------- + +- Python + + ```python + batches = [batch] * 5 + table = pa.Table.from_batches(batches) + object_id = int(client.put(table)) + ``` + +- Rust + + ```rust + let mut client = IPCClient::default()?; + let table = client.get::(object_id)?; + assert_that!(table.num_batches()).is_equal_to(5); + for batch in table.batches().iter() { + // ... + } + ``` + +Inter-op with Python: `polars.DataFrame` +---------------------------------------- + +- Python + + ```python + import polars + + dataframe = polars.DataFrame(table) + object_id = int(client.put(dataframe)) + ``` + +- Rust + + ```rust + let mut client = IPCClient::default()?; + let batch = client.get::(object_id)?; + let dataframe = batch.as_ref().as_ref(); + assert_that!(dataframe.width()).is_equal_to(3); + for column in dataframe.get_columns() { + // ... + } + ``` diff --git a/rust/vineyard-integration-testing/Cargo.toml b/rust/vineyard-integration-testing/Cargo.toml new file mode 100644 index 0000000000..aa9d81684a --- /dev/null +++ b/rust/vineyard-integration-testing/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "vineyard-integration-testing" +version = { workspace = true } +description = "Vineyard Rust SDK: integration testing" +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +readme = { workspace = true } + +[features] +default = ["nightly"] +nightly = [] + +[lib] +name = "vineyard_integration_testing" +path = "src/lib.rs" + +[dependencies] + +[dev-dependencies] +arrow-array = { workspace = true } +inline-python = { workspace = true } +polars-core = { workspace = true } +spectral = { workspace = true } +vineyard = { workspace = true } +vineyard-polars = { workspace = true } diff --git a/rust/vineyard-integration-testing/src/ds/arrow_test.rs b/rust/vineyard-integration-testing/src/ds/arrow_test.rs new file mode 100644 index 0000000000..5bfc03a2bc --- /dev/null +++ b/rust/vineyard-integration-testing/src/ds/arrow_test.rs @@ -0,0 +1,152 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use inline_python::{python, Context}; + use spectral::prelude::*; + + use vineyard::client::*; + use vineyard::ds::arrow::*; + + #[test] + fn test_arrow_recordbatch() -> Result<()> { + let ctx = Context::new(); + ctx.run(python! { + import pandas as pd + import pyarrow as pa + import vineyard + + client = vineyard.connect() + + arrays = [ + pa.array([1, 2, 3, 4]), + pa.array(["foo", "bar", "baz", "qux"]), + pa.array([3.0, 5.0, 7.0, 9.0]), + ] + batch = pa.RecordBatch.from_arrays(arrays, ["f0", "f1", "f2"]) + object_id = int(client.put(batch)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let batch = client.get::(object_id)?; + assert_that!(batch.num_columns()).is_equal_to(3); + assert_that!(batch.num_rows()).is_equal_to(4); + let schema = batch.schema(); + let names = ["f0", "f1", "f2"]; + let recordbatch = batch.as_ref().as_ref(); + for (index, name) in names.into_iter().enumerate() { + assert_that!(schema.field(index).name()).is_equal_to(&name.to_string()); + + let column = recordbatch.column(index); + match index { + 0 => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + let expected: arrow_array::Int64Array = vec![1, 2, 3, 4].into(); + assert_that!(array).is_equal_to(&expected); + } + 1 => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + let expected: arrow_array::LargeStringArray = + vec!["foo", "bar", "baz", "qux"].into(); + assert_that!(array).is_equal_to(&expected); + } + 2 => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + let expected: arrow_array::Float64Array = vec![3.0, 5.0, 7.0, 9.0].into(); + assert_that!(array).is_equal_to(&expected); + } + _ => unreachable!(), + } + } + return Ok(()); + } + + #[test] + fn test_arrow_table() -> Result<()> { + let ctx = Context::new(); + ctx.run(python! { + import pandas as pd + import pyarrow as pa + import vineyard + client = vineyard.connect() + + arrays = [ + pa.array([1, 2, 3, 4]), + pa.array(["foo", "bar", "baz", "qux"]), + pa.array([3.0, 5.0, 7.0, 9.0]), + ] + batch = pa.RecordBatch.from_arrays(arrays, ["f0", "f1", "f2"]) + batches = [batch] * 5 + table = pa.Table.from_batches(batches) + object_id = int(client.put(table)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let table = client.get::
(object_id)?; + assert_that!(table.num_batches()).is_equal_to(5); + for batch in table.batches().iter() { + assert_that!(batch.num_columns()).is_equal_to(3); + assert_that!(batch.num_rows()).is_equal_to(4); + let schema = batch.schema(); + let names = ["f0", "f1", "f2"]; + let recordbatch = batch.as_ref().as_ref(); + for (index, name) in names.into_iter().enumerate() { + assert_that!(schema.field(index).name()).is_equal_to(&name.to_string()); + + let column = recordbatch.column(index); + match index { + 0 => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + let expected: arrow_array::Int64Array = vec![1, 2, 3, 4].into(); + assert_that!(array).is_equal_to(&expected); + } + 1 => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + let expected: arrow_array::LargeStringArray = + vec!["foo", "bar", "baz", "qux"].into(); + assert_that!(array).is_equal_to(&expected); + } + 2 => { + let array = column + .as_any() + .downcast_ref::() + .unwrap(); + let expected: arrow_array::Float64Array = vec![3.0, 5.0, 7.0, 9.0].into(); + assert_that!(array).is_equal_to(&expected); + } + _ => unreachable!(), + } + } + } + return Ok(()); + } +} diff --git a/rust/vineyard-integration-testing/src/ds/mod.rs b/rust/vineyard-integration-testing/src/ds/mod.rs new file mode 100644 index 0000000000..af2cd6e62a --- /dev/null +++ b/rust/vineyard-integration-testing/src/ds/mod.rs @@ -0,0 +1,18 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod arrow_test; +pub mod numpy_test; +pub mod pandas_test; +pub mod polars_test; diff --git a/rust/vineyard-integration-testing/src/ds/numpy_test.rs b/rust/vineyard-integration-testing/src/ds/numpy_test.rs new file mode 100644 index 0000000000..82e103d694 --- /dev/null +++ b/rust/vineyard-integration-testing/src/ds/numpy_test.rs @@ -0,0 +1,91 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use inline_python::{python, Context}; + use spectral::prelude::*; + + use vineyard::client::*; + use vineyard::ds::tensor::{Float64Tensor, Int32Tensor, StringTensor}; + + #[test] + fn test_numpy_int32() -> Result<()> { + let ctx = Context::new(); + ctx.run(python! { + import numpy as np + import vineyard + + client = vineyard.connect() + + np_array = np.random.rand(10, 20).astype(np.int32) + object_id = int(client.put(np_array)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let tensor = client.get::(object_id)?; + assert_that!(tensor.shape().to_vec()).is_equal_to(vec![10, 20]); + return Ok(()); + } + + #[test] + fn test_numpy_float64() -> Result<()> { + let ctx = Context::new(); + ctx.run(python! { + import numpy as np + import vineyard + + client = vineyard.connect() + + np_array = np.random.rand(10, 20) + object_id = int(client.put(np_array)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let tensor = client.get::(object_id)?; + assert_that!(tensor.shape().to_vec()).is_equal_to(vec![10, 20]); + return Ok(()); + } + + #[ignore = "ndarray with string type in python side needs to be fixed"] + #[test] + fn test_numpy_string() -> Result<()> { + use arrow_array::array::Array; + + let ctx = Context::new(); + ctx.run(python! { + import numpy as np + import vineyard + + client = vineyard.connect() + + np_array = np.array(['a', 'b', 'c', 'd', 'e']) + object_id = int(client.put(np_array)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let tensor = client.get::(object_id)?; + assert_that!(tensor.shape().to_vec()).is_equal_to(vec![5]); + let array = tensor.as_ref().as_ref(); + assert_that!(array.len()).is_equal_to(5); + for index in 0..array.len() { + assert_that!(array.value(index).to_string()) + .is_equal_to(format!("{}", (index as u8 + b'a') as char)); + } + return Ok(()); + } +} diff --git a/rust/vineyard-integration-testing/src/ds/pandas_test.rs b/rust/vineyard-integration-testing/src/ds/pandas_test.rs new file mode 100644 index 0000000000..6b51347838 --- /dev/null +++ b/rust/vineyard-integration-testing/src/ds/pandas_test.rs @@ -0,0 +1,75 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use inline_python::{python, Context}; + use spectral::prelude::*; + + use vineyard::client::*; + use vineyard::ds::dataframe::DataFrame; + + #[test] + fn test_pandas_int() -> Result<()> { + let ctx = Context::new(); + ctx.run(python! { + import numpy as np + import pandas as pd + import vineyard + + client = vineyard.connect() + + df = pd.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + object_id = int(client.put(df)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let dataframe = client.get::(object_id)?; + assert_that!(dataframe.num_columns()).is_equal_to(2); + assert_that!(dataframe.names().to_vec()).is_equal_to(vec!["a".into(), "b".into()]); + for index in 0..dataframe.num_columns() { + let column = dataframe.column(index); + assert_that!(column.len()).is_equal_to(4); + } + return Ok(()); + } + + #[ignore = "ndarray with string type in python side needs to be fixed"] + #[test] + fn test_pandas_string() -> Result<()> { + let ctx = Context::new(); + ctx.run(python! { + import numpy as np + import pandas as pd + import vineyard + + client = vineyard.connect() + + df = pd.DataFrame({'a': ["1", "2", "3", "4"], 'b': ["5", "6", "7", "8"]}) + object_id = int(client.put(df)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let dataframe = client.get::(object_id)?; + assert_that!(dataframe.num_columns()).is_equal_to(2); + assert_that!(dataframe.names().to_vec()).is_equal_to(vec!["a".into(), "b".into()]); + for index in 0..dataframe.num_columns() { + let column = dataframe.column(index); + assert_that!(column.len()).is_equal_to(4); + } + return Ok(()); + } +} diff --git a/rust/vineyard-integration-testing/src/ds/polars_test.rs b/rust/vineyard-integration-testing/src/ds/polars_test.rs new file mode 100644 index 0000000000..1a4d76e41e --- /dev/null +++ b/rust/vineyard-integration-testing/src/ds/polars_test.rs @@ -0,0 +1,69 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use inline_python::{python, Context}; + use polars_core::prelude::NamedFrom; + use polars_core::series::Series; + use spectral::prelude::*; + + use vineyard::client::*; + use vineyard_polars::ds::dataframe::DataFrame; + + #[test] + fn test_polars_dataframe() -> Result<()> { + let ctx = Context::new(); + ctx.run(python! { + import pandas as pd + import pyarrow as pa + import polars + + import vineyard + + client = vineyard.connect() + + arrays = [ + pa.array([1, 2, 3, 4]), + pa.array(["foo", "bar", "baz", "qux"]), + pa.array([3.0, 5.0, 7.0, 9.0]), + ] + batch = pa.RecordBatch.from_arrays(arrays, ["f0", "f1", "f2"]) + batches = [batch] * 5 + table = pa.Table.from_batches(batches) + dataframe = polars.DataFrame(table) + object_id = int(client.put(dataframe)) + }); + let object_id = ctx.get::("object_id"); + + let mut client = IPCClient::default()?; + let batch = client.get::(object_id)?; + let dataframe = batch.as_ref().as_ref(); + assert_that!(dataframe.width()).is_equal_to(3); + let mut names = Vec::with_capacity(dataframe.width()); + for column in dataframe.get_columns() { + names.push(column.name()); + } + assert_that!(names).is_equal_to(vec!["f0", "f1", "f2"]); + + // check column values + assert_that!(dataframe.column("f0").unwrap().head(Some(4))) + .is_equal_to(&Series::new("f0", [1, 2, 3, 4])); + assert_that!(dataframe.column("f1").unwrap().head(Some(4))) + .is_equal_to(&Series::new("f1", ["foo", "bar", "baz", "qux"])); + assert_that!(dataframe.column("f2").unwrap().head(Some(4))) + .is_equal_to(&Series::new("f2", [3.0, 5.0, 7.0, 9.0])); + return Ok(()); + } +} diff --git a/rust/vineyard-integration-testing/src/lib.rs b/rust/vineyard-integration-testing/src/lib.rs new file mode 100644 index 0000000000..d6d203d50a --- /dev/null +++ b/rust/vineyard-integration-testing/src/lib.rs @@ -0,0 +1,17 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(clippy::needless_return)] + +pub mod ds; diff --git a/rust/vineyard-polars/Cargo.toml b/rust/vineyard-polars/Cargo.toml new file mode 100644 index 0000000000..7b4301390c --- /dev/null +++ b/rust/vineyard-polars/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "vineyard-polars" +version = { workspace = true } +description = "Vineyard Rust SDK: polars integration for DataFrame" +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +readme = { workspace = true } + +[features] +default = ["nightly"] +nightly = [] + +[lib] +name = "vineyard_polars" +path = "src/lib.rs" + +[dependencies] +arrow-array = { workspace = true } +arrow-schema = { workspace = true } +arrow2 = { workspace = true } +itertools = "0.11" +ndarray = "0.15" +polars-core = "0.32" +serde_json = "1.0" +vineyard = { workspace = true } + +[dev-dependencies] +spectral = { workspace = true } diff --git a/rust/vineyard-polars/src/ds/dataframe.rs b/rust/vineyard-polars/src/ds/dataframe.rs new file mode 100644 index 0000000000..dc0eb3843a --- /dev/null +++ b/rust/vineyard-polars/src/ds/dataframe.rs @@ -0,0 +1,347 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow2::array; +use arrow2::datatypes; +use itertools::izip; +use polars_core::prelude as polars; +use serde_json::Value; + +use vineyard::client::*; +use vineyard::ds::arrow::{Table, TableBuilder}; +use vineyard::ds::dataframe::DataFrame as VineyardDataFrame; + +/// Convert a Polars error to a Vineyard error, as orphan impls are not allowed +/// in Rust +/// +/// Usage: +/// +/// ```no_run +/// let x = polars::DataFrame::new(...).map_err(error)?; +/// ``` +fn error(error: polars::PolarsError) -> VineyardError { + VineyardError::invalid(format!("{}", error)) +} + +#[derive(Debug, Default)] +pub struct DataFrame { + meta: ObjectMeta, + dataframe: polars::DataFrame, +} + +impl_typename!(DataFrame, "vineyard::Table"); + +impl Object for DataFrame { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + let ty = meta.get_typename()?; + if ty == typename::() { + return self.construct_from_pandas_dataframe(meta); + } else if ty == typename::
() { + return self.construct_from_arrow_table(meta); + } else { + return Err(VineyardError::type_error(format!( + "cannot construct DataFrame from this metadata: {}", + ty + ))); + } + } +} + +register_vineyard_object!(DataFrame); + +impl DataFrame { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut object = Box::::default(); + object.construct(meta)?; + Ok(object) + } + + fn construct_from_pandas_dataframe(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + let dataframe = downcast_object::(VineyardDataFrame::new_boxed(meta)?)?; + let names = dataframe.names().to_vec(); + let columns: Vec> = dataframe + .columns() + .iter() + .map(|c| array::from_data(&c.array().to_data())) + .collect(); + let series: Vec = names + .iter() + .zip(columns) + .map(|(name, column)| { + let datatype = polars::DataType::from(column.data_type()); + unsafe { + polars_core::series::Series::from_chunks_and_dtype_unchecked( + name, + vec![column], + &datatype, + ) + } + }) + .collect::>(); + self.dataframe = polars::DataFrame::new(series).map_err(error)?; + return Ok(()); + } + + fn construct_from_arrow_table(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::
(), meta.get_typename()?)?; + let table = downcast_object::
(Table::new_boxed(meta)?)?; + let schema = table.schema(); + let names = schema + .fields() + .iter() + .map(|f| f.name().clone()) + .collect::>(); + let types = schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); + let mut columns: Vec>> = Vec::with_capacity(table.num_columns()); + for index in 0..table.num_columns() { + let mut chunks = Vec::with_capacity(table.num_batches()); + for batch in table.batches() { + let batch = batch.as_ref().as_ref(); + let chunk = batch.column(index); + chunks.push(array::from_data(&chunk.to_data())); + } + columns.push(chunks); + } + let series: Vec = izip!(&names, types, columns) + .map(|(name, datatype, chunks)| unsafe { + polars_core::series::Series::from_chunks_and_dtype_unchecked( + name, + chunks, + &polars::DataType::from(&datatypes::DataType::from(datatype)), + ) + }) + .collect::>(); + self.dataframe = polars::DataFrame::new(series).map_err(error)?; + return Ok(()); + } +} + +impl AsRef for DataFrame { + fn as_ref(&self) -> &polars::DataFrame { + &self.dataframe + } +} + +/// Building a polars dataframe into a pandas-compatible dataframe. +pub struct PandasDataFrameBuilder { + sealed: bool, + names: Vec, + columns: Vec>, +} + +impl ObjectBuilder for PandasDataFrameBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for PandasDataFrameBuilder { + fn build(&mut self, _client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let mut meta = ObjectMeta::new_from_typename(typename::()); + meta.add_usize("__values_-size", self.names.len()); + meta.add_isize("partition_index_row_", -1); + meta.add_isize("partition_index_column_", -1); + meta.add_isize("row_batch_index_", -1); + for (index, (name, column)) in self.names.iter().zip(self.columns).enumerate() { + meta.add_value( + &format!("__values_-key-{}", index), + Value::String(name.into()), + ); + meta.add_member(&format!("__values_-value-{}", index), column)?; + } + let metadata = client.create_metadata(&meta)?; + return DataFrame::new_boxed(metadata); + } +} + +impl PandasDataFrameBuilder { + pub fn new(names: Vec, columns: Vec>) -> Result { + return Ok(PandasDataFrameBuilder { + sealed: false, + names, + columns, + }); + } + + pub fn new_from_arrays( + client: &mut IPCClient, + names: Vec, + arrays: Vec>, + ) -> Result { + use vineyard::ds::tensor::build_tensor; + + let mut columns = Vec::with_capacity(arrays.len()); + for array in arrays { + columns.push(build_tensor(client, array.into())?); + } + return Ok(PandasDataFrameBuilder { + sealed: false, + names, + columns, + }); + } + + pub fn new_from_dataframe( + client: &mut IPCClient, + dataframe: &polars::DataFrame, + ) -> Result { + let mut names = Vec::with_capacity(dataframe.width()); + let mut columns = Vec::with_capacity(dataframe.width()); + for column in dataframe.get_columns() { + let column = column.rechunk(); // FIXME(avoid copying) + names.push(column.name().into()); + columns.push(column.chunks()[0].clone()); + } + return Self::new_from_arrays(client, names, columns); + } +} + +/// Building a polars dataframe into a arrow's table-compatible dataframe. +pub struct ArrowDataFrameBuilder(pub TableBuilder); + +impl ObjectBuilder for ArrowDataFrameBuilder { + fn sealed(&self) -> bool { + self.0.sealed() + } + + fn set_sealed(&mut self, sealed: bool) { + self.0.set_sealed(sealed) + } +} + +impl ObjectBase for ArrowDataFrameBuilder { + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + self.0.build(client) + } + + fn seal(self, client: &mut IPCClient) -> Result> { + let table = downcast_object::
(self.0.seal(client)?)?; + return DataFrame::new_boxed(table.metadata()); + } +} + +impl ArrowDataFrameBuilder { + /// batches[0]: the first record batch + /// batches[0][0]: the first column of the first record batch + pub fn new( + client: &mut IPCClient, + names: Vec, + datatypes: Vec, + num_rows: Vec, + num_columns: usize, + batches: Vec>>, + ) -> Result { + let schema = arrow_schema::Schema::new( + izip!(names, datatypes) + .map(|(name, datatype)| { + arrow_schema::Field::from(datatypes::Field::new(name, datatype, false)) + }) + .collect::>(), + ); + return Ok(ArrowDataFrameBuilder(TableBuilder::new_from_bathes( + client, + &schema, + num_rows, + num_columns, + batches, + )?)); + } + + /// batches[0]: the first record batch + /// batches[0][0]: the first column of the first record batch + pub fn new_from_batches( + client: &mut IPCClient, + names: Vec, + datatypes: Vec, + batches: Vec>>, + ) -> Result { + use vineyard::ds::arrow::build_array; + + let mut num_rows = Vec::with_capacity(batches.len()); + let mut num_columns = 0; + let mut chunks = Vec::with_capacity(batches.len()); + for batch in batches { + let mut columns = Vec::with_capacity(batch.len()); + num_columns = columns.len(); + if num_columns == 0 { + num_rows.push(0); + } else { + num_rows.push(batch[0].len()); + } + for array in batch { + columns.push(build_array(client, array.into())?); + } + chunks.push(columns); + } + return Self::new(client, names, datatypes, num_rows, num_columns, chunks); + } + + /// columns[0]: the first column + /// columns[0][0]: the first chunk of the first column + pub fn new_from_columns( + client: &mut IPCClient, + names: Vec, + datatypes: Vec, + columns: Vec>>, + ) -> Result { + use vineyard::ds::arrow::build_array; + + let mut num_rows = Vec::new(); + let num_columns = columns.len(); + let mut chunks = Vec::new(); + for (column_index, column) in columns.into_iter().enumerate() { + for (chunk_index, chunk) in column.into_iter().enumerate() { + if column_index == 0 { + chunks.push(Vec::new()); + num_rows.push(chunk.len()); + } + chunks[chunk_index].push(build_array(client, chunk.into())?); + } + } + return Self::new(client, names, datatypes, num_rows, num_columns, chunks); + } + + pub fn new_from_dataframe( + client: &mut IPCClient, + dataframe: &polars::DataFrame, + ) -> Result { + let mut names = Vec::with_capacity(dataframe.width()); + let mut datatypes = Vec::with_capacity(dataframe.width()); + let mut columns = Vec::with_capacity(dataframe.width()); + for column in dataframe.get_columns() { + names.push(column.name().into()); + datatypes.push(column.dtype().to_arrow()); + columns.push(column.chunks().clone()); + } + return Self::new_from_columns(client, names, datatypes, columns); + } +} diff --git a/rust/vineyard-polars/src/ds/dataframe_test.rs b/rust/vineyard-polars/src/ds/dataframe_test.rs new file mode 100644 index 0000000000..f54b6c5ad4 --- /dev/null +++ b/rust/vineyard-polars/src/ds/dataframe_test.rs @@ -0,0 +1,16 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[allow(unused_imports)] +use std::any::Any; diff --git a/rust/vineyard-polars/src/ds/mod.rs b/rust/vineyard-polars/src/ds/mod.rs new file mode 100644 index 0000000000..b180bb9d15 --- /dev/null +++ b/rust/vineyard-polars/src/ds/mod.rs @@ -0,0 +1,16 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod dataframe; +pub mod dataframe_test; diff --git a/rust/vineyard-polars/src/lib.rs b/rust/vineyard-polars/src/lib.rs new file mode 100644 index 0000000000..d6d203d50a --- /dev/null +++ b/rust/vineyard-polars/src/lib.rs @@ -0,0 +1,17 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![allow(clippy::needless_return)] + +pub mod ds; diff --git a/rust/vineyard/Cargo.toml b/rust/vineyard/Cargo.toml index a645cbe0c9..cfd463b99b 100644 --- a/rust/vineyard/Cargo.toml +++ b/rust/vineyard/Cargo.toml @@ -1,16 +1,14 @@ [package] name = "vineyard" -version = "0.16.2" -description = "Vineyard Rust SDK" -homepage = "https://github.com/v6d-io/v6d" -repository = "https://github.com/v6d-io/v6d" -license = "Apache-2.0" -keywords = [ "vineyard" ] -include = [ - "src/**/*.rs", - "Cargo.toml", -] -edition = "2021" +version = { workspace = true } +description = "Vineyard Rust SDK: core library" +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +keywords = { workspace = true } +include = { workspace = true } +edition = { workspace = true } +readme = { workspace = true } [features] default = ["nightly"] @@ -21,21 +19,27 @@ name = "vineyard" path = "src/lib.rs" [dependencies] -arrow = "44" -const_format = "0.2" +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-ipc = { workspace = true } +arrow-schema = { workspace = true } +ctor = "0.2" downcast-rs = "1.2" env_logger = "0.9" +gensym = "0.1" +itertools = "0.11" lazy_static = "1" log = "0.4" memmap2 = "0.7" num-traits = "0.2" num-derive = "0.4" +parking_lot = "0.12" rand = "0.8" sendfd = "0.4" serde = "1.0" serde_derive = "1.0" serde_json = "1.0" -thiserror = "1.0" +static_str_ops = "0.1.2" [dev-dependencies] -spectral = "*" +spectral = { workspace = true } diff --git a/rust/vineyard/src/client/client.rs b/rust/vineyard/src/client/client.rs index 84978ce548..04e2d5278f 100644 --- a/rust/vineyard/src/client/client.rs +++ b/rust/vineyard/src/client/client.rs @@ -14,6 +14,8 @@ use std::collections::HashMap; +use parking_lot::ReentrantMutexGuard; + use crate::common::util::json::*; use crate::common::util::protocol::*; use crate::common::util::status::*; @@ -55,7 +57,7 @@ impl InstanceStatus { pub trait Client { /// Disconnect this client. - fn disconnect(&mut self) -> (); + fn disconnect(&mut self); fn connected(&mut self) -> bool; @@ -69,16 +71,14 @@ pub trait Client { fn get_metadata(&mut self, id: ObjectID) -> Result; - fn get_metadata_batch(&mut self, ids: &Vec) -> Result>; + fn get_metadata_batch(&mut self, ids: &[ObjectID]) -> Result>; fn fetch_and_get_metadata(&mut self, id: ObjectID) -> Result { - self.ensure_connect()?; let local_id = self.migrate(id)?; return self.get_metadata(local_id); } - fn fetch_and_get_metadata_batch(&mut self, ids: &Vec) -> Result> { - self.ensure_connect()?; + fn fetch_and_get_metadata_batch(&mut self, ids: &[ObjectID]) -> Result> { let mut local_ids = Vec::new(); for id in ids { local_ids.push(self.migrate(*id)?); @@ -87,28 +87,28 @@ pub trait Client { } fn drop_buffer(&mut self, id: ObjectID) -> Result<()> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_drop_buffer_request(id)?; self.do_write(&message_out)?; return read_drop_buffer_reply(&self.do_read()?); } fn seal_buffer(&mut self, id: ObjectID) -> Result<()> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_seal_request(id)?; self.do_write(&message_out)?; return read_seal_reply(&self.do_read()?); } fn get_data(&mut self, id: ObjectID, sync_remote: bool, wait: bool) -> Result { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_get_data_request(id, sync_remote, wait)?; self.do_write(&message_out)?; return read_get_data_reply(&self.do_read()?); } - fn get_data_batch(&mut self, ids: &Vec) -> Result> { - self.ensure_connect()?; + fn get_data_batch(&mut self, ids: &[ObjectID]) -> Result> { + let _ = self.ensure_connect()?; let message_out = write_get_data_batch_request(&ids, false, false)?; self.do_write(&message_out)?; let reply = read_get_data_batch_reply(&self.do_read()?)?; @@ -128,7 +128,7 @@ pub trait Client { } fn create_data(&mut self, data: &JSON) -> Result<(ObjectID, Signature, InstanceID)> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_create_data_request(data)?; self.do_write(&message_out)?; let reply = read_create_data_reply(&self.do_read()?)?; @@ -145,14 +145,14 @@ pub trait Client { } fn delete(&mut self, id: ObjectID, force: bool, deep: bool) -> Result<()> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_delete_data_request(id, force, deep, false)?; self.do_write(&message_out)?; return read_delete_data_reply(&self.do_read()?); } - fn delete_batch(&mut self, ids: &Vec, force: bool, deep: bool) -> Result<()> { - self.ensure_connect()?; + fn delete_batch(&mut self, ids: &[ObjectID], force: bool, deep: bool) -> Result<()> { + let _ = self.ensure_connect()?; let message_out = write_delete_data_batch_request(ids, force, deep, false)?; self.do_write(&message_out)?; return read_delete_data_reply(&self.do_read()?); @@ -160,7 +160,7 @@ pub trait Client { /// @param pattern: The pattern of typename. fn list_data(&mut self, pattern: &str, regex: bool, limit: usize) -> Result> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_list_data_request(pattern, regex, limit)?; self.do_write(&message_out)?; return read_list_data_reply(&self.do_read()?); @@ -173,70 +173,70 @@ pub trait Client { regex: bool, limit: usize, ) -> Result> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_list_name_request(pattern, regex, limit)?; self.do_write(&message_out)?; return read_list_name_reply(&self.do_read()?); } fn persist(&mut self, id: ObjectID) -> Result<()> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_persist_request(id)?; self.do_write(&message_out)?; return read_persist_reply(&self.do_read()?); } fn if_persist(&mut self, id: ObjectID) -> Result { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_if_persist_request(id)?; self.do_write(&message_out)?; return read_if_persist_reply(&self.do_read()?); } fn exists(&mut self, id: ObjectID) -> Result { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_exists_request(id)?; self.do_write(&message_out)?; return read_exists_reply(&self.do_read()?); } fn put_name(&mut self, id: ObjectID, name: &str) -> Result<()> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_put_name_request(id, name)?; self.do_write(&message_out)?; return read_put_name_reply(&self.do_read()?); } - fn get_name(&mut self, name: &String, wait: bool) -> Result { - self.ensure_connect()?; + fn get_name(&mut self, name: &str, wait: bool) -> Result { + let _ = self.ensure_connect()?; let message_out = write_get_name_request(name, wait)?; self.do_write(&message_out)?; return read_get_name_reply(&self.do_read()?); } - fn drop_name(&mut self, name: &String) -> Result<()> { - self.ensure_connect()?; + fn drop_name(&mut self, name: &str) -> Result<()> { + let _ = self.ensure_connect()?; let message_out = write_drop_name_request(name)?; self.do_write(&message_out)?; return read_drop_name_reply(&self.do_read()?); } fn migrate(&mut self, id: ObjectID) -> Result { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_migrate_object_request(id)?; self.do_write(&message_out)?; return read_migrate_object_reply(&self.do_read()?); } fn clear(&mut self) -> Result<()> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let message_out = write_clear_request()?; self.do_write(&message_out)?; return read_clear_reply(&self.do_read()?); } fn label(&mut self, id: ObjectID, key: &str, value: &str) -> Result<()> { - self.ensure_connect()?; + let _ = self.ensure_connect()?; let keys: Vec = vec![key.into()]; let values: Vec = vec![value.into()]; let message_out = write_label_request(id, &keys, &values)?; @@ -245,51 +245,46 @@ pub trait Client { } fn evict(&mut self, id: ObjectID) -> Result<()> { - self.ensure_connect()?; - let message_out = write_evict_request(&vec![id])?; + let _ = self.ensure_connect()?; + let message_out = write_evict_request(&[id])?; self.do_write(&message_out)?; return read_evict_reply(&self.do_read()?); } - fn evict_batch(&mut self, ids: &Vec) -> Result<()> { - self.ensure_connect()?; + fn evict_batch(&mut self, ids: &[ObjectID]) -> Result<()> { + let _ = self.ensure_connect()?; let message_out = write_evict_request(ids)?; self.do_write(&message_out)?; return read_evict_reply(&self.do_read()?); } fn load(&mut self, id: ObjectID, pin: bool) -> Result<()> { - self.ensure_connect()?; - let message_out = write_load_request(&vec![id], pin)?; + let _ = self.ensure_connect()?; + let message_out = write_load_request(&[id], pin)?; self.do_write(&message_out)?; return read_load_reply(&self.do_read()?); } - fn load_batch(&mut self, ids: &Vec, pin: bool) -> Result<()> { - self.ensure_connect()?; + fn load_batch(&mut self, ids: &[ObjectID], pin: bool) -> Result<()> { + let _ = self.ensure_connect()?; let message_out = write_load_request(ids, pin)?; self.do_write(&message_out)?; return read_load_reply(&self.do_read()?); } fn unpin(&mut self, id: ObjectID) -> Result<()> { - self.ensure_connect()?; - let message_out = write_unpin_request(&vec![id])?; + let _ = self.ensure_connect()?; + let message_out = write_unpin_request(&[id])?; self.do_write(&message_out)?; return read_unpin_reply(&self.do_read()?); } - fn unpin_batch(&mut self, ids: &Vec) -> Result<()> { - self.ensure_connect()?; + fn unpin_batch(&mut self, ids: &[ObjectID]) -> Result<()> { + let _ = self.ensure_connect()?; let message_out = write_unpin_request(ids)?; self.do_write(&message_out)?; return read_unpin_reply(&self.do_read()?); } - fn ensure_connect(&mut self) -> Result<()> { - if !self.connected() { - return Err(VineyardError::io_error("client not connected".into())); - } - return Ok(()); - } + fn ensure_connect(&mut self) -> Result>; } diff --git a/rust/vineyard/src/client/ds/blob.rs b/rust/vineyard/src/client/ds/blob.rs index f28a6245d2..15ee945b0f 100644 --- a/rust/vineyard/src/client/ds/blob.rs +++ b/rust/vineyard/src/client/ds/blob.rs @@ -13,11 +13,11 @@ // limitations under the License. use std::collections::{HashMap, HashSet}; - use std::fmt::{Debug, Display, Formatter}; -use std::rc::Rc; +use std::mem::ManuallyDrop; +use std::ops::{Deref, DerefMut}; -use arrow::buffer as arrow; +use arrow_buffer::Buffer; use crate::common::util::arrow::*; use crate::common::util::status::*; @@ -36,7 +36,7 @@ use super::object_meta::ObjectMeta; pub struct Blob { meta: ObjectMeta, size: usize, - buffer: Option>, + buffer: Option, } impl_typename!(Blob, "vineyard::Blob"); @@ -46,17 +46,17 @@ impl Default for Blob { Blob { meta: ObjectMeta::default(), size: usize::MAX, - buffer: None as Option>, + buffer: None as Option, } } } impl Object for Blob { fn construct(&mut self, meta: ObjectMeta) -> Result<()> { - vineyard_assert_typename(meta.get_typename()?, &typename::())?; + vineyard_assert_typename(typename::(), meta.get_typename()?)?; self.meta = meta; - if let Some(_) = self.buffer { + if self.buffer.is_some() { return Ok(()); } if self.meta.get_id() == empty_blob_id() { @@ -87,7 +87,7 @@ impl Display for Blob { } impl Blob { - pub fn new(meta: ObjectMeta, size: usize, buffer: Option>) -> Self { + pub fn new(meta: ObjectMeta, size: usize, buffer: Option) -> Self { Blob { meta: meta, size: size, @@ -99,19 +99,24 @@ impl Blob { self.size } - pub fn empty(client: *mut IPCClient) -> Box { - let mut blob = Blob::default(); - blob.size = 0; + pub fn empty(client: *mut IPCClient) -> Result> { + let mut blob = Blob { + size: 0, + ..Blob::default() + }; blob.meta.set_id(empty_blob_id()); blob.meta.set_signature(empty_blob_id() as Signature); - blob.meta.set_typename(&typename::()); + blob.meta.set_typename(typename::()); blob.meta.add_int("length", 0); blob.meta.set_nbytes(0); blob.meta .add_uint("instance_id", unsafe { &*client }.instance_id()); blob.meta.add_bool("transient", true); blob.meta.set_client(client); - return Box::new(blob); + blob.buffer = Some(arrow_buffer_null()); + blob.meta + .set_or_add_buffer(empty_blob_id(), Some(arrow_buffer_null()))?; + return Ok(Box::new(blob)); } pub fn as_ptr(&self) -> Result<*const u8> { @@ -119,10 +124,19 @@ impl Blob { return Ok(buffer.as_ptr()); } + pub fn as_typed_ptr(&self) -> Result<*const T> { + let ptr = self.as_ptr()?; + return Ok(ptr as *const T); + } + pub fn as_ptr_unchecked(&self) -> *const u8 { return self.buffer_unchecked().as_ptr(); } + pub fn as_typed_ptr_unchecked(&self) -> *const T { + return self.as_ptr_unchecked() as *const T; + } + pub fn as_slice(&self) -> Result<&[u8]> { return unsafe { Ok(std::slice::from_raw_parts(self.as_ptr()?, self.size)) }; } @@ -131,7 +145,7 @@ impl Blob { return unsafe { std::slice::from_raw_parts(self.as_ptr_unchecked(), self.size) }; } - pub fn buffer(&self) -> Result> { + pub fn buffer(&self) -> Result { match &self.buffer { None => { if self.size > 0 { @@ -141,11 +155,11 @@ impl Blob { object_id_to_string(self.meta().get_id()) ))); } - let buffer = to_buffer_null(); - return Ok(Rc::new(buffer)); + let buffer = arrow_buffer_null(); + return Ok(buffer); } Some(buffer) => { - if self.size > 0 && buffer.len() == 0 { + if self.size > 0 && buffer.is_empty() { return Err(VineyardError::invalid(format!( "The object might be a (partially) remote object and the payload data is not locally available: {}", @@ -157,11 +171,11 @@ impl Blob { } } - pub fn buffer_unchecked(&self) -> Rc { + pub fn buffer_unchecked(&self) -> Buffer { match &self.buffer { None => { - let buffer = to_buffer_null(); - return Rc::new(buffer); + let buffer = arrow_buffer_null(); + return buffer; } Some(buffer) => { return buffer.clone(); @@ -170,11 +184,26 @@ impl Blob { } } +impl Deref for Blob { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + return self.as_slice_unchecked(); + } +} + +impl AsRef<[u8]> for Blob { + fn as_ref(&self) -> &[u8] { + return self.as_slice_unchecked(); + } +} + #[derive(Debug)] pub struct BlobWriter { sealed: bool, + client: *mut IPCClient, object_id: ObjectID, - buffer: Option>, + buffer: ManuallyDrop>, metadata: HashMap, } @@ -190,25 +219,28 @@ impl ObjectBuilder for BlobWriter { impl ObjectBase for BlobWriter { fn build(&mut self, _client: &mut IPCClient) -> Result<()> { - if !self.sealed { - self.set_sealed(true); + if self.sealed { + return Ok(()); } + self.set_sealed(true); return Ok(()); } - fn seal(self: Self, client: &mut IPCClient) -> Result> { + fn seal(self, client: &mut IPCClient) -> Result> { client.seal_buffer(self.object_id)?; - let mut blob = Blob::default(); - blob.size = self.size(); + let mut blob = Blob { + size: self.size(), + ..Blob::default() + }; blob.meta.set_id(self.object_id); - blob.meta.set_typename(&typename::()); + blob.meta.set_typename(typename::()); blob.meta.set_nbytes(self.size()); blob.meta .add_uint("length", TryInto::::try_into(self.size())?); blob.meta.add_uint("instance_id", client.instance_id()); blob.meta.add_bool("transient", true); - blob.buffer = Some(Rc::new(to_buffer(self.as_ptr(), self.size()))); + blob.buffer = Some(arrow_buffer(self.as_ptr(), self.size())); blob.meta .set_or_add_buffer(self.object_id, blob.buffer.clone())?; return Ok(Box::new(blob)); @@ -216,11 +248,22 @@ impl ObjectBase for BlobWriter { } impl BlobWriter { - pub fn new(id: ObjectID, buffer: Option>) -> Self { + pub fn new(id: ObjectID, buffer: Option) -> Self { BlobWriter { sealed: false, + client: std::ptr::null_mut(), object_id: id, - buffer: buffer, + buffer: ManuallyDrop::new(buffer), + metadata: HashMap::new(), + } + } + + pub fn new_with_client(client: *mut IPCClient, id: ObjectID, buffer: Option) -> Self { + BlobWriter { + sealed: false, + client: client, + object_id: id, + buffer: ManuallyDrop::new(buffer), metadata: HashMap::new(), } } @@ -230,59 +273,102 @@ impl BlobWriter { } pub fn size(&self) -> usize { - match &self.buffer { + match &self.buffer.deref() { None => 0, Some(buf) => buf.len(), } } pub fn as_ptr(&self) -> *const u8 { - return match &self.buffer { + return match &self.buffer.deref() { None => std::ptr::null(), Some(buf) => buf.as_ptr(), }; } - pub fn as_mut_ptr(&self) -> *mut u8 { + pub fn as_typed_ptr(&self) -> *const T { + return self.as_ptr() as *const T; + } + + pub fn as_mut_ptr(&mut self) -> *mut u8 { return self.as_ptr() as *mut u8; } + pub fn as_typed_mut_ptr(&mut self) -> *mut T { + return self.as_mut_ptr() as *mut T; + } + pub fn as_slice(&self) -> &[u8] { return unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()) }; } - pub fn as_mut_slice(&self) -> &mut [u8] { + pub fn as_mut_slice(&mut self) -> &mut [u8] { return unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size()) }; } - pub fn buffer(&self) -> Option> { - return self.buffer.clone(); + pub fn buffer(&self) -> Option<&Buffer> { + return self.buffer.as_ref(); } - pub fn abort(&self, mut client: IPCClient) -> Result<()> { + pub fn release(mut self) -> Option { + return unsafe { ManuallyDrop::take(&mut self.buffer) }; + } + + pub fn abort(&self) -> Result<()> { if self.sealed { return Err(VineyardError::object_sealed( - "The blob write has already been sealed and cannot be aborted".into(), + "The blob write has already been sealed and cannot be aborted", )); } - return client.drop_buffer(self.object_id); + if let Some(client) = unsafe { self.client.as_mut() } { + return client.drop_buffer(self.object_id); + } + return Ok(()); + } + + pub fn add_key_value(&mut self, key: &str, value: &str) { + self.metadata.insert(key.into(), value.into()); + } +} + +impl Drop for BlobWriter { + fn drop(&mut self) { + if let Err(err) = self.abort() { + error!("Failed to abort blob writer: {}, {}", self.object_id, err); + } + } +} + +impl Deref for BlobWriter { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + return self.as_slice(); + } +} + +impl DerefMut for BlobWriter { + fn deref_mut(&mut self) -> &mut Self::Target { + return self.as_mut_slice(); } +} - pub fn add_key_value(&mut self, key: &String, value: &String) { - self.metadata.insert(key.to_string(), value.to_string()); +impl AsRef<[u8]> for BlobWriter { + fn as_ref(&self) -> &[u8] { + return self.as_slice(); } } pub struct BufferSet { buffer_ids: HashSet, - buffers: HashMap>>, + buffers: HashMap>, } impl Default for BufferSet { fn default() -> BufferSet { BufferSet { buffer_ids: HashSet::new() as HashSet, - buffers: HashMap::new() as HashMap>>, + buffers: HashMap::new() as HashMap>, } } } @@ -313,11 +399,11 @@ impl BufferSet { return &self.buffer_ids; } - pub fn buffers(&self) -> &HashMap>> { + pub fn buffers(&self) -> &HashMap> { return &self.buffers; } - pub fn buffers_mut(&mut self) -> &mut HashMap>> { + pub fn buffers_mut(&mut self) -> &mut HashMap> { return &mut self.buffers; } @@ -337,11 +423,7 @@ impl BufferSet { } } - pub fn emplace_buffer( - &mut self, - id: ObjectID, - buffer: Option>, - ) -> Result<()> { + pub fn emplace_buffer(&mut self, id: ObjectID, buffer: Option) -> Result<()> { match self.buffers.get(&id) { Some(Some(_)) => { return Err(VineyardError::invalid(format!( @@ -366,12 +448,12 @@ impl BufferSet { for (key, value) in others.buffers.iter() { match value { None => { - self.buffer_ids.insert(key.clone()); - self.buffers.insert(key.clone(), None); + self.buffer_ids.insert(*key); + self.buffers.insert(*key, None); } Some(buffer) => { - self.buffer_ids.insert(key.clone()); - self.buffers.insert(key.clone(), Some(Rc::clone(buffer))); + self.buffer_ids.insert(*key); + self.buffers.insert(*key, Some(buffer.clone())); } } } @@ -381,7 +463,7 @@ impl BufferSet { return self.buffers.get(&id).is_some(); } - pub fn get(&self, id: ObjectID) -> Result>> { + pub fn get(&self, id: ObjectID) -> Result> { return self .buffers .get(&id) diff --git a/rust/vineyard/src/client/ds/blob_test.rs b/rust/vineyard/src/client/ds/blob_test.rs index 6f2856450c..d7c6ca7f47 100644 --- a/rust/vineyard/src/client/ds/blob_test.rs +++ b/rust/vineyard/src/client/ds/blob_test.rs @@ -14,7 +14,8 @@ #[cfg(test)] mod tests { - use std::rc::Rc; + use std::mem::ManuallyDrop; + use std::sync::atomic::{AtomicUsize, Ordering}; use spectral::prelude::*; @@ -22,46 +23,102 @@ mod tests { use super::super::*; #[test] - fn test_blob() { + fn test_manually_drop() { + static mut drop_a_called: AtomicUsize = AtomicUsize::new(0); + static mut drop_b_called: AtomicUsize = AtomicUsize::new(0); + + struct A {} + + impl Drop for A { + fn drop(&mut self) { + // record a's dtor + unsafe { + drop_a_called.fetch_add(1, Ordering::SeqCst); + } + } + } + + impl A { + pub fn tell(&self) {} + } + + struct B { + a: ManuallyDrop, + } + + impl Drop for B { + fn drop(&mut self) { + // a should live + assert!(unsafe { drop_a_called.load(Ordering::SeqCst) } == 0); + // record b's dtor + unsafe { + drop_b_called.fetch_add(1, Ordering::SeqCst); + } + } + } + + impl B { + pub fn release(mut self) -> A { + return unsafe { ManuallyDrop::take(&mut self.a) }; + } + } + + let b = B { + a: ManuallyDrop::new(A {}), + }; + assert!(unsafe { drop_a_called.load(Ordering::SeqCst) } == 0); + assert!(unsafe { drop_b_called.load(Ordering::SeqCst) } == 0); + let a = b.release(); + assert!(unsafe { drop_a_called.load(Ordering::SeqCst) } == 0); + assert!(unsafe { drop_b_called.load(Ordering::SeqCst) } == 1); + a.tell(); + drop(a); + assert!(unsafe { drop_a_called.load(Ordering::SeqCst) } == 1); + assert!(unsafe { drop_b_called.load(Ordering::SeqCst) } == 1); + } + + #[test] + fn test_blob() -> Result<()> { const N: usize = 1024; - let mut conn = IPCClient::default().unwrap(); - let client = Rc::get_mut(&mut conn).unwrap(); + let mut client = IPCClient::default()?; - let blob_writer = client.create_blob(N).unwrap(); + let mut blob_writer = client.create_blob(N)?; let blob_writer_id = blob_writer.id(); - assert_that(&blob_writer_id).is_greater_than(0); + assert_that!(blob_writer_id).is_greater_than(0); let slice_mut = blob_writer.as_mut_slice(); - for i in 0..N { - slice_mut[i] = i as u8; + for (idx, item) in slice_mut.iter_mut().enumerate() { + *item = idx as u8; } // test seal { - let object = blob_writer.seal(client).unwrap(); - let blob = downcast_object::(object).unwrap(); + let object = blob_writer.seal(&mut client)?; + let blob = downcast_object::(object)?; let blob_id = blob.id(); - assert_that(&blob_id).is_greater_than(0); - assert_that(&blob_id).is_equal_to(blob_writer_id); + assert_that!(blob_id).is_greater_than(0); + assert_that!(blob_id).is_equal_to(blob_writer_id); - let slice = blob.as_slice().unwrap(); - for i in 0..N { - assert_that(&slice[i]).is_equal_to(i as u8); + let slice = blob.as_slice()?; + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to(idx as u8); } } // test get blob { - let blob = client.get_blob(blob_writer_id).unwrap(); + let blob = client.get_blob(blob_writer_id)?; let blob_id = blob.id(); - assert_that(&blob_id).is_greater_than(0); - assert_that(&blob_id).is_equal_to(blob_writer_id); + assert_that!(blob_id).is_greater_than(0); + assert_that!(blob_id).is_equal_to(blob_writer_id); - let slice = blob.as_slice().unwrap(); - for i in 0..N { - assert_that(&slice[i]).is_equal_to(i as u8); + let slice = blob.as_slice()?; + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to(idx as u8); } } + + return Ok(()); } } diff --git a/rust/vineyard/src/client/ds/object.rs b/rust/vineyard/src/client/ds/object.rs index a0aa8b654b..225f9feba7 100644 --- a/rust/vineyard/src/client/ds/object.rs +++ b/rust/vineyard/src/client/ds/object.rs @@ -25,6 +25,11 @@ use crate::common::util::uuid::*; use super::super::IPCClient; use super::object_meta::ObjectMeta; +/// Re-export the gensym and ctor symbol to avoid introducing new +/// dependencies (gensym, ctor) in callers. +pub use ctor::ctor; +pub use gensym::gensym; + pub trait Create: TypeName { fn create() -> Box; } @@ -34,7 +39,7 @@ pub trait ObjectBase: Downcast { return Ok(()); } - fn seal(self: Self, client: &mut IPCClient) -> Result>; + fn seal(self, client: &mut IPCClient) -> Result>; } impl_downcast!(ObjectBase); @@ -44,8 +49,12 @@ pub trait ObjectMetaAttr { self.meta().get_id() } + /// Return a reference to the meta data of the object. fn meta(&self) -> &ObjectMeta; + /// Return a move of the inner metadata. + fn metadata(self) -> ObjectMeta; + fn nbytes(&self) -> usize { self.meta().get_nbytes() } @@ -60,7 +69,7 @@ pub trait ObjectMetaAttr { } pub trait Object: ObjectBase + ObjectMetaAttr { - fn as_any(self: &'_ Self) -> &'_ dyn Any + fn as_any(&'_ self) -> &'_ dyn Any where Self: Sized + 'static, { @@ -72,6 +81,12 @@ pub trait Object: ObjectBase + ObjectMetaAttr { impl_downcast!(Object); +pub fn downcast( + object: Box, +) -> std::result::Result, Box> { + return object.downcast::(); +} + pub fn downcast_object(object: Box) -> Result> { return object.downcast::().map_err(|_| { VineyardError::invalid(format!( @@ -90,6 +105,12 @@ pub fn downcast_object_ref(object: &dyn Object) -> Result< ))); } +pub fn downcast_rc( + object: Rc, +) -> std::result::Result, Rc> { + return object.downcast_rc::(); +} + pub fn downcast_object_rc(object: Rc) -> Result> { return object.downcast_rc::().map_err(|_| { VineyardError::invalid(format!( @@ -116,20 +137,47 @@ pub trait ObjectBuilder: ObjectBase { fn set_sealed(&mut self, sealed: bool); fn ensure_not_sealed(&mut self) -> Result<()> { - return vineyard_assert(!self.sealed(), "The builder has already been sealed".into()); + return vineyard_assert(!self.sealed(), "The builder has already been sealed"); } } impl_downcast!(ObjectBuilder); +#[doc(hidden)] +#[macro_export] +macro_rules! register_vineyard_type_impl { + ($gensym:ident, $type:ty) => { + #[$crate::client::ds::object::ctor] + static $gensym: Result = ObjectFactory::register::<$type>(); + }; +} + +#[macro_export] +macro_rules! register_vineyard_type { + ($type:ty) => { + $crate::client::ds::object::gensym! { $crate::client::ds::object::register_vineyard_type_impl!{ $type } } + } +} + +#[macro_export] +macro_rules! register_vineyard_types { + { + $( $type:ty; )+ + } => { + $( + $crate::client::ds::object::register_vineyard_type!($type); + )+ + }; +} + +#[macro_export] macro_rules! register_vineyard_object { // match when no type parameters are present ($t:tt) => { + $crate::client::ds::object::register_vineyard_type!($t); + impl Create for $t { fn create() -> Box { - lazy_static! { - static ref __BLOB_REGISTERED: Result = ObjectFactory::register::<$t>(); - } return Box::new(Self::default()); } } @@ -138,6 +186,10 @@ macro_rules! register_vineyard_object { fn meta(&self) -> &ObjectMeta { return &self.meta; } + + fn metadata(self) -> ObjectMeta { + return self.meta; + } } impl ObjectBase for $t { @@ -156,12 +208,6 @@ macro_rules! register_vineyard_object { { impl< $( $N $(: $b0 $(+$bs)* $(+$lts)* )? ),* > Create for $t< $( $N ),* > { fn create() -> Box { - // As generic template type parameter cannot be used in static methods, - // we skip the registration here util we found a better way. - // - // lazy_static! { - // static ref __BLOB_REGISTERED: Result = ObjectFactory::register::<$t< $( $N ),* >>(); - // } return Box::new(Self::default()); } } @@ -170,6 +216,10 @@ macro_rules! register_vineyard_object { fn meta(&self) -> &ObjectMeta { return &self.meta; } + + fn metadata(self) -> ObjectMeta { + return self.meta; + } } impl< $( $N $(: $b0 $(+$bs)* $(+$lts)* )? ),* > ObjectBase for $t< $( $N ),* > { @@ -185,4 +235,7 @@ macro_rules! register_vineyard_object { }; } -pub(crate) use register_vineyard_object; +pub use register_vineyard_object; +pub use register_vineyard_type; +pub use register_vineyard_type_impl; +pub use register_vineyard_types; diff --git a/rust/vineyard/src/client/ds/object_factory.rs b/rust/vineyard/src/client/ds/object_factory.rs index a6947d13e9..3cbaf02f26 100644 --- a/rust/vineyard/src/client/ds/object_factory.rs +++ b/rust/vineyard/src/client/ds/object_factory.rs @@ -13,9 +13,10 @@ // limitations under the License. use std::collections::HashMap; - use std::sync::{Arc, Mutex}; +use ctor::ctor; + use crate::common::util::status::*; use crate::common::util::typename::typename; @@ -26,6 +27,10 @@ pub struct ObjectFactory {} type ObjectInitializer = fn() -> Box; +#[ctor] +static KNOWN_TYPES: Arc>> = + Arc::new(Mutex::new(HashMap::new())); + impl ObjectFactory { pub fn register() -> Result { let typename = typename::(); @@ -63,13 +68,13 @@ impl ObjectFactory { return Ok(object); } - pub fn factory_ref() -> &'static Mutex> { - return &**ObjectFactory::get_known_types(); + pub fn factory_ref() -> &'static Mutex> { + return ObjectFactory::get_known_types(); } - fn get_known_types() -> &'static Arc>> { + fn get_known_types() -> &'static Arc>> { lazy_static! { - static ref KNOWN_TYPES: Arc>> = + static ref KNOWN_TYPES: Arc>> = Arc::new(Mutex::new(HashMap::new())); } return &KNOWN_TYPES; diff --git a/rust/vineyard/src/client/ds/object_meta.rs b/rust/vineyard/src/client/ds/object_meta.rs index da74c097ff..a4a17d2e27 100644 --- a/rust/vineyard/src/client/ds/object_meta.rs +++ b/rust/vineyard/src/client/ds/object_meta.rs @@ -16,7 +16,7 @@ use std::collections::{HashMap, HashSet}; use std::hash::Hash; use std::rc::Rc; -use arrow::buffer as arrow; +use arrow_buffer::Buffer; use serde::de::DeserializeOwned; use serde::Serialize; use serde_json::{json, Value}; @@ -61,13 +61,13 @@ impl ObjectMeta { return Ok(meta); } - pub fn from_typename(typename: &str) -> Self { + pub fn new_from_typename(typename: &str) -> Self { let mut meta = ObjectMeta::default(); meta.set_typename(typename); return meta; } - pub fn from_metadata(metadata: JSON) -> Result { + pub fn new_from_metadata(metadata: JSON) -> Result { let mut meta = ObjectMeta::default(); meta.set_meta_data(std::ptr::null_mut(), metadata)?; return Ok(meta); @@ -80,7 +80,7 @@ impl ObjectMeta { pub fn get_client(&self) -> Result<&mut IPCClient> { if self.client.is_null() { return Err(VineyardError::invalid( - "the associated client is not available".into(), + "the associated client is not available", )); } else { return Ok(unsafe { &mut *self.client }); @@ -251,7 +251,7 @@ impl ObjectMeta { return get_usize(&self.meta, key); } - pub fn add_string(&mut self, key: &str, value: &str) { + pub fn add_string>(&mut self, key: &str, value: T) { self.meta .insert(key.into(), serde_json::Value::String(value.into())); } @@ -260,7 +260,7 @@ impl ObjectMeta { return get_string(&self.meta, key); } - pub fn add_vector(&mut self, key: &str, value: &Vec) -> Result<()> { + pub fn add_vector(&mut self, key: &str, value: &[T]) -> Result<()> { self.add_value(key, serde_json::to_value(value)?); return Ok(()); } @@ -271,7 +271,11 @@ impl ObjectMeta { }); } - pub fn add_set(&mut self, key: &str, value: &HashSet) -> Result<()> { + pub fn add_set( + &mut self, + key: &str, + value: &HashSet, + ) -> Result<()> { self.add_value(key, serde_json::to_value(value)?); return Ok(()); } @@ -298,9 +302,10 @@ impl ObjectMeta { } pub fn add_member_meta(&mut self, key: &str, member: &ObjectMeta) -> Result<()> { - if let Some(_) = self + if self .meta .insert(key.into(), Value::Object(member.meta.clone())) + .is_some() { return Err(VineyardError::invalid(format!( "key '{}' already exists", @@ -348,11 +353,13 @@ impl ObjectMeta { } } - pub fn get_member(&self, name: &str) -> Result> { + pub fn get_member(&self, name: &str) -> Result> { + use crate::client::downcast_object; + let meta = self.get_member_meta(name)?; let mut object = T::create(); object.construct(meta)?; - return Ok(object); + return downcast_object::(object); } pub fn get_member_untyped(&self, name: &str) -> Result> { @@ -364,16 +371,13 @@ impl ObjectMeta { match self.meta.get(key) { Some(Value::Object(value)) => { let mut meta = ObjectMeta::default(); - meta.set_meta_data(self.client.clone(), value.clone())?; + meta.set_meta_data(self.client, value.clone())?; let buffers = meta.get_buffers_mut()?; for (id, buffer) in buffers.buffers_mut() { - match self.buffers.get(*id) { - Ok(Some(buf)) => { - let _ = buffer.insert(buf); - } - // for remote object, the blob may not present here - _ => {} + // for remote object, the blob may not present here + if let Ok(Some(buf)) = self.buffers.get(*id) { + let _ = buffer.insert(buf); } } if self.force_local { @@ -396,7 +400,28 @@ impl ObjectMeta { } } - pub fn get_buffer(&self, blob_id: ObjectID) -> Result>> { + pub fn get_member_id(&self, key: &str) -> Result { + match self.meta.get(key) { + Some(Value::Object(value)) => { + let id = object_id_from_string(get_string(value, "id")?)?; + return Ok(id); + } + Some(_) => { + return Err(VineyardError::invalid(format!( + "Invalid json value at key {}: not an object", + key + ))); + } + _ => { + return Err(VineyardError::invalid(format!( + "Invalid json value: key {} not found", + key + ))); + } + } + } + + pub fn get_buffer(&self, blob_id: ObjectID) -> Result> { return self.buffers.get(blob_id).map_err(|err| { VineyardError::invalid(format!( "The target blob {} doesn't exist: {}", @@ -406,7 +431,7 @@ impl ObjectMeta { }); } - pub fn set_buffer(&mut self, id: ObjectID, buffer: Option>) -> Result<()> { + pub fn set_buffer(&mut self, id: ObjectID, buffer: Option) -> Result<()> { match self.get_buffers_mut() { Ok(buffers) => { buffers.emplace_buffer(id, buffer)?; @@ -418,11 +443,7 @@ impl ObjectMeta { return Ok(()); } - pub fn set_or_add_buffer( - &mut self, - id: ObjectID, - buffer: Option>, - ) -> Result<()> { + pub fn set_or_add_buffer(&mut self, id: ObjectID, buffer: Option) -> Result<()> { match self.get_buffers_mut() { Ok(buffers) => { let _ = buffers.emplace(id); // ensure the id exists @@ -482,7 +503,7 @@ impl ObjectMeta { None => { warn!("Cannot extend buffers of a shared object meta."); return Err(VineyardError::invalid( - "Cannot manipulate buffers of a shared object meta.".into(), + "Cannot manipulate buffers of a shared object meta.", )); } }; diff --git a/rust/vineyard/src/client/ipc_client.rs b/rust/vineyard/src/client/ipc_client.rs index 0bbbb5f2e6..5f305c68d9 100644 --- a/rust/vineyard/src/client/ipc_client.rs +++ b/rust/vineyard/src/client/ipc_client.rs @@ -16,9 +16,10 @@ use std::collections::HashMap; use std::io; use std::net::Shutdown; use std::os::unix::net::UnixStream; -use std::rc::Rc; -use arrow::buffer as arrow; +use arrow_buffer::Buffer; +use parking_lot::ReentrantMutex; +use parking_lot::ReentrantMutexGuard; use crate::common::util::arrow::*; use crate::common::util::protocol::*; @@ -34,7 +35,7 @@ use super::io::*; mod memory { - use std::collections::{HashMap, HashSet}; + use std::collections::{hash_map, HashMap, HashSet}; use std::fs::File; use std::os::fd::{AsRawFd, FromRawFd}; use std::os::unix::net::UnixStream; @@ -110,14 +111,13 @@ mod memory { pub fn mmap( &mut self, - stream: &mut UnixStream, + stream: &UnixStream, fd: i32, map_size: usize, realign: bool, ) -> Result<*const u8> { - if !self.entries.contains_key(&fd) { - self.entries - .insert(fd, MmapEntry::new(recv_fd(stream)?, map_size, realign)); + if let hash_map::Entry::Vacant(entry) = self.entries.entry(fd) { + entry.insert(MmapEntry::new(recv_fd(stream)?, map_size, realign)); } match self.entries.get_mut(&fd) { Some(entry) => { @@ -134,14 +134,13 @@ mod memory { pub fn mmap_mut( &mut self, - stream: &mut UnixStream, + stream: &UnixStream, fd: i32, map_size: usize, realign: bool, ) -> Result<*mut u8> { - if !self.entries.contains_key(&fd) { - self.entries - .insert(fd, MmapEntry::new(recv_fd(stream)?, map_size, realign)); + if let hash_map::Entry::Vacant(entry) = self.entries.entry(fd) { + entry.insert(MmapEntry::new(recv_fd(stream)?, map_size, realign)); } match self.entries.get_mut(&fd) { Some(entry) => { @@ -185,11 +184,18 @@ pub struct IPCClient { pub support_rpc_compression: bool, stream: UnixStream, + lock: ReentrantMutex<()>, mmap: memory::MmapManager, } +impl Drop for IPCClient { + fn drop(&mut self) { + self.disconnect(); + } +} + impl Client for IPCClient { - fn disconnect(&mut self) -> () { + fn disconnect(&mut self) { if !self.connected() { return; } @@ -211,7 +217,7 @@ impl Client for IPCClient { #[cfg(feature = "nightly")] fn connected(&mut self) -> bool { - if let Err(_) = self.stream.set_nonblocking(true) { + if self.stream.set_nonblocking(true).is_err() { return false; } match self.stream.peek(&mut [0]) { @@ -230,6 +236,13 @@ impl Client for IPCClient { } } + fn ensure_connect(&mut self) -> Result> { + if !self.connected() { + return Err(VineyardError::io_error("client not connected")); + } + return Ok(self.lock.lock()); + } + fn do_read(&mut self) -> Result { return do_read(&mut self.stream); } @@ -243,7 +256,6 @@ impl Client for IPCClient { } fn create_metadata(&mut self, metadata: &ObjectMeta) -> Result { - self.ensure_connect()?; let mut meta = metadata.clone(); meta.set_instance_id(self.instance_id()); meta.set_transient(true); @@ -277,7 +289,7 @@ impl Client for IPCClient { return Ok(meta); } - fn get_metadata_batch(&mut self, ids: &Vec) -> Result> { + fn get_metadata_batch(&mut self, ids: &[ObjectID]) -> Result> { let data_vec = self.get_data_batch(ids)?; let mut metadatas = Vec::new(); let mut buffer_id_vec: Vec = Vec::new(); @@ -300,16 +312,17 @@ impl Client for IPCClient { } impl IPCClient { - pub fn default() -> Result> { + #[allow(clippy::should_implement_trait)] + pub fn default() -> Result { let default_ipc_socket = std::env::var(VINEYARD_IPC_SOCKET_KEY)?; return IPCClient::connect(&default_ipc_socket); } - pub fn connect(socket: &str) -> Result> { + pub fn connect(socket: &str) -> Result { let mut stream = connect_ipc_socket_retry(&socket)?; let message_out = write_register_request(RegisterRequest { - version: VERSION.to_string(), - store_type: "Normal".to_string(), + version: VERSION.into(), + store_type: "Normal".into(), session_id: 0, username: String::new(), password: String::new(), @@ -317,7 +330,7 @@ impl IPCClient { })?; do_write(&mut stream, &message_out)?; let reply = read_register_reply(&do_read(&mut stream)?)?; - return Ok(Rc::new(IPCClient { + return Ok(IPCClient { connected: true, ipc_socket: reply.ipc_socket, rpc_endpoint: reply.rpc_endpoint, @@ -325,12 +338,12 @@ impl IPCClient { server_version: reply.version, support_rpc_compression: reply.support_rpc_compression, stream: stream, + lock: ReentrantMutex::new(()), mmap: memory::MmapManager::new(), - })); + }); } pub fn create_blob(&mut self, size: usize) -> Result { - self.ensure_connect()?; let (id, buffer) = self.create_buffer(size)?; return Ok(BlobWriter::new(id, buffer)); } @@ -341,36 +354,40 @@ impl IPCClient { Some(buffer) => buffer.len(), None => 0, }; - let mut meta = ObjectMeta::from_typename(&typename::()); + let mut meta = ObjectMeta::new_from_typename(typename::()); meta.set_id(id); meta.set_instance_id(self.instance_id()); meta.set_or_add_buffer(id, buffer.clone())?; return Ok(Blob::new(meta, size, buffer)); } - fn create_buffer(&mut self, size: usize) -> Result<(ObjectID, Option>)> { - self.ensure_connect()?; + fn create_buffer(&mut self, size: usize) -> Result<(ObjectID, Option)> { + if size == 0 { + return Ok((empty_blob_id(), Some(arrow_buffer_null()))); + } + let _ = self.ensure_connect()?; let message_out = write_create_buffer_request(size)?; self.do_write(&message_out)?; let reply = read_create_buffer_reply(&self.do_read()?)?; if reply.payload.data_size == 0 { - return Ok((reply.id, None)); + return Ok((reply.id, Some(arrow_buffer_null()))); } let pointer = self.mmap.mmap_mut( - &mut self.stream, + &self.stream, reply.payload.store_fd, reply.payload.map_size, true, )?; - let buffer = to_buffer_offset(pointer, reply.payload.data_offset, reply.payload.data_size); - return Ok((reply.id, Some(Rc::new(buffer)))); + let buffer = + arrow_buffer_with_offset(pointer, reply.payload.data_offset, reply.payload.data_size); + return Ok((reply.id, Some(buffer))); } - fn get_buffer(&mut self, id: ObjectID, unsafe_: bool) -> Result>> { - let buffers = self.get_buffers(&vec![id], unsafe_)?; + fn get_buffer(&mut self, id: ObjectID, unsafe_: bool) -> Result> { + let buffers = self.get_buffers(&[id], unsafe_)?; return buffers .get(&id) - .map(|v| v.clone()) + .cloned() .ok_or(VineyardError::object_not_exists(format!( "buffer {} doesn't exist", id @@ -379,10 +396,10 @@ impl IPCClient { fn get_buffers( &mut self, - ids: &Vec, + ids: &[ObjectID], unsafe_: bool, - ) -> Result>>> { - self.ensure_connect()?; + ) -> Result>> { + let _ = self.ensure_connect()?; let message_out = write_get_buffers_request(&ids, unsafe_)?; self.do_write(&message_out)?; let reply = read_get_buffers_reply(&self.do_read()?)?; @@ -390,13 +407,14 @@ impl IPCClient { let mut buffers = HashMap::new(); for payload in reply.payloads { if payload.data_size == 0 { - buffers.insert(payload.object_id, None); + buffers.insert(payload.object_id, Some(arrow_buffer_null())); + continue; } - let pointer = - self.mmap - .mmap(&mut self.stream, payload.store_fd, payload.map_size, true)?; - let buffer = to_buffer_offset(pointer, payload.data_offset, payload.data_size); - buffers.insert(payload.object_id, Some(Rc::new(buffer))); + let pointer = self + .mmap + .mmap(&self.stream, payload.store_fd, payload.map_size, true)?; + let buffer = arrow_buffer_with_offset(pointer, payload.data_offset, payload.data_size); + buffers.insert(payload.object_id, Some(buffer)); } return Ok(buffers); } diff --git a/rust/vineyard/src/client/ipc_client_test.rs b/rust/vineyard/src/client/ipc_client_test.rs index 20d01c4120..113a2e17f0 100644 --- a/rust/vineyard/src/client/ipc_client_test.rs +++ b/rust/vineyard/src/client/ipc_client_test.rs @@ -14,14 +14,13 @@ #[cfg(test)] mod tests { - use std::rc::Rc; - use super::super::*; #[test] - fn test_ipc_connect() { - let mut conn = IPCClient::default().unwrap(); - let client = Rc::get_mut(&mut conn).unwrap(); + fn test_ipc_connect() -> Result<()> { + let mut client = IPCClient::default()?; assert!(client.connected()); + + return Ok(()); } } diff --git a/rust/vineyard/src/client/rpc_client.rs b/rust/vineyard/src/client/rpc_client.rs index f46af69f2b..e98a65b53b 100644 --- a/rust/vineyard/src/client/rpc_client.rs +++ b/rust/vineyard/src/client/rpc_client.rs @@ -14,7 +14,8 @@ use std::io; use std::net::{Shutdown, TcpStream}; -use std::rc::Rc; + +use parking_lot::{ReentrantMutex, ReentrantMutexGuard}; use crate::common::util::protocol::*; use crate::common::util::status::*; @@ -34,10 +35,17 @@ pub struct RPCClient { pub support_rpc_compression: bool, stream: TcpStream, + lock: ReentrantMutex<()>, +} + +impl Drop for RPCClient { + fn drop(&mut self) { + self.disconnect(); + } } impl Client for RPCClient { - fn disconnect(&mut self) -> () { + fn disconnect(&mut self) { if !self.connected() { return; } @@ -59,7 +67,7 @@ impl Client for RPCClient { #[cfg(feature = "nightly")] fn connected(&mut self) -> bool { - if let Err(_) = self.stream.set_nonblocking(true) { + if self.stream.set_nonblocking(true).is_err() { return false; } match self.stream.peek(&mut [0]) { @@ -78,6 +86,13 @@ impl Client for RPCClient { } } + fn ensure_connect(&mut self) -> Result> { + if !self.connected() { + return Err(VineyardError::io_error("client not connected")); + } + return Ok(self.lock.lock()); + } + fn do_read(&mut self) -> Result { return do_read(&mut self.stream); } @@ -91,7 +106,6 @@ impl Client for RPCClient { } fn create_metadata(&mut self, metadata: &ObjectMeta) -> Result { - self.ensure_connect()?; let mut meta = metadata.clone(); meta.set_instance_id(self.instance_id()); meta.set_transient(true); @@ -113,15 +127,15 @@ impl Client for RPCClient { fn get_metadata(&mut self, id: ObjectID) -> Result { let data = self.get_data(id, false, false)?; - let meta = ObjectMeta::from_metadata(data)?; + let meta = ObjectMeta::new_from_metadata(data)?; return Ok(meta); } - fn get_metadata_batch(&mut self, ids: &Vec) -> Result> { + fn get_metadata_batch(&mut self, ids: &[ObjectID]) -> Result> { let data_vec = self.get_data_batch(ids)?; let mut metadatas = Vec::new(); for data in data_vec { - let meta = ObjectMeta::from_metadata(data)?; + let meta = ObjectMeta::new_from_metadata(data)?; metadatas.push(meta); } return Ok(metadatas); @@ -129,7 +143,8 @@ impl Client for RPCClient { } impl RPCClient { - pub fn default() -> Result> { + #[allow(clippy::should_implement_trait)] + pub fn default() -> Result { let rpc_endpoint = std::env::var(VINEYARD_RPC_ENDPOINT_KEY)?; let (host, port) = match rpc_endpoint.rfind(':') { Some(idx) => ( @@ -141,11 +156,11 @@ impl RPCClient { return RPCClient::connect(host, port); } - pub fn connect(host: &str, port: u16) -> Result> { + pub fn connect(host: &str, port: u16) -> Result { let mut stream = connect_rpc_endpoint_retry(host, port)?; let message_out = write_register_request(RegisterRequest { - version: VERSION.to_string(), - store_type: "Normal".to_string(), + version: VERSION.into(), + store_type: "Normal".into(), session_id: 0, username: String::new(), password: String::new(), @@ -153,7 +168,7 @@ impl RPCClient { })?; do_write(&mut stream, &message_out)?; let reply = read_register_reply(&do_read(&mut stream)?)?; - return Ok(Rc::new(RPCClient { + return Ok(RPCClient { connected: true, ipc_socket: reply.ipc_socket, rpc_endpoint: reply.rpc_endpoint, @@ -161,6 +176,7 @@ impl RPCClient { server_version: reply.version, support_rpc_compression: reply.support_rpc_compression, stream: stream, - })); + lock: ReentrantMutex::new(()), + }); } } diff --git a/rust/vineyard/src/client/rpc_client_test.rs b/rust/vineyard/src/client/rpc_client_test.rs index dacd5f20b5..e2d699f03f 100644 --- a/rust/vineyard/src/client/rpc_client_test.rs +++ b/rust/vineyard/src/client/rpc_client_test.rs @@ -14,14 +14,13 @@ #[cfg(test)] mod tests { - use std::rc::Rc; - use super::super::*; #[test] - fn test_rpc_connect() { - let mut conn = RPCClient::default().unwrap(); - let client = Rc::get_mut(&mut conn).unwrap(); + fn test_rpc_connect() -> Result<()> { + let mut client = RPCClient::default()?; assert!(client.connected()); + + return Ok(()); } } diff --git a/rust/vineyard/src/common/util/arrow.rs b/rust/vineyard/src/common/util/arrow.rs index 2f8a86a161..af7ec953c0 100644 --- a/rust/vineyard/src/common/util/arrow.rs +++ b/rust/vineyard/src/common/util/arrow.rs @@ -15,6 +15,8 @@ use std::ptr::NonNull; use std::sync::Arc; +use arrow_buffer::{alloc, Buffer}; + /// An `arrow::alloc::Allocation` implementation to prevent the pointer /// been freed by `arrow::Buffer`. /// @@ -25,27 +27,27 @@ use std::sync::Arc; /// /// Instead, we cast pointers (`*const u8` and `*mut u8`) to the expected type. pub struct MmapAllocation {} -impl arrow::alloc::Allocation for MmapAllocation {} +impl alloc::Allocation for MmapAllocation {} lazy_static! { static ref MMAP_ALLOCATION: Arc = Arc::new(MmapAllocation {}); } -pub fn to_buffer(pointer: *const u8, len: usize) -> arrow::buffer::Buffer { - return to_buffer_mut(pointer as *mut u8, len); +pub fn arrow_buffer(pointer: *const u8, len: usize) -> Buffer { + return arrow_buffer_mut(pointer as *mut u8, len); } -pub fn to_buffer_offset(pointer: *const u8, offset: isize, len: usize) -> arrow::buffer::Buffer { - return to_buffer_offset_mut(pointer as *mut u8, offset, len); +pub fn arrow_buffer_with_offset(pointer: *const u8, offset: isize, len: usize) -> Buffer { + return arrow_buffer_with_offset_mut(pointer as *mut u8, offset, len); } -pub fn to_buffer_mut(pointer: *mut u8, len: usize) -> arrow::buffer::Buffer { - return to_buffer_offset_mut(pointer as *mut u8, 0, len); +pub fn arrow_buffer_mut(pointer: *mut u8, len: usize) -> Buffer { + return arrow_buffer_with_offset_mut(pointer as *mut u8, 0, len); } -pub fn to_buffer_offset_mut(pointer: *mut u8, offset: isize, len: usize) -> arrow::buffer::Buffer { +pub fn arrow_buffer_with_offset_mut(pointer: *mut u8, offset: isize, len: usize) -> Buffer { return unsafe { - arrow::buffer::Buffer::from_custom_allocation( + Buffer::from_custom_allocation( NonNull::new_unchecked(pointer.offset(offset) as *mut u8), len, MMAP_ALLOCATION.clone(), @@ -53,6 +55,6 @@ pub fn to_buffer_offset_mut(pointer: *mut u8, offset: isize, len: usize) -> arro }; } -pub fn to_buffer_null() -> arrow::buffer::Buffer { - return to_buffer_mut(std::ptr::null_mut(), 0); +pub fn arrow_buffer_null() -> Buffer { + return arrow_buffer_mut(std::ptr::null_mut(), 0); } diff --git a/rust/vineyard/src/common/util/json.rs b/rust/vineyard/src/common/util/json.rs index 0fc2ff29ee..cc4af9bd5b 100644 --- a/rust/vineyard/src/common/util/json.rs +++ b/rust/vineyard/src/common/util/json.rs @@ -22,7 +22,7 @@ pub type JSONResult = serde_json::Result; pub fn parse_json_object<'a>(root: &'a Value) -> Result<&'a JSON> { return root.as_object().ok_or(VineyardError::io_error( - "incoming message is not a JSON object".into(), + "incoming message is not a JSON object", )); } diff --git a/rust/vineyard/src/common/util/protocol.rs b/rust/vineyard/src/common/util/protocol.rs index 3254e746c3..4d9eee9e42 100644 --- a/rust/vineyard/src/common/util/protocol.rs +++ b/rust/vineyard/src/common/util/protocol.rs @@ -182,7 +182,7 @@ fn check_ipc_error<'a>(root: &'a JSON, reply_type: &str) -> Result<()> { format!("unexpected reply type: '{}'", message_type), ); } else { - return vineyard_assert(false, "no 'type' field in the response".into()); + return vineyard_assert(false, "no 'type' field in the response"); } } @@ -288,7 +288,7 @@ pub fn read_create_disk_buffer_reply(message: &str) -> Result fd: get_int::(root, "fd")? .to_i32() .ok_or(VineyardError::io_error( - "fd received from server must be a 32-bit integter".into(), + "fd received from server must be a 32-bit integer", ))?, }); } @@ -317,11 +317,11 @@ pub fn read_create_gpu_buffer_reply(message: &str) -> Result>>()?; @@ -354,7 +354,7 @@ pub struct GetBuffersReply { pub compress: bool, } -pub fn write_get_buffers_request(ids: &Vec, unsafe_: bool) -> JSONResult { +pub fn write_get_buffers_request(ids: &[ObjectID], unsafe_: bool) -> JSONResult { return serde_json::to_string(&json!({ "type": Command::GET_BUFFERS_REQUEST, "ids": ids, @@ -368,26 +368,18 @@ pub fn read_get_buffers_reply(message: &str) -> Result { check_ipc_error(&root, Command::GET_BUFFERS_REPLY)?; let mut reply = GetBuffersReply::default(); - let mut parsed = false; - - match root["payloads"] { - Value::Array(ref payloads) => { - parsed = true; - for payload in payloads { - reply - .payloads - .push(Payload::from_json( - payload.as_object().ok_or(VineyardError::io_error( - "invalid get_buffers reply: payload in message is not a JSON object" - .into(), - ))?, - )?); - } - } - _ => {} - } - if !parsed { + if let Some(Value::Array(ref payloads)) = root.get("payloads") { + for payload in payloads { + reply + .payloads + .push(Payload::from_json(payload.as_object().ok_or( + VineyardError::io_error( + "invalid get_buffers reply: payload in message is not a JSON object", + ), + )?)?); + } + } else { let num: i64 = get_int(root, "num")?; for i in 0..num { match root[&i.to_string()] { @@ -396,27 +388,25 @@ pub fn read_get_buffers_reply(message: &str) -> Result { } _ => { return Err(VineyardError::io_error( - "invalid get_buffers reply: payload in message is not a JSON object" - .to_string(), + "invalid get_buffers reply: payload in message is not a JSON object", )); } } } } - if root.contains_key("fds") && root["fds"].is_array() { - for fd in root["fds"].as_array().unwrap() { + if let Some(Value::Array(ref fds)) = root.get("fds") { + for fd in fds { reply.fds.push( fd.as_i64() - .ok_or(VineyardError::io_error("fd is not an integer".into()))? + .ok_or(VineyardError::io_error("fd is not an integer"))? .to_i32() .ok_or(VineyardError::io_error( - "fd received from server must be a 32-bit integter".into(), + "fd received from server must be a 32-bit integer", ))?, ); } } - return Ok(reply); } @@ -447,7 +437,7 @@ pub fn read_create_remote_buffer_reply(message: &str) -> Result) -> JSONResult { +pub fn write_get_remote_buffers_request(ids: &[ObjectID]) -> JSONResult { return serde_json::to_string(&json!({ "type": Command::GET_REMOTE_BUFFERS_REQUEST, "ids": ids, @@ -458,7 +448,7 @@ pub fn read_get_remote_buffers_reply(message: &str) -> Result { return read_get_buffers_reply(message); } -pub fn write_increase_reference_count_request(id: &Vec) -> JSONResult { +pub fn write_increase_reference_count_request(id: &[ObjectID]) -> JSONResult { return serde_json::to_string(&json!({ "type": Command::INCREASE_REFERENCE_COUNT_REQUEST, "ids": id, @@ -532,7 +522,7 @@ pub fn read_get_data_reply(message: &str) -> Result { Value::Array(ref content) => { if content.len() != 1 { return Err(VineyardError::io_error( - "failed to read get_data reply: content array's length is not 1".into(), + "failed to read get_data reply: content array's length is not 1", )); } return Ok(parse_json_object(&content[0])?.clone()); @@ -540,7 +530,7 @@ pub fn read_get_data_reply(message: &str) -> Result { Value::Object(ref content) => match content.iter().next() { None => { return Err(VineyardError::io_error( - "failed to read get_data reply: content dict's length is not 1".into(), + "failed to read get_data reply: content dict's length is not 1", )); } Some((_, meta)) => { @@ -549,14 +539,14 @@ pub fn read_get_data_reply(message: &str) -> Result { }, _ => { return Err(VineyardError::io_error( - "failed to read get_data reply: content is not an array or a dict".into(), + "failed to read get_data reply: content is not an array or a dict", )); } } } pub fn write_get_data_batch_request( - ids: &Vec, + ids: &[ObjectID], sync_remote: bool, wait: bool, ) -> JSONResult { @@ -597,7 +587,7 @@ pub fn read_get_data_batch_reply(message: &str) -> Result { return Err(VineyardError::io_error( - "failed to read get_data reply: content is not an array or a dict".into(), + "failed to read get_data reply: content is not an array or a dict", )); } } @@ -627,7 +617,7 @@ pub fn read_list_data_reply(message: &str) -> Result> { } _ => { return Err(VineyardError::io_error( - "failed to read list_data reply: data is not an array".into(), + "failed to read list_data reply: data is not an array", )); } } @@ -649,7 +639,7 @@ pub fn write_delete_data_request( } pub fn write_delete_data_batch_request( - ids: &Vec, + ids: &[ObjectID], force: bool, deep: bool, fastpath: bool, @@ -716,11 +706,7 @@ pub fn read_if_persist_reply(message: &str) -> Result { return Ok(get_bool_or(root, "persist", false)); } -pub fn write_label_request( - id: ObjectID, - keys: &Vec, - values: &Vec, -) -> JSONResult { +pub fn write_label_request(id: ObjectID, keys: &[String], values: &[String]) -> JSONResult { return serde_json::to_string(&json!({ "type": Command::LABEL_REQUEST, "id": id, @@ -797,9 +783,10 @@ pub fn read_list_name_reply(message: &str) -> Result> let root = parse_json_object(&root)?; check_ipc_error(&root, Command::LIST_NAME_REPLY)?; - let names = parse_json_object(root.get("names").ok_or(VineyardError::io_error( - "message does not contain names".into(), - ))?)?; + let names = parse_json_object( + root.get("names") + .ok_or(VineyardError::io_error("message does not contain names"))?, + )?; let mut result = HashMap::new(); for (name, value) in names { match value.as_u64() { @@ -827,7 +814,7 @@ pub fn read_drop_name_reply(message: &str) -> Result<()> { return Ok(()); } -pub fn write_evict_request(ids: &Vec) -> JSONResult { +pub fn write_evict_request(ids: &[ObjectID]) -> JSONResult { return serde_json::to_string(&json!({ "type": Command::EVICT_REQUEST, "ids": ids, @@ -842,7 +829,7 @@ pub fn read_evict_reply(message: &str) -> Result<()> { return Ok(()); } -pub fn write_load_request(ids: &Vec, pin: bool) -> JSONResult { +pub fn write_load_request(ids: &[ObjectID], pin: bool) -> JSONResult { return serde_json::to_string(&json!({ "type": Command::LOAD_REQUEST, "ids": ids, @@ -858,7 +845,7 @@ pub fn read_load_reply(message: &str) -> Result<()> { return Ok(()); } -pub fn write_unpin_request(ids: &Vec) -> JSONResult { +pub fn write_unpin_request(ids: &[ObjectID]) -> JSONResult { return serde_json::to_string(&json!({ "type": Command::UNPIN_REQUEST, "ids": ids, diff --git a/rust/vineyard/src/common/util/status.rs b/rust/vineyard/src/common/util/status.rs index dbde47188a..ec27695f4b 100644 --- a/rust/vineyard/src/common/util/status.rs +++ b/rust/vineyard/src/common/util/status.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::backtrace::{Backtrace, BacktraceStatus}; use std::env::VarError as EnvVarError; use std::io::Error as IOError; use std::num::{ParseFloatError, ParseIntError, TryFromIntError}; @@ -19,12 +20,12 @@ use std::sync::PoisonError; use num_derive::{FromPrimitive, ToPrimitive}; use serde_json::Error as JSONError; -use thiserror::Error; use super::uuid::ObjectID; -#[derive(Debug, Clone, PartialEq, Eq, FromPrimitive, ToPrimitive)] +#[derive(Debug, Clone, Default, PartialEq, Eq, FromPrimitive, ToPrimitive)] pub enum StatusCode { + #[default] OK = 0, Invalid = 1, KeyError = 2, @@ -71,512 +72,420 @@ pub enum StatusCode { UnknownError = 255, } -#[derive(Error, Debug, Clone)] pub struct VineyardError { pub code: StatusCode, pub message: String, + pub backtrace: Backtrace, +} + +impl std::error::Error for VineyardError {} + +impl std::fmt::Debug for VineyardError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.backtrace.status() == BacktraceStatus::Captured { + write!( + f, + "{:?}: {}\nBacktrace: {:?}", + self.code, self.message, self.backtrace + ) + } else { + write!(f, "{:?}: {}", self.code, self.message) + } + } +} + +impl std::fmt::Display for VineyardError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Vineyard error {:?}: {}", self.code, self.message) + } } impl From for VineyardError { fn from(error: IOError) -> Self { - VineyardError { - code: StatusCode::IOError, - message: format!("internal io error: {}", error), - } + VineyardError::new(StatusCode::IOError, format!("internal io error: {}", error)) } } impl From for VineyardError { fn from(error: EnvVarError) -> Self { - VineyardError { - code: StatusCode::IOError, - message: format!("env var error: {}", error), - } + VineyardError::new(StatusCode::IOError, format!("env var error: {}", error)) } } impl From for VineyardError { fn from(error: ParseIntError) -> Self { - VineyardError { - code: StatusCode::IOError, - message: format!("parse int error: {}", error), - } + VineyardError::new(StatusCode::IOError, format!("parse int error: {}", error)) } } impl From for VineyardError { fn from(error: ParseFloatError) -> Self { - VineyardError { - code: StatusCode::IOError, - message: format!("parse float error: {}", error), - } + VineyardError::new(StatusCode::IOError, format!("parse float error: {}", error)) } } impl From for VineyardError { fn from(error: TryFromIntError) -> Self { - VineyardError { - code: StatusCode::IOError, - message: format!("try from int error: {}", error), - } + VineyardError::new( + StatusCode::IOError, + format!("try from int error: {}", error), + ) } } impl From> for VineyardError { fn from(error: PoisonError) -> Self { - VineyardError { - code: StatusCode::Invalid, - message: format!("lock poison error: {}", error), - } + VineyardError::new(StatusCode::Invalid, format!("lock poison error: {}", error)) } } impl From for VineyardError { fn from(error: JSONError) -> Self { - VineyardError { - code: StatusCode::MetaTreeInvalid, - message: error.to_string(), - } - } -} - -impl std::fmt::Display for VineyardError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Vineyard error {:?}: {}", self.code, self.message) - } -} - -impl std::default::Default for StatusCode { - fn default() -> Self { - StatusCode::OK + VineyardError::new(StatusCode::IOError, format!("json error: {}", error)) } } pub type Result = std::result::Result; impl VineyardError { - pub fn new(code: StatusCode, message: String) -> Self { - VineyardError { code, message } + pub fn new>(code: StatusCode, message: T) -> Self { + VineyardError { + code, + message: message.into(), + backtrace: Backtrace::capture(), + } } - pub fn wrap(self: &Self, message: String) -> Self { + pub fn wrap>(self, message: T) -> Self { if self.ok() { - return self.clone(); - } - VineyardError { - code: self.code.clone(), - message: format!("{}: {}", self.message, message), + return self; } + VineyardError::new(self.code, format!("{}: {}", self.message, message.into())) } - pub fn invalid(message: String) -> Self { - VineyardError { - code: StatusCode::Invalid, - message: message, - } + pub fn invalid>(message: T) -> Self { + VineyardError::new(StatusCode::Invalid, message.into()) } - pub fn key_error(message: String) -> Self { - VineyardError { - code: StatusCode::KeyError, - message: message, - } + pub fn key_error>(message: T) -> Self { + VineyardError::new(StatusCode::KeyError, message) } - pub fn type_error(message: String) -> Self { - VineyardError { - code: StatusCode::TypeError, - message: message, - } + pub fn type_error>(message: T) -> Self { + VineyardError::new(StatusCode::TypeError, message) } - pub fn io_error(message: String) -> Self { - VineyardError { - code: StatusCode::IOError, - message: message, - } + pub fn io_error>(message: T) -> Self { + VineyardError::new(StatusCode::IOError, message) } - pub fn end_of_file(message: String) -> Self { - VineyardError { - code: StatusCode::EndOfFile, - message: message, - } + pub fn end_of_file>(message: T) -> Self { + VineyardError::new(StatusCode::EndOfFile, message) } - pub fn not_implemented(message: String) -> Self { - VineyardError { - code: StatusCode::NotImplemented, - message: message, - } + pub fn not_implemented>(message: T) -> Self { + VineyardError::new(StatusCode::NotImplemented, message) } - pub fn assertion_failed(message: String) -> Self { - VineyardError { - code: StatusCode::AssertionFailed, - message: message, - } + pub fn assertion_failed>(message: T) -> Self { + VineyardError::new(StatusCode::AssertionFailed, message) } - pub fn user_input_error(message: String) -> Self { - VineyardError { - code: StatusCode::UserInputError, - message: message, - } + pub fn user_input_error>(message: T) -> Self { + VineyardError::new(StatusCode::UserInputError, message) } - pub fn object_exists(message: String) -> Self { - VineyardError { - code: StatusCode::ObjectExists, - message: message, - } + pub fn object_exists>(message: T) -> Self { + VineyardError::new(StatusCode::ObjectExists, message) } - pub fn object_not_exists(message: String) -> Self { - VineyardError { - code: StatusCode::ObjectNotExists, - message: message, - } + pub fn object_not_exists>(message: T) -> Self { + VineyardError::new(StatusCode::ObjectNotExists, message) } - pub fn object_sealed(message: String) -> Self { - VineyardError { - code: StatusCode::ObjectSealed, - message: message, - } + pub fn object_sealed>(message: T) -> Self { + VineyardError::new(StatusCode::ObjectSealed, message) } - pub fn object_not_sealed(message: String) -> Self { - VineyardError { - code: StatusCode::ObjectNotSealed, - message: message, - } + pub fn object_not_sealed>(message: T) -> Self { + VineyardError::new(StatusCode::ObjectNotSealed, message) } - pub fn object_is_blob(message: String) -> Self { - VineyardError { - code: StatusCode::ObjectIsBlob, - message: message, - } + pub fn object_is_blob>(message: T) -> Self { + VineyardError::new(StatusCode::ObjectIsBlob, message) } - pub fn object_type_error(expected: String, actual: String) -> Self { - VineyardError { - code: StatusCode::ObjectTypeError, - message: format!("expect typename '{}', but got '{}'", expected, actual), - } + pub fn object_type_error, V: Into>(expected: U, actual: V) -> Self { + VineyardError::new( + StatusCode::ObjectTypeError, + format!( + "expect typename '{}', but got '{}'", + expected.into(), + actual.into() + ), + ) } pub fn object_spilled(object_id: ObjectID) -> Self { - VineyardError { - code: StatusCode::ObjectSpilled, - message: format!("object '{}' has already been spilled", object_id), - } + VineyardError::new( + StatusCode::ObjectSpilled, + format!("object '{}' has already been spilled", object_id), + ) } pub fn object_not_spilled(object_id: ObjectID) -> Self { - VineyardError { - code: StatusCode::ObjectNotSpilled, - message: format!("object '{}' hasn't been spilled yet", object_id), - } + VineyardError::new( + StatusCode::ObjectNotSpilled, + format!("object '{}' hasn't been spilled yet", object_id), + ) } - pub fn meta_tree_invalid(message: String) -> Self { - VineyardError { - code: StatusCode::MetaTreeInvalid, - message: message, - } + pub fn meta_tree_invalid>(message: T) -> Self { + VineyardError::new(StatusCode::MetaTreeInvalid, message) } - pub fn meta_tree_type_invalid(message: String) -> Self { - VineyardError { - code: StatusCode::MetaTreeTypeInvalid, - message: message, - } + pub fn meta_tree_type_invalid>(message: T) -> Self { + VineyardError::new(StatusCode::MetaTreeTypeInvalid, message) } - pub fn meta_tree_type_not_exists(message: String) -> Self { - VineyardError { - code: StatusCode::MetaTreeTypeNotExists, - message: message, - } + pub fn meta_tree_type_not_exists>(message: T) -> Self { + VineyardError::new(StatusCode::MetaTreeTypeNotExists, message) } - pub fn meta_tree_name_invalid(message: String) -> Self { - VineyardError { - code: StatusCode::MetaTreeNameInvalid, - message: message, - } + pub fn meta_tree_name_invalid>(message: T) -> Self { + VineyardError::new(StatusCode::MetaTreeNameInvalid, message) } - pub fn meta_tree_name_not_exists(message: String) -> Self { - VineyardError { - code: StatusCode::MetaTreeNameNotExists, - message: message, - } + pub fn meta_tree_name_not_exists>(message: T) -> Self { + VineyardError::new(StatusCode::MetaTreeNameNotExists, message) } - pub fn meta_tree_link_invalid(message: String) -> Self { - VineyardError { - code: StatusCode::MetaTreeLinkInvalid, - message: message, - } + pub fn meta_tree_link_invalid>(message: T) -> Self { + VineyardError::new(StatusCode::MetaTreeLinkInvalid, message) } - pub fn meta_tree_subtree_not_exists(message: String) -> Self { - VineyardError { - code: StatusCode::MetaTreeSubtreeNotExists, - message: message, - } + pub fn meta_tree_subtree_not_exists>(message: T) -> Self { + VineyardError::new(StatusCode::MetaTreeSubtreeNotExists, message) } - pub fn vineyard_server_not_ready(message: String) -> Self { - VineyardError { - code: StatusCode::VineyardServerNotReady, - message: message, - } + pub fn vineyard_server_not_ready>(message: T) -> Self { + VineyardError::new(StatusCode::VineyardServerNotReady, message) } - pub fn connection_failed(message: String) -> Self { - VineyardError { - code: StatusCode::ConnectionFailed, - message: message, - } + pub fn arrow_error>(message: T) -> Self { + VineyardError::new(StatusCode::ArrowError, message) } - pub fn etcd_error(message: String) -> Self { - VineyardError { - code: StatusCode::EtcdError, - message: message, - } + pub fn connection_failed>(message: T) -> Self { + VineyardError::new(StatusCode::ConnectionFailed, message) } - pub fn redis_error(message: String) -> Self { - VineyardError { - code: StatusCode::RedisError, - message: message, - } + pub fn etcd_error>(message: T) -> Self { + VineyardError::new(StatusCode::EtcdError, message) } - pub fn already_stopped(message: String) -> Self { - VineyardError { - code: StatusCode::AlreadyStopped, - message: message, - } + pub fn redis_error>(message: T) -> Self { + VineyardError::new(StatusCode::RedisError, message) } - pub fn not_enough_memory(message: String) -> Self { - VineyardError { - code: StatusCode::NotEnoughMemory, - message: message, - } + pub fn already_stopped>(message: T) -> Self { + VineyardError::new(StatusCode::AlreadyStopped, message) } - pub fn stream_drained(message: String) -> Self { - VineyardError { - code: StatusCode::StreamDrained, - message: message, - } + pub fn not_enough_memory>(message: T) -> Self { + VineyardError::new(StatusCode::NotEnoughMemory, message) } - pub fn stream_failed(message: String) -> Self { - VineyardError { - code: StatusCode::StreamFailed, - message: message, - } + pub fn stream_drained>(message: T) -> Self { + VineyardError::new(StatusCode::StreamDrained, message) } - pub fn invalid_stream_state(message: String) -> Self { - VineyardError { - code: StatusCode::InvalidStreamState, - message: message, - } + pub fn stream_failed>(message: T) -> Self { + VineyardError::new(StatusCode::StreamFailed, message) } - pub fn stream_opened(message: String) -> Self { - VineyardError { - code: StatusCode::StreamOpened, - message: message, - } + pub fn invalid_stream_state>(message: T) -> Self { + VineyardError::new(StatusCode::InvalidStreamState, message) } - pub fn global_object_invalid(message: String) -> Self { - VineyardError { - code: StatusCode::GlobalObjectInvalid, - message: message, - } + pub fn stream_opened>(message: T) -> Self { + VineyardError::new(StatusCode::StreamOpened, message) } - pub fn unknown_error(message: String) -> Self { - VineyardError { - code: StatusCode::UnknownError, - message: message, - } + pub fn global_object_invalid>(message: T) -> Self { + VineyardError::new(StatusCode::GlobalObjectInvalid, message) + } + + pub fn unknown_error>(message: T) -> Self { + VineyardError::new(StatusCode::UnknownError, message) } - pub fn ok(self: &Self) -> bool { + pub fn ok(&self) -> bool { return self.code == StatusCode::OK; } - pub fn is_invalid(self: &Self) -> bool { + pub fn is_invalid(&self) -> bool { return self.code == StatusCode::Invalid; } - pub fn is_key_error(self: &Self) -> bool { + pub fn is_key_error(&self) -> bool { return self.code == StatusCode::KeyError; } - pub fn is_type_error(self: &Self) -> bool { + pub fn is_type_error(&self) -> bool { return self.code == StatusCode::TypeError; } - pub fn is_io_error(self: &Self) -> bool { + pub fn is_io_error(&self) -> bool { return self.code == StatusCode::IOError; } - pub fn is_end_of_file(self: &Self) -> bool { + pub fn is_end_of_file(&self) -> bool { return self.code == StatusCode::EndOfFile; } - pub fn is_not_implemented(self: &Self) -> bool { + pub fn is_not_implemented(&self) -> bool { return self.code == StatusCode::NotImplemented; } - pub fn is_assertion_failed(self: &Self) -> bool { + pub fn is_assertion_failed(&self) -> bool { return self.code == StatusCode::AssertionFailed; } - pub fn is_user_input_error(self: &Self) -> bool { + pub fn is_user_input_error(&self) -> bool { return self.code == StatusCode::UserInputError; } - pub fn is_object_exists(self: &Self) -> bool { + pub fn is_object_exists(&self) -> bool { return self.code == StatusCode::ObjectExists; } - pub fn is_object_not_exists(self: &Self) -> bool { + pub fn is_object_not_exists(&self) -> bool { return self.code == StatusCode::ObjectNotExists; } - pub fn is_object_sealed(self: &Self) -> bool { + pub fn is_object_sealed(&self) -> bool { return self.code == StatusCode::ObjectSealed; } - pub fn is_object_not_sealed(self: &Self) -> bool { + pub fn is_object_not_sealed(&self) -> bool { return self.code == StatusCode::ObjectNotSealed; } - pub fn is_object_is_blob(self: &Self) -> bool { + pub fn is_object_is_blob(&self) -> bool { return self.code == StatusCode::ObjectIsBlob; } - pub fn is_object_type_error(self: &Self) -> bool { + pub fn is_object_type_error(&self) -> bool { return self.code == StatusCode::ObjectTypeError; } - pub fn is_object_spilled(self: &Self) -> bool { + pub fn is_object_spilled(&self) -> bool { return self.code == StatusCode::ObjectSpilled; } - pub fn is_object_not_spilled(self: &Self) -> bool { + pub fn is_object_not_spilled(&self) -> bool { return self.code == StatusCode::ObjectNotSpilled; } - pub fn is_meta_tree_invalid(self: &Self) -> bool { + pub fn is_meta_tree_invalid(&self) -> bool { return self.code == StatusCode::MetaTreeInvalid || self.code == StatusCode::MetaTreeNameInvalid || self.code == StatusCode::MetaTreeTypeInvalid || self.code == StatusCode::MetaTreeLinkInvalid; } - pub fn is_meta_tree_element_not_exists(self: &Self) -> bool { + pub fn is_meta_tree_element_not_exists(&self) -> bool { return self.code == StatusCode::MetaTreeNameNotExists || self.code == StatusCode::MetaTreeTypeNotExists || self.code == StatusCode::MetaTreeSubtreeNotExists; } - pub fn is_vineyard_server_not_ready(self: &Self) -> bool { + pub fn is_vineyard_server_not_ready(&self) -> bool { return self.code == StatusCode::VineyardServerNotReady; } - pub fn is_arrow_error(self: &Self) -> bool { + pub fn is_arrow_error(&self) -> bool { return self.code == StatusCode::ArrowError; } - pub fn is_connection_failed(self: &Self) -> bool { + pub fn is_connection_failed(&self) -> bool { return self.code == StatusCode::ConnectionFailed; } - pub fn is_connection_error(self: &Self) -> bool { + pub fn is_connection_error(&self) -> bool { return self.code == StatusCode::ConnectionError; } - pub fn is_etcd_error(self: &Self) -> bool { + pub fn is_etcd_error(&self) -> bool { return self.code == StatusCode::EtcdError; } - pub fn is_already_stopped(self: &Self) -> bool { + pub fn is_already_stopped(&self) -> bool { return self.code == StatusCode::AlreadyStopped; } - pub fn is_not_enough_memory(self: &Self) -> bool { + pub fn is_not_enough_memory(&self) -> bool { return self.code == StatusCode::NotEnoughMemory; } - pub fn is_stream_drained(self: &Self) -> bool { + pub fn is_stream_drained(&self) -> bool { return self.code == StatusCode::StreamDrained; } - pub fn is_stream_failed(self: &Self) -> bool { + pub fn is_stream_failed(&self) -> bool { return self.code == StatusCode::StreamFailed; } - pub fn is_invalid_stream_state(self: &Self) -> bool { + pub fn is_invalid_stream_state(&self) -> bool { return self.code == StatusCode::InvalidStreamState; } - pub fn is_stream_opened(self: &Self) -> bool { + pub fn is_stream_opened(&self) -> bool { return self.code == StatusCode::StreamOpened; } - pub fn is_global_object_invalid(self: &Self) -> bool { + pub fn is_global_object_invalid(&self) -> bool { return self.code == StatusCode::GlobalObjectInvalid; } - pub fn is_unknown_error(self: &Self) -> bool { + pub fn is_unknown_error(&self) -> bool { return self.code == StatusCode::UnknownError; } - pub fn code(self: &Self) -> &StatusCode { + pub fn code(&self) -> &StatusCode { return &self.code; } - pub fn message(self: &Self) -> &String { + pub fn message(&self) -> &String { return &self.message; } } -pub fn vineyard_check_ok(status: Result) { - if let Err(_) = status { - panic!("Error occurs.") +pub fn vineyard_check_ok(status: Result) { + if status.is_err() { + panic!("Error occurs: {:?}.", status) } } -pub fn vineyard_assert(condition: bool, message: String) -> Result<()> { +pub fn vineyard_assert>(condition: bool, message: T) -> Result<()> { if !condition { return Err(VineyardError::assertion_failed(format!( "assertion failed: {}", - message + message.into() ))); } return Ok(()); } -pub fn vineyard_assert_typename(expect: &str, actual: &str) -> Result<()> { +pub fn vineyard_assert_typename + PartialEq, V: Into>( + expect: U, + actual: V, +) -> Result<()> { if expect != actual { return Err(VineyardError::object_type_error( - expect.to_string(), - actual.to_string(), + expect.into(), + actual.into(), )); } return Ok(()); diff --git a/rust/vineyard/src/common/util/typename.rs b/rust/vineyard/src/common/util/typename.rs index 9e25f0d247..59719a35b8 100644 --- a/rust/vineyard/src/common/util/typename.rs +++ b/rust/vineyard/src/common/util/typename.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use static_str_ops::*; + /// A trait to generate specialized type name for given Rust type. /// /// Note that the `typename()` method doesn't return `&'static str` for @@ -21,27 +23,31 @@ /// /// We temporarily use `String` as the return type and leave it as a TODO. pub trait TypeName { - fn typename() -> String { - return std::any::type_name::().into(); + fn typename() -> &'static str + where + Self: Sized, + { + return staticize_once!(std::any::type_name::()); } } /// Generate typename for given type in Rust. -pub fn typename() -> String { +pub fn typename() -> &'static str { return T::typename(); } +#[macro_export] macro_rules! impl_typename { ($t:ty, $name:expr) => { impl TypeName for $t { - fn typename() -> String { - return $name.into(); + fn typename() -> &'static str { + return $name; } } }; } -pub(crate) use impl_typename; +pub use impl_typename; impl_typename!(i8, "int8"); impl_typename!(u8, "uint8"); diff --git a/rust/vineyard/src/common/util/uuid.rs b/rust/vineyard/src/common/util/uuid.rs index af4dc54c5f..c4a3a2295e 100644 --- a/rust/vineyard/src/common/util/uuid.rs +++ b/rust/vineyard/src/common/util/uuid.rs @@ -40,7 +40,10 @@ pub fn is_blob(id: ObjectID) -> bool { pub fn object_id_from_string(s: &str) -> Result { if s.len() < 2 { - return Err(VineyardError::invalid("invalid object id".to_string())); + return Err(VineyardError::invalid(format!( + "invalid object id: '{}'", + s + ))); } return Ok(ObjectID::from_str_radix(&s[1..], 16)?); } diff --git a/rust/vineyard/src/ds/array.rs b/rust/vineyard/src/ds/array.rs index dba3350411..2164a85187 100644 --- a/rust/vineyard/src/ds/array.rs +++ b/rust/vineyard/src/ds/array.rs @@ -12,11 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::convert::AsRef; use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; + +use static_str_ops::*; use crate::client::*; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Array { meta: ObjectMeta, size: usize, @@ -25,8 +29,8 @@ pub struct Array { } impl TypeName for Array { - fn typename() -> String { - return format!("vineyard::Array<{}>", T::typename()); + fn typename() -> &'static str { + return staticize(format!("vineyard::Array<{}>", T::typename())); } } @@ -43,7 +47,7 @@ impl Default for Array { impl Object for Array { fn construct(&mut self, meta: ObjectMeta) -> Result<()> { - vineyard_assert_typename(meta.get_typename()?, &typename::>())?; + vineyard_assert_typename(typename::(), meta.get_typename()?)?; self.meta = meta; self.size = self.meta.get_usize("size_")?; @@ -53,12 +57,24 @@ impl Object for Array { } register_vineyard_object!(Array); +register_vineyard_types! { + Array; + Array; + Array; + Array; + Array; + Array; + Array; + Array; + Array; + Array; +} impl Array { pub fn new_boxed(meta: ObjectMeta) -> Result> { - let mut array: Array = Array::default(); + let mut array = Box::::default(); array.construct(meta)?; - return Ok(Box::new(array)); + return Ok(array); } pub fn size(&self) -> usize { @@ -76,6 +92,20 @@ impl Array { } } +impl Deref for Array { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + return self.as_slice(); + } +} + +impl AsRef<[T]> for Array { + fn as_ref(&self) -> &[T] { + return self.as_slice(); + } +} + pub struct ArrayBuilder { sealed: bool, size: usize, @@ -94,16 +124,19 @@ impl ObjectBuilder for ArrayBuilder { } impl ObjectBase for ArrayBuilder { - fn build(&mut self, _client: &mut IPCClient) -> Result<()> { - if !self.sealed { - self.set_sealed(true); + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); } + self.set_sealed(true); + self.buffer.build(client)?; return Ok(()); } - fn seal(self: Self, client: &mut IPCClient) -> Result> { + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; let buffer = self.buffer.seal(client)?; - let mut meta = ObjectMeta::from_typename(&typename::>()); + let mut meta = ObjectMeta::new_from_typename(typename::>()); meta.add_member("buffer_", buffer)?; meta.add_usize("size_", self.size); meta.set_nbytes(self.size * std::mem::size_of::()); @@ -124,11 +157,11 @@ impl ArrayBuilder { return Ok(builder); } - pub fn from_vec(client: &mut IPCClient, vec: &Vec) -> Result { + pub fn from_vec(client: &mut IPCClient, vec: &[T]) -> Result { let mut builder = ArrayBuilder::new(client, vec.len())?; let dest: *mut T = builder.as_mut_ptr(); unsafe { - std::ptr::copy_nonoverlapping(vec.as_ptr(), dest, vec.len() * std::mem::size_of::()); + std::ptr::copy_nonoverlapping(vec.as_ptr(), dest, vec.len()); } return Ok(builder); } @@ -137,7 +170,7 @@ impl ArrayBuilder { let mut builder = ArrayBuilder::new(client, size)?; let dest: *mut T = builder.as_mut_ptr(); unsafe { - std::ptr::copy_nonoverlapping(data, dest, size * std::mem::size_of::()); + std::ptr::copy_nonoverlapping(data, dest, size); } return Ok(builder); } @@ -162,3 +195,23 @@ impl ArrayBuilder { return unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.size) }; } } + +impl Deref for ArrayBuilder { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + return self.as_slice(); + } +} + +impl DerefMut for ArrayBuilder { + fn deref_mut(&mut self) -> &mut Self::Target { + return self.as_mut_slice(); + } +} + +impl AsRef<[T]> for ArrayBuilder { + fn as_ref(&self) -> &[T] { + return self.as_slice(); + } +} diff --git a/rust/vineyard/src/ds/array_test.rs b/rust/vineyard/src/ds/array_test.rs index acd5c625cf..9422721307 100644 --- a/rust/vineyard/src/ds/array_test.rs +++ b/rust/vineyard/src/ds/array_test.rs @@ -15,38 +15,39 @@ #[cfg(test)] mod tests { use std::fmt::Debug; - use std::rc::Rc; use num_traits::FromPrimitive; use spectral::prelude::*; - use super::super::super::client::*; use super::super::array::*; + use crate::client::*; - fn test_array_generic() { + fn test_array_generic() -> Result<()> + { const N: usize = 1024; - let mut conn = IPCClient::default().unwrap(); - let client = Rc::get_mut(&mut conn).unwrap(); + let mut client = IPCClient::default()?; - let mut builder = ArrayBuilder::::new(client, N).unwrap(); + let mut builder = ArrayBuilder::::new(&mut client, N)?; let slice_mut = builder.as_mut_slice(); - for i in 0..N { - slice_mut[i] = T::from_usize(i).unwrap(); + for (idx, item) in slice_mut.iter_mut().enumerate() { + *item = T::from_usize(idx).ok_or(VineyardError::invalid("cannot convert to T"))?; } let array_object_id: ObjectID; // test seal { - let object = builder.seal(client).unwrap(); - let array = downcast_object::>(object).unwrap(); + let object = builder.seal(&mut client)?; + let array = downcast_object::>(object)?; let blob_id = array.id(); - assert_that(&blob_id).is_greater_than(0); + assert_that!(blob_id).is_greater_than(0); let slice = array.as_slice(); - for i in 0..N { - assert_that(&slice[i]).is_equal_to(T::from_usize(i).unwrap()); + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to( + T::from_usize(idx).ok_or(VineyardError::invalid("cannot convert to T"))?, + ); } array_object_id = array.id(); } @@ -55,33 +56,37 @@ mod tests { { let array = client.get::>(array_object_id).unwrap(); let array_id = array.id(); - assert_that(&array_id).is_greater_than(0); - assert_that(&array_id).is_equal_to(array_object_id); + assert_that!(array_id).is_greater_than(0); + assert_that!(array_id).is_equal_to(array_object_id); let slice = array.as_slice(); - for i in 0..N { - assert_that(&slice[i]).is_equal_to(T::from_usize(i).unwrap()); + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to( + T::from_usize(idx).ok_or(VineyardError::invalid("cannot convert to T"))?, + ); } } + + return Ok(()); } #[test] - fn test_array_int32() { - test_array_generic::(); + fn test_array_int32() -> Result<()> { + return test_array_generic::(); } #[test] - fn test_array_int64() { - test_array_generic::(); + fn test_array_int64() -> Result<()> { + return test_array_generic::(); } #[test] - fn test_array_float() { - test_array_generic::(); + fn test_array_float() -> Result<()> { + return test_array_generic::(); } #[test] - fn test_array_double() { - test_array_generic::(); + fn test_array_double() -> Result<()> { + return test_array_generic::(); } } diff --git a/rust/vineyard/src/ds/arrow.rs b/rust/vineyard/src/ds/arrow.rs index 9742a59c99..02256da9bf 100644 --- a/rust/vineyard/src/ds/arrow.rs +++ b/rust/vineyard/src/ds/arrow.rs @@ -12,4 +12,1162 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::io::*; +use std::marker::PhantomData; +use std::rc::Rc; +use std::sync::Arc; + +use arrow_array as array; +use arrow_array::builder; +use arrow_array::builder::GenericStringBuilder; +use arrow_array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow_schema as schema; +use arrow_schema::ArrowError; +use downcast_rs::impl_downcast; +use itertools::izip; +use serde_json::{json, Value}; +use static_str_ops::*; + +use super::arrow_utils::*; +use crate::client::*; + +impl From for VineyardError { + fn from(error: ArrowError) -> Self { + VineyardError::new(StatusCode::ArrowError, format!("{}", error)) + } +} + +pub trait Array: Object { + fn array(&self) -> array::ArrayRef; +} + +impl_downcast!(Array); + +pub fn downcast_array(object: Box) -> Result> { + return object + .downcast::() + .map_err(|_| VineyardError::invalid(format!("downcast object to array failed",))); +} + +pub fn downcast_array_ref(object: &dyn Array) -> Result<&T> { + return object + .downcast_ref::() + .ok_or(VineyardError::invalid(format!( + "downcast object '{:?}' to array failed", + object.meta().get_typename()?, + ))); +} + +pub fn downcast_array_rc(object: Rc) -> Result> { + return object + .downcast_rc::() + .map_err(|_| VineyardError::invalid(format!("downcast object to array failed",))); +} + +pub trait NumericType = ToArrowType where ::Type: array::ArrowPrimitiveType; + +pub type TypedBuffer = + ScalarBuffer<<::Type as array::ArrowPrimitiveType>::Native>; + +pub type TypedArray = array::PrimitiveArray<::Type>; +pub type TypedBuilder = builder::PrimitiveBuilder<::Type>; + +#[derive(Debug)] +pub struct NumericArray { + meta: ObjectMeta, + array: Arc>, +} + +impl Array for NumericArray { + fn array(&self) -> array::ArrayRef { + return self.array.clone(); + } +} + +pub type Int8Array = NumericArray; +pub type UInt8Array = NumericArray; +pub type Int16Array = NumericArray; +pub type UInt16Array = NumericArray; +pub type Int32Array = NumericArray; +pub type UInt32Array = NumericArray; +pub type Int64Array = NumericArray; +pub type UInt64Array = NumericArray; +pub type Float32Array = NumericArray; +pub type Float64Array = NumericArray; + +impl TypeName for NumericArray { + fn typename() -> &'static str { + return staticize(format!("vineyard::NumericArray<{}>", T::typename())); + } +} + +impl Default for NumericArray { + fn default() -> Self { + NumericArray { + meta: ObjectMeta::default(), + array: Arc::new(TypedArray::::new(vec![].into(), None)), + } + } +} + +impl Object for NumericArray { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + self.meta = meta; + let values = resolve_scalar_buffer::(&self.meta, "buffer_")?; + let nulls = resolve_null_bitmap_buffer(&self.meta, "null_bitmap_")?; + self.array = Arc::new(TypedArray::::new(values, nulls)); + return Ok(()); + } +} + +register_vineyard_object!(NumericArray); +register_vineyard_types! { + Int8Array; + UInt8Array; + Int16Array; + UInt16Array; + Int32Array; + UInt32Array; + Int64Array; + UInt64Array; + Float32Array; + Float64Array; +} + +impl NumericArray { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut array = Box::::default(); + array.construct(meta)?; + return Ok(array); + } + + pub fn data(&self) -> Arc> { + return self.array.clone(); + } + + pub fn len(&self) -> usize { + return self.array.len(); + } + + pub fn is_empty(&self) -> bool { + return self.array.is_empty(); + } + + pub fn as_slice(&self) -> &[T] { + return unsafe { + std::slice::from_raw_parts(self.array.values().inner().as_ptr() as _, self.len()) + }; + } +} + +impl AsRef> for NumericArray { + fn as_ref(&self) -> &TypedArray { + return &self.array; + } +} + +pub struct NumericBuilder { + sealed: bool, + length: usize, + offset: usize, + null_count: usize, + buffer: BlobWriter, + null_bitmap: Option, + phantom: PhantomData, +} + +pub type Int8Builder = NumericBuilder; +pub type UInt8Builder = NumericBuilder; +pub type Int16Builder = NumericBuilder; +pub type UInt16Builder = NumericBuilder; +pub type Int32Builder = NumericBuilder; +pub type UInt32Builder = NumericBuilder; +pub type Int64Builder = NumericBuilder; +pub type UInt64Builder = NumericBuilder; +pub type Float32Builder = NumericBuilder; +pub type Float64Builder = NumericBuilder; + +impl ObjectBuilder for NumericBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for NumericBuilder { + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + self.buffer.build(client)?; + if let Some(ref mut null_bitmap) = self.null_bitmap { + null_bitmap.build(client)?; + } + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let mut nbytes = self.buffer.len(); + let buffer = self.buffer.seal(client)?; + let null_bitmap = match self.null_bitmap { + None => None, + Some(null_bitmap) => { + nbytes += null_bitmap.len(); + Some(null_bitmap.seal(client)?) + } + }; + let mut meta = ObjectMeta::new_from_typename(typename::>()); + meta.add_member("buffer_", buffer)?; + if let Some(null_bitmap) = null_bitmap { + meta.add_member("null_bitmap_", null_bitmap)?; + } else { + meta.add_member("null_bitmap_", Blob::empty(client)?)?; + } + meta.add_usize("length_", self.length); + meta.add_usize("offset_", self.offset); + meta.add_usize("null_count_", self.null_count); + meta.set_nbytes(nbytes); + let metadata = client.create_metadata(&meta)?; + return NumericArray::::new_boxed(metadata); + } +} + +impl NumericBuilder { + pub fn new(client: &mut IPCClient, length: usize) -> Result { + let buffer = client.create_blob(std::mem::size_of::() * length)?; + return Ok(NumericBuilder { + sealed: false, + length, + offset: 0, + null_count: 0, + buffer, + null_bitmap: None, + phantom: PhantomData, + }); + } + + pub fn new_from_array(client: &mut IPCClient, array: &TypedArray) -> Result { + use arrow_array::Array; + + let buffer = build_scalar_buffer::(client, array.values())?; + let null_bitmap = build_null_bitmap_buffer(client, array.nulls())?; + return Ok(NumericBuilder { + sealed: false, + length: array.len(), + offset: 0, + null_count: array.null_count(), + buffer, + null_bitmap, + phantom: PhantomData, + }); + } + + pub fn new_from_builder(client: &mut IPCClient, builder: &mut TypedBuilder) -> Result { + let array = builder.finish(); + return Self::new_from_array(client, &array); + } + + pub fn len(&self) -> usize { + return self.length; + } + + pub fn is_empty(&self) -> bool { + return self.length == 0; + } + + pub fn offset(&self) -> usize { + return self.offset; + } + + pub fn null_count(&self) -> usize { + return self.null_count; + } + + pub fn as_slice(&mut self) -> &[T] { + return unsafe { std::mem::transmute(self.buffer.as_slice()) }; + } + + pub fn as_mut_slice(&mut self) -> &mut [T] { + return unsafe { std::mem::transmute(self.buffer.as_mut_slice()) }; + } +} + +#[derive(Debug)] +pub struct BaseStringArray { + meta: ObjectMeta, + array: Arc>, +} + +impl Array for BaseStringArray { + fn array(&self) -> array::ArrayRef { + return self.array.clone(); + } +} + +pub type StringArray = BaseStringArray; +pub type LargeStringArray = BaseStringArray; + +impl TypeName for BaseStringArray { + fn typename() -> &'static str { + if std::mem::size_of::() == 4 { + return staticize("vineyard::BaseBinaryArray"); + } else { + return staticize("vineyard::BaseBinaryArray"); + } + } +} + +impl Default for BaseStringArray { + fn default() -> Self { + BaseStringArray { + meta: ObjectMeta::default(), + array: Arc::new(array::GenericStringArray::::new_null(0)), + } + } +} + +impl Object for BaseStringArray { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + self.meta = meta; + let values = resolve_buffer(&self.meta, "buffer_data_")?; + let offsets = resolve_offsets_buffer::(&self.meta, "buffer_offsets_")?; + let nulls = resolve_null_bitmap_buffer(&self.meta, "null_bitmap_")?; + self.array = Arc::new(array::GenericStringArray::::new(offsets, values, nulls)); + return Ok(()); + } +} + +register_vineyard_object!(BaseStringArray); +register_vineyard_types! { + StringArray; + LargeStringArray; +} + +impl BaseStringArray { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut array = Box::::default(); + array.construct(meta)?; + return Ok(array); + } + + pub fn data(&self) -> Arc> { + return self.array.clone(); + } + + pub fn len(&self) -> usize { + use arrow_array::Array; + return self.array.len(); + } + + pub fn is_empty(&self) -> bool { + use arrow_array::Array; + return self.array.is_empty(); + } + + pub fn as_slice(&self) -> &[u8] { + return self.array.value_data(); + } + + pub fn as_slice_offsets(&self) -> &[O] { + return self.array.value_offsets(); + } +} + +impl AsRef> for BaseStringArray { + fn as_ref(&self) -> &array::GenericStringArray { + return &self.array; + } +} + +pub struct BaseStringBuilder { + sealed: bool, + length: usize, + offset: usize, + null_count: usize, + value_data: BlobWriter, + value_offsets: BlobWriter, + null_bitmap: Option, + phantom: PhantomData, +} + +pub type StringBuilder = BaseStringBuilder; +pub type LargeStringBuilder = BaseStringBuilder; + +impl ObjectBuilder for BaseStringBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for BaseStringBuilder { + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + self.value_data.build(client)?; + self.value_offsets.build(client)?; + if let Some(ref mut null_bitmap) = self.null_bitmap { + null_bitmap.build(client)?; + } + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let mut nbytes = self.value_data.len(); + let value_data = self.value_data.seal(client)?; + nbytes += self.value_offsets.len(); + let value_offsets = self.value_offsets.seal(client)?; + let null_bitmap = match self.null_bitmap { + None => None, + Some(null_bitmap) => { + nbytes += null_bitmap.len(); + Some(null_bitmap.seal(client)?) + } + }; + let mut meta = ObjectMeta::new_from_typename(typename::>()); + meta.add_member("buffer_data_", value_data)?; + meta.add_member("buffer_offsets_", value_offsets)?; + if let Some(null_bitmap) = null_bitmap { + meta.add_member("null_bitmap_", null_bitmap)?; + } else { + meta.add_member("null_bitmap_", Blob::empty(client)?)?; + } + meta.add_usize("length_", self.length); + meta.add_usize("offset_", self.offset); + meta.add_usize("null_count_", self.null_count); + meta.set_nbytes(nbytes); + let metadata = client.create_metadata(&meta)?; + return BaseStringArray::::new_boxed(metadata); + } +} + +impl BaseStringBuilder { + pub fn new_from_array(client: &mut IPCClient, array: &GenericStringArray) -> Result { + use arrow_array::Array; + + let value_data = build_buffer(client, array.values())?; + let value_offsets = build_offsets_buffer(client, array.offsets())?; + let null_bitmap = build_null_bitmap_buffer(client, array.nulls())?; + return Ok(BaseStringBuilder { + sealed: false, + length: array.len(), + offset: 0, + null_count: array.null_count(), + value_data, + value_offsets, + null_bitmap, + phantom: PhantomData, + }); + } + + pub fn new_from_builder( + client: &mut IPCClient, + builder: &mut GenericStringBuilder, + ) -> Result { + let array = builder.finish(); + return Self::new_from_array(client, &array); + } + + pub fn len(&self) -> usize { + return self.length; + } + + pub fn is_empty(&self) -> bool { + return self.length == 0; + } + + pub fn offset(&self) -> usize { + return self.offset; + } + + pub fn null_count(&self) -> usize { + return self.null_count; + } + + pub fn as_slice(&mut self) -> &[u8] { + return self.value_data.as_slice(); + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + return unsafe { std::mem::transmute(self.value_data.as_mut_slice()) }; + } + + pub fn as_slice_offsets(&mut self) -> &[O] { + return unsafe { std::mem::transmute(self.value_offsets.as_slice()) }; + } + + pub fn as_mut_slice_offsets(&mut self) -> &mut [O] { + return unsafe { std::mem::transmute(self.value_offsets.as_mut_slice()) }; + } +} + +pub fn downcast_to_array(object: Box) -> Result> { + macro_rules! downcast { + ($object: ident, $ty: ty) => { + |$object| match $object.downcast::<$ty>() { + Ok(array) => Ok(array), + Err(original) => Err(original), + } + }; + } + + let mut object: std::result::Result, Box> = Err(object); + object = object + .or_else(downcast!(object, Int8Array)) + .or_else(downcast!(object, UInt8Array)) + .or_else(downcast!(object, Int16Array)) + .or_else(downcast!(object, UInt16Array)) + .or_else(downcast!(object, Int32Array)) + .or_else(downcast!(object, UInt32Array)) + .or_else(downcast!(object, Int64Array)) + .or_else(downcast!(object, UInt64Array)) + .or_else(downcast!(object, Float32Array)) + .or_else(downcast!(object, Float64Array)) + .or_else(downcast!(object, StringArray)) + .or_else(downcast!(object, LargeStringArray)); + + match object { + Ok(array) => return Ok(array), + Err(object) => { + return Err(VineyardError::invalid(format!( + "downcast object to array failed, object type is: '{}'", + object.meta().get_typename()?, + ))) + } + }; +} + +pub fn build_array(client: &mut IPCClient, array: ArrayRef) -> Result> { + macro_rules! build { + ($array: ident, $array_ty: ty, $builder_ty: ty) => { + |$array| match $array.as_any().downcast_ref::<$array_ty>() { + Some(array) => match <$builder_ty>::new_from_array(client, array) { + Ok(builder) => match builder.seal(client) { + Ok(object) => Ok(object), + Err(_) => Err(array as &dyn array::Array), + }, + Err(_) => Err(array as &dyn array::Array), + }, + None => Err($array), + } + }; + } + + let mut array: std::result::Result, &dyn array::Array> = Err(array.as_ref()); + array = array + .or_else(build!(array, array::Int8Array, Int8Builder)) + .or_else(build!(array, array::UInt8Array, UInt8Builder)) + .or_else(build!(array, array::Int16Array, Int16Builder)) + .or_else(build!(array, array::UInt16Array, UInt16Builder)) + .or_else(build!(array, array::Int32Array, Int32Builder)) + .or_else(build!(array, array::UInt32Array, UInt32Builder)) + .or_else(build!(array, array::Int64Array, Int64Builder)) + .or_else(build!(array, array::UInt64Array, UInt64Builder)) + .or_else(build!(array, array::Float32Array, Float32Builder)) + .or_else(build!(array, array::Float64Array, Float64Builder)) + .or_else(build!(array, array::StringArray, StringBuilder)) + .or_else(build!(array, array::LargeStringArray, LargeStringBuilder)); + + match array { + Ok(builder) => return Ok(builder), + Err(array) => { + return Err(VineyardError::invalid(format!( + "build array failed, array type is: '{}'", + array.data_type(), + ))) + } + }; +} + +#[derive(Debug)] +pub struct SchemaProxy { + meta: ObjectMeta, + schema: schema::Schema, +} + +impl TypeName for SchemaProxy { + fn typename() -> &'static str { + return staticize("vineyard::SchemaProxy"); + } +} + +impl Default for SchemaProxy { + fn default() -> Self { + SchemaProxy { + meta: ObjectMeta::default(), + schema: schema::Schema::empty(), + } + } +} + +impl Object for SchemaProxy { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + self.meta = meta; + let schema: Vec = match self.meta.get_value("schema_binary_")? { + Value::Object(values) => { + let schema = values.get("bytes").ok_or(VineyardError::invalid( + "construct schema from binary failed: failed to get schema binary", + ))?; + match schema { + Value::Array(array) => { + let mut values = Vec::with_capacity(array.len()); + for v in array { + match v { + Value::Number(n) => { + if let Some(n) = n.as_u64() { + values.push(n as u8); + } else { + return Err(VineyardError::invalid( + format!("construct schema from binary failed: failed to get schema binary: not a positive number: {:?}", n), + )); + } + } + _ => return Err(VineyardError::invalid( + format!("construct schema from binary failed: failed to get schema binary: not a positive number: {:?}", v), + )), + } + } + Ok(values) + } + _ => Err(VineyardError::invalid( + "construct schema from binary failed: value is not an array", + )), + } + } + _ => Err(VineyardError::invalid( + "construct schema from binary failed: failed to get schema binary", + )), + }?; + self.schema = arrow_ipc::convert::try_schema_from_ipc_buffer(schema.as_slice())?; + return Ok(()); + } +} + +register_vineyard_object!(SchemaProxy); + +impl SchemaProxy { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut schema = Box::::default(); + schema.construct(meta)?; + return Ok(schema); + } +} + +impl AsRef for SchemaProxy { + fn as_ref(&self) -> &schema::Schema { + return &self.schema; + } +} + +pub struct SchemaProxyBuilder { + sealed: bool, + schema_binary: Vec, + schema_textual: String, +} + +impl ObjectBuilder for SchemaProxyBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for SchemaProxyBuilder { + fn build(&mut self, _client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let mut meta = ObjectMeta::new_from_typename(typename::()); + meta.add_value( + "schema_binary_", + json!( + { + "bytes": self.schema_binary, + } + ), + ); + meta.add_string("schema_textual_", self.schema_textual); + meta.set_nbytes(self.schema_binary.len()); + let metadata = client.create_metadata(&meta)?; + return SchemaProxy::new_boxed(metadata); + } +} + +impl SchemaProxyBuilder { + pub fn new(schema: &schema::Schema) -> Result { + let buffer: Vec = Vec::new(); + let writer = arrow_ipc::writer::StreamWriter::try_new(buffer, schema)?; + let schema_binary = writer.into_inner()?; + let schema_textual = schema.to_string(); + return Ok(SchemaProxyBuilder { + sealed: false, + schema_binary, + schema_textual, + }); + } + + pub fn new_from_builder(builder: schema::SchemaBuilder) -> Result { + return Self::new(&builder.finish()); + } +} + +#[derive(Debug)] +pub struct RecordBatch { + meta: ObjectMeta, + batch: array::RecordBatch, +} + +impl_typename!(RecordBatch, "vineyard::RecordBatch"); + +impl Default for RecordBatch { + fn default() -> Self { + RecordBatch { + meta: ObjectMeta::default(), + batch: array::RecordBatch::new_empty(Arc::new(schema::Schema::empty())), + } + } +} + +impl Object for RecordBatch { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + self.meta = meta; + let schema = self.meta.get_member::("schema_")?; + let schema = schema.as_ref().as_ref().clone(); + let _num_rows = self.meta.get_usize("row_num_")?; + let _num_columns = self.meta.get_usize("column_num_")?; + let columns_size = self.meta.get_usize("__columns_-size")?; + let mut arrays = Vec::with_capacity(columns_size); + for i in 0..columns_size { + let column = self.meta.get_member_untyped(&format!("__columns_-{}", i))?; + arrays.push(downcast_to_array(column)?.array()); + } + self.batch = array::RecordBatch::try_new(Arc::new(schema), arrays)?; + return Ok(()); + } +} + +register_vineyard_object!(RecordBatch); + +impl RecordBatch { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut batch = Box::::default(); + batch.construct(meta)?; + return Ok(batch); + } + + pub fn schema(&self) -> Arc { + return self.batch.schema(); + } + + pub fn num_rows(&self) -> usize { + return self.batch.num_rows(); + } + + pub fn num_columns(&self) -> usize { + return self.batch.num_columns(); + } +} + +impl AsRef for RecordBatch { + fn as_ref(&self) -> &array::RecordBatch { + return &self.batch; + } +} + +pub struct RecordBatchBuilder { + sealed: bool, + schema: SchemaProxyBuilder, + row_num: usize, + column_num: usize, + columns: Vec>, +} + +impl ObjectBuilder for RecordBatchBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for RecordBatchBuilder { + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + self.schema.build(client)?; + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let mut meta = ObjectMeta::new_from_typename(typename::()); + meta.add_member("schema_", self.schema.seal(client)?)?; + meta.add_usize("row_num_", self.row_num); + meta.add_usize("column_num_", self.column_num); + meta.add_usize("__columns_-size", self.columns.len()); + for (i, column) in self.columns.into_iter().enumerate() { + meta.add_member(&format!("__columns_-{}", i), column)?; + } + let metadata = client.create_metadata(&meta)?; + return RecordBatch::new_boxed(metadata); + } +} + +impl RecordBatchBuilder { + pub fn new( + _client: &mut IPCClient, + schema: &arrow_schema::Schema, + row_num: usize, + column_num: usize, + columns: Vec>, + ) -> Result { + return Ok(RecordBatchBuilder { + sealed: false, + schema: SchemaProxyBuilder::new(schema)?, + row_num: row_num, + column_num: column_num, + columns: columns, + }); + } + + pub fn new_from_recordbatch( + client: &mut IPCClient, + batch: &array::RecordBatch, + ) -> Result { + let mut columns = Vec::with_capacity(batch.num_columns()); + for i in 0..batch.num_columns() { + let array = batch.column(i); + let array = build_array(client, array.clone())?; + columns.push(array); + } + return Self::new( + client, + batch.schema().as_ref(), + batch.num_rows(), + batch.num_columns(), + columns, + ); + } +} + +#[derive(Debug)] +pub struct Table { + meta: ObjectMeta, + schema: schema::Schema, + num_rows: usize, + num_columns: usize, + batches: Vec>, +} + +impl_typename!(Table, "vineyard::Table"); + +impl Default for Table { + fn default() -> Self { + Table { + meta: ObjectMeta::default(), + schema: schema::Schema::empty(), + num_rows: 0, + num_columns: 0, + batches: Vec::new(), + } + } +} + +impl Object for Table { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + self.meta = meta; + let schema = self.meta.get_member::("schema_")?; + let schema = schema.as_ref().as_ref().clone(); + self.num_rows = self.meta.get_usize("num_rows_")?; + self.num_columns = self.meta.get_usize("num_columns_")?; + let _batch_num = self.meta.get_usize("batch_num_")?; + let partitions_size = self.meta.get_usize("partitions_-size")?; + let mut batches = Vec::with_capacity(partitions_size); + for i in 0..partitions_size { + let batch = self + .meta + .get_member::(&format!("partitions_-{}", i))?; + batches.push(batch); + } + self.schema = schema; + self.batches = batches; + return Ok(()); + } +} + +register_vineyard_object!(Table); + +impl Table { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut table = Box::::default(); + table.construct(meta)?; + return Ok(table); + } + + pub fn schema(&self) -> &schema::Schema { + return &self.schema; + } + + pub fn num_rows(&self) -> usize { + return self.num_rows; + } + + pub fn num_columns(&self) -> usize { + return self.num_columns; + } + + pub fn num_batches(&self) -> usize { + return self.batches.len(); + } + + pub fn batches(&self) -> &[Box] { + return &self.batches; + } +} + +impl AsRef<[Box]> for Table { + fn as_ref(&self) -> &[Box] { + return &self.batches; + } +} + +pub struct TableBuilder { + sealed: bool, + global: bool, + schema: SchemaProxyBuilder, + num_rows: usize, + num_columns: usize, + batches: Vec>, +} + +impl ObjectBuilder for TableBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for TableBuilder { + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + self.schema.build(client)?; + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let mut meta = ObjectMeta::new_from_typename(typename::
()); + meta.set_global(self.global); + meta.add_member("schema_", self.schema.seal(client)?)?; + meta.add_usize("num_rows_", self.num_rows); + meta.add_usize("num_columns_", self.num_columns); + meta.add_usize("batch_num_", self.batches.len()); + meta.add_usize("partitions_-size", self.batches.len()); + for (i, batch) in self.batches.into_iter().enumerate() { + meta.add_member(&format!("partitions_-{}", i), batch)?; + } + let metadata = client.create_metadata(&meta)?; + return Table::new_boxed(metadata); + } +} + +impl TableBuilder { + pub fn new( + _client: &mut IPCClient, + schema: &schema::Schema, + num_rows: usize, + num_columns: usize, + batches: Vec>, + ) -> Result { + return Ok(TableBuilder { + sealed: false, + global: false, + schema: SchemaProxyBuilder::new(schema)?, + num_rows: num_rows, + num_columns: num_columns, + batches: batches, + }); + } + + pub fn new_from_bathes( + client: &mut IPCClient, + schema: &schema::Schema, + num_rows: Vec, + num_columns: usize, + batches: Vec>>, + ) -> Result { + let mut chunks = Vec::with_capacity(batches.len()); + let mut total_num_rows = 0; + for (num_row, batch) in izip!(num_rows, batches) { + total_num_rows += num_row; + let batch = RecordBatchBuilder::new(client, schema, num_row, num_columns, batch)?; + chunks.push(batch.seal(client)?); + } + return Ok(TableBuilder { + sealed: false, + global: false, + schema: SchemaProxyBuilder::new(schema)?, + num_rows: total_num_rows, + num_columns: num_columns, + batches: chunks, + }); + } + + pub fn new_from_recordbatches( + client: &mut IPCClient, + schema: &schema::Schema, + table: &[array::RecordBatch], + ) -> Result { + let schema = SchemaProxyBuilder::new(schema)?; + + let mut batches = Vec::with_capacity(table.len()); + let mut num_rows = 0; + let mut num_columns = 0; + for batch in table { + num_rows += batch.num_rows(); + num_columns = batch.num_columns(); + let batch = RecordBatchBuilder::new_from_recordbatch(client, batch)?; + batches.push(batch.seal(client)?); + } + return Ok(TableBuilder { + sealed: false, + global: false, + schema: schema, + num_rows: num_rows, + num_columns: num_columns, + batches: batches, + }); + } +} + +pub(crate) fn resolve_buffer(meta: &ObjectMeta, key: &str) -> Result { + let id = meta.get_member_id(key)?; + match meta.get_buffer(id)? { + None => { + return Err(VineyardError::invalid(format!( + "buffer '{}' not exists in metadata", + key + ))); + } + Some(buffer) => { + return Ok(buffer); + } + } +} + +pub(crate) fn resolve_null_bitmap_buffer( + meta: &ObjectMeta, + key: &str, +) -> Result> { + let id = meta.get_member_id(key)?; + if is_blob(id) { + return Ok(None); + } + if let Ok(buffer) = resolve_buffer(meta, key) { + let length = meta.get_usize("length_")?; + let null_count = meta.get_usize("null_count_")?; + let offset = meta.get_usize("offset_")?; + let buffer = BooleanBuffer::new(buffer, offset, length); + return Ok(Some(unsafe { + NullBuffer::new_unchecked(buffer, null_count) + })); + } + return Ok(None); +} + +pub(crate) fn resolve_scalar_buffer( + meta: &ObjectMeta, + key: &str, +) -> Result> { + let buffer = resolve_buffer(meta, key)?; + let length = meta + .get_usize("length_") + .unwrap_or(buffer.len() / std::mem::size_of::()); + let offset = meta.get_usize("offset_").unwrap_or(0); + return Ok(TypedBuffer::::new(buffer, offset, length)); +} + +pub(crate) fn resolve_offsets_buffer( + meta: &ObjectMeta, + key: &str, +) -> Result> { + let buffer = resolve_buffer(meta, key)?; + let length = meta.get_usize("length_")? + 1; + let offset = meta.get_usize("offset_")?; + let buffer = ScalarBuffer::::new(buffer, offset, length); + return Ok(unsafe { OffsetBuffer::new_unchecked(buffer) }); +} + +pub(crate) fn build_buffer(client: &mut IPCClient, buffer: &Buffer) -> Result { + let mut blob = client.create_blob(buffer.len())?; + unsafe { + std::ptr::copy_nonoverlapping(buffer.as_ptr(), blob.as_typed_mut_ptr::(), buffer.len()); + }; + return Ok(blob); +} + +pub(crate) fn build_null_bitmap_buffer( + client: &mut IPCClient, + buffer: Option<&NullBuffer>, +) -> Result> { + match buffer { + None => { + return Ok(None); + } + Some(buffer) => { + let null_bitmap = build_buffer(client, buffer.buffer())?; + return Ok(Some(null_bitmap)); + } + } +} + +pub(crate) fn build_scalar_buffer( + client: &mut IPCClient, + buffer: &TypedBuffer, +) -> Result { + let values = build_buffer(client, buffer.inner())?; + return Ok(values); +} + +pub(crate) fn build_offsets_buffer( + client: &mut IPCClient, + buffer: &OffsetBuffer, +) -> Result { + let offsets = build_buffer(client, buffer.inner().inner())?; + return Ok(offsets); +} diff --git a/rust/vineyard/src/ds/arrow_test.rs b/rust/vineyard/src/ds/arrow_test.rs index 9742a59c99..7ed5792529 100644 --- a/rust/vineyard/src/ds/arrow_test.rs +++ b/rust/vineyard/src/ds/arrow_test.rs @@ -12,4 +12,250 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::io::*; +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array as array; + use arrow_schema as schema; + use spectral::prelude::*; + + use super::super::arrow::*; + use crate::client::*; + + #[test] + fn test_int_array() -> Result<()> { + const N: usize = 1024; + let mut client = IPCClient::default()?; + + // prepare data + let vec = (0..N).map(|i| i as i32).collect::>(); + let array = array::Int32Array::from(vec); + + let array_object_id: ObjectID; + + // build into vineyard + { + let builder = Int32Builder::new_from_array(&mut client, &array)?; + let object = builder.seal(&mut client)?; + let array = downcast_object::(object)?; + array_object_id = array.id(); + assert_that!(array.len()).is_equal_to(N); + + let slice = array.as_slice(); + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to(idx as i32); + } + } + + // get from vineyard + { + let array = client.get::(array_object_id).unwrap(); + let array_id = array.id(); + assert_that!(array_id).is_greater_than(0); + assert_that!(array_id).is_equal_to(array_object_id); + assert_that!(array.len()).is_equal_to(N); + + let slice = array.as_slice(); + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to(idx as i32); + } + } + + return Ok(()); + } + + #[test] + fn test_double_array() -> Result<()> { + const N: usize = 1024; + let mut client = IPCClient::default()?; + + // prepare data + let vec = (0..N).map(|i| i as f64).collect::>(); + let array = array::Float64Array::from(vec); + + let array_object_id: ObjectID; + + // build into vineyard + { + let builder = Float64Builder::new_from_array(&mut client, &array)?; + let object = builder.seal(&mut client)?; + let array = downcast_object::(object)?; + array_object_id = array.id(); + assert_that!(array.len()).is_equal_to(N); + + let slice = array.as_slice(); + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to(idx as f64); + } + } + + // get from vineyard + { + let array = client.get::(array_object_id).unwrap(); + let array_id = array.id(); + assert_that!(array_id).is_greater_than(0); + assert_that!(array_id).is_equal_to(array_object_id); + assert_that!(array.len()).is_equal_to(N); + + let slice = array.as_slice(); + for (idx, item) in slice.iter().enumerate() { + assert_that!(*item).is_equal_to(idx as f64); + } + } + + return Ok(()); + } + + #[test] + fn test_string_array() -> Result<()> { + const N: usize = 1024; + let mut client = IPCClient::default()?; + + // prepare data + let vec = (0..N).map(|i| format!("{:?}", i)).collect::>(); + let strings = vec.join(""); + let array = array::LargeStringArray::from(vec); + let array_object_id: ObjectID; + + // build into vineyard + { + let builder = LargeStringBuilder::new_from_array(&mut client, &array)?; + let object = builder.seal(&mut client)?; + + let array = downcast_object::(object)?; + array_object_id = array.id(); + assert_that!(array.len()).is_equal_to(N); + + let slice = array.as_slice(); + for (item, expect) in slice.iter().zip(strings.as_bytes().iter()) { + assert_that!(item).is_equal_to(expect); + } + } + + // get from vineyard + { + let array = client.get::(array_object_id).unwrap(); + let array_id = array.id(); + assert_that!(array_id).is_greater_than(0); + assert_that!(array_id).is_equal_to(array_object_id); + assert_that!(array.len()).is_equal_to(N); + + let slice = array.as_slice(); + for (item, expect) in slice.iter().zip(strings.as_bytes().iter()) { + assert_that!(item).is_equal_to(expect); + } + } + + return Ok(()); + } + + #[test] + fn test_record_batch() -> Result<()> { + const N: usize = 1024; + let mut client = IPCClient::default()?; + + // prepare data + let vec0 = (0..N).map(|i| i as i32).collect::>(); + let vec1 = (0..N).map(|i| i as f64).collect::>(); + let vec2 = (0..N).map(|i| format!("{:?}", i)).collect::>(); + let array0 = array::Int32Array::from(vec0); + let array1 = array::Float64Array::from(vec1); + let array2 = array::LargeStringArray::from(vec2); + + let schema = schema::Schema::new(vec![ + schema::Field::new("f0", schema::DataType::Int32, false), + schema::Field::new("f1", schema::DataType::Float64, false), + schema::Field::new("f2", schema::DataType::LargeUtf8, false), + ]); + let batch = array::RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(array0), Arc::new(array1), Arc::new(array2)], + )?; + + let recordbatch_object_id: ObjectID; + + // build into vineyard + { + let builder = RecordBatchBuilder::new_from_recordbatch(&mut client, &batch)?; + let object = builder.seal(&mut client)?; + let recordbatch = downcast_object::(object)?; + recordbatch_object_id = recordbatch.id(); + assert_that!(recordbatch.num_rows()).is_equal_to(N); + assert_that!(recordbatch.num_columns()).is_equal_to(3); + + let recordbatch = recordbatch.as_ref().as_ref(); + + let column0 = recordbatch + .column(0) + .as_any() + .downcast_ref::() + .ok_or(VineyardError::type_error("downcast to Int32Array failed"))?; + for (idx, item) in column0.iter().enumerate() { + assert_that!(item).is_equal_to(Some(idx as i32)); + } + + let column1 = recordbatch + .column(1) + .as_any() + .downcast_ref::() + .ok_or(VineyardError::type_error("downcast to Float64Array failed"))?; + for (idx, item) in column1.iter().enumerate() { + assert_that!(item).is_equal_to(Some(idx as f64)); + } + + let column2 = recordbatch + .column(2) + .as_any() + .downcast_ref::() + .ok_or(VineyardError::type_error( + "downcast to LargeStringArray failed", + ))?; + for (idx, item) in column2.iter().enumerate() { + assert_that!(item).is_equal_to(Some(format!("{}", idx).as_str())); + } + } + + // get from vineyard + { + let recordbatch = client.get::(recordbatch_object_id).unwrap(); + let recordbatch_id = recordbatch.id(); + assert_that!(recordbatch_id).is_greater_than(0); + assert_that!(recordbatch_id).is_equal_to(recordbatch_object_id); + assert_that!(recordbatch.num_rows()).is_equal_to(N); + assert_that!(recordbatch.num_columns()).is_equal_to(3); + + let recordbatch = recordbatch.as_ref().as_ref(); + + let column0 = recordbatch + .column(0) + .as_any() + .downcast_ref::() + .ok_or(VineyardError::type_error("downcast to Int32Array failed"))?; + for (idx, item) in column0.iter().enumerate() { + assert_that!(item).is_equal_to(Some(idx as i32)); + } + + let column1 = recordbatch + .column(1) + .as_any() + .downcast_ref::() + .ok_or(VineyardError::type_error("downcast to Float64Array failed"))?; + for (idx, item) in column1.iter().enumerate() { + assert_that!(item).is_equal_to(Some(idx as f64)); + } + + let column2 = recordbatch + .column(2) + .as_any() + .downcast_ref::() + .ok_or(VineyardError::type_error( + "downcast to LargeStringArray failed", + ))?; + for (idx, item) in column2.iter().enumerate() { + assert_that!(item).is_equal_to(Some(format!("{}", idx).as_str())); + } + } + return Ok(()); + } +} diff --git a/rust/vineyard/src/ds/arrow_utils.rs b/rust/vineyard/src/ds/arrow_utils.rs new file mode 100644 index 0000000000..e0989cea7a --- /dev/null +++ b/rust/vineyard/src/ds/arrow_utils.rs @@ -0,0 +1,81 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_array::array::*; +use arrow_array::builder::*; +use arrow_array::types::*; +use arrow_schema::DataType; + +pub trait ToArrowType { + type Type; + type ArrayType; + type BuilderType; + + fn datatype() -> DataType; +} + +macro_rules! impl_to_arrow_type { + ($ty:ty, $datatype:expr, $type:ty, $array_ty:ty, $builder_ty:ty) => { + impl ToArrowType for $ty { + type Type = $type; + type ArrayType = $array_ty; + type BuilderType = $builder_ty; + + fn datatype() -> DataType { + return $datatype; + } + } + }; +} + +impl_to_arrow_type!(i8, DataType::Int8, Int8Type, Int8Array, Int8Builder); +impl_to_arrow_type!(u8, DataType::UInt8, UInt8Type, UInt8Array, UInt8Builder); +impl_to_arrow_type!(i16, DataType::Int16, Int16Type, Int16Array, Int16Builder); +impl_to_arrow_type!( + u16, + DataType::UInt16, + UInt16Type, + UInt16Array, + UInt16Builder +); +impl_to_arrow_type!(i32, DataType::Int32, Int32Type, Int32Array, Int32Builder); +impl_to_arrow_type!( + u32, + DataType::UInt32, + UInt32Type, + UInt32Array, + UInt32Builder +); +impl_to_arrow_type!(i64, DataType::Int64, Int64Type, Int64Array, Int64Builder); +impl_to_arrow_type!( + u64, + DataType::UInt64, + UInt64Type, + UInt64Array, + UInt64Builder +); +impl_to_arrow_type!( + f32, + DataType::Float32, + Float32Type, + Float32Array, + Float32Builder +); +impl_to_arrow_type!( + f64, + DataType::Float64, + Float64Type, + Float64Array, + Float64Builder +); diff --git a/rust/vineyard/src/ds/dataframe.rs b/rust/vineyard/src/ds/dataframe.rs new file mode 100644 index 0000000000..6269572b0b --- /dev/null +++ b/rust/vineyard/src/ds/dataframe.rs @@ -0,0 +1,148 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_array::ArrayRef; +use serde_json::Value; + +use super::tensor::*; +use crate::client::*; + +#[derive(Default)] +pub struct DataFrame { + meta: ObjectMeta, + names: Vec, + columns: Vec>, +} + +impl_typename!(DataFrame, "vineyard::DataFrame"); + +impl Object for DataFrame { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + let size = meta.get_usize("__values_-size")?; + self.names = Vec::with_capacity(size); + self.columns = Vec::with_capacity(size); + for i in 0..size { + let name = meta.get_value(&format!("__values_-key-{}", i))?; + let name = match name { + Value::String(name) => name, + _ => name.to_string(), + }; + self.names.push(name); + let column = meta.get_member_untyped(&format!("__values_-value-{}", i))?; + self.columns.push(downcast_to_tensor(column)?); + } + return Ok(()); + } +} + +register_vineyard_object!(DataFrame); + +impl DataFrame { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut object = Box::::default(); + object.construct(meta)?; + Ok(object) + } + + pub fn num_columns(&self) -> usize { + self.columns.len() + } + + pub fn names(&self) -> &[String] { + &self.names + } + + pub fn name(&self, index: usize) -> &str { + &self.names[index] + } + + pub fn columns(&self) -> &[Box] { + &self.columns + } + + pub fn column(&self, index: usize) -> ArrayRef { + self.columns[index].array() + } +} + +pub struct DataFrameBuilder { + sealed: bool, + names: Vec, + columns: Vec>, +} + +impl ObjectBuilder for DataFrameBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for DataFrameBuilder { + fn build(&mut self, _client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let mut meta = ObjectMeta::new_from_typename(typename::()); + meta.add_usize("__values_-size", self.names.len()); + meta.add_isize("partition_index_row_", -1); + meta.add_isize("partition_index_column_", -1); + meta.add_isize("row_batch_index_", -1); + for (index, (name, column)) in self.names.iter().zip(self.columns).enumerate() { + meta.add_value( + &format!("__values_-key-{}", index), + Value::String(name.into()), + ); + meta.add_member(&format!("__values_-value-{}", index), column)?; + } + let metadata = client.create_metadata(&meta)?; + return DataFrame::new_boxed(metadata); + } +} + +impl DataFrameBuilder { + pub fn new(names: Vec, columns: Vec>) -> Result { + return Ok(DataFrameBuilder { + sealed: false, + names, + columns, + }); + } + + pub fn new_from_arrays( + client: &mut IPCClient, + names: Vec, + arrays: Vec, + ) -> Result { + let mut columns = Vec::with_capacity(arrays.len()); + for array in arrays { + columns.push(build_tensor(client, array)?); + } + return Ok(DataFrameBuilder { + sealed: false, + names, + columns: columns, + }); + } +} diff --git a/rust/vineyard/src/ds/dataframe_test.rs b/rust/vineyard/src/ds/dataframe_test.rs new file mode 100644 index 0000000000..f54b6c5ad4 --- /dev/null +++ b/rust/vineyard/src/ds/dataframe_test.rs @@ -0,0 +1,16 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[allow(unused_imports)] +use std::any::Any; diff --git a/rust/vineyard/src/ds/hashmap.rs b/rust/vineyard/src/ds/hashmap.rs index 1472bd56fe..f54b6c5ad4 100644 --- a/rust/vineyard/src/ds/hashmap.rs +++ b/rust/vineyard/src/ds/hashmap.rs @@ -12,4 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[allow(unused_imports)] use std::any::Any; diff --git a/rust/vineyard/src/ds/hashmap_test.rs b/rust/vineyard/src/ds/hashmap_test.rs index 9742a59c99..f54b6c5ad4 100644 --- a/rust/vineyard/src/ds/hashmap_test.rs +++ b/rust/vineyard/src/ds/hashmap_test.rs @@ -12,4 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::io::*; +#[allow(unused_imports)] +use std::any::Any; diff --git a/rust/vineyard/src/ds/mod.rs b/rust/vineyard/src/ds/mod.rs index 49a23bdcd9..1b87c8e86e 100644 --- a/rust/vineyard/src/ds/mod.rs +++ b/rust/vineyard/src/ds/mod.rs @@ -16,5 +16,10 @@ pub mod array; pub mod array_test; pub mod arrow; pub mod arrow_test; +pub mod arrow_utils; +pub mod dataframe; +pub mod dataframe_test; pub mod hashmap; pub mod hashmap_test; +pub mod tensor; +pub mod tensor_test; diff --git a/rust/vineyard/src/ds/tensor.rs b/rust/vineyard/src/ds/tensor.rs new file mode 100644 index 0000000000..62716cb518 --- /dev/null +++ b/rust/vineyard/src/ds/tensor.rs @@ -0,0 +1,536 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Debug; +use std::marker::PhantomData; +use std::rc::Rc; +use std::sync::Arc; + +use array::{ArrayRef, OffsetSizeTrait}; +use arrow_array as array; +use arrow_array::builder; +use downcast_rs::impl_downcast; +use static_str_ops::*; + +use crate::client::*; + +use super::arrow::*; + +pub trait Tensor: Array {} + +impl_downcast!(Tensor); + +pub fn downcast_tensor(object: Box) -> Result> { + return object + .downcast::() + .map_err(|_| VineyardError::invalid(format!("downcast object to tensor failed",))); +} + +pub fn downcast_tensor_ref(object: &dyn Tensor) -> Result<&T> { + return object + .downcast_ref::() + .ok_or(VineyardError::invalid(format!( + "downcast object '{:?}' to tensor failed", + object.meta().get_typename()?, + ))); +} + +pub fn downcast_tensor_rc(object: Rc) -> Result> { + return object + .downcast_rc::() + .map_err(|_| VineyardError::invalid(format!("downcast object to tensor failed",))); +} + +#[derive(Debug)] +pub struct NumericTensor { + meta: ObjectMeta, + shape: Vec, + tensor: Arc>, +} + +pub type Int8Tensor = NumericTensor; +pub type UInt8Tensor = NumericTensor; +pub type Int16Tensor = NumericTensor; +pub type UInt16Tensor = NumericTensor; +pub type Int32Tensor = NumericTensor; +pub type UInt32Tensor = NumericTensor; +pub type Int64Tensor = NumericTensor; +pub type UInt64Tensor = NumericTensor; +pub type Float32Tensor = NumericTensor; +pub type Float64Tensor = NumericTensor; + +impl TypeName for NumericTensor { + fn typename() -> &'static str { + return staticize(format!("vineyard::Tensor<{}>", T::typename())); + } +} + +impl Array for NumericTensor { + fn array(&self) -> array::ArrayRef { + return self.tensor.clone(); + } +} + +impl Tensor for NumericTensor {} + +impl Default for NumericTensor { + fn default() -> Self { + NumericTensor { + meta: ObjectMeta::default(), + shape: vec![], + tensor: Arc::new(TypedArray::::new(vec![].into(), None)), + } + } +} + +impl Object for NumericTensor { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + self.meta = meta; + self.shape = self.meta.get_vector("shape_")?; + let values: arrow_buffer::ScalarBuffer<_> = + resolve_scalar_buffer::(&self.meta, "buffer_")?; + self.tensor = Arc::new(TypedArray::::new(values, None)); + return Ok(()); + } +} + +register_vineyard_object!(NumericTensor); +register_vineyard_types! { + Int8Tensor; + UInt8Tensor; + Int16Tensor; + UInt16Tensor; + Int32Tensor; + UInt32Tensor; + Int64Tensor; + UInt64Tensor; + Float32Tensor; + Float64Tensor; +} + +impl NumericTensor { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut array = Box::::default(); + array.construct(meta)?; + return Ok(array); + } + + pub fn data(&self) -> Arc> { + return self.tensor.clone(); + } + + pub fn shape(&self) -> &[usize] { + return &self.shape; + } + + pub fn len(&self) -> usize { + return self.shape.iter().product::(); + } + + pub fn is_empty(&self) -> bool { + return self.len() == 0; + } + + pub fn as_slice(&self) -> &[T] { + return unsafe { + std::slice::from_raw_parts( + self.tensor.values().inner().as_ptr() as _, + self.tensor.len(), + ) + }; + } +} + +impl AsRef> for NumericTensor { + fn as_ref(&self) -> &TypedArray { + return &self.tensor; + } +} + +pub struct NumericTensorBuilder { + sealed: bool, + shape: Vec, + buffer: BlobWriter, + phantom: PhantomData, +} + +pub type Int8TensorBuilder = NumericTensorBuilder; +pub type UInt8TensorBuilder = NumericTensorBuilder; +pub type Int16TensorBuilder = NumericTensorBuilder; +pub type UInt16TensorBuilder = NumericTensorBuilder; +pub type Int32TensorBuilder = NumericTensorBuilder; +pub type UInt32TensorBuilder = NumericTensorBuilder; +pub type Int64TensorBuilder = NumericTensorBuilder; +pub type UInt64TensorBuilder = NumericTensorBuilder; +pub type Float32TensorBuilder = NumericTensorBuilder; +pub type Float64TensorBuilder = NumericTensorBuilder; + +impl ObjectBuilder for NumericTensorBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for NumericTensorBuilder { + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + self.buffer.build(client)?; + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let nbytes = self.buffer.len(); + let buffer = self.buffer.seal(client)?; + let mut meta = ObjectMeta::new_from_typename(typename::>()); + meta.add_member("buffer_", buffer)?; + meta.add_vector("shape_", &self.shape)?; + meta.set_nbytes(nbytes); + let metadata = client.create_metadata(&meta)?; + return NumericTensor::::new_boxed(metadata); + } +} + +impl NumericTensorBuilder { + pub fn new(client: &mut IPCClient, shape: &[usize]) -> Result { + let length = shape.iter().product::(); + let buffer = client.create_blob(std::mem::size_of::() * length)?; + return Ok(NumericTensorBuilder { + sealed: false, + shape: shape.to_vec(), + buffer, + phantom: PhantomData, + }); + } + + pub fn new_from_array( + client: &mut IPCClient, + shape: &[usize], + array: &TypedArray, + ) -> Result { + let buffer = build_scalar_buffer::(client, array.values())?; + return Ok(NumericTensorBuilder { + sealed: false, + shape: shape.to_vec(), + buffer, + phantom: PhantomData, + }); + } + + pub fn new_from_array_1d(client: &mut IPCClient, array: &TypedArray) -> Result { + return Self::new_from_array(client, &[array.len()], array); + } + + pub fn new_from_builder( + client: &mut IPCClient, + shape: &[usize], + builder: &mut TypedBuilder, + ) -> Result { + let array = builder.finish(); + return Self::new_from_array(client, shape, &array); + } + + pub fn shape(&self) -> &[usize] { + return &self.shape; + } + + pub fn len(&self) -> usize { + return self.shape.iter().product::(); + } + + pub fn is_empty(&self) -> bool { + return self.len() == 0; + } + + pub fn as_slice(&mut self) -> &[T] { + return unsafe { std::mem::transmute(self.buffer.as_slice()) }; + } + + pub fn as_mut_slice(&mut self) -> &mut [T] { + return unsafe { std::mem::transmute(self.buffer.as_mut_slice()) }; + } +} + +#[derive(Debug)] +pub struct StringTensor { + meta: ObjectMeta, + shape: Vec, + tensor: Arc>, +} + +impl Array for StringTensor { + fn array(&self) -> array::ArrayRef { + return self.tensor.clone(); + } +} + +impl Tensor for StringTensor {} + +impl TypeName for StringTensor { + fn typename() -> &'static str { + return staticize("vineyard::Tensor"); + } +} + +impl Default for StringTensor { + fn default() -> Self { + StringTensor { + meta: ObjectMeta::default(), + shape: vec![], + tensor: Arc::new(array::GenericStringArray::::new_null(0)), + } + } +} + +impl Object for StringTensor { + fn construct(&mut self, meta: ObjectMeta) -> Result<()> { + vineyard_assert_typename(typename::(), meta.get_typename()?)?; + self.meta = meta; + self.shape = self.meta.get_vector("shape_")?; + self.tensor = self.meta.get_member::("buffer_")?.data(); + return Ok(()); + } +} + +register_vineyard_object!(StringTensor); + +impl StringTensor { + pub fn new_boxed(meta: ObjectMeta) -> Result> { + let mut array = Box::::default(); + array.construct(meta)?; + return Ok(array); + } + + pub fn data(&self) -> Arc> { + return self.tensor.clone(); + } + + pub fn shape(&self) -> &[usize] { + return &self.shape; + } + + pub fn len(&self) -> usize { + return self.shape.iter().product::(); + } + + pub fn is_empty(&self) -> bool { + return self.len() == 0; + } + + pub fn as_slice(&self) -> &[u8] { + return self.tensor.value_data(); + } + + pub fn as_slice_offsets(&self) -> &[i64] { + return self.tensor.value_offsets(); + } +} + +impl AsRef> for StringTensor { + fn as_ref(&self) -> &array::GenericStringArray { + return &self.tensor; + } +} + +pub struct BaseStringTensorBuilder { + sealed: bool, + shape: Vec, + tensor: BaseStringBuilder, +} + +pub type StringTensorBuilder = BaseStringTensorBuilder; +pub type LargeStringTensorBuilder = BaseStringTensorBuilder; + +impl ObjectBuilder for BaseStringTensorBuilder { + fn sealed(&self) -> bool { + self.sealed + } + + fn set_sealed(&mut self, sealed: bool) { + self.sealed = sealed; + } +} + +impl ObjectBase for BaseStringTensorBuilder { + fn build(&mut self, client: &mut IPCClient) -> Result<()> { + if self.sealed { + return Ok(()); + } + self.set_sealed(true); + self.tensor.build(client)?; + return Ok(()); + } + + fn seal(mut self, client: &mut IPCClient) -> Result> { + self.build(client)?; + let nbytes = self.tensor.len(); + let tensor = self.tensor.seal(client)?; + let mut meta = ObjectMeta::new_from_typename(typename::()); + meta.add_member("buffer_", tensor)?; + meta.add_vector("shape_", &self.shape)?; + meta.add_vector::("partition_index_", &[-1, -1])?; + meta.set_nbytes(nbytes); + let metadata = client.create_metadata(&meta)?; + return StringTensor::new_boxed(metadata); + } +} + +impl BaseStringTensorBuilder { + pub fn new_from_array( + client: &mut IPCClient, + shape: &[usize], + array: &array::GenericStringArray, + ) -> Result { + return Ok(BaseStringTensorBuilder { + sealed: false, + shape: shape.to_vec(), + tensor: BaseStringBuilder::::new_from_array(client, array)?, + }); + } + + pub fn new_from_array_1d( + client: &mut IPCClient, + array: &array::GenericStringArray, + ) -> Result { + use array::Array; + return Self::new_from_array(client, &[array.len()], array); + } + + pub fn new_from_builder( + client: &mut IPCClient, + shape: &[usize], + builder: &mut builder::GenericStringBuilder, + ) -> Result { + let array = builder.finish(); + return Self::new_from_array(client, shape, &array); + } + + pub fn shape(&self) -> &[usize] { + return &self.shape; + } + + pub fn len(&self) -> usize { + return self.shape.iter().product::(); + } + + pub fn is_empty(&self) -> bool { + return self.len() == 0; + } + + pub fn as_slice(&mut self) -> &[u8] { + return self.tensor.as_slice(); + } + + pub fn as_mut_slice(&mut self) -> &mut [u8] { + return self.tensor.as_mut_slice(); + } + + pub fn as_slice_offsets(&mut self) -> &[O] { + return self.tensor.as_slice_offsets(); + } + + pub fn as_mut_slice_offsets(&mut self) -> &mut [O] { + return self.tensor.as_mut_slice_offsets(); + } +} + +pub fn downcast_to_tensor(object: Box) -> Result> { + macro_rules! downcast { + ($object: ident, $ty: ty) => { + |$object| match $object.downcast::<$ty>() { + Ok(array) => Ok(array), + Err(original) => Err(original), + } + }; + } + + let mut object: std::result::Result, Box> = Err(object); + object = object + .or_else(downcast!(object, Int8Tensor)) + .or_else(downcast!(object, UInt8Tensor)) + .or_else(downcast!(object, Int16Tensor)) + .or_else(downcast!(object, UInt16Tensor)) + .or_else(downcast!(object, Int32Tensor)) + .or_else(downcast!(object, UInt32Tensor)) + .or_else(downcast!(object, Int64Tensor)) + .or_else(downcast!(object, UInt64Tensor)) + .or_else(downcast!(object, Float32Tensor)) + .or_else(downcast!(object, Float64Tensor)) + .or_else(downcast!(object, StringTensor)) + .or_else(downcast!(object, StringTensor)); + + match object { + Ok(array) => return Ok(array), + Err(object) => { + return Err(VineyardError::invalid(format!( + "downcast object to tensor failed, object type is: '{}'", + object.meta().get_typename()?, + ))) + } + }; +} + +pub fn build_tensor(client: &mut IPCClient, array: ArrayRef) -> Result> { + macro_rules! build { + ($array: ident, $array_ty: ty, $builder_ty: ty) => { + |$array| match $array.as_any().downcast_ref::<$array_ty>() { + Some(array) => match <$builder_ty>::new_from_array_1d(client, array) { + Ok(builder) => match builder.seal(client) { + Ok(object) => Ok(object), + Err(_) => Err(array as &dyn array::Array), + }, + Err(_) => Err(array as &dyn array::Array), + }, + None => Err($array), + } + }; + } + + let mut array: std::result::Result, &dyn array::Array> = Err(array.as_ref()); + array = array + .or_else(build!(array, array::Int8Array, Int8TensorBuilder)) + .or_else(build!(array, array::UInt8Array, UInt8TensorBuilder)) + .or_else(build!(array, array::Int16Array, Int16TensorBuilder)) + .or_else(build!(array, array::UInt16Array, UInt16TensorBuilder)) + .or_else(build!(array, array::Int32Array, Int32TensorBuilder)) + .or_else(build!(array, array::UInt32Array, UInt32TensorBuilder)) + .or_else(build!(array, array::Int64Array, Int64TensorBuilder)) + .or_else(build!(array, array::UInt64Array, UInt64TensorBuilder)) + .or_else(build!(array, array::Float32Array, Float32TensorBuilder)) + .or_else(build!(array, array::Float64Array, Float64TensorBuilder)) + .or_else(build!(array, array::StringArray, StringTensorBuilder)) + .or_else(build!( + array, + array::LargeStringArray, + LargeStringTensorBuilder + )); + + match array { + Ok(builder) => return Ok(builder), + Err(array) => { + return Err(VineyardError::invalid(format!( + "build array failed, array type is: '{}'", + array.data_type(), + ))) + } + }; +} diff --git a/rust/vineyard/src/ds/tensor_test.rs b/rust/vineyard/src/ds/tensor_test.rs new file mode 100644 index 0000000000..f54b6c5ad4 --- /dev/null +++ b/rust/vineyard/src/ds/tensor_test.rs @@ -0,0 +1,16 @@ +// Copyright 2020-2023 Alibaba Group Holding Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[allow(unused_imports)] +use std::any::Any; diff --git a/rust/vineyard/src/lib.rs b/rust/vineyard/src/lib.rs index ab3e02a11d..e251e138e0 100644 --- a/rust/vineyard/src/lib.rs +++ b/rust/vineyard/src/lib.rs @@ -12,10 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![allow(clippy::box_default)] +#![allow(clippy::needless_borrow)] +#![allow(clippy::needless_lifetimes)] +#![allow(clippy::needless_return)] +#![allow(clippy::nonminimal_bool)] +#![allow(clippy::not_unsafe_ptr_arg_deref)] +#![allow(clippy::redundant_field_names)] +#![allow(clippy::unnecessary_cast)] +#![allow(clippy::vec_box)] #![allow(incomplete_features)] +#![allow(non_upper_case_globals)] #![cfg_attr(feature = "nightly", feature(associated_type_defaults))] #![cfg_attr(feature = "nightly", feature(box_into_inner))] #![cfg_attr(feature = "nightly", feature(specialization))] +#![cfg_attr(feature = "nightly", feature(trait_alias))] #![cfg_attr(feature = "nightly", feature(unix_socket_peek))] #[macro_use]