1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use candle::{Result, Tensor};

#[derive(Debug, Clone)]
pub struct DDPMWSchedulerConfig {
    scaler: f64,
    s: f64,
}

impl Default for DDPMWSchedulerConfig {
    fn default() -> Self {
        Self {
            scaler: 1f64,
            s: 0.008f64,
        }
    }
}

pub struct DDPMWScheduler {
    init_alpha_cumprod: f64,
    init_noise_sigma: f64,
    timesteps: Vec<f64>,
    pub config: DDPMWSchedulerConfig,
}

impl DDPMWScheduler {
    pub fn new(inference_steps: usize, config: DDPMWSchedulerConfig) -> Result<Self> {
        let init_alpha_cumprod = (config.s / (1. + config.s) * std::f64::consts::PI)
            .cos()
            .powi(2);
        let timesteps = (0..=inference_steps)
            .map(|i| 1. - i as f64 / inference_steps as f64)
            .collect::<Vec<_>>();
        Ok(Self {
            init_alpha_cumprod,
            init_noise_sigma: 1.0,
            timesteps,
            config,
        })
    }

    pub fn timesteps(&self) -> &[f64] {
        &self.timesteps
    }

    fn alpha_cumprod(&self, t: f64) -> f64 {
        let scaler = self.config.scaler;
        let s = self.config.s;
        let t = if scaler > 1. {
            1. - (1. - t).powf(scaler)
        } else if scaler < 1. {
            t.powf(scaler)
        } else {
            t
        };
        let alpha_cumprod = ((t + s) / (1. + s) * std::f64::consts::PI * 0.5)
            .cos()
            .powi(2)
            / self.init_alpha_cumprod;
        alpha_cumprod.clamp(0.0001, 0.9999)
    }

    fn previous_timestep(&self, ts: f64) -> f64 {
        let index = self
            .timesteps
            .iter()
            .enumerate()
            .map(|(idx, v)| (idx, (v - ts).abs()))
            .min_by(|x, y| x.1.total_cmp(&y.1))
            .unwrap()
            .0;
        self.timesteps[index + 1]
    }

    ///  Ensures interchangeability with schedulers that need to scale the denoising model input
    /// depending on the current timestep.
    pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
        sample
    }

    pub fn step(&self, model_output: &Tensor, ts: f64, sample: &Tensor) -> Result<Tensor> {
        let prev_t = self.previous_timestep(ts);

        let alpha_cumprod = self.alpha_cumprod(ts);
        let alpha_cumprod_prev = self.alpha_cumprod(prev_t);
        let alpha = alpha_cumprod / alpha_cumprod_prev;

        let mu = (sample - model_output * ((1. - alpha) / (1. - alpha_cumprod).sqrt()))?;
        let mu = (mu * (1. / alpha).sqrt())?;

        let std_noise = mu.randn_like(0., 1.)?;
        let std =
            std_noise * ((1. - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt();
        if prev_t == 0. {
            Ok(mu)
        } else {
            mu + std
        }
    }

    pub fn init_noise_sigma(&self) -> f64 {
        self.init_noise_sigma
    }
}