use crate::processors::PostProcessorWrapper;
use crate::tokenizer::{Encoding, PostProcessor, Result};
use crate::utils::macro_rules_attribute;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct Sequence {
processors: Vec<PostProcessorWrapper>,
}
impl Sequence {
pub fn new(processors: Vec<PostProcessorWrapper>) -> Self {
Self { processors }
}
}
impl PostProcessor for Sequence {
fn added_tokens(&self, is_pair: bool) -> usize {
self.processors
.iter()
.map(|p| p.added_tokens(is_pair))
.sum::<usize>()
}
fn process_encodings(
&self,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
for processor in &self.processors {
encodings = processor.process_encodings(encodings, add_special_tokens)?;
}
Ok(encodings)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::processors::{ByteLevel, PostProcessorWrapper};
use crate::tokenizer::{Encoding, PostProcessor};
use std::collections::HashMap;
use std::iter::FromIterator;
#[test]
fn process_chain() {
let start = Encoding::new(
vec![0; 5],
vec![0; 5],
vec![
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![(0, 1), (0, 11), (11, 18), (18, 25), (25, 29)],
vec![],
vec![],
vec![],
HashMap::new(),
);
let bytelevel = ByteLevel::default().trim_offsets(true);
let sequence = Sequence::new(vec![PostProcessorWrapper::ByteLevel(bytelevel)]);
let expected = Encoding::new(
vec![0; 5],
vec![0; 5],
vec![
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![(0, 0), (4, 9), (13, 18), (18, 23), (29, 29)],
vec![],
vec![],
vec![],
HashMap::from_iter(vec![(0, 0..5)]),
);
assert_eq!(
expected,
bytelevel.process(start.clone(), None, false).unwrap()
);
assert_eq!(
expected,
sequence.process(start.clone(), None, false).unwrap()
);
let pair_expected = Encoding::new(
vec![0; 10],
vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
vec![
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
"Ġ".into(),
"ĠĠĠĠHelloĠĠ".into(),
"ĠĠHello".into(),
"HelloĠĠ".into(),
"ĠĠĠĠ".into(),
],
vec![],
vec![
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
(0, 0),
(4, 9),
(13, 18),
(18, 23),
(29, 29),
],
vec![],
vec![],
vec![],
HashMap::from_iter(vec![(0, 0..5), (1, 5..10)]),
);
assert_eq!(
pair_expected,
bytelevel
.process(start.clone(), Some(start.clone()), false)
.unwrap()
);
assert_eq!(
pair_expected,
sequence.process(start.clone(), Some(start), false).unwrap()
);
}
}