use crate::processors::byte_level::process_offsets;
use crate::tokenizer::{Encoding, PostProcessor, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::iter::FromIterator;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "type")]
pub struct RobertaProcessing {
    sep: (String, u32),
    cls: (String, u32),
    trim_offsets: bool,
    add_prefix_space: bool,
}
impl Default for RobertaProcessing {
    fn default() -> Self {
        Self {
            sep: ("</s>".into(), 2),
            cls: ("<s>".into(), 0),
            trim_offsets: true,
            add_prefix_space: true,
        }
    }
}
impl RobertaProcessing {
    pub fn new(sep: (String, u32), cls: (String, u32)) -> Self {
        Self {
            sep,
            cls,
            ..Default::default()
        }
    }
    #[must_use]
    pub fn trim_offsets(mut self, v: bool) -> Self {
        self.trim_offsets = v;
        self
    }
    #[must_use]
    pub fn add_prefix_space(mut self, v: bool) -> Self {
        self.add_prefix_space = v;
        self
    }
}
impl PostProcessor for RobertaProcessing {
    fn added_tokens(&self, is_pair: bool) -> usize {
        if is_pair {
            4
        } else {
            2
        }
    }
    fn process_encodings(
        &self,
        mut encodings: Vec<Encoding>,
        add_special_tokens: bool,
    ) -> Result<Vec<Encoding>> {
        if self.trim_offsets {
            for encoding in encodings.iter_mut() {
                process_offsets(encoding, self.add_prefix_space);
                encoding
                    .get_overflowing_mut()
                    .iter_mut()
                    .for_each(|encoding| process_offsets(encoding, self.add_prefix_space));
            }
        }
        encodings
            .iter_mut()
            .for_each(|encoding| encoding.set_type_ids(vec![0; encoding.len()]));
        if !add_special_tokens {
            return Ok(encodings);
        }
        let encodings: Vec<Encoding> = encodings
            .iter_mut()
            .enumerate()
            .map(|(i, encoding)| {
                if i == 0 {
                    let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
                    let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
                    let tokens = [
                        &[self.cls.0.clone()],
                        encoding.get_tokens(),
                        &[self.sep.0.clone()],
                    ]
                    .concat();
                    let words = [&[None], encoding.get_word_ids(), &[None]].concat();
                    let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
                    let special_tokens =
                        [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
                    let attention_mask = vec![1; ids.len()];
                    let sequence_ranges = HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
                    Encoding::new(
                        ids,
                        type_ids,
                        tokens,
                        words,
                        offsets,
                        special_tokens,
                        attention_mask,
                        encoding
                            .take_overflowing()
                            .into_iter()
                            .map(|encoding| {
                                let ids =
                                    [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
                                let type_ids = vec![0; encoding.get_ids().len() + 2];
                                let tokens = [
                                    &[self.cls.0.clone()],
                                    encoding.get_tokens(),
                                    &[self.sep.0.clone()],
                                ]
                                .concat();
                                let words = [&[None], encoding.get_word_ids(), &[None]].concat();
                                let offsets =
                                    [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
                                let special_tokens =
                                    [&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
                                        .concat();
                                let attention_mask = vec![1; ids.len()];
                                let sequence_ranges =
                                    HashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
                                Encoding::new(
                                    ids,
                                    type_ids,
                                    tokens,
                                    words,
                                    offsets,
                                    special_tokens,
                                    attention_mask,
                                    vec![],
                                    sequence_ranges,
                                )
                            })
                            .collect(),
                        sequence_ranges,
                    )
                } else {
                    let pair_ids = [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
                    let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
                    let pair_tokens = [
                        &[self.sep.0.clone()],
                        encoding.get_tokens(),
                        &[self.sep.0.clone()],
                    ]
                    .concat();
                    let pair_words = [&[None], encoding.get_word_ids(), &[None]].concat();
                    let pair_offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
                    let pair_special_tokens =
                        [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
                    let pair_attention_mask = vec![1; pair_ids.len()];
                    let pair_sequence_ranges = HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
                    Encoding::new(
                        pair_ids,
                        pair_type_ids,
                        pair_tokens,
                        pair_words,
                        pair_offsets,
                        pair_special_tokens,
                        pair_attention_mask,
                        encoding
                            .take_overflowing()
                            .into_iter()
                            .map(|encoding| {
                                let pair_ids =
                                    [&[self.sep.1], encoding.get_ids(), &[self.sep.1]].concat();
                                let pair_type_ids = vec![0; encoding.get_ids().len() + 2];
                                let pair_tokens = [
                                    &[self.sep.0.clone()],
                                    encoding.get_tokens(),
                                    &[self.sep.0.clone()],
                                ]
                                .concat();
                                let pair_words =
                                    [&[None], encoding.get_word_ids(), &[None]].concat();
                                let pair_offsets =
                                    [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
                                let pair_special_tokens =
                                    [&[1], &vec![0u32; encoding.get_type_ids().len()][..], &[1]]
                                        .concat();
                                let pair_attention_mask = vec![1; pair_ids.len()];
                                let pair_sequence_ranges =
                                    HashMap::from_iter(vec![(1, 1..pair_ids.len() - 1)]);
                                Encoding::new(
                                    pair_ids,
                                    pair_type_ids,
                                    pair_tokens,
                                    pair_words,
                                    pair_offsets,
                                    pair_special_tokens,
                                    pair_attention_mask,
                                    vec![],
                                    pair_sequence_ranges,
                                )
                            })
                            .collect(),
                        pair_sequence_ranges,
                    )
                }
            })
            .collect();
        Ok(encodings)
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn serde() {
        let roberta = RobertaProcessing::default();
        let roberta_r = r#"{
            "type":"RobertaProcessing",
            "sep":["</s>",2],
            "cls":["<s>",0],
            "trim_offsets":true,
            "add_prefix_space":true
        }"#
        .replace(char::is_whitespace, "");
        assert_eq!(serde_json::to_string(&roberta).unwrap(), roberta_r);
        assert_eq!(
            serde_json::from_str::<RobertaProcessing>(&roberta_r).unwrap(),
            roberta
        );
    }
    #[test]
    fn roberta_processing() {
        let processor = RobertaProcessing::default();
        assert_eq!(processor.added_tokens(false), 2);
        assert_eq!(processor.added_tokens(true), 4);
        use crate::Token;
        let encoding = Encoding::from_tokens(
            vec![
                Token::new(12, "Hello".into(), (0, 5)),
                Token::new(14, "there".into(), (6, 11)),
            ],
            0,
        );
        let pair = Encoding::from_tokens(vec![Token::new(15, "pair".into(), (0, 4))], 0);
        let single_encoding = processor.process(encoding.clone(), None, true).unwrap();
        assert_eq!(
            single_encoding,
            Encoding::new(
                vec![0, 12, 14, 2],
                vec![0, 0, 0, 0],
                vec!["<s>".into(), "Hello".into(), "there".into(), "</s>".into()],
                vec![None, None, None, None],
                vec![(0, 0), (0, 5), (6, 11), (0, 0)],
                vec![1, 0, 0, 1],
                vec![1, 1, 1, 1],
                vec![],
                HashMap::from_iter(vec![(0, 1..3)]),
            )
        );
        assert_eq!(single_encoding.token_to_sequence(2), Some(0));
        assert_eq!(single_encoding.token_to_sequence(3), None);
        let pair_encoding = processor
            .process(encoding.clone(), Some(pair.clone()), true)
            .unwrap();
        assert_eq!(
            pair_encoding,
            Encoding::new(
                vec![0, 12, 14, 2, 2, 15, 2],
                vec![0, 0, 0, 0, 0, 0, 0],
                vec![
                    "<s>".into(),
                    "Hello".into(),
                    "there".into(),
                    "</s>".into(),
                    "</s>".into(),
                    "pair".into(),
                    "</s>".into()
                ],
                vec![None, None, None, None, None, None, None],
                vec![(0, 0), (0, 5), (6, 11), (0, 0), (0, 0), (0, 4), (0, 0)],
                vec![1, 0, 0, 1, 1, 0, 1],
                vec![1, 1, 1, 1, 1, 1, 1],
                vec![],
                HashMap::from_iter(vec![(0, 1..3), (1, 5..6)]),
            )
        );
        assert_eq!(pair_encoding.token_to_sequence(2), Some(0));
        assert_eq!(pair_encoding.token_to_sequence(3), None);
        assert_eq!(pair_encoding.token_to_sequence(4), None);
        assert_eq!(pair_encoding.token_to_sequence(5), Some(1));
        assert_eq!(pair_encoding.token_to_sequence(6), None);
        let pair_encoding = processor.process(encoding, Some(pair), false).unwrap();
        assert_eq!(
            pair_encoding,
            Encoding::new(
                vec![12, 14, 15],
                vec![0, 0, 0],
                vec!["Hello".into(), "there".into(), "pair".into(),],
                vec![None, None, None],
                vec![(0, 5), (6, 11), (0, 4)],
                vec![0, 0, 0],
                vec![1, 1, 1],
                vec![],
                HashMap::from_iter(vec![(0, 0..2), (1, 2..3)]),
            )
        );
        assert_eq!(pair_encoding.token_to_sequence(0), Some(0));
        assert_eq!(pair_encoding.token_to_sequence(1), Some(0));
        assert_eq!(pair_encoding.token_to_sequence(2), Some(1));
    }
}