use crate::parallelism::*;
use crate::tokenizer::{Offsets, Token};
use crate::utils::padding::PaddingDirection;
use crate::utils::truncation::TruncationDirection;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::ops::Range;
#[derive(Default, PartialEq, Debug, Clone, Serialize, Deserialize)]
pub struct Encoding {
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
words: Vec<Option<u32>>,
offsets: Vec<Offsets>,
special_tokens_mask: Vec<u32>,
attention_mask: Vec<u32>,
overflowing: Vec<Encoding>,
sequence_ranges: HashMap<usize, Range<usize>>,
}
impl Encoding {
#[allow(clippy::too_many_arguments)]
pub fn new(
ids: Vec<u32>,
type_ids: Vec<u32>,
tokens: Vec<String>,
words: Vec<Option<u32>>,
offsets: Vec<Offsets>,
special_tokens_mask: Vec<u32>,
attention_mask: Vec<u32>,
overflowing: Vec<Self>,
sequence_ranges: HashMap<usize, Range<usize>>,
) -> Self {
Self {
ids,
type_ids,
tokens,
words,
offsets,
special_tokens_mask,
attention_mask,
overflowing,
sequence_ranges,
}
}
pub fn with_capacity(len: usize) -> Self {
Self {
ids: Vec::with_capacity(len),
type_ids: Vec::with_capacity(len),
tokens: Vec::with_capacity(len),
words: Vec::with_capacity(len),
offsets: Vec::with_capacity(len),
special_tokens_mask: Vec::with_capacity(len),
attention_mask: Vec::with_capacity(len),
overflowing: vec![],
sequence_ranges: HashMap::new(),
}
}
pub fn from_tokens(tokens: Vec<Token>, type_id: u32) -> Self {
let length = tokens.len();
let (ids, tokens, offsets) = tokens.into_iter().fold(
(
Vec::with_capacity(length),
Vec::with_capacity(length),
Vec::with_capacity(length),
),
|(mut ids, mut tokens, mut offsets), t| {
ids.push(t.id);
tokens.push(t.value);
offsets.push(t.offsets);
(ids, tokens, offsets)
},
);
Self {
ids,
tokens,
offsets,
words: vec![None; length],
type_ids: vec![type_id; length],
attention_mask: vec![1; length],
special_tokens_mask: vec![0; length],
overflowing: vec![],
sequence_ranges: HashMap::new(),
}
}
pub fn is_empty(&self) -> bool {
self.ids.is_empty()
}
pub fn len(&self) -> usize {
self.ids.len()
}
pub fn n_sequences(&self) -> usize {
if self.sequence_ranges.is_empty() {
1
} else {
self.sequence_ranges.len()
}
}
pub fn set_sequence_id(&mut self, sequence_id: usize) {
self.sequence_ranges.insert(sequence_id, 0..self.len());
}
pub fn get_tokens(&self) -> &[String] {
&self.tokens[..]
}
pub fn get_word_ids(&self) -> &[Option<u32>] {
&self.words
}
pub fn get_word_ids_mut(&mut self) -> &mut [Option<u32>] {
&mut self.words
}
pub fn get_sequence_ids(&self) -> Vec<Option<usize>> {
let mut sequences = vec![None; self.len()];
for seq_id in 0..self.n_sequences() {
let range = self.sequence_range(seq_id);
let seq_len = range.len();
sequences.splice(range, std::iter::repeat(Some(seq_id)).take(seq_len));
}
sequences
}
pub fn get_ids(&self) -> &[u32] {
&self.ids
}
pub fn get_type_ids(&self) -> &[u32] {
&self.type_ids
}
pub fn set_type_ids(&mut self, type_ids: Vec<u32>) {
self.type_ids = type_ids;
}
pub fn get_offsets(&self) -> &[Offsets] {
&self.offsets
}
pub fn get_offsets_mut(&mut self) -> &mut [Offsets] {
&mut self.offsets
}
pub fn get_special_tokens_mask(&self) -> &[u32] {
&self.special_tokens_mask
}
pub fn get_attention_mask(&self) -> &[u32] {
&self.attention_mask
}
pub fn get_overflowing(&self) -> &Vec<Encoding> {
&self.overflowing
}
pub fn set_overflowing(&mut self, overflowing: Vec<Encoding>) {
self.overflowing = overflowing;
}
pub fn get_overflowing_mut(&mut self) -> &mut Vec<Encoding> {
&mut self.overflowing
}
pub fn take_overflowing(&mut self) -> Vec<Encoding> {
std::mem::take(&mut self.overflowing)
}
pub(crate) fn process_tokens_with_offsets_mut<F>(&mut self, func: F)
where
F: FnMut((usize, (&String, &mut Offsets))),
{
self.tokens
.iter()
.zip(self.offsets.iter_mut())
.enumerate()
.for_each(func)
}
fn sequence_range(&self, sequence_id: usize) -> Range<usize> {
self.sequence_ranges
.get(&sequence_id)
.cloned()
.unwrap_or(0..self.len())
}
pub fn token_to_sequence(&self, token: usize) -> Option<usize> {
if token > self.len() {
None
} else if self.sequence_ranges.is_empty() {
Some(0)
} else {
self.sequence_ranges.iter().find_map(|(seq_id, range)| {
if range.contains(&token) {
Some(*seq_id)
} else {
None
}
})
}
}
pub fn word_to_tokens(&self, word: u32, sequence_id: usize) -> Option<(usize, usize)> {
let (mut start, mut end) = (None, None);
let sequence_range = self.sequence_range(sequence_id);
self.words
.get(sequence_range.clone())?
.iter()
.enumerate()
.take_while(|(_, w)| **w <= Some(word))
.filter(|(_, w)| **w == Some(word))
.for_each(|(i, _)| {
if start.is_none() || Some(i) < start {
start = Some(i);
}
if end.is_none() || Some(i) >= end {
end = Some(i + 1);
}
});
if let (Some(start), Some(end)) = (start, end) {
Some((sequence_range.start + start, sequence_range.start + end))
} else {
None
}
}
pub fn word_to_chars(&self, word: u32, sequence_id: usize) -> Option<Offsets> {
self.word_to_tokens(word, sequence_id)
.and_then(|(start, end)| {
if end == 0 {
None
} else {
Some((self.offsets[start].0, self.offsets[end - 1].1))
}
})
}
pub fn token_to_chars(&self, token: usize) -> Option<(usize, Offsets)> {
Some((
self.token_to_sequence(token)?,
self.offsets.get(token).copied()?,
))
}
pub fn token_to_word(&self, token: usize) -> Option<(usize, u32)> {
Some((
self.token_to_sequence(token)?,
self.words.get(token).copied().flatten()?,
))
}
pub fn char_to_token(&self, pos: usize, sequence_id: usize) -> Option<usize> {
let sequence_range = self.sequence_range(sequence_id);
self.offsets
.get(sequence_range.clone())?
.iter()
.position(|(start, end)| pos >= *start && pos < *end)
.map(|pos| sequence_range.start + pos)
}
pub fn char_to_word(&self, pos: usize, sequence_id: usize) -> Option<u32> {
Some(
self.char_to_token(pos, sequence_id)
.and_then(|token| self.token_to_word(token))?
.1,
)
}
pub fn truncate(&mut self, max_len: usize, stride: usize, direction: TruncationDirection) {
let encoding_len = self.ids.len();
if max_len >= encoding_len {
return;
}
if max_len == 0 {
let o = std::mem::replace(self, Encoding::with_capacity(0));
self.overflowing.push(o);
return;
}
assert!(stride < max_len, "`stride` must be strictly less than `max_len={}` (note that `max_len` may be shorter than the max length of the original model, as it subtracts the number of special characters", max_len);
self.sequence_ranges.clear();
let offset = max_len - stride;
let mut end = false;
let parts_ranges: Vec<(usize, usize)> = match direction {
TruncationDirection::Right => (0..encoding_len)
.step_by(offset)
.filter_map(|start| {
if !end {
let stop = std::cmp::min(start + max_len, encoding_len);
end = stop == encoding_len;
Some((start, stop))
} else {
None
}
})
.collect(),
TruncationDirection::Left => (0..encoding_len)
.rev()
.step_by(offset)
.filter_map(|stop| {
let stop = stop + 1;
let start = if stop < max_len { 0 } else { stop - max_len };
if start < stop && !end {
end = start == 0;
Some((start, stop))
} else {
None
}
})
.collect(),
};
let mut i = 0;
let (start, stop) = parts_ranges[i];
let mut new_encoding = Encoding {
ids: self.ids[start..stop].to_vec(),
type_ids: self.type_ids[start..stop].to_vec(),
tokens: self.tokens[start..stop].to_vec(),
words: self.words[start..stop].to_vec(),
offsets: self.offsets[start..stop].to_vec(),
special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
attention_mask: self.attention_mask[start..stop].to_vec(),
overflowing: vec![],
sequence_ranges: HashMap::new(),
};
loop {
if i == parts_ranges.len() - 1 {
break;
}
i += 1;
let (start, stop) = parts_ranges[i];
new_encoding.overflowing.push(Encoding {
ids: self.ids[start..stop].to_vec(),
type_ids: self.type_ids[start..stop].to_vec(),
tokens: self.tokens[start..stop].to_vec(),
words: self.words[start..stop].to_vec(),
offsets: self.offsets[start..stop].to_vec(),
special_tokens_mask: self.special_tokens_mask[start..stop].to_vec(),
attention_mask: self.attention_mask[start..stop].to_vec(),
overflowing: vec![],
sequence_ranges: HashMap::new(),
});
}
*self = new_encoding;
}
pub fn merge<I: IntoIterator<Item = Encoding>>(encodings: I, growing_offsets: bool) -> Self {
let mut encoding = Encoding::default();
for sub in encodings {
encoding.merge_with(sub, growing_offsets);
}
encoding
}
pub fn merge_with(&mut self, pair: Encoding, growing_offsets: bool) {
let mut overflowings = vec![];
for self_o in &self.overflowing {
let mut n_encoding = self_o.clone();
n_encoding.merge_with(pair.clone(), growing_offsets);
overflowings.push(n_encoding);
for other_o in &pair.overflowing {
let mut n_encoding = self_o.clone();
n_encoding.merge_with(other_o.clone(), growing_offsets);
overflowings.push(n_encoding);
}
}
for other_o in &pair.overflowing {
let mut n_encoding = self.clone();
n_encoding.merge_with(other_o.clone(), growing_offsets);
overflowings.push(n_encoding);
}
let original_self_len = self.len(); self.sequence_ranges
.extend(pair.sequence_ranges.into_iter().map(|(seq_id, range)| {
(
seq_id,
original_self_len + range.start..original_self_len + range.end,
)
}));
self.ids.extend(pair.ids);
self.type_ids.extend(pair.type_ids);
self.tokens.extend(pair.tokens);
self.words.extend(pair.words);
let starting_offset = if growing_offsets {
self.offsets.last().map_or(0, |o| o.1)
} else {
0
};
self.offsets.extend(
pair.offsets
.into_iter()
.map(|(start, end)| (start + starting_offset, end + starting_offset))
.collect::<Vec<_>>(),
);
self.special_tokens_mask.extend(pair.special_tokens_mask);
self.attention_mask.extend(pair.attention_mask);
self.overflowing = overflowings;
}
pub fn pad(
&mut self,
target_length: usize,
pad_id: u32,
pad_type_id: u32,
pad_token: &str,
direction: PaddingDirection,
) {
self.overflowing.maybe_par_iter_mut().for_each(|encoding| {
encoding.pad(target_length, pad_id, pad_type_id, pad_token, direction)
});
if self.ids.len() >= target_length {
return;
}
let pad_length = target_length - self.ids.len();
match direction {
PaddingDirection::Left => {
self.ids = (0..pad_length)
.map(|_| pad_id)
.chain(self.ids.drain(..))
.collect();
self.type_ids = (0..pad_length)
.map(|_| pad_type_id)
.chain(self.type_ids.drain(..))
.collect();
self.tokens = (0..pad_length)
.map(|_| pad_token.to_owned())
.chain(self.tokens.drain(..))
.collect();
self.words = (0..pad_length)
.map(|_| None)
.chain(self.words.drain(..))
.collect();
self.attention_mask = (0..pad_length)
.map(|_| 0)
.chain(self.attention_mask.drain(..))
.collect();
self.special_tokens_mask = (0..pad_length)
.map(|_| 1)
.chain(self.special_tokens_mask.drain(..))
.collect();
self.offsets = (0..pad_length)
.map(|_| (0, 0))
.chain(self.offsets.drain(..))
.collect();
self.sequence_ranges
.iter_mut()
.for_each(|(_seq_id, range)| {
*range = (range.start + pad_length)..(range.end + pad_length)
});
}
PaddingDirection::Right => {
self.ids.extend((0..pad_length).map(|_| pad_id));
self.type_ids.extend((0..pad_length).map(|_| pad_type_id));
self.tokens
.extend((0..pad_length).map(|_| pad_token.to_owned()));
self.words.extend((0..pad_length).map(|_| None));
self.attention_mask.extend((0..pad_length).map(|_| 0));
self.special_tokens_mask.extend((0..pad_length).map(|_| 1));
self.offsets.extend((0..pad_length).map(|_| (0, 0)));
}
}
}
}
impl std::iter::FromIterator<Encoding> for Encoding {
fn from_iter<I: IntoIterator<Item = Encoding>>(iter: I) -> Self {
Self::merge(iter, false)
}
}
impl std::iter::FromIterator<(u32, String, (usize, usize), Option<u32>, u32)> for Encoding {
fn from_iter<I: IntoIterator<Item = (u32, String, (usize, usize), Option<u32>, u32)>>(
iter: I,
) -> Self {
let items = iter.into_iter();
let (lower, upper) = items.size_hint();
let length = upper.unwrap_or(lower);
let mut encoding = Self::with_capacity(length);
for (id, token, offsets, word, type_id) in items {
encoding.ids.push(id);
encoding.tokens.push(token);
encoding.offsets.push(offsets);
encoding.type_ids.push(type_id);
encoding.words.push(word);
encoding.special_tokens_mask.push(0);
encoding.attention_mask.push(1);
}
encoding
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::iter::FromIterator;
#[test]
fn merge_encodings() {
let mut a = Encoding {
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello ")],
words: vec![Some(0)],
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
..Default::default()
};
let b = Encoding {
ids: vec![2],
type_ids: vec![1],
tokens: vec![String::from("World!")],
words: vec![Some(0)],
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
..Default::default()
};
a.merge_with(b, true);
assert_eq!(
a,
Encoding {
ids: vec![1, 2],
type_ids: vec![0, 1],
tokens: vec![String::from("Hello "), String::from("World!")],
words: vec![Some(0), Some(0)],
offsets: vec![(0, 6), (6, 12)],
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
..Default::default()
}
);
}
#[test]
fn truncate() {
let mut a = Encoding {
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
String::from("Hello"),
String::from("World"),
String::from("!"),
],
words: vec![Some(0), Some(1), Some(2)],
offsets: vec![(0, 5), (6, 11), (11, 12)],
special_tokens_mask: vec![0, 0, 0],
attention_mask: vec![1, 1, 1],
..Default::default()
};
a.truncate(2, 0, TruncationDirection::Right);
assert_eq!(
a,
Encoding {
ids: vec![1, 2],
type_ids: vec![0, 0],
tokens: vec![String::from("Hello"), String::from("World")],
words: vec![Some(0), Some(1)],
offsets: vec![(0, 5), (6, 11)],
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: vec![Encoding {
ids: vec![3],
type_ids: vec![0],
tokens: vec![String::from("!")],
words: vec![Some(2)],
offsets: vec![(11, 12)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
..Default::default()
}],
..Default::default()
}
);
}
#[test]
fn truncate_to_empty() {
let mut a = Encoding {
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
String::from("Hello"),
String::from("World"),
String::from("!"),
],
words: vec![Some(0), Some(1), Some(2)],
offsets: vec![(0, 5), (6, 11), (11, 12)],
special_tokens_mask: vec![0, 0, 0],
attention_mask: vec![1, 1, 1],
..Default::default()
};
a.truncate(0, 0, TruncationDirection::Right);
assert_eq!(
a,
Encoding {
overflowing: vec![Encoding {
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
String::from("Hello"),
String::from("World"),
String::from("!"),
],
words: vec![Some(0), Some(1), Some(2)],
offsets: vec![(0, 5), (6, 11), (11, 12)],
special_tokens_mask: vec![0, 0, 0],
attention_mask: vec![1, 1, 1],
overflowing: vec![],
..Default::default()
}],
..Default::default()
}
);
}
#[test]
fn truncate_overflow_with_stride() {
let mut enc = Encoding {
ids: vec![1, 2, 3, 4, 5],
type_ids: vec![0, 0, 0, 0, 0],
tokens: vec![
String::from("42"),
String::from("is"),
String::from("the"),
String::from("answer"),
String::from("!"),
],
words: vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13), (13, 14)],
special_tokens_mask: vec![0, 0, 0, 0, 0],
attention_mask: vec![1, 1, 1, 1, 1],
overflowing: vec![],
..Default::default()
};
enc.truncate(4, 2, TruncationDirection::Right);
assert_eq!(
enc,
Encoding {
ids: vec![1, 2, 3, 4],
type_ids: vec![0, 0, 0, 0],
tokens: vec![
String::from("42"),
String::from("is"),
String::from("the"),
String::from("answer"),
],
words: vec![Some(0), Some(1), Some(2), Some(3)],
offsets: vec![(0, 2), (2, 4), (4, 7), (7, 13)],
special_tokens_mask: vec![0, 0, 0, 0],
attention_mask: vec![1, 1, 1, 1],
overflowing: vec![Encoding {
ids: vec![3, 4, 5],
type_ids: vec![0, 0, 0],
tokens: vec![
String::from("the"),
String::from("answer"),
String::from("!"),
],
words: vec![Some(2), Some(3), Some(4)],
offsets: vec![(4, 7), (7, 13), (13, 14)],
special_tokens_mask: vec![0, 0, 0],
attention_mask: vec![1, 1, 1],
overflowing: vec![],
..Default::default()
}],
..Default::default()
}
);
}
#[test]
fn truncate_left() {
let mut a = Encoding {
ids: vec![1, 2, 3],
type_ids: vec![0, 0, 0],
tokens: vec![
String::from("Hello"),
String::from("World"),
String::from("!"),
],
words: vec![Some(0), Some(1), Some(2)],
offsets: vec![(0, 5), (6, 11), (11, 12)],
special_tokens_mask: vec![0, 0, 0],
attention_mask: vec![1, 1, 1],
..Default::default()
};
a.truncate(2, 0, TruncationDirection::Left);
assert_eq!(
a,
Encoding {
ids: vec![2, 3],
type_ids: vec![0, 0],
tokens: vec![String::from("World"), String::from("!")],
words: vec![Some(1), Some(2)],
offsets: vec![(6, 11), (11, 12)],
special_tokens_mask: vec![0, 0],
attention_mask: vec![1, 1],
overflowing: vec![Encoding {
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello")],
words: vec![Some(0)],
offsets: vec![(0, 5)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
..Default::default()
}],
..Default::default()
}
);
}
#[test]
fn mappings() {
let encoding = Encoding {
ids: vec![0; 11], tokens: vec![
"He".into(),
"llo".into(),
"won".into(),
"der".into(),
"ful".into(),
"friend".into(),
"!".into(),
"How".into(),
"are".into(),
"you".into(),
"?".into(),
],
offsets: vec![
(0, 2),
(2, 5),
(7, 10),
(10, 13),
(13, 16),
(17, 23),
(23, 24),
(0, 3),
(4, 7),
(8, 11),
(11, 12),
],
words: vec![
Some(0),
Some(0),
Some(1),
Some(1),
Some(1),
Some(2),
Some(3),
Some(0),
Some(1),
Some(2),
Some(3),
],
sequence_ranges: HashMap::from_iter(vec![(0, 0..7), (1, 7..11)]),
..Default::default()
};
assert_eq!(encoding.word_to_tokens(0, 0), Some((0, 2)));
assert_eq!(encoding.word_to_tokens(1, 0), Some((2, 5)));
assert_eq!(encoding.word_to_tokens(2, 0), Some((5, 6)));
assert_eq!(encoding.word_to_tokens(3, 0), Some((6, 7)));
assert_eq!(encoding.word_to_tokens(0, 1), Some((7, 8)));
assert_eq!(encoding.word_to_tokens(1, 1), Some((8, 9)));
assert_eq!(encoding.word_to_tokens(2, 1), Some((9, 10)));
assert_eq!(encoding.word_to_tokens(3, 1), Some((10, 11)));
assert_eq!(encoding.word_to_chars(0, 0), Some((0, 5)));
assert_eq!(encoding.word_to_chars(1, 0), Some((7, 16)));
assert_eq!(encoding.word_to_chars(0, 1), Some((0, 3)));
assert_eq!(encoding.word_to_chars(1, 1), Some((4, 7)));
assert_eq!(encoding.token_to_chars(0), Some((0, (0, 2))));
assert_eq!(encoding.token_to_chars(1), Some((0, (2, 5))));
assert_eq!(encoding.token_to_chars(7), Some((1, (0, 3))));
assert_eq!(encoding.token_to_chars(9), Some((1, (8, 11))));
assert_eq!(encoding.token_to_word(1), Some((0, 0)));
assert_eq!(encoding.token_to_word(2), Some((0, 1)));
assert_eq!(encoding.token_to_word(7), Some((1, 0)));
assert_eq!(encoding.token_to_word(9), Some((1, 2)));
assert_eq!(encoding.token_to_word(11), None);
assert_eq!(encoding.char_to_token(3, 0), Some(1));
assert_eq!(encoding.char_to_token(8, 0), Some(2));
assert_eq!(encoding.char_to_token(16, 0), None);
assert_eq!(encoding.char_to_token(23, 0), Some(6));
assert_eq!(encoding.char_to_token(2, 1), Some(7));
assert_eq!(encoding.char_to_token(9, 1), Some(9));
assert_eq!(encoding.char_to_word(3, 0), Some(0));
assert_eq!(encoding.char_to_word(8, 0), Some(1));
assert_eq!(encoding.char_to_word(16, 0), None);
assert_eq!(encoding.char_to_word(23, 0), Some(3));
assert_eq!(encoding.char_to_word(2, 1), Some(0));
assert_eq!(encoding.char_to_word(9, 1), Some(2));
}
#[test]
fn padding() {
let mut a = Encoding {
ids: vec![1],
type_ids: vec![0],
tokens: vec![String::from("Hello ")],
words: vec![Some(0)],
offsets: vec![(0, 6)],
special_tokens_mask: vec![0],
attention_mask: vec![1],
sequence_ranges: HashMap::from([(0, 0..1)]),
..Default::default()
};
let target_length = 2;
let pad_id = 99;
let pad_type_id = 0;
let pad_token = "[PAD]";
a.pad(
target_length,
pad_id,
pad_type_id,
pad_token,
PaddingDirection::Left,
);
assert_eq!(a.sequence_ranges, HashMap::from([(0, 1..2)]));
}
}