Skip to content

Commit

Permalink
Make USED_PARALLELISM atomic (#1532)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-daniel authored Jun 6, 2024
1 parent 25aee8b commit bfefcf6
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tokenizers/src/utils/parallelism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
use rayon::iter::IterBridge;
use rayon::prelude::*;
use rayon_cond::CondIterator;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;

// Re-export rayon current_num_threads
pub use rayon::current_num_threads;

pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";

// Reading/Writing this variable should always happen on the main thread
static mut USED_PARALLELISM: bool = false;
static USED_PARALLELISM: AtomicBool = AtomicBool::new(false);

/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
pub fn is_parallelism_configured() -> bool {
Expand All @@ -21,7 +22,7 @@ pub fn is_parallelism_configured() -> bool {

/// Check if at some point we used a parallel iterator
pub fn has_parallelism_been_used() -> bool {
unsafe { USED_PARALLELISM }
USED_PARALLELISM.load(Ordering::SeqCst)
}

/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
Expand Down Expand Up @@ -70,7 +71,7 @@ where
fn into_maybe_par_iter(self) -> CondIterator<P, S> {
let parallelism = get_parallelism();
if parallelism {
unsafe { USED_PARALLELISM = true };
USED_PARALLELISM.store(true, Ordering::SeqCst);
}
CondIterator::new(self, parallelism)
}
Expand Down Expand Up @@ -159,7 +160,7 @@ where
let iter = CondIterator::from_serial(self);

if get_parallelism() {
unsafe { USED_PARALLELISM = true };
USED_PARALLELISM.store(true, Ordering::SeqCst);
CondIterator::from_parallel(iter.into_parallel().right().unwrap())
} else {
iter
Expand Down

0 comments on commit bfefcf6

Please sign in to comment.