use candle::{DType, Result, Tensor, D};
use candle_nn::{
    batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
    Func, VarBuilder,
};
struct StageConfig {
    blocks: usize,
    channels: usize,
}
const STAGES: [StageConfig; 5] = [
    StageConfig {
        blocks: 1,
        channels: 64,
    },
    StageConfig {
        blocks: 2,
        channels: 64,
    },
    StageConfig {
        blocks: 8,
        channels: 128,
    },
    StageConfig {
        blocks: 10,
        channels: 256,
    },
    StageConfig {
        blocks: 1,
        channels: 512,
    },
];
#[derive(Clone)]
pub struct Config {
    k: usize,
    alphas: [f32; 5],
}
impl Config {
    pub fn s0() -> Self {
        Self {
            k: 4,
            alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
        }
    }
    pub fn s1() -> Self {
        Self {
            k: 1,
            alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
        }
    }
    pub fn s2() -> Self {
        Self {
            k: 1,
            alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
        }
    }
    pub fn s3() -> Self {
        Self {
            k: 1,
            alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
        }
    }
    pub fn s4() -> Self {
        Self {
            k: 1,
            alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
        }
    }
}
fn squeeze_and_excitation(
    in_channels: usize,
    squeeze_channels: usize,
    vb: VarBuilder,
) -> Result<Func<'static>> {
    let conv2d_cfg = Conv2dConfig {
        ..Default::default()
    };
    let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
    let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
    Ok(Func::new(move |xs| {
        let residual = xs;
        let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
        let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
        residual.broadcast_mul(&xs)
    }))
}
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
    let (gamma, beta) = bn.weight_and_bias().unwrap();
    let mu = bn.running_mean();
    let sigma = (bn.running_var() + bn.eps())?.sqrt();
    let gps = (gamma / sigma)?;
    let bias = (beta - mu * &gps)?;
    let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
    Ok((weights, bias))
}
#[allow(clippy::too_many_arguments)]
fn mobileone_block(
    has_identity: bool,
    k: usize,
    dim: usize,
    stride: usize,
    padding: usize,
    groups: usize,
    kernel: usize,
    in_channels: usize,
    out_channels: usize,
    vb: VarBuilder,
) -> Result<Func<'static>> {
    let conv2d_cfg = Conv2dConfig {
        stride,
        padding,
        groups,
        ..Default::default()
    };
    let mut w = Tensor::zeros(
        (out_channels, in_channels / groups, kernel, kernel),
        DType::F32,
        vb.device(),
    )?;
    let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
    for i in 0..k {
        let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
        let conv_kxk = conv2d_no_bias(
            in_channels,
            out_channels,
            kernel,
            conv2d_cfg,
            vb.pp(format!("conv_kxk.{i}.conv")),
        )?;
        let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
        w = (w + wk)?;
        b = (b + bk)?;
    }
    if kernel > 1 {
        let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
        let conv_scale = conv2d_no_bias(
            in_channels,
            out_channels,
            1,
            conv2d_cfg,
            vb.pp("conv_scale.conv"),
        )?;
        let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
        ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
        ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
        w = (w + ws)?;
        b = (b + bs)?;
    }
    let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
    if has_identity {
        let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
        let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
        let id = in_channels / groups;
        for i in 0..in_channels {
            if kernel > 1 {
                weights[i * kernel * kernel + 4] = 1.0;
            } else {
                weights[i * (id + 1)] = 1.0;
            }
        }
        let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
        let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
        w = (w + wi)?;
        b = (b + bi)?;
    }
    let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
    Ok(Func::new(move |xs| {
        let mut xs = xs.apply(&reparam_conv)?;
        if let Ok(f) = &se {
            xs = xs.apply(f)?;
        }
        xs = xs.relu()?;
        Ok(xs)
    }))
}
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
    let channels = STAGES[stage].channels as f32;
    let alpha = cfg.alphas[stage];
    match stage {
        0 => std::cmp::min(64, (channels * alpha) as usize),
        _ => (channels * alpha) as usize,
    }
}
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
    let nblocks = STAGES[idx].blocks;
    let mut blocks = Vec::with_capacity(nblocks);
    let mut in_channels = output_channels_per_stage(cfg, idx - 1);
    for block_idx in 0..nblocks {
        let out_channels = output_channels_per_stage(cfg, idx);
        let (has_identity, stride) = if block_idx == 0 {
            (false, 2)
        } else {
            (true, 1)
        };
        blocks.push(mobileone_block(
            has_identity,
            cfg.k,
            in_channels,
            stride,
            1,
            in_channels,
            3,
            in_channels,
            in_channels,
            vb.pp(block_idx * 2),
        )?);
        blocks.push(mobileone_block(
            has_identity,
            cfg.k,
            out_channels,
            1, 0, 1, 1, in_channels,
            out_channels,
            vb.pp(block_idx * 2 + 1),
        )?);
        in_channels = out_channels;
    }
    Ok(Func::new(move |xs| {
        let mut xs = xs.clone();
        for block in blocks.iter() {
            xs = xs.apply(block)?
        }
        Ok(xs)
    }))
}
fn mobileone_model(
    config: &Config,
    nclasses: Option<usize>,
    vb: VarBuilder,
) -> Result<Func<'static>> {
    let cls = match nclasses {
        None => None,
        Some(nclasses) => {
            let outputs = output_channels_per_stage(config, 4);
            let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
            Some(linear)
        }
    };
    let stem_dim = output_channels_per_stage(config, 0);
    let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
    let vb = vb.pp("stages");
    let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
    let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
    let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
    let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
    Ok(Func::new(move |xs| {
        let xs = xs
            .apply(&stem)?
            .apply(&stage1)?
            .apply(&stage2)?
            .apply(&stage3)?
            .apply(&stage4)?
            .mean(D::Minus2)?
            .mean(D::Minus1)?;
        match &cls {
            None => Ok(xs),
            Some(cls) => xs.apply(cls),
        }
    }))
}
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
    mobileone_model(cfg, Some(nclasses), vb)
}
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
    mobileone_model(cfg, None, vb)
}