1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use super::WordLevel;
use crate::utils::parallelism::*;
use crate::{AddedToken, Result, Trainer};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashMap;

#[non_exhaustive]
#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
pub struct WordLevelTrainer {
    /// The minimum frequency a word must have to be part of the vocabulary
    #[builder(default = "0")]
    pub min_frequency: u64,
    /// The target vocabulary size
    #[builder(default = "30_000")]
    pub vocab_size: usize,
    /// Whether to show progress while training
    #[builder(default = "true")]
    pub show_progress: bool,
    /// A list of special tokens that the model should know of
    #[builder(default)]
    pub special_tokens: Vec<AddedToken>,

    #[builder(default, private)]
    words: HashMap<String, u64>,
}

impl Default for WordLevelTrainer {
    fn default() -> Self {
        Self::builder().build().unwrap()
    }
}

impl WordLevelTrainer {
    pub fn builder() -> WordLevelTrainerBuilder {
        WordLevelTrainerBuilder::default()
    }

    fn do_train(
        &self,
        word_counts: &HashMap<String, u64>,
        model: &mut WordLevel,
    ) -> Result<Vec<AddedToken>> {
        let mut ordered_counts = word_counts.iter().collect::<Vec<_>>();

        //sort the word counts first by inverse counts and then by word, in order
        //to keep the sorting deterministic in case of equal counts
        let cmp = |l: &(&String, &u64), r: &(&String, &u64)| -> Ordering {
            let count_comp: Ordering = l.1.cmp(r.1);
            if count_comp != Ordering::Equal {
                return count_comp.reverse();
            }
            l.0.cmp(r.0)
        };

        ordered_counts.sort_by(cmp);

        let word_level = WordLevel::builder()
            .vocab(
                self.special_tokens
                    .iter()
                    .map(|token| token.content.clone())
                    .chain(
                        ordered_counts
                            .into_iter()
                            .filter(|(_, n)| **n >= self.min_frequency)
                            .map(|(w, _)| w.to_owned()),
                    )
                    .take(self.vocab_size)
                    .enumerate()
                    .map(|(i, w)| (w, i as u32))
                    .collect(),
            )
            .build()?;

        // Transfer the vocab
        model.vocab = word_level.vocab;
        model.vocab_r = word_level.vocab_r;

        Ok(self.special_tokens.clone())
    }
}

impl Trainer for WordLevelTrainer {
    type Model = WordLevel;

    /// Train a WordLevel model
    fn train(&self, model: &mut WordLevel) -> Result<Vec<AddedToken>> {
        self.do_train(&self.words, model)
    }

    /// Whether we should show progress
    fn should_show_progress(&self) -> bool {
        self.show_progress
    }

    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
    where
        I: Iterator<Item = S> + Send,
        S: AsRef<str> + Send,
        F: Fn(&str) -> Result<Vec<String>> + Sync,
    {
        let words: Result<HashMap<String, u64>> = iterator
            .maybe_par_bridge()
            .map(|sequence| {
                let words = process(sequence.as_ref())?;
                let mut map = HashMap::new();
                for word in words {
                    map.entry(word).and_modify(|c| *c += 1).or_insert(1);
                }
                Ok(map)
            })
            .reduce(
                || Ok(HashMap::new()),
                |acc, ws| {
                    let mut acc = acc?;
                    for (k, v) in ws? {
                        acc.entry(k).and_modify(|c| *c += v).or_insert(v);
                    }
                    Ok(acc)
                },
            );

        self.words = words?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_train() {
        let word_counts: HashMap<String, u64> = [
            ("the".into(), 25),
            ("roses".into(), 22),
            ("are".into(), 24),
            ("red".into(), 12),
            ("voilets".into(), 10),
            ("blue".into(), 16),
        ]
        .iter()
        .cloned()
        .collect();

        let mut trainer = WordLevelTrainer {
            vocab_size: 5,
            ..Default::default()
        };

        let mut model = WordLevel::default();
        trainer.do_train(&word_counts, &mut model).unwrap();
        let expected_vocab: HashMap<String, u32> = [
            ("the".into(), 0),
            ("are".into(), 1),
            ("roses".into(), 2),
            ("blue".into(), 3),
            ("red".into(), 4),
        ]
        .iter()
        .cloned()
        .collect();
        assert_eq!(model.vocab, expected_vocab);

        // If we specify a min_frequency
        trainer.min_frequency = 15;
        let mut model = WordLevel::default();
        trainer.do_train(&word_counts, &mut model).unwrap();
        let expected_vocab: HashMap<String, u32> = [
            ("the".into(), 0),
            ("are".into(), 1),
            ("roses".into(), 2),
            ("blue".into(), 3),
        ]
        .iter()
        .cloned()
        .collect();

        assert_eq!(model.vocab, expected_vocab);
    }
}