#![allow(dead_code)]
use candle::{Result, Tensor};
pub trait SchedulerConfig: std::fmt::Debug + Send + Sync {
    fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
}
pub trait Scheduler {
    fn timesteps(&self) -> &[usize];
    fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
    fn init_noise_sigma(&self) -> f64;
    fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
    fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}
#[derive(Debug, Clone, Copy)]
pub enum BetaSchedule {
    Linear,
    ScaledLinear,
    SquaredcosCapV2,
}
#[derive(Debug, Clone, Copy)]
pub enum PredictionType {
    Epsilon,
    VPrediction,
    Sample,
}
#[derive(Debug, Clone, Copy)]
pub enum TimestepSpacing {
    Leading,
    Linspace,
    Trailing,
}
impl Default for TimestepSpacing {
    fn default() -> Self {
        Self::Leading
    }
}
pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Result<Tensor> {
    let alpha_bar = |time_step: usize| {
        f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2)
    };
    let mut betas = Vec::with_capacity(num_diffusion_timesteps);
    for i in 0..num_diffusion_timesteps {
        let t1 = i / num_diffusion_timesteps;
        let t2 = (i + 1) / num_diffusion_timesteps;
        betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta));
    }
    let betas_len = betas.len();
    Tensor::from_vec(betas, betas_len, &candle::Device::Cpu)
}