From 07086179bacb67c816aa5d7ee4399b61fd8c514d Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sat, 9 Sep 2023 21:08:59 +0200 Subject: [PATCH] Use string cache holder in context manager --- .../logical/categorical/string_cache.rs | 18 ++++++-------- py-polars/polars/string_cache.py | 10 ++++---- py-polars/src/functions/string_cache.rs | 21 ++++++++++++---- py-polars/src/lib.rs | 4 ++-- py-polars/tests/unit/test_string_cache.py | 24 ++++++++++++++++++- 5 files changed, 54 insertions(+), 23 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs index 9b3d02d685e48..1b7b41c5b7813 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs @@ -43,14 +43,14 @@ impl Default for IUseStringCache { impl IUseStringCache { /// Hold the StringCache pub fn hold() -> IUseStringCache { - _set_string_cache(true); + set_string_cache(true); IUseStringCache { private_zst: () } } } impl Drop for IUseStringCache { fn drop(&mut self) { - _set_string_cache(false) + set_string_cache(false) } } @@ -66,7 +66,7 @@ impl Drop for IUseStringCache { /// /// [`Categorical`]: crate::datatypes::DataType::Categorical pub fn enable_string_cache() { - _set_string_cache(true) + set_string_cache(true) } /// Disable and clear the global string cache. @@ -77,9 +77,9 @@ pub fn disable_string_cache() { /// Execute a function with the global string cache enabled. pub fn with_string_cache T, T>(func: F) -> T { - _set_string_cache(true); + set_string_cache(true); let out = func(); - _set_string_cache(false); + set_string_cache(false); out } @@ -88,12 +88,8 @@ pub fn using_string_cache() -> bool { USE_STRING_CACHE.load(Ordering::Acquire) > 0 } -/// Incrementing or decrement the number of string cache uses. -/// -/// WARNING: Do not use this function directly. This is a private function -/// intended for creating RAII objects. It is technically public because it is -/// used directly by the Python implementation to create a context manager. -pub fn _set_string_cache(active: bool) { +/// Increment or decrement the number of string cache uses. +fn set_string_cache(active: bool) { if active { USE_STRING_CACHE.fetch_add(1, Ordering::Release); } else { diff --git a/py-polars/polars/string_cache.py b/py-polars/polars/string_cache.py index 5c0dc3c411c7a..5c0a6317ffab0 100644 --- a/py-polars/polars/string_cache.py +++ b/py-polars/polars/string_cache.py @@ -7,6 +7,7 @@ with contextlib.suppress(ImportError): # Module not available when building docs import polars.polars as plr + from polars.polars import PyStringCacheHolder if TYPE_CHECKING: from types import TracebackType @@ -56,7 +57,7 @@ class StringCache(contextlib.ContextDecorator): """ def __enter__(self) -> StringCache: - plr._set_string_cache(True) + self._string_cache = PyStringCacheHolder() return self def __exit__( @@ -65,7 +66,7 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - plr._set_string_cache(False) + del self._string_cache def enable_string_cache(enable: bool | None = None) -> None: @@ -132,8 +133,9 @@ def enable_string_cache(enable: bool | None = None) -> None: " and `disable_string_cache()` to disable the string cache.", version="0.19.3", ) - plr._set_string_cache(enable) - return + if enable is False: + plr.disable_string_cache() + return plr.enable_string_cache() diff --git a/py-polars/src/functions/string_cache.rs b/py-polars/src/functions/string_cache.rs index 78bbab8cd94b8..7613953e11661 100644 --- a/py-polars/src/functions/string_cache.rs +++ b/py-polars/src/functions/string_cache.rs @@ -1,11 +1,7 @@ use polars_core; +use polars_core::IUseStringCache; use pyo3::prelude::*; -#[pyfunction] -pub fn _set_string_cache(active: bool) { - polars_core::_set_string_cache(active) -} - #[pyfunction] pub fn enable_string_cache() { polars_core::enable_string_cache() @@ -20,3 +16,18 @@ pub fn disable_string_cache() { pub fn using_string_cache() -> bool { polars_core::using_string_cache() } + +#[pyclass] +pub struct PyStringCacheHolder { + _inner: IUseStringCache, +} + +#[pymethods] +impl PyStringCacheHolder { + #[new] + fn new() -> Self { + Self { + _inner: IUseStringCache::hold(), + } + } +} diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index ec0f1ee49ed1d..3f3887da8d693 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -55,6 +55,7 @@ use crate::error::{ StructFieldNotFoundError, }; use crate::expr::PyExpr; +use crate::functions::string_cache::PyStringCacheHolder; use crate::lazyframe::PyLazyFrame; use crate::lazygroupby::PyLazyGroupBy; use crate::series::PySeries; @@ -75,6 +76,7 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); + m.add_class::().unwrap(); #[cfg(feature = "csv")] m.add_class::().unwrap(); #[cfg(feature = "sql")] @@ -211,8 +213,6 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::meta::threadpool_size)) .unwrap(); - m.add_wrapped(wrap_pyfunction!(functions::string_cache::_set_string_cache)) - .unwrap(); m.add_wrapped(wrap_pyfunction!( functions::string_cache::enable_string_cache )) diff --git a/py-polars/tests/unit/test_string_cache.py b/py-polars/tests/unit/test_string_cache.py index fc73de291fba3..c0478b3b15345 100644 --- a/py-polars/tests/unit/test_string_cache.py +++ b/py-polars/tests/unit/test_string_cache.py @@ -16,7 +16,7 @@ def _disable_string_cache() -> Iterator[None]: def sc(set: bool) -> None: - """Short syntax for checking whether string cache is set.""" + """Short syntax for asserting whether the global string cache is being used.""" assert pl.using_string_cache() is set @@ -85,6 +85,28 @@ def test_string_cache_context_manager_mixed_with_enable_disable() -> None: sc(False) +def test_string_cache_decorator() -> None: + @pl.StringCache() + def my_function() -> None: + sc(True) + + sc(False) + my_function() + sc(False) + + +def test_string_cache_decorator_mixed_with_enable() -> None: + @pl.StringCache() + def my_function() -> None: + sc(True) + pl.enable_string_cache() + sc(True) + + sc(False) + my_function() + sc(True) + + def test_string_cache_enable_arg_deprecated() -> None: sc(False) with pytest.deprecated_call():