1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
//! Ancestral sampling with Euler method steps.
//!
//! Reference implementation in Rust:
//!
//! https://github.com/pykeio/diffusers/blob/250b9ad1898af41e76a74c0d8d4292652823338a/src/schedulers/euler_ancestral_discrete.rs
//!
//! Based on the original [`k-diffusion` implementation by Katherine Crowson][kd].
///
/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
use super::{
    schedulers::{
        betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
        TimestepSpacing,
    },
    utils::interp,
};
use candle::{bail, Error, Result, Tensor};

/// The configuration for the EulerAncestral Discrete scheduler.
#[derive(Debug, Clone, Copy)]
pub struct EulerAncestralDiscreteSchedulerConfig {
    /// The value of beta at the beginning of training.n
    pub beta_start: f64,
    /// The value of beta at the end of training.
    pub beta_end: f64,
    /// How beta evolved during training.
    pub beta_schedule: BetaSchedule,
    /// Adjust the indexes of the inference schedule by this value.
    pub steps_offset: usize,
    /// prediction type of the scheduler function, one of `epsilon` (predicting
    /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
    /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
    pub prediction_type: PredictionType,
    /// number of diffusion steps used to train the model
    pub train_timesteps: usize,
    /// time step spacing for the diffusion process
    pub timestep_spacing: TimestepSpacing,
}

impl Default for EulerAncestralDiscreteSchedulerConfig {
    fn default() -> Self {
        Self {
            beta_start: 0.00085f64,
            beta_end: 0.012f64,
            beta_schedule: BetaSchedule::ScaledLinear,
            steps_offset: 1,
            prediction_type: PredictionType::Epsilon,
            train_timesteps: 1000,
            timestep_spacing: TimestepSpacing::Leading,
        }
    }
}

impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
    fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
        Ok(Box::new(EulerAncestralDiscreteScheduler::new(
            inference_steps,
            *self,
        )?))
    }
}

/// The EulerAncestral Discrete scheduler.
#[derive(Debug, Clone)]
pub struct EulerAncestralDiscreteScheduler {
    timesteps: Vec<usize>,
    sigmas: Vec<f64>,
    init_noise_sigma: f64,
    pub config: EulerAncestralDiscreteSchedulerConfig,
}

// clip_sample: False, set_alpha_to_one: False
impl EulerAncestralDiscreteScheduler {
    /// Creates a new EulerAncestral Discrete scheduler given the number of steps to be
    /// used for inference as well as the number of steps that was used
    /// during training.
    pub fn new(
        inference_steps: usize,
        config: EulerAncestralDiscreteSchedulerConfig,
    ) -> Result<Self> {
        let step_ratio = config.train_timesteps / inference_steps;
        let timesteps: Vec<usize> = match config.timestep_spacing {
            TimestepSpacing::Leading => (0..(inference_steps))
                .map(|s| s * step_ratio + config.steps_offset)
                .rev()
                .collect(),
            TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
                if *n > step_ratio {
                    Some(n - step_ratio)
                } else {
                    None
                }
            })
            .map(|n| n - 1)
            .collect(),
            TimestepSpacing::Linspace => {
                super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
                    .to_vec1::<f64>()?
                    .iter()
                    .map(|&f| f as usize)
                    .rev()
                    .collect()
            }
        };

        let betas = match config.beta_schedule {
            BetaSchedule::ScaledLinear => super::utils::linspace(
                config.beta_start.sqrt(),
                config.beta_end.sqrt(),
                config.train_timesteps,
            )?
            .sqr()?,
            BetaSchedule::Linear => {
                super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
            }
            BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
        };
        let betas = betas.to_vec1::<f64>()?;
        let mut alphas_cumprod = Vec::with_capacity(betas.len());
        for &beta in betas.iter() {
            let alpha = 1.0 - beta;
            alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
        }
        let sigmas: Vec<f64> = alphas_cumprod
            .iter()
            .map(|&f| ((1. - f) / f).sqrt())
            .collect();

        let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect();

        let mut sigmas_int = interp(
            &timesteps.iter().map(|&t| t as f64).collect::<Vec<_>>(),
            &sigmas_xa,
            &sigmas,
        );
        sigmas_int.push(0.0);

        // standard deviation of the initial noise distribution
        // f64 does not implement Ord such that there is no `max`, so we need to use this workaround
        let init_noise_sigma = *sigmas_int
            .iter()
            .chain(std::iter::once(&0.0))
            .reduce(|a, b| if a > b { a } else { b })
            .expect("init_noise_sigma could not be reduced from sigmas - this should never happen");

        Ok(Self {
            sigmas: sigmas_int,
            timesteps,
            init_noise_sigma,
            config,
        })
    }
}

impl Scheduler for EulerAncestralDiscreteScheduler {
    fn timesteps(&self) -> &[usize] {
        self.timesteps.as_slice()
    }

    /// Ensures interchangeability with schedulers that need to scale the denoising model input
    /// depending on the current timestep.
    ///
    /// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
    fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
        let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
            Some(i) => i,
            None => bail!("timestep out of this schedulers bounds: {timestep}"),
        };

        let sigma = self
            .sigmas
            .get(step_index)
            .expect("step_index out of sigma bounds - this shouldn't happen");

        sample / ((sigma.powi(2) + 1.).sqrt())
    }

    /// Performs a backward step during inference.
    fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
        let step_index = self
            .timesteps
            .iter()
            .position(|&p| p == timestep)
            .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;

        let sigma_from = &self.sigmas[step_index];
        let sigma_to = &self.sigmas[step_index + 1];

        // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
        let pred_original_sample = match self.config.prediction_type {
            PredictionType::Epsilon => (sample - (model_output * *sigma_from))?,
            PredictionType::VPrediction => {
                ((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))?
                    + (sample / (sigma_from.powi(2) + 1.0))?)?
            }
            PredictionType::Sample => bail!("prediction_type not implemented yet: sample"),
        };

        let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2))
            / sigma_from.powi(2))
        .sqrt();
        let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt();

        // 2. convert to a ODE derivative
        let derivative = ((sample - pred_original_sample)? / *sigma_from)?;
        let dt = sigma_down - *sigma_from;
        let prev_sample = (sample + derivative * dt)?;

        let noise = prev_sample.randn_like(0.0, 1.0)?;

        prev_sample + noise * sigma_up
    }

    fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
        let step_index = self
            .timesteps
            .iter()
            .position(|&p| p == timestep)
            .ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;

        let sigma = self
            .sigmas
            .get(step_index)
            .expect("step_index out of sigma bounds - this shouldn't happen");

        original + (noise * *sigma)?
    }

    fn init_noise_sigma(&self) -> f64 {
        match self.config.timestep_spacing {
            TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
            TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
        }
    }
}