use crate::models::bpe::BPE;
use crate::tokenizer::{Model, Result, Token};
use std::{
borrow::Cow,
collections::HashMap,
fs::File,
io::prelude::*,
io::{BufRead, BufReader},
path::{Path, PathBuf},
};
mod serialization;
mod trainer;
pub use trainer::*;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("WordPiece error: Missing [UNK] token from the vocabulary")]
MissingUnkToken,
}
type Vocab = HashMap<String, u32>;
type VocabR = HashMap<u32, String>;
struct Config {
files: Option<String>,
vocab: Vocab,
unk_token: String,
continuing_subword_prefix: String,
max_input_chars_per_word: usize,
}
pub struct WordPieceBuilder {
config: Config,
}
impl Default for WordPieceBuilder {
fn default() -> Self {
Self {
config: Config {
files: None,
vocab: HashMap::new(),
unk_token: String::from("[UNK]"),
continuing_subword_prefix: String::from("##"),
max_input_chars_per_word: 100,
},
}
}
}
impl WordPieceBuilder {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn files(mut self, vocab: String) -> Self {
self.config.files = Some(vocab);
self
}
#[must_use]
pub fn vocab(mut self, vocab: Vocab) -> Self {
self.config.vocab = vocab;
self
}
#[must_use]
pub fn unk_token(mut self, unk_token: String) -> Self {
self.config.unk_token = unk_token;
self
}
#[must_use]
pub fn continuing_subword_prefix(mut self, continuing_subword_prefix: String) -> Self {
self.config.continuing_subword_prefix = continuing_subword_prefix;
self
}
#[must_use]
pub fn max_input_chars_per_word(mut self, max_input_chars_per_word: usize) -> Self {
self.config.max_input_chars_per_word = max_input_chars_per_word;
self
}
pub fn build(mut self) -> Result<WordPiece> {
if let Some(vocab) = self.config.files {
self.config.vocab = WordPiece::read_file(&vocab)?;
}
let vocab_r = self
.config
.vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
Ok(WordPiece {
vocab: self.config.vocab,
vocab_r,
unk_token: self.config.unk_token,
continuing_subword_prefix: self.config.continuing_subword_prefix,
max_input_chars_per_word: self.config.max_input_chars_per_word,
})
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct WordPiece {
vocab: Vocab,
vocab_r: VocabR,
pub unk_token: String,
pub continuing_subword_prefix: String,
pub max_input_chars_per_word: usize,
}
impl std::fmt::Debug for WordPiece {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("WordPiece")
.field("unk_token", &self.unk_token)
.field("continuing_subword_prefix", &self.continuing_subword_prefix)
.field("max_input_chars_per_word", &self.max_input_chars_per_word)
.field("vocab", &self.vocab.len())
.finish()
}
}
impl Default for WordPiece {
fn default() -> Self {
Self {
vocab: HashMap::new(),
vocab_r: HashMap::new(),
unk_token: String::from("[UNK]"),
continuing_subword_prefix: String::from("##"),
max_input_chars_per_word: 100,
}
}
}
impl WordPiece {
pub fn builder() -> WordPieceBuilder {
WordPieceBuilder::new()
}
pub fn read_file(vocab: &str) -> Result<Vocab> {
let file = File::open(vocab)?;
let file = BufReader::new(file);
let mut vocab = HashMap::new();
for (index, line) in file.lines().enumerate() {
let line = line?;
vocab.insert(line.trim_end().to_owned(), index as u32);
}
Ok(vocab)
}
pub fn from_file(vocab: &str) -> WordPieceBuilder {
WordPiece::builder().files(vocab.to_owned())
}
pub fn from_bpe(bpe: &BPE) -> Self {
let mut wp = Self::builder().vocab(bpe.get_vocab()).build().unwrap();
if let Some(unk) = bpe.get_unk_token() {
wp.unk_token = unk.to_owned();
}
if let Some(prefix) = bpe.get_continuing_subword_prefix() {
wp.continuing_subword_prefix = prefix.to_owned();
}
wp
}
}
impl Model for WordPiece {
type Trainer = WordPieceTrainer;
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.clone()
}
fn get_vocab_size(&self) -> usize {
self.vocab.len()
}
fn tokenize(&self, sequence: &str) -> Result<Vec<Token>> {
let char_len = sequence.chars().count();
if char_len > self.max_input_chars_per_word {
return Ok(vec![Token {
value: self.unk_token.clone(),
id: *self
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (0, sequence.len()),
}]);
}
let mut is_bad = false;
let mut start = 0;
let mut sub_tokens: Vec<Token> = vec![];
while start < sequence.len() {
let mut end = sequence.len();
let mut cur_str = None;
while start < end {
let mut substr: Cow<str> = Cow::Borrowed(&sequence[start..end]);
if start > 0 {
substr = Cow::Owned(format!("{}{}", self.continuing_subword_prefix, substr));
}
if self.vocab.contains_key(substr.as_ref()) {
cur_str = Some(Token {
id: self.vocab[substr.as_ref()],
value: substr.to_string(),
offsets: (start, end),
});
break;
}
end -= substr.chars().last().map_or(1, |c| c.len_utf8());
}
if cur_str.is_none() {
is_bad = true;
break;
}
sub_tokens.push(cur_str.unwrap());
start = end;
}
if is_bad {
Ok(vec![Token {
value: self.unk_token.clone(),
id: *self
.vocab
.get(&self.unk_token)
.ok_or(Error::MissingUnkToken)?,
offsets: (0, sequence.len()),
}])
} else {
Ok(sub_tokens)
}
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
}
fn save(&self, folder: &Path, name: Option<&str>) -> Result<Vec<PathBuf>> {
let vocab_file_name = match name {
Some(name) => format!("{}-vocab.txt", name),
None => "vocab.txt".to_string(),
};
let vocab_path: PathBuf = [folder, Path::new(vocab_file_name.as_str())]
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
vocab.sort_unstable_by_key(|k| *k.1);
vocab_file.write_all(
&vocab
.into_iter()
.flat_map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
.collect::<Vec<_>>()[..],
)?;
Ok(vec![vocab_path])
}
fn get_trainer(&self) -> Self::Trainer {
WordPieceTrainer::builder().build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
assert!(format!("{}", Error::MissingUnkToken).contains("Missing [UNK] token"));
}
}