use rayon::iter::IterBridge;
use rayon::prelude::*;
use rayon_cond::CondIterator;
pub use rayon::current_num_threads;
pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
static mut USED_PARALLELISM: bool = false;
pub fn is_parallelism_configured() -> bool {
std::env::var(ENV_VARIABLE).is_ok()
}
pub fn has_parallelism_been_used() -> bool {
unsafe { USED_PARALLELISM }
}
pub fn get_parallelism() -> bool {
match std::env::var(ENV_VARIABLE) {
Ok(mut v) => {
v.make_ascii_lowercase();
!matches!(v.as_ref(), "" | "off" | "false" | "f" | "no" | "n" | "0")
}
Err(_) => true, }
}
pub fn set_parallelism(val: bool) {
std::env::set_var(ENV_VARIABLE, if val { "true" } else { "false" })
}
pub trait MaybeParallelIterator<P, S>
where
P: ParallelIterator,
S: Iterator<Item = P::Item>,
{
fn into_maybe_par_iter(self) -> CondIterator<P, S>;
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S>;
}
impl<P, S, I> MaybeParallelIterator<P, S> for I
where
I: IntoParallelIterator<Iter = P, Item = P::Item> + IntoIterator<IntoIter = S, Item = S::Item>,
P: ParallelIterator,
S: Iterator<Item = P::Item>,
{
fn into_maybe_par_iter(self) -> CondIterator<P, S> {
let parallelism = get_parallelism();
if parallelism {
unsafe { USED_PARALLELISM = true };
}
CondIterator::new(self, parallelism)
}
fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> {
if cond {
self.into_maybe_par_iter()
} else {
CondIterator::from_serial(self)
}
}
}
pub trait MaybeParallelRefIterator<'data, P, S>
where
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S>;
}
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
where
&'data I: MaybeParallelIterator<P, S>,
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
self.into_maybe_par_iter()
}
fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S> {
self.into_maybe_par_iter_cond(cond)
}
}
pub trait MaybeParallelRefMutIterator<'data, P, S>
where
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S>;
}
impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
where
&'data mut I: MaybeParallelIterator<P, S>,
P: ParallelIterator,
S: Iterator<Item = P::Item>,
P::Item: 'data,
{
fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
self.into_maybe_par_iter()
}
fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S> {
self.into_maybe_par_iter_cond(cond)
}
}
pub trait MaybeParallelBridge<T, S>
where
S: Iterator<Item = T> + Send,
T: Send,
{
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S>;
}
impl<T, S> MaybeParallelBridge<T, S> for S
where
S: Iterator<Item = T> + Send,
T: Send,
{
fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
let iter = CondIterator::from_serial(self);
if get_parallelism() {
unsafe { USED_PARALLELISM = true };
CondIterator::from_parallel(iter.into_parallel().right().unwrap())
} else {
iter
}
}
fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S> {
if cond {
self.maybe_par_bridge()
} else {
CondIterator::from_serial(self)
}
}
}
pub trait MaybeParallelSlice<'data, T>
where
T: Sync,
{
fn maybe_par_chunks(
&'_ self,
chunk_size: usize,
) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>;
fn maybe_par_chunks_cond(
&'_ self,
cond: bool,
chunk_size: usize,
) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>;
}
impl<T> MaybeParallelSlice<'_, T> for [T]
where
T: Sync,
{
fn maybe_par_chunks(
&'_ self,
chunk_size: usize,
) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> {
let parallelism = get_parallelism();
if parallelism {
CondIterator::from_parallel(self.par_chunks(chunk_size))
} else {
CondIterator::from_serial(self.chunks(chunk_size))
}
}
fn maybe_par_chunks_cond(
&'_ self,
cond: bool,
chunk_size: usize,
) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> {
if cond {
self.maybe_par_chunks(chunk_size)
} else {
CondIterator::from_serial(self.chunks(chunk_size))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maybe_parallel_iterator() {
let mut v = vec![1u32, 2, 3, 4, 5, 6];
assert_eq!(v.maybe_par_iter().sum::<u32>(), 21);
assert_eq!(
v.maybe_par_iter_mut()
.map(|v| {
*v *= 2;
*v
})
.sum::<u32>(),
42
);
assert_eq!(v.maybe_par_iter().sum::<u32>(), 42);
assert_eq!(v.into_maybe_par_iter().sum::<u32>(), 42);
}
#[test]
fn test_maybe_parallel_slice() {
let v = [1, 2, 3, 4, 5];
let chunks: Vec<_> = v.maybe_par_chunks(2).collect();
assert_eq!(chunks, vec![&[1, 2][..], &[3, 4], &[5]]);
}
}