From f5a8c6cf070d6e269490833587fdb17bb532a97a Mon Sep 17 00:00:00 2001 From: Ray Zhang Date: Mon, 17 Jul 2023 01:46:15 -0400 Subject: [PATCH] feat(rust, python): Add cloudpickle for serializing python UDFs (#9921) --- .../polars-lazy/polars-plan/src/dsl/python_udf.rs | 14 +++++++++----- py-polars/polars/utils/show_versions.py | 1 + py-polars/pyproject.toml | 3 ++- py-polars/requirements-dev.txt | 1 + py-polars/tests/unit/test_serde.py | 14 ++++++++++++++ 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/polars/polars-lazy/polars-plan/src/dsl/python_udf.rs b/polars/polars-lazy/polars-plan/src/dsl/python_udf.rs index e5d73d696869..46bf0d97795b 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/python_udf.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/python_udf.rs @@ -56,8 +56,9 @@ impl Serialize for PythonFunction { S: Serializer, { Python::with_gil(|py| { - let pickle = PyModule::import(py, "pickle") - .expect("Unable to import 'pickle'") + let pickle = PyModule::import(py, "cloudpickle") + .or(PyModule::import(py, "pickle")) + .expect("Unable to import 'cloudpickle' or 'pickle'") .getattr("dumps") .unwrap(); @@ -83,7 +84,8 @@ impl<'a> Deserialize<'a> for PythonFunction { let bytes = Vec::::deserialize(deserializer)?; Python::with_gil(|py| { - let pickle = PyModule::import(py, "pickle") + let pickle = PyModule::import(py, "cloudpickle") + .or(PyModule::import(py, "pickle")) .expect("Unable to import 'pickle'") .getattr("loads") .unwrap(); @@ -122,7 +124,8 @@ impl PythonUdfExpression { let remainder = &buf[reader.position() as usize..]; Python::with_gil(|py| { - let pickle = PyModule::import(py, "pickle") + let pickle = PyModule::import(py, "cloudpickle") + .or(PyModule::import(py, "pickle")) .expect("Unable to import 'pickle'") .getattr("loads") .unwrap(); @@ -169,7 +172,8 @@ impl SeriesUdf for PythonUdfExpression { ciborium::ser::into_writer(&self.output_type, &mut *buf).unwrap(); Python::with_gil(|py| { - let pickle = PyModule::import(py, "pickle") + let pickle = PyModule::import(py, "cloudpickle") + .or(PyModule::import(py, "pickle")) .expect("Unable to import 'pickle'") .getattr("dumps") .unwrap(); diff --git a/py-polars/polars/utils/show_versions.py b/py-polars/polars/utils/show_versions.py index f34db533ec26..2f7ff9ed42f6 100644 --- a/py-polars/polars/utils/show_versions.py +++ b/py-polars/polars/utils/show_versions.py @@ -59,6 +59,7 @@ def _get_dependency_info() -> dict[str, str]: # see the list of dependencies in pyproject.toml opt_deps = [ "adbc_driver_sqlite", + "cloudpickle", "connectorx", "deltalake", "fsspec", diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 724f6638ce6e..b6a615424d50 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -51,8 +51,9 @@ pydantic = ["pydantic"] sqlalchemy = ["sqlalchemy", "pandas"] xlsxwriter = ["xlsxwriter"] adbc = ["adbc_driver_sqlite"] +cloudpickle = ["cloudpickle"] all = [ - "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,deltalake,timezone,matplotlib,pydantic,sqlalchemy,xlsxwriter,adbc]", + "polars[pyarrow,pandas,numpy,fsspec,connectorx,xlsx2csv,deltalake,timezone,matplotlib,pydantic,sqlalchemy,xlsxwriter,adbc,cloudpickle]", ] [tool.mypy] diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 1d05bc42f0bc..96d05a38d370 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -17,6 +17,7 @@ xlsx2csv XlsxWriter adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows' connectorx==0.3.2a5; python_version >= '3.8' # Latest full release is broken - unpin when 0.3.2 released +cloudpickle # Tooling hypothesis==6.79.4; python_version < '3.8' diff --git a/py-polars/tests/unit/test_serde.py b/py-polars/tests/unit/test_serde.py index 0734a8eb9b5a..fe5939e1bc56 100644 --- a/py-polars/tests/unit/test_serde.py +++ b/py-polars/tests/unit/test_serde.py @@ -152,3 +152,17 @@ def test_pickle_lazyframe_udf() -> None: q = pickle.loads(b) assert q.collect()["a"].to_list() == [2, 4, 6] + + +def test_pickle_lazyframe_nested_function_udf() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + + # NOTE: This is only possible when we're using cloudpickle. + def inner_df_times2(df: pl.DataFrame) -> pl.DataFrame: + return df.select(pl.all() * 2) + + q = df.lazy().map(inner_df_times2) + b = pickle.dumps(q) + + q = pickle.loads(b) + assert q.collect()["a"].to_list() == [2, 4, 6]