diff --git a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs index 69568d55462d..4d922b72c9f5 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/builder.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/builder.rs @@ -512,12 +512,12 @@ impl CategoricalChunked { mod test { use crate::chunked_array::categorical::CategoricalChunkedBuilder; use crate::prelude::*; - use crate::{enable_string_cache, reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; #[test] fn test_categorical_rev() -> PolarsResult<()> { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); + disable_string_cache(); let slice = &[ Some("foo"), None, @@ -532,7 +532,7 @@ mod test { assert_eq!(out.get_rev_map().len(), 2); // test the global branch - enable_string_cache(true); + enable_string_cache(); // empty global cache let out = ca.cast(&DataType::Categorical(None))?; let out = out.categorical().unwrap().clone(); @@ -556,11 +556,13 @@ mod test { #[test] fn test_categorical_builder() { - use crate::{enable_string_cache, reset_string_cache}; + use crate::{disable_string_cache, enable_string_cache}; let _lock = crate::SINGLE_LOCK.lock(); - for b in &[false, true] { - reset_string_cache(); - enable_string_cache(*b); + for use_string_cache in [false, true] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } // Use 2 builders to check if the global string cache // does not interfere with the index mapping diff --git a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs index 2e53a4e5fddb..a5251097bf96 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/merge.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/merge.rs @@ -153,13 +153,13 @@ impl CategoricalChunked { mod test { use super::*; use crate::chunked_array::categorical::CategoricalChunkedBuilder; - use crate::{enable_string_cache, reset_string_cache, IUseStringCache}; + use crate::{disable_string_cache, enable_string_cache, StringCacheHolder}; #[test] fn test_merge_rev_map() { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); - let _sc = IUseStringCache::hold(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); let mut builder1 = CategoricalChunkedBuilder::new("foo", 10); let mut builder2 = CategoricalChunkedBuilder::new("foo", 10); diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index ec6cd04704ca..d66e6318b5ef 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -2,7 +2,7 @@ mod builder; mod from; mod merge; mod ops; -pub mod stringcache; +pub mod string_cache; use bitflags::bitflags; pub use builder::*; @@ -265,12 +265,12 @@ mod test { use std::convert::TryFrom; use super::*; - use crate::{enable_string_cache, reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; #[test] fn test_categorical_round_trip() -> PolarsResult<()> { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); + disable_string_cache(); let slice = &[ Some("foo"), None, @@ -295,8 +295,8 @@ mod test { #[test] fn test_append_categorical() { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); - enable_string_cache(true); + disable_string_cache(); + enable_string_cache(); let mut s1 = Series::new("1", vec!["a", "b", "c"]) .cast(&DataType::Categorical(None)) @@ -329,8 +329,7 @@ mod test { #[test] fn test_categorical_flow() -> PolarsResult<()> { let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); - enable_string_cache(false); + disable_string_cache(); // tests several things that may lose the dtype information let s = Series::new("a", vec!["a", "b", "c"]).cast(&DataType::Categorical(None))?; diff --git a/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs similarity index 62% rename from crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs rename to crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs index be590fa40066..f39a1523446c 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/stringcache.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/string_cache.rs @@ -1,6 +1,6 @@ use std::hash::{Hash, Hasher}; -use std::sync::atomic::{AtomicU32, Ordering}; -use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; +use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use std::sync::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; use ahash::RandomState; use hashbrown::hash_map::RawEntryMut; @@ -11,80 +11,103 @@ use crate::datatypes::PlIdHashMap; use crate::hashing::_HASHMAP_INIT_SIZE; use crate::prelude::InitHashMaps; -/// We use atomic reference counting -/// to determine how many threads use the string cache -/// if the refcount is zero, we may clear the string cache. -pub(crate) static USE_STRING_CACHE: AtomicU32 = AtomicU32::new(0); +/// We use atomic reference counting to determine how many threads use the +/// string cache. If the refcount is zero, we may clear the string cache. +static STRING_CACHE_REFCOUNT: Mutex = Mutex::new(0); +static STRING_CACHE_ENABLED_GLOBALLY: AtomicBool = AtomicBool::new(false); static STRING_CACHE_UUID_CTR: AtomicU32 = AtomicU32::new(0); -/// RAII for the string cache -/// If an operation creates categoricals and uses them in a join -/// or comparison that operation must hold this cache via -/// `let handle = IUseStringCache::hold()` -/// The cache is valid until `handle` is dropped. +/// Enable the global string cache as long as the object is alive ([RAII]). +/// +/// # Examples +/// +/// Enable the string cache by initializing the object: +/// +/// ``` +/// use polars_core::StringCacheHolder; +/// +/// let _sc = StringCacheHolder::hold(); +/// ``` +/// +/// The string cache is enabled until `handle` is dropped. /// /// # De-allocation +/// /// Multiple threads can hold the string cache at the same time. -/// The contents of the cache will only get dropped when no -/// thread holds it. -pub struct IUseStringCache { +/// The contents of the cache will only get dropped when no thread holds it. +/// +/// [RAII]: https://en.wikipedia.org/wiki/Resource_acquisition_is_initialization +pub struct StringCacheHolder { // only added so that it will never be constructed directly #[allow(dead_code)] private_zst: (), } -impl Default for IUseStringCache { +impl Default for StringCacheHolder { fn default() -> Self { Self::hold() } } -impl IUseStringCache { +impl StringCacheHolder { /// Hold the StringCache - pub fn hold() -> IUseStringCache { - enable_string_cache(true); - IUseStringCache { private_zst: () } + pub fn hold() -> StringCacheHolder { + increment_string_cache_refcount(); + StringCacheHolder { private_zst: () } } } -impl Drop for IUseStringCache { +impl Drop for StringCacheHolder { fn drop(&mut self) { - enable_string_cache(false) + decrement_string_cache_refcount(); } } -pub fn with_string_cache T, T>(func: F) -> T { - enable_string_cache(true); - let out = func(); - enable_string_cache(false); - out +fn increment_string_cache_refcount() { + let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount += 1; +} +fn decrement_string_cache_refcount() { + let mut refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount -= 1; + if *refcount == 0 { + STRING_CACHE.clear() + } } -/// Use a global string cache for the Categorical Types. +/// Enable the global string cache. /// -/// This is used to cache the string categories locally. -/// This allows join operations on categorical types. -pub fn enable_string_cache(toggle: bool) { - if toggle { - USE_STRING_CACHE.fetch_add(1, Ordering::Release); - } else { - let previous = USE_STRING_CACHE.fetch_sub(1, Ordering::Release); - if previous == 0 || previous == 1 { - USE_STRING_CACHE.store(0, Ordering::Release); - STRING_CACHE.clear() - } +/// [`Categorical`] columns created under the same global string cache have the +/// same underlying physical value when string values are equal. This allows the +/// columns to be concatenated or used in a join operation, for example. +/// +/// Note that enabling the global string cache introduces some overhead. +/// The amount of overhead depends on the number of categories in your data. +/// It is advised to enable the global string cache only when strictly necessary. +/// +/// [`Categorical`]: crate::datatypes::DataType::Categorical +pub fn enable_string_cache() { + let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(true, Ordering::AcqRel); + if !was_enabled { + increment_string_cache_refcount(); } } -/// Reset the global string cache used for the Categorical Types. -pub fn reset_string_cache() { - USE_STRING_CACHE.store(0, Ordering::Release); - STRING_CACHE.clear() +/// Disable and clear the global string cache. +/// +/// Note: Consider using [`StringCacheHolder`] for a more reliable way of +/// enabling and disabling the string cache. +pub fn disable_string_cache() { + let was_enabled = STRING_CACHE_ENABLED_GLOBALLY.swap(false, Ordering::AcqRel); + if was_enabled { + decrement_string_cache_refcount(); + } } -/// Check if string cache is set. +/// Check whether the global string cache is enabled. pub fn using_string_cache() -> bool { - USE_STRING_CACHE.load(Ordering::Acquire) > 0 + let refcount = STRING_CACHE_REFCOUNT.lock().unwrap(); + *refcount > 0 } // This is the hash and the Index offset in the linear buffer diff --git a/crates/polars-core/src/chunked_array/mod.rs b/crates/polars-core/src/chunked_array/mod.rs index d7518ecec4b7..f93ba54d5c46 100644 --- a/crates/polars-core/src/chunked_array/mod.rs +++ b/crates/polars-core/src/chunked_array/mod.rs @@ -827,9 +827,9 @@ pub(crate) mod test { #[test] #[cfg(feature = "dtype-categorical")] fn test_iter_categorical() { - use crate::{reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, SINGLE_LOCK}; let _lock = SINGLE_LOCK.lock(); - reset_string_cache(); + disable_string_cache(); let ca = Utf8Chunked::new("", &[Some("foo"), None, Some("bar"), Some("ham")]); let ca = ca.cast(&DataType::Categorical(None)).unwrap(); let ca = ca.categorical().unwrap(); diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index c87d7fce9130..337dee580b2a 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -120,7 +120,7 @@ impl CategoricalChunked { #[cfg(test)] mod test { use crate::prelude::*; - use crate::{enable_string_cache, reset_string_cache, SINGLE_LOCK}; + use crate::{disable_string_cache, enable_string_cache, SINGLE_LOCK}; fn assert_order(ca: &CategoricalChunked, cmp: &[&str]) { let s = ca.cast(&DataType::Utf8).unwrap(); @@ -133,9 +133,12 @@ mod test { let init = &["c", "b", "a", "d"]; let _lock = SINGLE_LOCK.lock(); - for toggle in [true, false] { - reset_string_cache(); - enable_string_cache(toggle); + for use_string_cache in [true, false] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } + let s = Series::new("", init).cast(&DataType::Categorical(None))?; let ca = s.categorical()?; let mut ca_lexical = ca.clone(); @@ -157,13 +160,16 @@ mod test { } #[test] - fn test_cat_lexical_sort_multiple() -> PolarsResult<()> { let init = &["c", "b", "a", "a"]; let _lock = SINGLE_LOCK.lock(); - for enable in [true, false] { - enable_string_cache(enable); + for use_string_cache in [true, false] { + disable_string_cache(); + if use_string_cache { + enable_string_cache(); + } + let s = Series::new("", init).cast(&DataType::Categorical(None))?; let ca = s.categorical()?; let mut ca_lexical: CategoricalChunked = ca.clone(); diff --git a/crates/polars-core/src/lib.rs b/crates/polars-core/src/lib.rs index a734db66e38b..8bc12845f01f 100644 --- a/crates/polars-core/src/lib.rs +++ b/crates/polars-core/src/lib.rs @@ -38,7 +38,7 @@ use once_cell::sync::Lazy; use rayon::{ThreadPool, ThreadPoolBuilder}; #[cfg(feature = "dtype-categorical")] -pub use crate::chunked_array::logical::categorical::stringcache::*; +pub use crate::chunked_array::logical::categorical::string_cache::*; pub static PROCESS_ID: Lazy = Lazy::new(|| { SystemTime::now() diff --git a/crates/polars-error/src/lib.rs b/crates/polars-error/src/lib.rs index db190d63370b..db89125643ec 100644 --- a/crates/polars-error/src/lib.rs +++ b/crates/polars-error/src/lib.rs @@ -200,7 +200,7 @@ Help: if you're using Python, this may look something like: Alternatively, if the performance cost is acceptable, you could just set: import polars as pl - pl.enable_string_cache(True) + pl.enable_string_cache() on startup."#.trim_start()) }; diff --git a/crates/polars-io/src/csv/read.rs b/crates/polars-io/src/csv/read.rs index 5f0b3a228596..8d5b4f67dc90 100644 --- a/crates/polars-io/src/csv/read.rs +++ b/crates/polars-io/src/csv/read.rs @@ -584,7 +584,7 @@ where #[cfg(feature = "dtype-categorical")] if _has_cat { - _cat_lock = Some(polars_core::IUseStringCache::hold()) + _cat_lock = Some(polars_core::StringCacheHolder::hold()) } let mut csv_reader = self.core_reader(Some(Arc::new(schema)), to_cast)?; @@ -602,7 +602,7 @@ where }) .unwrap_or(false); if has_cat { - _cat_lock = Some(polars_core::IUseStringCache::hold()) + _cat_lock = Some(polars_core::StringCacheHolder::hold()) } } let mut csv_reader = self.core_reader(self.schema.clone(), vec![])?; diff --git a/crates/polars-io/src/csv/read_impl/batched_mmap.rs b/crates/polars-io/src/csv/read_impl/batched_mmap.rs index 18824d5e08f1..93251de658bf 100644 --- a/crates/polars-io/src/csv/read_impl/batched_mmap.rs +++ b/crates/polars-io/src/csv/read_impl/batched_mmap.rs @@ -136,7 +136,7 @@ impl<'a> CoreReader<'a> { // RAII structure that will ensure we maintain a global stringcache #[cfg(feature = "dtype-categorical")] let _cat_lock = if _has_cat { - Some(polars_core::IUseStringCache::hold()) + Some(polars_core::StringCacheHolder::hold()) } else { None }; @@ -196,7 +196,7 @@ pub struct BatchedCsvReaderMmap<'a> { schema: SchemaRef, rows_read: IdxSize, #[cfg(feature = "dtype-categorical")] - _cat_lock: Option, + _cat_lock: Option, #[cfg(not(feature = "dtype-categorical"))] _cat_lock: Option, } diff --git a/crates/polars-io/src/csv/read_impl/batched_read.rs b/crates/polars-io/src/csv/read_impl/batched_read.rs index af3831f00b70..7f6b94c579f1 100644 --- a/crates/polars-io/src/csv/read_impl/batched_read.rs +++ b/crates/polars-io/src/csv/read_impl/batched_read.rs @@ -219,7 +219,7 @@ impl<'a> CoreReader<'a> { // RAII structure that will ensure we maintain a global stringcache #[cfg(feature = "dtype-categorical")] let _cat_lock = if _has_cat { - Some(polars_core::IUseStringCache::hold()) + Some(polars_core::StringCacheHolder::hold()) } else { None }; @@ -279,7 +279,7 @@ pub struct BatchedCsvReaderRead<'a> { schema: SchemaRef, rows_read: IdxSize, #[cfg(feature = "dtype-categorical")] - _cat_lock: Option, + _cat_lock: Option, #[cfg(not(feature = "dtype-categorical"))] _cat_lock: Option, } diff --git a/crates/polars-io/src/parquet/read_impl.rs b/crates/polars-io/src/parquet/read_impl.rs index aadbf06e6f84..a5612751557f 100644 --- a/crates/polars-io/src/parquet/read_impl.rs +++ b/crates/polars-io/src/parquet/read_impl.rs @@ -251,10 +251,10 @@ pub fn read_parquet( // if there are multiple row groups and categorical data // we need a string cache // we keep it alive until the end of the function - let _string_cache = if n_row_groups > 1 { + let _sc = if n_row_groups > 1 { #[cfg(feature = "dtype-categorical")] { - Some(polars_core::IUseStringCache::hold()) + Some(polars_core::StringCacheHolder::hold()) } #[cfg(not(feature = "dtype-categorical"))] { diff --git a/crates/polars-lazy/src/physical_plan/expressions/window.rs b/crates/polars-lazy/src/physical_plan/expressions/window.rs index cbae8d370524..7a0e0c7f80a6 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/window.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/window.rs @@ -495,7 +495,7 @@ impl PhysicalExpr for WindowExpr { // Worst case is that a categorical is created with indexes from the string // cache which is fine, as the physical representation is undefined. #[cfg(feature = "dtype-categorical")] - let _sc = polars_core::IUseStringCache::hold(); + let _sc = polars_core::StringCacheHolder::hold(); let mut ac = self.run_aggregation(df, state, &gb)?; use MapStrategy::*; diff --git a/crates/polars-plan/src/logical_plan/functions/mod.rs b/crates/polars-plan/src/logical_plan/functions/mod.rs index fd7ef573ea0c..72efcf252c39 100644 --- a/crates/polars-plan/src/logical_plan/functions/mod.rs +++ b/crates/polars-plan/src/logical_plan/functions/mod.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use polars_core::prelude::*; #[cfg(feature = "dtype-categorical")] -use polars_core::IUseStringCache; +use polars_core::StringCacheHolder; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; use smartstring::alias::String as SmartString; @@ -332,7 +332,7 @@ impl FunctionNode { // we use a global string cache here as streaming chunks all have different rev maps #[cfg(feature = "dtype-categorical")] { - let _hold = IUseStringCache::hold(); + let _sc = StringCacheHolder::hold(); Arc::get_mut(function).unwrap().call_udf(df) } diff --git a/crates/polars/src/docs/performance.rs b/crates/polars/src/docs/performance.rs index b27d12f95028..e8f5da30c538 100644 --- a/crates/polars/src/docs/performance.rs +++ b/crates/polars/src/docs/performance.rs @@ -67,7 +67,7 @@ //! //! fn example(mut df_a: DataFrame, mut df_b: DataFrame) -> PolarsResult { //! // Set a global string cache -//! enable_string_cache(true); +//! enable_string_cache(); //! //! df_a.try_apply("a", |s| s.categorical().cloned())?; //! df_b.try_apply("b", |s| s.categorical().cloned())?; diff --git a/crates/polars/tests/it/core/joins.rs b/crates/polars/tests/it/core/joins.rs index 9077d5cff21d..5f3550abfa79 100644 --- a/crates/polars/tests/it/core/joins.rs +++ b/crates/polars/tests/it/core/joins.rs @@ -1,6 +1,6 @@ use polars_core::utils::{accumulate_dataframes_vertical, split_df}; #[cfg(feature = "dtype-categorical")] -use polars_core::{reset_string_cache, IUseStringCache}; +use polars_core::{disable_string_cache, StringCacheHolder, SINGLE_LOCK}; use super::*; @@ -256,8 +256,9 @@ fn test_join_multiple_columns() { #[cfg_attr(miri, ignore)] #[cfg(feature = "dtype-categorical")] fn test_join_categorical() { - let _lock = IUseStringCache::hold(); - let _lock = polars_core::SINGLE_LOCK.lock(); + let _guard = SINGLE_LOCK.lock(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); let (mut df_a, mut df_b) = get_dfs(); @@ -294,11 +295,10 @@ fn test_join_categorical() { let (mut df_a, mut df_b) = get_dfs(); df_a.try_apply("b", |s| s.cast(&DataType::Categorical(None))) .unwrap(); - // create a new cache - reset_string_cache(); - // _sc is needed to ensure we hold the string cache. - let _sc = IUseStringCache::hold(); + // Create a new string cache + drop(_sc); + let _sc = StringCacheHolder::hold(); df_b.try_apply("bar", |s| s.cast(&DataType::Categorical(None))) .unwrap(); diff --git a/crates/polars/tests/it/lazy/expressions/arity.rs b/crates/polars/tests/it/lazy/expressions/arity.rs index 290bd9f3efca..a4cfc7796a66 100644 --- a/crates/polars/tests/it/lazy/expressions/arity.rs +++ b/crates/polars/tests/it/lazy/expressions/arity.rs @@ -116,7 +116,7 @@ fn includes_null_predicate_3038() -> PolarsResult<()> { #[test] #[cfg(feature = "dtype-categorical")] fn test_when_then_otherwise_cats() -> PolarsResult<()> { - polars::enable_string_cache(true); + polars::enable_string_cache(); let lf = df!["book" => [Some("bookA"), None, diff --git a/crates/polars/tests/it/lazy/predicate_queries.rs b/crates/polars/tests/it/lazy/predicate_queries.rs index 36e63d64773d..d9aa60870e58 100644 --- a/crates/polars/tests/it/lazy/predicate_queries.rs +++ b/crates/polars/tests/it/lazy/predicate_queries.rs @@ -1,6 +1,6 @@ // used only if feature="is_in", feature="dtype-categorical" #[allow(unused_imports)] -use polars_core::{with_string_cache, SINGLE_LOCK}; +use polars_core::{disable_string_cache, StringCacheHolder, SINGLE_LOCK}; use super::*; @@ -132,24 +132,22 @@ fn test_is_in_categorical_3420() -> PolarsResult<()> { ]?; let _guard = SINGLE_LOCK.lock(); + disable_string_cache(); + let _sc = StringCacheHolder::hold(); - let _: PolarsResult<_> = with_string_cache(|| { - let s = Series::new("x", ["a", "b", "c"]).strict_cast(&DataType::Categorical(None))?; - let out = df - .lazy() - .with_column(col("a").strict_cast(DataType::Categorical(None))) - .filter(col("a").is_in(lit(s).alias("x"))) - .collect()?; - - let mut expected = df![ - "a" => ["a", "b", "c"], - "b" => [1, 2, 3] - ]?; - expected.try_apply("a", |s| s.cast(&DataType::Categorical(None)))?; - assert!(out.frame_equal(&expected)); + let s = Series::new("x", ["a", "b", "c"]).strict_cast(&DataType::Categorical(None))?; + let out = df + .lazy() + .with_column(col("a").strict_cast(DataType::Categorical(None))) + .filter(col("a").is_in(lit(s).alias("x"))) + .collect()?; - Ok(()) - }); + let mut expected = df![ + "a" => ["a", "b", "c"], + "b" => [1, 2, 3] + ]?; + expected.try_apply("a", |s| s.cast(&DataType::Categorical(None)))?; + assert!(out.frame_equal(&expected)); Ok(()) } diff --git a/py-polars/docs/source/reference/functions.rst b/py-polars/docs/source/reference/functions.rst index a454c4fa3be4..1200c2d94d74 100644 --- a/py-polars/docs/source/reference/functions.rst +++ b/py-polars/docs/source/reference/functions.rst @@ -51,4 +51,5 @@ and a decorator, in order to explicitly scope cache lifetime. StringCache enable_string_cache + disable_string_cache using_string_cache diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 8500a34fd0de..9abe028cbe25 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -182,7 +182,12 @@ from polars.lazyframe import LazyFrame from polars.series import Series from polars.sql import SQLContext -from polars.string_cache import StringCache, enable_string_cache, using_string_cache +from polars.string_cache import ( + StringCache, + disable_string_cache, + enable_string_cache, + using_string_cache, +) from polars.type_aliases import PolarsDataType from polars.utils import build_info, get_index_type, show_versions, threadpool_size @@ -277,6 +282,7 @@ "scan_pyarrow_dataset", # polars.stringcache "StringCache", + "disable_string_cache", "enable_string_cache", "using_string_cache", # polars.config diff --git a/py-polars/polars/string_cache.py b/py-polars/polars/string_cache.py index c9f94c792971..8e5ab3ea26dd 100644 --- a/py-polars/polars/string_cache.py +++ b/py-polars/polars/string_cache.py @@ -3,9 +3,11 @@ import contextlib from typing import TYPE_CHECKING +from polars.utils.deprecation import issue_deprecation_warning + with contextlib.suppress(ImportError): # Module not available when building docs - from polars.polars import enable_string_cache as _enable_string_cache - from polars.polars import using_string_cache as _using_string_cache + import polars.polars as plr + from polars.polars import PyStringCacheHolder if TYPE_CHECKING: from types import TracebackType @@ -13,52 +15,59 @@ class StringCache(contextlib.ContextDecorator): """ - Context manager that allows data sources to share the same categorical features. + Context manager for enabling and disabling the global string cache. + + :class:`Categorical` columns created under the same global string cache have + the same underlying physical value when string values are equal. This allows the + columns to be concatenated or used in a join operation, for example. + + Notes + ----- + Enabling the global string cache introduces some overhead. + The amount of overhead depends on the number of categories in your data. + It is advised to enable the global string cache only when strictly necessary. - This will temporarily cache the string categories until the context manager is - finished. If StringCaches are nested, the global cache will only be invalidated - when the outermost context exits. + If ``StringCache`` calls are nested, the global string cache will only be disabled + and cleared when the outermost context exits. Examples -------- + Construct two Series using the same global string cache. + >>> with pl.StringCache(): - ... df1 = pl.DataFrame( - ... data={ - ... "color": ["red", "green", "blue", "orange"], - ... "value": [1, 2, 3, 4], - ... }, - ... schema={"color": pl.Categorical, "value": pl.UInt8}, - ... ) - ... df2 = pl.DataFrame( - ... data={ - ... "color": ["yellow", "green", "orange", "black", "red"], - ... "char": ["a", "b", "c", "d", "e"], - ... }, - ... schema={"color": pl.Categorical, "char": pl.Utf8}, - ... ) + ... s1 = pl.Series("color", ["red", "green", "red"], dtype=pl.Categorical) + ... s2 = pl.Series("color", ["blue", "red", "green"], dtype=pl.Categorical) ... - ... # Both dataframes use the same string cache for the categorical column, - ... # so the join operation on that column will succeed. - ... df_join = df1.join(df2, how="inner", on="color") + + As both Series are constructed under the same global string cache, + they can be concatenated. + + >>> pl.concat([s1, s2]) + shape: (6,) + Series: 'color' [cat] + [ + "red" + "green" + "red" + "blue" + "red" + "green" + ] + + The class can also be used as a function decorator, in which case the string cache + is enabled during function execution, and disabled afterwards. + + >>> @pl.StringCache() + ... def construct_categoricals() -> pl.Series: + ... s1 = pl.Series("color", ["red", "green", "red"], dtype=pl.Categorical) + ... s2 = pl.Series("color", ["blue", "red", "green"], dtype=pl.Categorical) + ... return pl.concat([s1, s2]) ... - >>> df_join - shape: (3, 3) - ┌────────┬───────┬──────┐ - │ color ┆ value ┆ char │ - │ --- ┆ --- ┆ --- │ - │ cat ┆ u8 ┆ str │ - ╞════════╪═══════╪══════╡ - │ green ┆ 2 ┆ b │ - │ orange ┆ 4 ┆ c │ - │ red ┆ 1 ┆ e │ - └────────┴───────┴──────┘ """ def __enter__(self) -> StringCache: - self._already_enabled = _using_string_cache() - if not self._already_enabled: - _enable_string_cache(True) + self._string_cache = PyStringCacheHolder() return self def __exit__( @@ -67,57 +76,125 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - # note: if global string cache was already enabled - # on __enter__, do NOT reset it on __exit__ - if not self._already_enabled: - _enable_string_cache(False) + del self._string_cache -# TODO: Rename/redesign this function -# https://github.com/pola-rs/polars/issues/10425 -def enable_string_cache(enable: bool) -> None: # noqa: FBT001 +def enable_string_cache(enable: bool | None = None) -> None: """ - Enable (or disable) the global string cache. + Enable the global string cache. - This ensures that casts to Categorical dtypes will have - the same category values when string values are equal. + :class:`Categorical` columns created under the same global string cache have + the same underlying physical value when string values are equal. This allows the + columns to be concatenated or used in a join operation, for example. Parameters ---------- enable Enable or disable the global string cache. + .. deprecated:: 0.19.3 + ``enable_string_cache`` no longer accepts an argument. + Call ``enable_string_cache()`` to enable the string cache + and ``disable_string_cache()`` to disable the string cache. + + See Also + -------- + StringCache : Context manager for enabling and disabling the string cache. + disable_string_cache : Function to disable the string cache. + + Notes + ----- + Enabling the global string cache introduces some overhead. + The amount of overhead depends on the number of categories in your data. + It is advised to enable the global string cache only when strictly necessary. + + Consider using the :class:`StringCache` context manager for a more reliable way of + enabling and disabling the string cache. + + Examples + -------- + Construct two Series using the same global string cache. + + >>> pl.enable_string_cache() + >>> s1 = pl.Series("color", ["red", "green", "red"], dtype=pl.Categorical) + >>> s2 = pl.Series("color", ["blue", "red", "green"], dtype=pl.Categorical) + >>> pl.disable_string_cache() + + As both Series are constructed under the same global string cache, + they can be concatenated. + + >>> pl.concat([s1, s2]) + shape: (6,) + Series: 'color' [cat] + [ + "red" + "green" + "red" + "blue" + "red" + "green" + ] + + """ + if enable is not None: + issue_deprecation_warning( + "`enable_string_cache` no longer accepts an argument." + " Call `enable_string_cache()` to enable the string cache" + " and `disable_string_cache()` to disable the string cache.", + version="0.19.3", + ) + if enable is False: + plr.disable_string_cache() + return + + plr.enable_string_cache() + + +def disable_string_cache() -> bool: + """ + Disable and clear the global string cache. + + See Also + -------- + enable_string_cache : Function to enable the string cache. + StringCache : Context manager for enabling and disabling the string cache. + + Notes + ----- + Consider using the :class:`StringCache` context manager for a more reliable way of + enabling and disabling the string cache. + + When used in conjunction with the :class:`StringCache` context manager, the string + cache will not be disabled until the context manager exits. + Examples -------- - >>> pl.enable_string_cache(True) - >>> df1 = pl.DataFrame( - ... data={"color": ["red", "green", "blue", "orange"], "value": [1, 2, 3, 4]}, - ... schema={"color": pl.Categorical, "value": pl.UInt8}, - ... ) - >>> df2 = pl.DataFrame( - ... data={ - ... "color": ["yellow", "green", "orange", "black", "red"], - ... "char": ["a", "b", "c", "d", "e"], - ... }, - ... schema={"color": pl.Categorical, "char": pl.Utf8}, - ... ) - >>> df_join = df1.join(df2, how="inner", on="color") - >>> df_join - shape: (3, 3) - ┌────────┬───────┬──────┐ - │ color ┆ value ┆ char │ - │ --- ┆ --- ┆ --- │ - │ cat ┆ u8 ┆ str │ - ╞════════╪═══════╪══════╡ - │ green ┆ 2 ┆ b │ - │ orange ┆ 4 ┆ c │ - │ red ┆ 1 ┆ e │ - └────────┴───────┴──────┘ + Construct two Series using the same global string cache. + + >>> pl.enable_string_cache() + >>> s1 = pl.Series("color", ["red", "green", "red"], dtype=pl.Categorical) + >>> s2 = pl.Series("color", ["blue", "red", "green"], dtype=pl.Categorical) + >>> pl.disable_string_cache() + + As both Series are constructed under the same global string cache, + they can be concatenated. + + >>> pl.concat([s1, s2]) + shape: (6,) + Series: 'color' [cat] + [ + "red" + "green" + "red" + "blue" + "red" + "green" + ] """ - _enable_string_cache(enable) + return plr.disable_string_cache() def using_string_cache() -> bool: - """Return the current state of the global string cache (enabled/disabled).""" - return _using_string_cache() + """Check whether the global string cache is enabled.""" + return plr.using_string_cache() diff --git a/py-polars/src/functions/meta.rs b/py-polars/src/functions/meta.rs index e824e6f19bef..467c65ffc133 100644 --- a/py-polars/src/functions/meta.rs +++ b/py-polars/src/functions/meta.rs @@ -23,16 +23,6 @@ pub fn threadpool_size() -> usize { POOL.current_num_threads() } -#[pyfunction] -pub fn enable_string_cache(toggle: bool) { - polars_core::enable_string_cache(toggle) -} - -#[pyfunction] -pub fn using_string_cache() -> bool { - polars_core::using_string_cache() -} - #[pyfunction] pub fn set_float_fmt(fmt: &str) -> PyResult<()> { let fmt = match fmt { diff --git a/py-polars/src/functions/mod.rs b/py-polars/src/functions/mod.rs index e9648cb4f586..b9470f8a6550 100644 --- a/py-polars/src/functions/mod.rs +++ b/py-polars/src/functions/mod.rs @@ -6,4 +6,5 @@ pub mod meta; pub mod misc; pub mod random; pub mod range; +pub mod string_cache; pub mod whenthen; diff --git a/py-polars/src/functions/string_cache.rs b/py-polars/src/functions/string_cache.rs new file mode 100644 index 000000000000..617273802898 --- /dev/null +++ b/py-polars/src/functions/string_cache.rs @@ -0,0 +1,33 @@ +use polars_core; +use polars_core::StringCacheHolder; +use pyo3::prelude::*; + +#[pyfunction] +pub fn enable_string_cache() { + polars_core::enable_string_cache() +} + +#[pyfunction] +pub fn disable_string_cache() { + polars_core::disable_string_cache() +} + +#[pyfunction] +pub fn using_string_cache() -> bool { + polars_core::using_string_cache() +} + +#[pyclass] +pub struct PyStringCacheHolder { + _inner: StringCacheHolder, +} + +#[pymethods] +impl PyStringCacheHolder { + #[new] + fn new() -> Self { + Self { + _inner: StringCacheHolder::hold(), + } + } +} diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index c41d14628f80..3f3887da8d69 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -55,6 +55,7 @@ use crate::error::{ StructFieldNotFoundError, }; use crate::expr::PyExpr; +use crate::functions::string_cache::PyStringCacheHolder; use crate::lazyframe::PyLazyFrame; use crate::lazygroupby::PyLazyGroupBy; use crate::series::PySeries; @@ -75,6 +76,7 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::().unwrap(); m.add_class::().unwrap(); m.add_class::().unwrap(); + m.add_class::().unwrap(); #[cfg(feature = "csv")] m.add_class::().unwrap(); #[cfg(feature = "sql")] @@ -211,10 +213,18 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::meta::threadpool_size)) .unwrap(); - m.add_wrapped(wrap_pyfunction!(functions::meta::enable_string_cache)) - .unwrap(); - m.add_wrapped(wrap_pyfunction!(functions::meta::using_string_cache)) - .unwrap(); + m.add_wrapped(wrap_pyfunction!( + functions::string_cache::enable_string_cache + )) + .unwrap(); + m.add_wrapped(wrap_pyfunction!( + functions::string_cache::disable_string_cache + )) + .unwrap(); + m.add_wrapped(wrap_pyfunction!( + functions::string_cache::using_string_cache + )) + .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::meta::set_float_fmt)) .unwrap(); m.add_wrapped(wrap_pyfunction!(functions::meta::get_float_fmt)) diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index 9c758bafff3b..1d279088b2b4 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -47,7 +47,7 @@ def doctest_teardown(d: doctest.DocTest) -> None: # don't let config changes or string cache state leak between tests polars.Config.restore_defaults() - polars.enable_string_cache(False) + polars.disable_string_cache() def modules_in_path(p: Path) -> Iterator[ModuleType]: diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index d3dbe40ebbb1..605c36fb5791 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -1290,45 +1290,6 @@ def test_duration_arithmetic() -> None: ) -def test_string_cache_eager_lazy() -> None: - # tests if the global string cache is really global and not interfered by the lazy - # execution. first the global settings was thread-local and this breaks with the - # parallel execution of lazy - with pl.StringCache(): - df1 = pl.DataFrame( - {"region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"]} - ).select([pl.col("region_ids").cast(pl.Categorical)]) - - df2 = pl.DataFrame( - {"seq_name": ["reg4", "reg2", "reg1"], "score": [3.0, 1.0, 2.0]} - ).select([pl.col("seq_name").cast(pl.Categorical), pl.col("score")]) - - expected = pl.DataFrame( - { - "region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"], - "score": [2.0, 1.0, None, 3.0, None], - } - ).with_columns(pl.col("region_ids").cast(pl.Categorical)) - - result = df1.join(df2, left_on="region_ids", right_on="seq_name", how="left") - assert_frame_equal(result, expected) - - # also check row-wise categorical insert. - # (column-wise is preferred, but this shouldn't fail) - for params in ( - {"schema": [("region_ids", pl.Categorical)]}, - { - "schema": ["region_ids"], - "schema_overrides": {"region_ids": pl.Categorical}, - }, - ): - df3 = pl.DataFrame( # type: ignore[arg-type] - data=[["reg1"], ["reg2"], ["reg3"], ["reg4"], ["reg5"]], - **params, - ) - assert_frame_equal(df1, df3) - - def test_assign() -> None: # check if can assign in case of a single column df = pl.DataFrame({"a": [1, 2, 3]}) diff --git a/py-polars/tests/unit/test_cfg.py b/py-polars/tests/unit/test_cfg.py index 78e2e694ce76..72a9e22d276a 100644 --- a/py-polars/tests/unit/test_cfg.py +++ b/py-polars/tests/unit/test_cfg.py @@ -8,8 +8,6 @@ import polars as pl from polars.config import _POLARS_CFG_ENV_VARS, _get_float_fmt -from polars.exceptions import StringCacheMismatchError -from polars.testing import assert_frame_equal @pytest.fixture(autouse=True) @@ -511,33 +509,6 @@ def test_shape_format_for_big_numbers() -> None: ) -def test_string_cache() -> None: - df1 = pl.DataFrame({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) - df2 = pl.DataFrame({"a": ["foo", "spam", "eggs"], "c": [3, 2, 2]}) - - # ensure cache is off when casting to categorical; the join will fail - pl.enable_string_cache(False) - assert pl.using_string_cache() is False - - df1a = df1.with_columns(pl.col("a").cast(pl.Categorical)) - df2a = df2.with_columns(pl.col("a").cast(pl.Categorical)) - with pytest.raises(StringCacheMismatchError): - _ = df1a.join(df2a, on="a", how="inner") - - # now turn on the cache - pl.enable_string_cache(True) - assert pl.using_string_cache() is True - - df1b = df1.with_columns(pl.col("a").cast(pl.Categorical)) - df2b = df2.with_columns(pl.col("a").cast(pl.Categorical)) - out = df1b.join(df2b, on="a", how="inner") - - expected = pl.DataFrame( - {"a": ["foo"], "b": [1], "c": [3]}, schema_overrides={"a": pl.Categorical} - ) - assert_frame_equal(out, expected) - - @pytest.mark.write_disk() def test_config_load_save(tmp_path: Path) -> None: for file in ( diff --git a/py-polars/tests/unit/test_string_cache.py b/py-polars/tests/unit/test_string_cache.py new file mode 100644 index 000000000000..809b355bffda --- /dev/null +++ b/py-polars/tests/unit/test_string_cache.py @@ -0,0 +1,183 @@ +from typing import Iterator + +import pytest + +import polars as pl +from polars.exceptions import StringCacheMismatchError +from polars.testing import assert_frame_equal + + +@pytest.fixture(autouse=True) +def _disable_string_cache() -> Iterator[None]: + """Fixture to make sure the string cache is disabled before and after each test.""" + pl.disable_string_cache() + yield + pl.disable_string_cache() + + +def sc(set: bool) -> None: + """Short syntax for asserting whether the global string cache is being used.""" + assert pl.using_string_cache() is set + + +def test_string_cache_enable_disable() -> None: + sc(False) + pl.enable_string_cache() + sc(True) + pl.disable_string_cache() + sc(False) + + +def test_string_cache_enable_disable_repeated() -> None: + sc(False) + pl.enable_string_cache() + sc(True) + pl.enable_string_cache() + sc(True) + pl.disable_string_cache() + sc(False) + pl.disable_string_cache() + sc(False) + + +def test_string_cache_context_manager() -> None: + sc(False) + with pl.StringCache(): + sc(True) + sc(False) + + +def test_string_cache_context_manager_nested() -> None: + sc(False) + with pl.StringCache(): + sc(True) + with pl.StringCache(): + sc(True) + sc(True) + sc(False) + + +def test_string_cache_context_manager_mixed_with_enable_disable() -> None: + sc(False) + with pl.StringCache(): + sc(True) + pl.enable_string_cache() + sc(True) + sc(True) + + with pl.StringCache(): + sc(True) + sc(True) + + with pl.StringCache(): + sc(True) + with pl.StringCache(): + sc(True) + pl.disable_string_cache() + sc(True) + sc(True) + sc(False) + + with pl.StringCache(): + sc(True) + pl.disable_string_cache() + sc(True) + sc(False) + + +def test_string_cache_decorator() -> None: + @pl.StringCache() + def my_function() -> None: + sc(True) + + sc(False) + my_function() + sc(False) + + +def test_string_cache_decorator_mixed_with_enable() -> None: + @pl.StringCache() + def my_function() -> None: + sc(True) + pl.enable_string_cache() + sc(True) + + sc(False) + my_function() + sc(True) + + +def test_string_cache_enable_arg_deprecated() -> None: + sc(False) + with pytest.deprecated_call(): + pl.enable_string_cache(True) + sc(True) + with pytest.deprecated_call(): + pl.enable_string_cache(False) + sc(False) + + +def test_string_cache_join() -> None: + df1 = pl.DataFrame({"a": ["foo", "bar", "ham"], "b": [1, 2, 3]}) + df2 = pl.DataFrame({"a": ["foo", "spam", "eggs"], "c": [3, 2, 2]}) + + # ensure cache is off when casting to categorical; the join will fail + pl.disable_string_cache() + assert pl.using_string_cache() is False + + df1a = df1.with_columns(pl.col("a").cast(pl.Categorical)) + df2a = df2.with_columns(pl.col("a").cast(pl.Categorical)) + with pytest.raises(StringCacheMismatchError): + _ = df1a.join(df2a, on="a", how="inner") + + # now turn on the cache + pl.enable_string_cache() + assert pl.using_string_cache() is True + + df1b = df1.with_columns(pl.col("a").cast(pl.Categorical)) + df2b = df2.with_columns(pl.col("a").cast(pl.Categorical)) + out = df1b.join(df2b, on="a", how="inner") + + expected = pl.DataFrame( + {"a": ["foo"], "b": [1], "c": [3]}, schema_overrides={"a": pl.Categorical} + ) + assert_frame_equal(out, expected) + + +def test_string_cache_eager_lazy() -> None: + # tests if the global string cache is really global and not interfered by the lazy + # execution. first the global settings was thread-local and this breaks with the + # parallel execution of lazy + with pl.StringCache(): + df1 = pl.DataFrame( + {"region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"]} + ).select([pl.col("region_ids").cast(pl.Categorical)]) + + df2 = pl.DataFrame( + {"seq_name": ["reg4", "reg2", "reg1"], "score": [3.0, 1.0, 2.0]} + ).select([pl.col("seq_name").cast(pl.Categorical), pl.col("score")]) + + expected = pl.DataFrame( + { + "region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"], + "score": [2.0, 1.0, None, 3.0, None], + } + ).with_columns(pl.col("region_ids").cast(pl.Categorical)) + + result = df1.join(df2, left_on="region_ids", right_on="seq_name", how="left") + assert_frame_equal(result, expected) + + # also check row-wise categorical insert. + # (column-wise is preferred, but this shouldn't fail) + for params in ( + {"schema": [("region_ids", pl.Categorical)]}, + { + "schema": ["region_ids"], + "schema_overrides": {"region_ids": pl.Categorical}, + }, + ): + df3 = pl.DataFrame( # type: ignore[arg-type] + data=[["reg1"], ["reg2"], ["reg3"], ["reg4"], ["reg5"]], + **params, + ) + assert_frame_equal(df1, df3)