use nom::{number::complete::le_u32, IResult, ToUsize};
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
use std::convert::TryFrom;
use unicode_segmentation::UnicodeSegmentation;
#[derive(Default, Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", try_from = "PrecompiledDeserializer")]
pub struct Precompiled {
#[serde(serialize_with = "as_base64", deserialize_with = "from_base64")]
precompiled_charsmap: Vec<u8>,
#[serde(skip)]
normalized: String,
#[serde(skip)]
trie: DoubleArray,
}
#[doc(hidden)]
#[derive(Deserialize)]
#[serde(tag = "type")]
struct PrecompiledDeserializer {
#[serde(deserialize_with = "from_base64")]
precompiled_charsmap: Vec<u8>,
}
fn as_base64<T, S>(key: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: AsRef<[u8]>,
S: Serializer,
{
serializer.serialize_str(&base64::encode(key.as_ref()))
}
fn from_base64<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
where
D: Deserializer<'de>,
{
let s: &str = Deserialize::deserialize(deserializer)?;
let precompiled_charsmap = base64::decode(s).map_err(|err| Error::custom(err.to_string()))?;
Ok(precompiled_charsmap)
}
impl TryFrom<PrecompiledDeserializer> for Precompiled {
type Error = PrecompiledError;
fn try_from(t: PrecompiledDeserializer) -> Result<Self, Self::Error> {
Self::from(&t.precompiled_charsmap)
}
}
pub type ArrayUnit = usize;
trait ArrayUnitTrait {
fn has_leaf(&self) -> bool;
fn value(&self) -> isize;
fn label(&self) -> usize;
fn offset(&self) -> usize;
}
impl ArrayUnitTrait for ArrayUnit {
fn has_leaf(&self) -> bool {
(self >> 8) & 1 == 1
}
fn value(&self) -> isize {
(self & ((1usize << 31) - 1)) as isize
}
fn label(&self) -> usize {
self & ((1usize << 31) | 0xFF)
}
fn offset(&self) -> usize {
(self >> 10) << ((self & (1usize << 9)) >> 6)
}
}
type Array = Vec<ArrayUnit>;
#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct DoubleArray {
array: Array,
}
impl DoubleArray {
fn from(array: Array) -> Self {
Self { array }
}
pub fn common_prefix_search(&self, key: &[u8]) -> Vec<isize> {
let mut node_pos = 0;
let mut results = vec![];
let mut unit = self.array[node_pos];
node_pos ^= unit.offset();
for c in key {
if *c == 0u8 {
break;
}
node_pos ^= *c as usize;
unit = self.array[node_pos];
if unit.label() != *c as usize {
return results;
}
node_pos ^= unit.offset();
if unit.has_leaf() {
results.push(self.array[node_pos].value());
}
}
results
}
}
fn parse(precompiled_charsmap: &[u8]) -> IResult<&[u8], Array> {
let (mut rest, trie_size) = le_u32(precompiled_charsmap)?;
let trie_char_size = trie_size / 4;
let mut trie_blob = Vec::with_capacity(trie_char_size as usize);
for _ in 0..trie_char_size {
let (rest2, n) = le_u32(rest)?;
rest = rest2;
trie_blob.push(n.to_usize());
}
let normalized_blob = rest;
Ok((normalized_blob, trie_blob))
}
#[derive(Debug)]
pub enum PrecompiledError {
ParseError,
NormalizedInvalidUtf8,
}
impl std::fmt::Display for PrecompiledError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "Cannot parse precompiled_charsmap")
}
}
impl std::error::Error for PrecompiledError {}
impl Precompiled {
pub fn from(precompiled_charsmap: &[u8]) -> Result<Precompiled, PrecompiledError> {
let (normalized_blob, trie_blob) =
parse(precompiled_charsmap).map_err(|_| PrecompiledError::ParseError)?;
let normalized = String::from_utf8(normalized_blob.to_vec())
.map_err(|_| PrecompiledError::NormalizedInvalidUtf8)?;
let trie = DoubleArray::from(trie_blob);
let precompiled = Precompiled {
precompiled_charsmap: precompiled_charsmap.to_vec(),
normalized,
trie,
};
Ok(precompiled)
}
pub fn transform(&self, chunk: &str) -> Option<&str> {
let results = self.trie.common_prefix_search(chunk.as_bytes());
if results.is_empty() {
None
} else {
let index = results[0] as usize;
let mut index2 = index;
while index2 < self.normalized.len() {
if *self.normalized.as_bytes().get(index2)? == 0u8 {
break;
}
index2 += 1;
}
let normalized = &self.normalized[index..index2];
Some(normalized)
}
}
pub fn normalize_string(&self, original: &str) -> String {
let mut string = String::with_capacity(original.len());
original.graphemes(true).for_each(|grapheme| {
if grapheme.len() < 6 {
if let Some(norm) = self.transform(grapheme) {
for c in norm.chars() {
string.push(c);
}
return;
}
}
for (char_index, c) in grapheme.char_indices() {
let part = &grapheme[char_index..char_index + c.len_utf8()];
if let Some(norm) = self.transform(part) {
for c in norm.chars() {
string.push(c);
}
} else {
string.push(c);
}
}
});
string
}
}
#[cfg(test)]
mod tests;