1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use super::{super::OrderedVocabIter, WordLevel, WordLevelBuilder};
use serde::{
    de::{MapAccess, Visitor},
    ser::SerializeStruct,
    Deserialize, Deserializer, Serialize, Serializer,
};
use std::collections::HashSet;

impl Serialize for WordLevel {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut model = serializer.serialize_struct("WordLevel", 3)?;
        let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
        model.serialize_field("type", "WordLevel")?;
        model.serialize_field("vocab", &ordered_vocab)?;
        model.serialize_field("unk_token", &self.unk_token)?;
        model.end()
    }
}

impl<'de> Deserialize<'de> for WordLevel {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_struct(
            "WordLevel",
            &["type", "vocab", "unk_token"],
            WordLevelVisitor,
        )
    }
}

struct WordLevelVisitor;
impl<'de> Visitor<'de> for WordLevelVisitor {
    type Value = WordLevel;

    fn expecting(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(fmt, "struct WordLevel")
    }

    fn visit_map<V>(self, mut map: V) -> std::result::Result<Self::Value, V::Error>
    where
        V: MapAccess<'de>,
    {
        let mut builder = WordLevelBuilder::new();
        let mut missing_fields = vec![
            // for retrocompatibility the "type" field is not mandatory
            "unk_token",
            "vocab",
        ]
        .into_iter()
        .collect::<HashSet<_>>();
        while let Some(key) = map.next_key::<String>()? {
            match key.as_ref() {
                "vocab" => builder = builder.vocab(map.next_value()?),
                "unk_token" => builder = builder.unk_token(map.next_value()?),
                "type" => match map.next_value()? {
                    "WordLevel" => {}
                    u => {
                        return Err(serde::de::Error::invalid_value(
                            serde::de::Unexpected::Str(u),
                            &"WordLevel",
                        ))
                    }
                },
                _ => {}
            }
            missing_fields.remove::<str>(&key);
        }

        if !missing_fields.is_empty() {
            Err(serde::de::Error::missing_field(
                missing_fields.iter().next().unwrap(),
            ))
        } else {
            Ok(builder.build().map_err(serde::de::Error::custom)?)
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::models::wordlevel::{Vocab, WordLevel, WordLevelBuilder};

    #[test]
    fn serde() {
        let wl = WordLevel::default();
        let wl_s = r#"{"type":"WordLevel","vocab":{},"unk_token":"<unk>"}"#;

        assert_eq!(serde_json::to_string(&wl).unwrap(), wl_s);
        assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wl);
    }

    #[test]
    fn incomplete_vocab() {
        let vocab: Vocab = [("<unk>".into(), 0), ("b".into(), 2)]
            .iter()
            .cloned()
            .collect();
        let wordlevel = WordLevelBuilder::default()
            .vocab(vocab)
            .unk_token("<unk>".to_string())
            .build()
            .unwrap();
        let wl_s = r#"{"type":"WordLevel","vocab":{"<unk>":0,"b":2},"unk_token":"<unk>"}"#;
        assert_eq!(serde_json::to_string(&wordlevel).unwrap(), wl_s);
        assert_eq!(serde_json::from_str::<WordLevel>(wl_s).unwrap(), wordlevel);
    }

    #[test]
    fn deserialization_should_fail() {
        let missing_unk = r#"{"type":"WordLevel","vocab":{}}"#;
        assert!(serde_json::from_str::<WordLevel>(missing_unk)
            .unwrap_err()
            .to_string()
            .starts_with("missing field `unk_token`"));

        let wrong_type = r#"{"type":"WordPiece","vocab":{}}"#;
        assert!(serde_json::from_str::<WordLevel>(wrong_type)
            .unwrap_err()
            .to_string()
            .starts_with("invalid value: string \"WordPiece\", expected WordLevel"));
    }
}