use crate::tokenizer::{Decoder, Result};
use monostate::MustBe;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Clone, Debug, Serialize, Default)]
#[non_exhaustive]
pub struct ByteFallback {
#[serde(rename = "type")]
type_: MustBe!("ByteFallback"),
}
impl ByteFallback {
pub fn new() -> Self {
Self {
type_: MustBe!("ByteFallback"),
}
}
}
impl Decoder for ByteFallback {
fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
let mut new_tokens: Vec<String> = vec![];
let mut previous_byte_tokens: Vec<u8> = vec![];
for token in tokens {
let bytes = if token.len() == 6 && token.starts_with("<0x") && token.ends_with('>') {
if let Ok(byte) = u8::from_str_radix(&token[3..5], 16) {
Some(byte)
} else {
None
}
} else {
None
};
if let Some(bytes) = bytes {
previous_byte_tokens.push(bytes);
} else {
if !previous_byte_tokens.is_empty() {
if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
new_tokens.push(string);
} else {
for _ in 0..previous_byte_tokens.len() {
new_tokens.push("�".into());
}
}
previous_byte_tokens.clear();
}
new_tokens.push(token);
}
}
if !previous_byte_tokens.is_empty() {
if let Ok(string) = String::from_utf8(previous_byte_tokens.clone()) {
new_tokens.push(string);
} else {
for _ in 0..previous_byte_tokens.len() {
new_tokens.push("�".into());
}
}
}
Ok(new_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode() {
let decoder = ByteFallback::new();
let res = decoder
.decode_chain(vec!["Hey".into(), "friend!".into()])
.unwrap();
assert_eq!(res, vec!["Hey", "friend!"]);
let res = decoder.decode_chain(vec!["<0x61>".into()]).unwrap();
assert_eq!(res, vec!["a"]);
let res = decoder.decode_chain(vec!["<0xE5>".into()]).unwrap();
assert_eq!(res, vec!["�"]);
let res = decoder
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into()])
.unwrap();
assert_eq!(res, vec!["�", "�"]);
let res = decoder
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "<0xab>".into()])
.unwrap();
assert_eq!(res, vec!["叫"]);
let res = decoder
.decode_chain(vec![
"<0xE5>".into(),
"<0x8f>".into(),
"<0xab>".into(),
"a".into(),
])
.unwrap();
assert_eq!(res, vec!["叫", "a"]);
let res = decoder
.decode_chain(vec!["<0xE5>".into(), "<0x8f>".into(), "a".into()])
.unwrap();
assert_eq!(res, vec!["�", "�", "a"]);
}
}