use crate::tokenizer::{Encoding, Result};
use serde::{Deserialize, Serialize};
use std::cmp;
use std::mem;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq, Default)]
pub enum TruncationDirection {
Left,
#[default]
Right,
}
impl std::convert::AsRef<str> for TruncationDirection {
fn as_ref(&self) -> &str {
match self {
TruncationDirection::Left => "left",
TruncationDirection::Right => "right",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TruncationParams {
#[serde(default)]
pub direction: TruncationDirection,
pub max_length: usize,
pub strategy: TruncationStrategy,
pub stride: usize,
}
impl Default for TruncationParams {
fn default() -> Self {
Self {
max_length: 512,
strategy: TruncationStrategy::default(),
stride: 0,
direction: TruncationDirection::default(),
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum TruncationError {
#[error("Truncation error: Second sequence not provided")]
SecondSequenceNotProvided,
#[error("Truncation error: Sequence to truncate too short to respect the provided max_length")]
SequenceTooShort,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Eq)]
pub enum TruncationStrategy {
LongestFirst,
OnlyFirst,
OnlySecond,
}
impl Default for TruncationStrategy {
fn default() -> Self {
Self::LongestFirst
}
}
impl std::convert::AsRef<str> for TruncationStrategy {
fn as_ref(&self) -> &str {
match self {
Self::LongestFirst => "longest_first",
Self::OnlyFirst => "only_first",
Self::OnlySecond => "only_second",
}
}
}
pub fn truncate_encodings(
mut encoding: Encoding,
mut pair_encoding: Option<Encoding>,
params: &TruncationParams,
) -> Result<(Encoding, Option<Encoding>)> {
if params.max_length == 0 {
encoding.truncate(0, params.stride, params.direction);
if let Some(other_encoding) = pair_encoding.as_mut() {
other_encoding.truncate(0, params.stride, params.direction);
}
return Ok((encoding, pair_encoding));
}
let total_length = encoding.get_ids().len()
+ pair_encoding
.as_ref()
.map(|e| e.get_ids().len())
.unwrap_or(0);
let to_remove = if total_length > params.max_length {
total_length - params.max_length
} else {
return Ok((encoding, pair_encoding));
};
match params.strategy {
TruncationStrategy::LongestFirst => {
if let Some(other_encoding) = pair_encoding.as_mut() {
let mut n1 = encoding.get_ids().len();
let mut n2 = other_encoding.get_ids().len();
let mut swap = false;
if n1 > n2 {
swap = true;
mem::swap(&mut n1, &mut n2);
}
if n1 > params.max_length {
n2 = n1;
} else {
n2 = cmp::max(n1, params.max_length - n1);
}
if n1 + n2 > params.max_length {
n1 = params.max_length / 2;
n2 = n1 + params.max_length % 2;
}
if swap {
mem::swap(&mut n1, &mut n2);
}
encoding.truncate(n1, params.stride, params.direction);
other_encoding.truncate(n2, params.stride, params.direction);
} else {
encoding.truncate(total_length - to_remove, params.stride, params.direction);
}
}
TruncationStrategy::OnlyFirst | TruncationStrategy::OnlySecond => {
let target = if params.strategy == TruncationStrategy::OnlyFirst {
Ok(&mut encoding)
} else if let Some(encoding) = pair_encoding.as_mut() {
Ok(encoding)
} else {
Err(Box::new(TruncationError::SecondSequenceNotProvided))
}?;
let target_len = target.get_ids().len();
if target_len > to_remove {
target.truncate(target_len - to_remove, params.stride, params.direction);
} else {
return Err(Box::new(TruncationError::SequenceTooShort));
}
}
}
Ok((encoding, pair_encoding))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::Encoding;
use std::collections::HashMap;
fn get_empty() -> Encoding {
Encoding::new(
vec![],
vec![],
vec![],
vec![],
vec![],
vec![],
vec![],
vec![],
HashMap::new(),
)
}
fn get_short() -> Encoding {
Encoding::new(
vec![1, 2],
vec![0, 0],
vec![String::from("a"), String::from("b")],
vec![Some(0), Some(1)],
vec![(0, 1), (1, 2)],
vec![0, 0],
vec![1, 1],
vec![],
HashMap::new(),
)
}
fn get_medium() -> Encoding {
Encoding::new(
vec![3, 4, 5, 6],
vec![0, 0, 0, 0],
vec![
String::from("d"),
String::from("e"),
String::from("f"),
String::from("g"),
],
vec![Some(0), Some(1), Some(2), Some(3)],
vec![(0, 1), (1, 2), (2, 3), (3, 4)],
vec![0, 0, 0, 0],
vec![1, 1, 1, 1],
vec![],
HashMap::new(),
)
}
fn get_long() -> Encoding {
Encoding::new(
vec![7, 8, 9, 10, 11, 12, 13, 14],
vec![0, 0, 0, 0, 0, 0, 0, 0],
vec![
String::from("h"),
String::from("i"),
String::from("j"),
String::from("k"),
String::from("l"),
String::from("m"),
String::from("n"),
String::from("o"),
],
vec![
Some(0),
Some(1),
Some(2),
Some(3),
Some(4),
Some(5),
Some(6),
Some(7),
],
vec![
(0, 1),
(1, 2),
(2, 3),
(3, 4),
(4, 5),
(5, 6),
(6, 7),
(6, 8),
],
vec![0, 0, 0, 0, 0, 0, 0, 0],
vec![1, 1, 1, 1, 1, 1, 1, 1],
vec![],
HashMap::new(),
)
}
fn truncate_and_assert(
encoding1: Encoding,
encoding2: Encoding,
params: &TruncationParams,
n1: usize,
n2: usize,
) {
match truncate_encodings(encoding1, Some(encoding2), params) {
Ok((e1, Some(e2))) => {
assert!(e1.get_ids().len() == n1);
assert!(e2.get_ids().len() == n2);
}
_ => panic!(),
};
}
#[test]
fn truncate_encodings_longest_first() {
let params = TruncationParams {
max_length: 7,
strategy: TruncationStrategy::LongestFirst,
stride: 0,
direction: TruncationDirection::Right,
};
truncate_and_assert(get_empty(), get_empty(), ¶ms, 0, 0);
truncate_and_assert(get_empty(), get_short(), ¶ms, 0, 2);
truncate_and_assert(get_empty(), get_medium(), ¶ms, 0, 4);
truncate_and_assert(get_empty(), get_long(), ¶ms, 0, 7);
truncate_and_assert(get_short(), get_empty(), ¶ms, 2, 0);
truncate_and_assert(get_short(), get_short(), ¶ms, 2, 2);
truncate_and_assert(get_short(), get_medium(), ¶ms, 2, 4);
truncate_and_assert(get_short(), get_long(), ¶ms, 2, 5);
truncate_and_assert(get_medium(), get_empty(), ¶ms, 4, 0);
truncate_and_assert(get_medium(), get_short(), ¶ms, 4, 2);
truncate_and_assert(get_medium(), get_medium(), ¶ms, 3, 4);
truncate_and_assert(get_medium(), get_long(), ¶ms, 3, 4);
truncate_and_assert(get_long(), get_empty(), ¶ms, 7, 0);
truncate_and_assert(get_long(), get_short(), ¶ms, 5, 2);
truncate_and_assert(get_long(), get_medium(), ¶ms, 4, 3);
truncate_and_assert(get_long(), get_long(), ¶ms, 3, 4);
}
#[test]
fn truncate_encodings_empty() {
let params = TruncationParams {
max_length: 0,
strategy: TruncationStrategy::LongestFirst,
stride: 0,
direction: TruncationDirection::Right,
};
truncate_and_assert(get_empty(), get_short(), ¶ms, 0, 0);
truncate_and_assert(get_medium(), get_medium(), ¶ms, 0, 0);
truncate_and_assert(get_long(), get_long(), ¶ms, 0, 0);
}
#[test]
fn test_deserialize_defaults() {
let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;
let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();
assert_eq!(params.direction, TruncationDirection::Right);
}
}