use super::{super::OrderedVocabIter, convert_merges_to_hashmap, BpeBuilder, Pair, BPE};
use serde::{
de::{Error, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
use std::collections::HashMap;
impl Serialize for BPE {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut model = serializer.serialize_struct("BPE", 8)?;
model.serialize_field("type", "BPE")?;
model.serialize_field("dropout", &self.dropout)?;
model.serialize_field("unk_token", &self.unk_token)?;
model.serialize_field("continuing_subword_prefix", &self.continuing_subword_prefix)?;
model.serialize_field("end_of_word_suffix", &self.end_of_word_suffix)?;
model.serialize_field("fuse_unk", &self.fuse_unk)?;
model.serialize_field("byte_fallback", &self.byte_fallback)?;
let mut merges: Vec<(&Pair, &u32)> = self
.merges
.iter()
.map(|(pair, (rank, _))| (pair, rank))
.collect();
merges.sort_unstable_by_key(|k| *k.1);
let merges_str = merges
.into_iter()
.map(|(pair, _)| format!("{} {}", self.vocab_r[&pair.0], self.vocab_r[&pair.1]))
.collect::<Vec<_>>();
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
model.serialize_field("vocab", &ordered_vocab)?;
model.serialize_field("merges", &merges_str)?;
model.end()
}
}
impl<'de> Deserialize<'de> for BPE {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct(
"BPE",
&[
"type",
"dropout",
"unk_token",
"continuing_subword_prefix",
"end_of_word_suffix",
"fuse_unk",
"byte_fallback",
"vocab",
"merges",
],
BPEVisitor,
)
}
}
struct BPEVisitor;
impl<'de> Visitor<'de> for BPEVisitor {
type Value = BPE;
fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "struct BPE")
}
fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut builder = BpeBuilder::new();
let mut vocab: Option<HashMap<String, u32>> = None;
let mut merges: Option<Vec<String>> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_ref() {
"dropout" => {
if let Some(dropout) = map.next_value()? {
builder = builder.dropout(dropout);
}
}
"unk_token" => {
if let Some(unk) = map.next_value()? {
builder = builder.unk_token(unk);
}
}
"continuing_subword_prefix" => {
if let Some(prefix) = map.next_value()? {
builder = builder.continuing_subword_prefix(prefix);
}
}
"end_of_word_suffix" => {
if let Some(suffix) = map.next_value()? {
builder = builder.end_of_word_suffix(suffix);
}
}
"fuse_unk" => {
if let Some(suffix) = map.next_value()? {
builder = builder.fuse_unk(suffix);
}
}
"byte_fallback" => {
if let Some(suffix) = map.next_value()? {
builder = builder.byte_fallback(suffix);
}
}
"vocab" => vocab = Some(map.next_value()?),
"merges" => merges = Some(map.next_value()?),
"type" => match map.next_value()? {
"BPE" => {}
u => {
return Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(u),
&"BPE",
))
}
},
_ => {}
}
}
if let (Some(vocab), Some(merges)) = (vocab, merges) {
let merges =
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?;
builder = builder.vocab_and_merges(vocab, merges);
Ok(builder.build().map_err(Error::custom)?)
} else {
Err(Error::custom("Missing vocab/merges"))
}
}
}