pub mod bert;
pub mod roberta;
pub mod sequence;
pub mod template;
pub use super::pre_tokenizers::byte_level;
use serde::{Deserialize, Serialize};
use crate::pre_tokenizers::byte_level::ByteLevel;
use crate::processors::bert::BertProcessing;
use crate::processors::roberta::RobertaProcessing;
use crate::processors::sequence::Sequence;
use crate::processors::template::TemplateProcessing;
use crate::{Encoding, PostProcessor, Result};
#[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Eq)]
#[serde(untagged)]
pub enum PostProcessorWrapper {
Roberta(RobertaProcessing),
Bert(BertProcessing),
ByteLevel(ByteLevel),
Template(TemplateProcessing),
Sequence(Sequence),
}
impl PostProcessor for PostProcessorWrapper {
fn added_tokens(&self, is_pair: bool) -> usize {
match self {
Self::Bert(bert) => bert.added_tokens(is_pair),
Self::ByteLevel(bl) => bl.added_tokens(is_pair),
Self::Roberta(roberta) => roberta.added_tokens(is_pair),
Self::Template(template) => template.added_tokens(is_pair),
Self::Sequence(bl) => bl.added_tokens(is_pair),
}
}
fn process_encodings(
&self,
encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
match self {
Self::Bert(bert) => bert.process_encodings(encodings, add_special_tokens),
Self::ByteLevel(bl) => bl.process_encodings(encodings, add_special_tokens),
Self::Roberta(roberta) => roberta.process_encodings(encodings, add_special_tokens),
Self::Template(template) => template.process_encodings(encodings, add_special_tokens),
Self::Sequence(bl) => bl.process_encodings(encodings, add_special_tokens),
}
}
}
impl_enum_from!(BertProcessing, PostProcessorWrapper, Bert);
impl_enum_from!(ByteLevel, PostProcessorWrapper, ByteLevel);
impl_enum_from!(RobertaProcessing, PostProcessorWrapper, Roberta);
impl_enum_from!(TemplateProcessing, PostProcessorWrapper, Template);
impl_enum_from!(Sequence, PostProcessorWrapper, Sequence);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deserialize_bert_roberta_correctly() {
let roberta = RobertaProcessing::default();
let roberta_r = r#"{
"type":"RobertaProcessing",
"sep":["</s>",2],
"cls":["<s>",0],
"trim_offsets":true,
"add_prefix_space":true
}"#
.replace(char::is_whitespace, "");
assert_eq!(serde_json::to_string(&roberta).unwrap(), roberta_r);
assert_eq!(
serde_json::from_str::<PostProcessorWrapper>(&roberta_r).unwrap(),
PostProcessorWrapper::Roberta(roberta)
);
let bert = BertProcessing::default();
let bert_r = r#"{"type":"BertProcessing","sep":["[SEP]",102],"cls":["[CLS]",101]}"#;
assert_eq!(serde_json::to_string(&bert).unwrap(), bert_r);
assert_eq!(
serde_json::from_str::<PostProcessorWrapper>(bert_r).unwrap(),
PostProcessorWrapper::Bert(bert)
);
}
}