use candle::{Result, Tensor, D};
use candle_nn::{
    batch_norm, conv2d_no_bias, linear, BatchNorm, Conv2d, Conv2dConfig, Func, VarBuilder,
};
const CHANNELS_PER_STAGE: [usize; 5] = [64, 64, 128, 256, 512];
#[derive(Clone)]
pub struct Config {
    a: f32,
    b: f32,
    groups: usize,
    stages: [usize; 4],
}
impl Config {
    pub fn a0() -> Self {
        Self {
            a: 0.75,
            b: 2.5,
            groups: 1,
            stages: [2, 4, 14, 1],
        }
    }
    pub fn a1() -> Self {
        Self {
            a: 1.0,
            b: 2.5,
            groups: 1,
            stages: [2, 4, 14, 1],
        }
    }
    pub fn a2() -> Self {
        Self {
            a: 1.5,
            b: 2.75,
            groups: 1,
            stages: [2, 4, 14, 1],
        }
    }
    pub fn b0() -> Self {
        Self {
            a: 1.0,
            b: 2.5,
            groups: 1,
            stages: [4, 6, 16, 1],
        }
    }
    pub fn b1() -> Self {
        Self {
            a: 2.0,
            b: 4.0,
            groups: 1,
            stages: [4, 6, 16, 1],
        }
    }
    pub fn b2() -> Self {
        Self {
            a: 2.5,
            b: 5.0,
            groups: 1,
            stages: [4, 6, 16, 1],
        }
    }
    pub fn b3() -> Self {
        Self {
            a: 3.0,
            b: 5.0,
            groups: 1,
            stages: [4, 6, 16, 1],
        }
    }
    pub fn b1g4() -> Self {
        Self {
            a: 2.0,
            b: 4.0,
            groups: 4,
            stages: [4, 6, 16, 1],
        }
    }
    pub fn b2g4() -> Self {
        Self {
            a: 2.5,
            b: 5.0,
            groups: 4,
            stages: [4, 6, 16, 1],
        }
    }
    pub fn b3g4() -> Self {
        Self {
            a: 3.0,
            b: 5.0,
            groups: 4,
            stages: [4, 6, 16, 1],
        }
    }
}
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
    let (gamma, beta) = bn.weight_and_bias().unwrap();
    let mu = bn.running_mean();
    let sigma = (bn.running_var() + bn.eps())?.sqrt();
    let gps = (gamma / sigma)?;
    let bias = (beta - mu * &gps)?;
    let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
    Ok((weights, bias))
}
fn repvgg_layer(
    has_identity: bool,
    dim: usize,
    stride: usize,
    in_channels: usize,
    out_channels: usize,
    groups: usize,
    vb: VarBuilder,
) -> Result<Func<'static>> {
    let conv2d_cfg = Conv2dConfig {
        stride,
        groups,
        padding: 1,
        ..Default::default()
    };
    let conv1x1_bn = batch_norm(dim, 1e-5, vb.pp("conv_1x1.bn"))?;
    let conv1x1 = conv2d_no_bias(
        in_channels,
        out_channels,
        1,
        conv2d_cfg,
        vb.pp("conv_1x1.conv"),
    )?;
    let (mut w1, b1) = fuse_conv_bn(conv1x1.weight(), conv1x1_bn)?;
    w1 = w1.pad_with_zeros(D::Minus1, 1, 1)?;
    w1 = w1.pad_with_zeros(D::Minus2, 1, 1)?;
    let convkxk_bn = batch_norm(dim, 1e-5, vb.pp("conv_kxk.bn"))?;
    let conv3x3 = conv2d_no_bias(
        in_channels,
        out_channels,
        3,
        conv2d_cfg,
        vb.pp("conv_kxk.conv"),
    )?;
    let (w3, b3) = fuse_conv_bn(conv3x3.weight(), convkxk_bn)?;
    let mut w = (w1 + w3)?;
    let mut b = (b1 + b3)?;
    if has_identity {
        let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
        let mut weights: Vec<f32> = vec![0.0; conv3x3.weight().elem_count()];
        let in_dim = in_channels / groups;
        for i in 0..in_channels {
            weights[i * in_dim * 3 * 3 + (i % in_dim) * 3 * 3 + 4] = 1.0;
        }
        let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
        let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
        w = (w + wi)?;
        b = (b + bi)?;
    }
    let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
    Ok(Func::new(move |xs| {
        let xs = xs.apply(&reparam_conv)?.relu()?;
        Ok(xs)
    }))
}
fn output_channels_per_stage(a: f32, b: f32, stage: usize) -> usize {
    let channels = CHANNELS_PER_STAGE[stage] as f32;
    match stage {
        0 => std::cmp::min(64, (channels * a) as usize),
        4 => (channels * b) as usize,
        _ => (channels * a) as usize,
    }
}
fn repvgg_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
    let nlayers = cfg.stages[idx - 1];
    let mut layers = Vec::with_capacity(nlayers);
    let prev_layers: usize = cfg.stages[..idx - 1].iter().sum();
    let out_channels_prev = output_channels_per_stage(cfg.a, cfg.b, idx - 1);
    let out_channels = output_channels_per_stage(cfg.a, cfg.b, idx);
    for layer_idx in 0..nlayers {
        let (has_identity, stride, in_channels) = if layer_idx == 0 {
            (false, 2, out_channels_prev)
        } else {
            (true, 1, out_channels)
        };
        let groups = if (prev_layers + layer_idx) % 2 == 1 {
            cfg.groups
        } else {
            1
        };
        layers.push(repvgg_layer(
            has_identity,
            out_channels,
            stride,
            in_channels,
            out_channels,
            groups,
            vb.pp(layer_idx),
        )?)
    }
    Ok(Func::new(move |xs| {
        let mut xs = xs.clone();
        for layer in layers.iter() {
            xs = xs.apply(layer)?
        }
        Ok(xs)
    }))
}
fn repvgg_model(config: &Config, nclasses: Option<usize>, vb: VarBuilder) -> Result<Func<'static>> {
    let cls = match nclasses {
        None => None,
        Some(nclasses) => {
            let outputs = output_channels_per_stage(config.a, config.b, 4);
            let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
            Some(linear)
        }
    };
    let stem_dim = output_channels_per_stage(config.a, config.b, 0);
    let stem = repvgg_layer(false, stem_dim, 2, 3, stem_dim, 1, vb.pp("stem"))?;
    let vb = vb.pp("stages");
    let stage1 = repvgg_stage(config, 1, vb.pp(0))?;
    let stage2 = repvgg_stage(config, 2, vb.pp(1))?;
    let stage3 = repvgg_stage(config, 3, vb.pp(2))?;
    let stage4 = repvgg_stage(config, 4, vb.pp(3))?;
    Ok(Func::new(move |xs| {
        let xs = xs
            .apply(&stem)?
            .apply(&stage1)?
            .apply(&stage2)?
            .apply(&stage3)?
            .apply(&stage4)?
            .mean(D::Minus1)?
            .mean(D::Minus1)?;
        match &cls {
            None => Ok(xs),
            Some(cls) => xs.apply(cls),
        }
    }))
}
pub fn repvgg(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
    repvgg_model(cfg, Some(nclasses), vb)
}
pub fn repvgg_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
    repvgg_model(cfg, None, vb)
}