use candle::{ModuleT, Result, Tensor};
use candle_nn::{FuncT, VarBuilder};
pub enum Models {
    Vgg13,
    Vgg16,
    Vgg19,
}
#[derive(Debug)]
pub struct Vgg<'a> {
    blocks: Vec<FuncT<'a>>,
}
struct PreLogitConfig {
    in_dim: (usize, usize, usize, usize),
    target_in: usize,
    target_out: usize,
}
impl<'a> Vgg<'a> {
    pub fn new(vb: VarBuilder<'a>, model: Models) -> Result<Self> {
        let blocks = match model {
            Models::Vgg13 => vgg13_blocks(vb)?,
            Models::Vgg16 => vgg16_blocks(vb)?,
            Models::Vgg19 => vgg19_blocks(vb)?,
        };
        Ok(Self { blocks })
    }
}
impl ModuleT for Vgg<'_> {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
        let mut xs = xs.unsqueeze(0)?;
        for block in self.blocks.iter() {
            xs = xs.apply_t(block, train)?;
        }
        Ok(xs)
    }
}
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
    let layers = convs
        .iter()
        .enumerate()
        .map(|(_, &(in_c, out_c, name))| {
            candle_nn::conv2d(
                in_c,
                out_c,
                3,
                candle_nn::Conv2dConfig {
                    stride: 1,
                    padding: 1,
                    ..Default::default()
                },
                vb.pp(name),
            )
        })
        .collect::<Result<Vec<_>>>()?;
    Ok(FuncT::new(move |xs, _train| {
        let mut xs = xs.clone();
        for layer in layers.iter() {
            xs = xs.apply(layer)?.relu()?
        }
        xs = xs.max_pool2d_with_stride(2, 2)?;
        Ok(xs)
    }))
}
fn fully_connected(
    num_classes: usize,
    pre_logit_1: PreLogitConfig,
    pre_logit_2: PreLogitConfig,
    vb: VarBuilder,
) -> Result<FuncT> {
    let lin = get_weights_and_biases(
        &vb.pp("pre_logits.fc1"),
        pre_logit_1.in_dim,
        pre_logit_1.target_in,
        pre_logit_1.target_out,
    )?;
    let lin2 = get_weights_and_biases(
        &vb.pp("pre_logits.fc2"),
        pre_logit_2.in_dim,
        pre_logit_2.target_in,
        pre_logit_2.target_out,
    )?;
    let dropout1 = candle_nn::Dropout::new(0.5);
    let dropout2 = candle_nn::Dropout::new(0.5);
    let dropout3 = candle_nn::Dropout::new(0.5);
    Ok(FuncT::new(move |xs, train| {
        let xs = xs.reshape((1, pre_logit_1.target_out))?;
        let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
        let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
        let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
        let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
        Ok(xs)
    }))
}
fn get_weights_and_biases(
    vs: &VarBuilder,
    in_dim: (usize, usize, usize, usize),
    target_in: usize,
    target_out: usize,
) -> Result<candle_nn::Linear> {
    let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
    let ws = vs.get_with_hints(in_dim, "weight", init_ws)?;
    let ws = ws.reshape((target_in, target_out))?;
    let bound = 1. / (target_out as f64).sqrt();
    let init_bs = candle_nn::Init::Uniform {
        lo: -bound,
        up: bound,
    };
    let bs = vs.get_with_hints(target_in, "bias", init_bs)?;
    Ok(candle_nn::Linear::new(ws, Some(bs)))
}
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
    let num_classes = 1000;
    let blocks = vec![
        conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
        conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
        conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?,
        conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?,
        conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?,
        fully_connected(
            num_classes,
            PreLogitConfig {
                in_dim: (4096, 512, 7, 7),
                target_in: 4096,
                target_out: 512 * 7 * 7,
            },
            PreLogitConfig {
                in_dim: (4096, 4096, 1, 1),
                target_in: 4096,
                target_out: 4096,
            },
            vb.clone(),
        )?,
    ];
    Ok(blocks)
}
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
    let num_classes = 1000;
    let blocks = vec![
        conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
        conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
        conv2d_block(
            &[
                (128, 256, "features.10"),
                (256, 256, "features.12"),
                (256, 256, "features.14"),
            ],
            &vb,
        )?,
        conv2d_block(
            &[
                (256, 512, "features.17"),
                (512, 512, "features.19"),
                (512, 512, "features.21"),
            ],
            &vb,
        )?,
        conv2d_block(
            &[
                (512, 512, "features.24"),
                (512, 512, "features.26"),
                (512, 512, "features.28"),
            ],
            &vb,
        )?,
        fully_connected(
            num_classes,
            PreLogitConfig {
                in_dim: (4096, 512, 7, 7),
                target_in: 4096,
                target_out: 512 * 7 * 7,
            },
            PreLogitConfig {
                in_dim: (4096, 4096, 1, 1),
                target_in: 4096,
                target_out: 4096,
            },
            vb.clone(),
        )?,
    ];
    Ok(blocks)
}
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
    let num_classes = 1000;
    let blocks = vec![
        conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
        conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
        conv2d_block(
            &[
                (128, 256, "features.10"),
                (256, 256, "features.12"),
                (256, 256, "features.14"),
                (256, 256, "features.16"),
            ],
            &vb,
        )?,
        conv2d_block(
            &[
                (256, 512, "features.19"),
                (512, 512, "features.21"),
                (512, 512, "features.23"),
                (512, 512, "features.25"),
            ],
            &vb,
        )?,
        conv2d_block(
            &[
                (512, 512, "features.28"),
                (512, 512, "features.30"),
                (512, 512, "features.32"),
                (512, 512, "features.34"),
            ],
            &vb,
        )?,
        fully_connected(
            num_classes,
            PreLogitConfig {
                in_dim: (4096, 512, 7, 7),
                target_in: 4096,
                target_out: 512 * 7 * 7,
            },
            PreLogitConfig {
                in_dim: (4096, 4096, 1, 1),
                target_in: 4096,
                target_out: 4096,
            },
            vb.clone(),
        )?,
    ];
    Ok(blocks)
}