1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use super::common::{AttnBlock, ResBlock, TimestepBlock};
use candle::{DType, Result, Tensor, D};
use candle_nn::VarBuilder;

#[derive(Debug)]
struct Block {
    res_block: ResBlock,
    ts_block: TimestepBlock,
    attn_block: AttnBlock,
}

#[derive(Debug)]
pub struct WPrior {
    projection: candle_nn::Conv2d,
    cond_mapper_lin1: candle_nn::Linear,
    cond_mapper_lin2: candle_nn::Linear,
    blocks: Vec<Block>,
    out_ln: super::common::WLayerNorm,
    out_conv: candle_nn::Conv2d,
    c_r: usize,
}

impl WPrior {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        c_in: usize,
        c: usize,
        c_cond: usize,
        c_r: usize,
        depth: usize,
        nhead: usize,
        use_flash_attn: bool,
        vb: VarBuilder,
    ) -> Result<Self> {
        let projection = candle_nn::conv2d(c_in, c, 1, Default::default(), vb.pp("projection"))?;
        let cond_mapper_lin1 = candle_nn::linear(c_cond, c, vb.pp("cond_mapper.0"))?;
        let cond_mapper_lin2 = candle_nn::linear(c, c, vb.pp("cond_mapper.2"))?;
        let out_ln = super::common::WLayerNorm::new(c)?;
        let out_conv = candle_nn::conv2d(c, c_in * 2, 1, Default::default(), vb.pp("out.1"))?;
        let mut blocks = Vec::with_capacity(depth);
        for index in 0..depth {
            let res_block = ResBlock::new(c, 0, 3, vb.pp(format!("blocks.{}", 3 * index)))?;
            let ts_block = TimestepBlock::new(c, c_r, vb.pp(format!("blocks.{}", 3 * index + 1)))?;
            let attn_block = AttnBlock::new(
                c,
                c,
                nhead,
                true,
                use_flash_attn,
                vb.pp(format!("blocks.{}", 3 * index + 2)),
            )?;
            blocks.push(Block {
                res_block,
                ts_block,
                attn_block,
            })
        }
        Ok(Self {
            projection,
            cond_mapper_lin1,
            cond_mapper_lin2,
            blocks,
            out_ln,
            out_conv,
            c_r,
        })
    }

    pub fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
        const MAX_POSITIONS: usize = 10000;
        let r = (r * MAX_POSITIONS as f64)?;
        let half_dim = self.c_r / 2;
        let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
        let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
            * -emb)?
            .exp()?;
        let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
        let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
        let emb = if self.c_r % 2 == 1 {
            emb.pad_with_zeros(D::Minus1, 0, 1)?
        } else {
            emb
        };
        emb.to_dtype(r.dtype())
    }

    pub fn forward(&self, xs: &Tensor, r: &Tensor, c: &Tensor) -> Result<Tensor> {
        let x_in = xs;
        let mut xs = xs.apply(&self.projection)?;
        let c_embed = c
            .apply(&self.cond_mapper_lin1)?
            .apply(&|xs: &_| candle_nn::ops::leaky_relu(xs, 0.2))?
            .apply(&self.cond_mapper_lin2)?;
        let r_embed = self.gen_r_embedding(r)?;
        for block in self.blocks.iter() {
            xs = block.res_block.forward(&xs, None)?;
            xs = block.ts_block.forward(&xs, &r_embed)?;
            xs = block.attn_block.forward(&xs, &c_embed)?;
        }
        let ab = xs.apply(&self.out_ln)?.apply(&self.out_conv)?.chunk(2, 1)?;
        (x_in - &ab[0])? / ((&ab[1] - 1.)?.abs()? + 1e-5)
    }
}