use crate::parallelism::*;
use crate::tokenizer::{Encoding, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum PaddingDirection {
    Left,
    Right,
}
impl std::convert::AsRef<str> for PaddingDirection {
    fn as_ref(&self) -> &str {
        match self {
            PaddingDirection::Left => "left",
            PaddingDirection::Right => "right",
        }
    }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PaddingParams {
    pub strategy: PaddingStrategy,
    pub direction: PaddingDirection,
    pub pad_to_multiple_of: Option<usize>,
    pub pad_id: u32,
    pub pad_type_id: u32,
    pub pad_token: String,
}
impl Default for PaddingParams {
    fn default() -> Self {
        Self {
            strategy: PaddingStrategy::BatchLongest,
            direction: PaddingDirection::Right,
            pad_to_multiple_of: None,
            pad_id: 0,
            pad_type_id: 0,
            pad_token: String::from("[PAD]"),
        }
    }
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PaddingStrategy {
    BatchLongest,
    Fixed(usize),
}
pub fn pad_encodings(encodings: &mut [Encoding], params: &PaddingParams) -> Result<()> {
    if encodings.is_empty() {
        return Ok(());
    }
    let mut pad_length = match params.strategy {
        PaddingStrategy::Fixed(size) => size,
        PaddingStrategy::BatchLongest => encodings
            .maybe_par_iter()
            .map(|e| e.get_ids().len())
            .max()
            .unwrap(),
    };
    if let Some(multiple) = params.pad_to_multiple_of {
        if multiple > 0 && pad_length % multiple > 0 {
            pad_length += multiple - pad_length % multiple;
        }
    }
    encodings.maybe_par_iter_mut().for_each(|encoding| {
        encoding.pad(
            pad_length,
            params.pad_id,
            params.pad_type_id,
            ¶ms.pad_token,
            params.direction,
        )
    });
    Ok(())
}
#[cfg(test)]
mod tests {
    use super::*;
    use crate::tokenizer::Encoding;
    use std::collections::HashMap;
    #[test]
    fn pad_to_multiple() {
        fn get_encodings() -> [Encoding; 2] {
            [
                Encoding::new(
                    vec![0, 1, 2, 3, 4],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    HashMap::new(),
                ),
                Encoding::new(
                    vec![0, 1, 2],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    vec![],
                    HashMap::new(),
                ),
            ]
        }
        let mut encodings = get_encodings();
        let mut params = PaddingParams {
            strategy: PaddingStrategy::Fixed(7),
            direction: PaddingDirection::Right,
            pad_to_multiple_of: Some(8),
            pad_id: 0,
            pad_type_id: 0,
            pad_token: String::from("[PAD]"),
        };
        pad_encodings(&mut encodings, ¶ms).unwrap();
        assert!(encodings.iter().all(|e| e.get_ids().len() == 8));
        let mut encodings = get_encodings();
        params.strategy = PaddingStrategy::BatchLongest;
        params.pad_to_multiple_of = Some(6);
        pad_encodings(&mut encodings, ¶ms).unwrap();
        assert!(encodings.iter().all(|e| e.get_ids().len() == 6));
        params.pad_to_multiple_of = Some(0);
        pad_encodings(&mut encodings, ¶ms).unwrap();
    }
}