use serde::{Deserialize, Serialize};
use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
use crate::utils::macro_rules_attribute;
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Digits {
pub individual_digits: bool,
}
impl Digits {
pub fn new(individual_digits: bool) -> Self {
Self { individual_digits }
}
}
impl Default for Digits {
fn default() -> Self {
Self::new(false)
}
}
impl PreTokenizer for Digits {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
if self.individual_digits {
pretokenized.split(|_, normalized| {
normalized.split(char::is_numeric, SplitDelimiterBehavior::Isolated)
})
} else {
pretokenized.split(|_, normalized| {
normalized.split(char::is_numeric, SplitDelimiterBehavior::Contiguous)
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{OffsetReferential, OffsetType};
#[test]
fn numbers() {
let pretok = Digits::new(false);
let mut pretokenized = PreTokenizedString::from("Hey 123 friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("Hey ", (0, 4)), ("123", (4, 7)), (" friend!", (7, 15))]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("Hey ", (0, 4)), ("123", (4, 7)), (" friend!", (7, 15))]
);
}
#[test]
fn individual_digits() {
let pretok = Digits::new(true);
let mut pretokenized = PreTokenizedString::from("Hey 123 friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hey ", (0, 4)),
("1", (4, 5)),
("2", (5, 6)),
("3", (6, 7)),
(" friend!", (7, 15))
]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("Hey ", (0, 4)),
("1", (4, 5)),
("2", (5, 6)),
("3", (6, 7)),
(" friend!", (7, 15))
]
);
}
}