1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
use serde::Deserialize;

#[derive(Debug, Clone, Deserialize)]
pub struct Config {
    pub vocab_size: usize,
    pub hidden_size: usize,
    pub encoder_hidden_size: usize,
    pub intermediate_size: usize,
    pub projection_dim: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub max_position_embeddings: usize,
    pub hidden_act: candle_nn::Activation,
    pub layer_norm_eps: f64,
    pub is_decoder: bool,
}

#[derive(Debug, Clone)]
struct TextEmbeddings {
    word_embedddings: Embedding,
    position_embeddings: Embedding,
    layer_norm: LayerNorm,
    position_ids: Tensor,
}

impl TextEmbeddings {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let word_embedddings =
            Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
        let position_embeddings = Embedding::new(
            cfg.max_position_embeddings,
            cfg.hidden_size,
            vb.pp("position_embeddings"),
        )?;
        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
        let position_ids =
            Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
        Ok(Self {
            word_embedddings,
            position_embeddings,
            layer_norm,
            position_ids,
        })
    }

    fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
        let seq_len = xs.dim(1)?;
        let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
        let embeddings = self.word_embedddings.forward(xs)?;
        let position_embeddings = self.position_embeddings.forward(&position_ids)?;
        (embeddings + position_embeddings)?.apply(&self.layer_norm)
    }
}

#[derive(Debug, Clone)]
struct TextSelfAttention {
    query: Linear,
    key: Linear,
    value: Linear,
    attention_head_size: usize,
    num_attention_heads: usize,
    attention_scale: f64,
    kv_cache: Option<(Tensor, Tensor)>,
}

impl TextSelfAttention {
    fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
        let num_attention_heads = cfg.num_attention_heads;
        let attention_head_size = cfg.hidden_size / num_attention_heads;
        let all_head_size = cfg.num_attention_heads * attention_head_size;
        let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
        let in_size = if is_cross_attention {
            cfg.encoder_hidden_size
        } else {
            cfg.hidden_size
        };
        let key = linear(in_size, all_head_size, vb.pp("key"))?;
        let value = linear(in_size, all_head_size, vb.pp("value"))?;
        let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
        Ok(Self {
            query,
            key,
            value,
            attention_head_size,
            num_attention_heads,
            attention_scale,
            kv_cache: None,
        })
    }

    fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
        let (b_size, seq_len, _) = xs.dims3()?;
        xs.reshape((
            b_size,
            seq_len,
            self.num_attention_heads,
            self.attention_head_size,
        ))?
        .permute((0, 2, 1, 3))
    }

    fn reset_kv_cache(&mut self) {
        self.kv_cache = None
    }

    fn forward(
        &mut self,
        xs: &Tensor,
        encoder_hidden_states: Option<&Tensor>,
        attention_mask: Option<&Tensor>,
    ) -> Result<Tensor> {
        let query = self
            .transpose_for_scores(&self.query.forward(xs)?)?
            .contiguous()?;
        let (key, value) = match encoder_hidden_states {
            None => {
                let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
                let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
                let (key, value) = match &self.kv_cache {
                    None => (key, value),
                    Some((prev_key, prev_value)) => {
                        let key = Tensor::cat(&[prev_key, &key], 2)?;
                        let value = Tensor::cat(&[prev_value, &value], 2)?;
                        (key, value)
                    }
                };
                self.kv_cache = Some((key.clone(), value.clone()));
                (key, value)
            }
            Some(xs) => {
                let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
                let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
                // no kv-cache in this case, but the results could probably be memoized.
                (key, value)
            }
        };
        let key = key.contiguous()?;
        let value = value.contiguous()?;
        let attention_scores = query.matmul(&key.t()?)?;
        let attention_scores = (attention_scores * self.attention_scale)?;
        let attention_scores = match attention_mask {
            Some(mask) => attention_scores.broadcast_add(mask)?,
            None => attention_scores,
        };
        let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
        attention_probs
            .matmul(&value)?
            .permute((0, 2, 1, 3))?
            .flatten_from(D::Minus2)
    }
}

#[derive(Debug, Clone)]
struct TextSelfOutput {
    dense: Linear,
    layer_norm: LayerNorm,
}

impl TextSelfOutput {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
        Ok(Self { dense, layer_norm })
    }

    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
        (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
    }
}

#[derive(Debug, Clone)]
struct TextAttention {
    self_: TextSelfAttention,
    output: TextSelfOutput,
}

impl TextAttention {
    fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
        let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
        let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
        Ok(Self { self_, output })
    }

    fn reset_kv_cache(&mut self) {
        self.self_.reset_kv_cache()
    }

    fn forward(
        &mut self,
        xs: &Tensor,
        encoder_hidden_states: Option<&Tensor>,
        attention_mask: Option<&Tensor>,
    ) -> Result<Tensor> {
        let self_outputs = self
            .self_
            .forward(xs, encoder_hidden_states, attention_mask)?;
        self.output.forward(&self_outputs, xs)
    }
}

#[derive(Debug, Clone)]
struct TextIntermediate {
    dense: Linear,
    intermediate_act_fn: candle_nn::Activation,
}

impl TextIntermediate {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
        Ok(Self {
            dense,
            intermediate_act_fn: cfg.hidden_act,
        })
    }
}

impl Module for TextIntermediate {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
    }
}

#[derive(Debug, Clone)]
struct TextOutput {
    dense: Linear,
    layer_norm: LayerNorm,
}

impl TextOutput {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
        Ok(Self { dense, layer_norm })
    }

    fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
        (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
    }
}

#[derive(Debug, Clone)]
struct TextLayer {
    attention: TextAttention,
    cross_attention: Option<TextAttention>,
    intermediate: TextIntermediate,
    output: TextOutput,
}

impl TextLayer {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
        let cross_attention = if cfg.is_decoder {
            Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
        } else {
            None
        };
        let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
        let output = TextOutput::new(cfg, vb.pp("output"))?;
        Ok(Self {
            attention,
            cross_attention,
            intermediate,
            output,
        })
    }

    fn reset_kv_cache(&mut self) {
        self.attention.reset_kv_cache();
        if let Some(ca) = &mut self.cross_attention {
            ca.reset_kv_cache()
        }
    }

    fn forward(
        &mut self,
        xs: &Tensor,
        encoder_hidden_states: &Tensor,
        attention_mask: &Tensor,
    ) -> Result<Tensor> {
        let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
        let attention_output = match &mut self.cross_attention {
            Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
            None => candle::bail!("expected some cross-attn"),
        };
        let intermediate_output = self.intermediate.forward(&attention_output)?;
        self.output.forward(&intermediate_output, &attention_output)
    }
}

#[derive(Debug, Clone)]
struct TextEncoder {
    layers: Vec<TextLayer>,
}

impl TextEncoder {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let vb = vb.pp("layer");
        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
        for i in 0..cfg.num_hidden_layers {
            let layer = TextLayer::new(cfg, vb.pp(i))?;
            layers.push(layer)
        }
        Ok(Self { layers })
    }

    fn reset_kv_cache(&mut self) {
        self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
    }

    fn forward(
        &mut self,
        xs: &Tensor,
        encoder_hidden_states: &Tensor,
        attention_mask: &Tensor,
    ) -> Result<Tensor> {
        let mut xs = xs.clone();
        for layer in self.layers.iter_mut() {
            xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
        }
        Ok(xs)
    }
}

#[derive(Debug, Clone)]
pub struct TextPooler {
    dense: Linear,
}

impl TextPooler {
    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
        Ok(Self { dense })
    }
}

impl Module for TextPooler {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.narrow(D::Minus1, 0, 1)?
            .squeeze(D::Minus1)?
            .apply(&self.dense)?
            .tanh()
    }
}

#[derive(Debug, Clone)]
struct TextPredictionHeadTransform {
    dense: Linear,
    transform_act_fn: candle_nn::Activation,
    layer_norm: LayerNorm,
}

impl TextPredictionHeadTransform {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
        let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
        Ok(Self {
            dense,
            transform_act_fn: cfg.hidden_act,
            layer_norm,
        })
    }
}

impl Module for TextPredictionHeadTransform {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.dense)?
            .apply(&self.transform_act_fn)?
            .apply(&self.layer_norm)
    }
}

#[derive(Debug, Clone)]
struct TextLMPredictionHead {
    transform: TextPredictionHeadTransform,
    decoder: Linear,
}

impl TextLMPredictionHead {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
        let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
        let bias = vb.get(cfg.vocab_size, "bias")?;
        let decoder = Linear::from_weights(weight, Some(bias));
        Ok(Self { transform, decoder })
    }
}

impl Module for TextLMPredictionHead {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.transform)?.apply(&self.decoder)
    }
}

#[derive(Debug, Clone)]
struct TextOnlyMLMHead {
    predictions: TextLMPredictionHead,
}

impl TextOnlyMLMHead {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
        Ok(Self { predictions })
    }
}

impl Module for TextOnlyMLMHead {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        self.predictions.forward(xs)
    }
}

#[derive(Debug, Clone)]
struct TextModel {
    embeddings: TextEmbeddings,
    encoder: TextEncoder,
    past_kv_len: usize,
    // We do not need the pooler for caption generation
}

impl TextModel {
    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
        let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
        Ok(Self {
            embeddings,
            encoder,
            past_kv_len: 0,
        })
    }

    fn forward(
        &mut self,
        input_ids: &Tensor,
        encoder_hidden_states: &Tensor,
        attention_mask: &Tensor,
    ) -> Result<Tensor> {
        let (_b_sz, seq_len) = input_ids.dims2()?;
        let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
        let sequence_output =
            self.encoder
                .forward(&embedding_output, encoder_hidden_states, attention_mask)?;
        self.past_kv_len += seq_len;
        // We're interested in the sequence-output rather than the pooled-output.
        Ok(sequence_output)
    }

    fn reset_kv_cache(&mut self) {
        self.past_kv_len = 0;
        self.encoder.reset_kv_cache();
    }
}

#[derive(Debug, Clone)]
pub struct TextLMHeadModel {
    bert: TextModel,
    cls: TextOnlyMLMHead,
}

impl TextLMHeadModel {
    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let bert = TextModel::new(cfg, vb.pp("bert"))?;
        let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
        Ok(Self { bert, cls })
    }

    pub fn forward(
        &mut self,
        input_ids: &Tensor,
        encoder_hidden_states: &Tensor,
    ) -> Result<Tensor> {
        let seq_len = input_ids.dim(1)?;
        let mask: Vec<_> = (0..seq_len)
            .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
            .collect();
        let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
        let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
        let prediction_scores = self.cls.forward(&sequence_output)?;
        // return_logits is false so we don't discard the last sequence element.
        Ok(prediction_scores)
    }

    pub fn reset_kv_cache(&mut self) {
        self.bert.reset_kv_cache()
    }
}