use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
#[serde(rename_all = "snake_case")]
pub enum PrependScheme {
    First,
    Never,
    Always,
}
#[derive(Debug, Clone, PartialEq, Serialize, Eq)]
#[serde(tag = "type")]
pub struct Metaspace {
    replacement: char,
    pub add_prefix_space: bool,
    pub prepend_scheme: PrependScheme,
    #[serde(skip)]
    str_rep: String,
}
impl<'de> Deserialize<'de> for Metaspace {
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        #[derive(Deserialize)]
        enum Type {
            Metaspace,
        }
        fn default_prepend_scheme_value() -> PrependScheme {
            PrependScheme::Always
        }
        #[derive(Deserialize)]
        pub struct MetaspaceHelper {
            #[serde(rename = "type")]
            _type: Type,
            replacement: char,
            pub add_prefix_space: bool,
            #[serde(default = "default_prepend_scheme_value")]
            pub prepend_scheme: PrependScheme,
            #[serde(skip, rename = "str_rep")]
            _str_rep: String,
        }
        let helper = MetaspaceHelper::deserialize(deserializer)?;
        let instance = Self::new_with_prepend_scheme(
            helper.replacement,
            helper.add_prefix_space,
            helper.prepend_scheme,
        );
        Ok(instance)
    }
}
impl Metaspace {
    pub fn new(replacement: char, add_prefix_space: bool) -> Self {
        Self::new_with_prepend_scheme(
            replacement,
            add_prefix_space,
            PrependScheme::Always, )
    }
    pub fn new_with_prepend_scheme(
        replacement: char,
        add_prefix_space: bool,
        prepend_scheme: PrependScheme,
    ) -> Self {
        Self {
            replacement,
            str_rep: replacement.to_string(),
            add_prefix_space,
            prepend_scheme,
        }
    }
    pub fn get_replacement(&self) -> char {
        self.replacement
    }
    pub fn set_replacement(&mut self, replacement: char) {
        self.replacement = replacement;
        self.str_rep = replacement.to_string();
    }
    pub fn get_prepend_scheme(&self) -> PrependScheme {
        self.prepend_scheme
    }
    pub fn set_prepend_scheme(&mut self, scheme: PrependScheme) {
        self.prepend_scheme = scheme;
    }
}
impl Default for Metaspace {
    fn default() -> Self {
        Self::new('▁', true)
    }
}
impl PreTokenizer for Metaspace {
    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
        let mut first_split = true;
        pretokenized.split(|_, mut normalized| {
            normalized.replace(' ', &self.str_rep)?;
            if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
                if self.prepend_scheme == PrependScheme::Always {
                    normalized.prepend(&self.str_rep);
                } else if self.prepend_scheme == PrependScheme::First && first_split {
                    normalized.prepend(&self.str_rep);
                    first_split = false;
                }
            } else {
                first_split = false;
            }
            normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
        })
    }
}
impl Decoder for Metaspace {
    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
        Ok(tokens
            .iter()
            .enumerate()
            .map(|(i, token)| {
                token
                    .chars()
                    .flat_map(|c| {
                        if c == self.replacement {
                            if i == 0 && self.add_prefix_space {
                                None
                            } else {
                                Some(' ')
                            }
                        } else {
                            Some(c)
                        }
                    })
                    .collect::<String>()
            })
            .collect())
    }
}
#[cfg(test)]
mod tests {
    use regex::Regex;
    use super::*;
    use crate::{OffsetReferential, OffsetType};
    #[test]
    fn serialization() {
        let metaspace = Metaspace::new('_', true);
        let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
        assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
        assert_eq!(
            serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
            metaspace
        );
        let metaspace = Metaspace::new('_', true);
        let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
        assert_eq!(
            serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
            metaspace
        );
        let metaspace_parsed: Metaspace = serde_json::from_str(
            r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#,
        )
        .unwrap();
        assert_eq!(metaspace_parsed, metaspace);
    }
    #[test]
    fn basic() {
        let pretok = Metaspace::new('▁', true);
        let mut pretokenized = PreTokenizedString::from("Hey friend!");
        pretok.pre_tokenize(&mut pretokenized).unwrap();
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![("▁Hey", (0, 6)), ("▁friend!", (6, 16))]
        );
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Original, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![("▁Hey", (0, 3)), ("▁friend!", (3, 11))]
        );
    }
    #[test]
    fn multiple_spaces() {
        let pretok = Metaspace::new('▁', true);
        let mut pretokenized = PreTokenizedString::from("Hey   friend!");
        pretok.pre_tokenize(&mut pretokenized).unwrap();
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![
                ("▁Hey", (0, 6)),
                ("▁", (6, 9)),
                ("▁", (9, 12)),
                ("▁friend!", (12, 22)),
            ]
        );
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Original, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![
                ("▁Hey", (0, 3)),
                ("▁", (3, 4)),
                ("▁", (4, 5)),
                ("▁friend!", (5, 13)),
            ]
        );
    }
    #[test]
    fn non_legacy_meta_space() {
        assert_eq!(
            Metaspace::new('▁', true),
            Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
        );
        let mut pretok = Metaspace::new('▁', true);
        pretok.set_prepend_scheme(PrependScheme::Always);
        assert_eq!(
            pretok,
            Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Always)
        );
        pretok.set_prepend_scheme(PrependScheme::Never);
        assert_eq!(
            pretok,
            Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::Never)
        );
        pretok.set_prepend_scheme(PrependScheme::First);
        assert_eq!(
            pretok,
            Metaspace::new_with_prepend_scheme('▁', true, PrependScheme::First)
        );
        let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
        let re_ref = Regex::new(r"(<s>)").unwrap();
        pretokenized
            .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
            .expect("Bad split");
        pretok.pre_tokenize(&mut pretokenized).unwrap();
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![
                ("▁Hey", (0, 6)),
                ("▁my", (6, 11)),
                ("▁friend", (11, 20)),
                ("▁", (20, 23)),
                ("<s>", (23, 26)),
                ("how", (26, 29)),
                ("▁are", (29, 35)),
                ("▁you", (35, 41))
            ]
        );
        pretok.set_prepend_scheme(PrependScheme::Always);
        pretok.pre_tokenize(&mut pretokenized).unwrap();
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![
                ("▁Hey", (0, 6)),
                ("▁my", (6, 11)),
                ("▁friend", (11, 20)),
                ("▁", (20, 23)),
                ("▁<s>", (23, 29)),
                ("▁how", (29, 35)),
                ("▁are", (35, 41)),
                ("▁you", (41, 47))
            ]
        );
        pretok.set_prepend_scheme(PrependScheme::First);
        let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); pretokenized
            .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
            .expect("Bad split");
        pretok.pre_tokenize(&mut pretokenized).unwrap();
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![
                ("▁Hey", (0, 6)),
                ("▁", (6, 9)),
                ("<s>", (9, 12)),
                ("how", (12, 15))
            ]
        );
        let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); pretokenized
            .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
            .expect("Bad split");
        pretok.pre_tokenize(&mut pretokenized).unwrap();
        assert_eq!(
            pretokenized
                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
                .into_iter()
                .map(|(s, o, _)| (s, o))
                .collect::<Vec<_>>(),
            vec![
                ("▁Hey", (0, 6)),
                ("▁", (6, 9)),
                ("<s>", (9, 12)),
                ("how", (12, 15)),
                ("▁", (15, 18)),
                ("<s>", (18, 21)),
                ("are", (21, 24)),
                ("▁", (24, 27)),
                ("<s>", (27, 30)),
                ("▁you", (30, 36))
            ]
        );
    }
    #[test]
    fn decode() {
        let decoder = Metaspace::new('▁', true);
        let res = decoder
            .decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
            .unwrap();
        assert_eq!(res, vec!["Hey", " friend!"])
    }
}