use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
#[derive(Debug)]
struct PatchEmbed {
    proj: candle_nn::Conv2d,
    span: tracing::Span,
}
impl PatchEmbed {
    fn new(
        in_chans: usize,
        embed_dim: usize,
        k_size: usize,
        stride: usize,
        padding: usize,
        vb: VarBuilder,
    ) -> Result<Self> {
        let cfg = candle_nn::Conv2dConfig {
            stride,
            padding,
            ..Default::default()
        };
        let proj = candle_nn::conv2d(in_chans, embed_dim, k_size, cfg, vb.pp("proj"))?;
        let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
        Ok(Self { proj, span })
    }
}
impl Module for PatchEmbed {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        xs.apply(&self.proj)?.permute((0, 2, 3, 1))
    }
}
struct Add3(usize, usize, usize, usize, usize);
impl candle::CustomOp3 for Add3 {
    fn name(&self) -> &'static str {
        "add3"
    }
    fn cpu_fwd(
        &self,
        s1: &candle::CpuStorage,
        l1: &candle::Layout,
        s2: &candle::CpuStorage,
        l2: &candle::Layout,
        s3: &candle::CpuStorage,
        l3: &candle::Layout,
    ) -> Result<(candle::CpuStorage, candle::Shape)> {
        use rayon::prelude::*;
        let Add3(b, q_h, q_w, k_h, k_w) = *self;
        let s1 = s1.as_slice::<f32>()?;
        let s1 = match l1.contiguous_offsets() {
            None => candle::bail!("input1 has to be contiguous"),
            Some((o1, o2)) => &s1[o1..o2],
        };
        let s2 = s2.as_slice::<f32>()?;
        let s2 = match l2.contiguous_offsets() {
            None => candle::bail!("input2 has to be contiguous"),
            Some((o1, o2)) => &s2[o1..o2],
        };
        let s3 = s3.as_slice::<f32>()?;
        let s3 = match l3.contiguous_offsets() {
            None => candle::bail!("input3 has to be contiguous"),
            Some((o1, o2)) => &s3[o1..o2],
        };
        let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
        dst.par_chunks_exact_mut(k_h * k_w)
            .enumerate()
            .for_each(|(b_idx, dst)| {
                let s1_idx = b_idx * k_h * k_w;
                let s2_idx = b_idx * k_h;
                let s3_idx = b_idx * k_w;
                for h_idx in 0..k_h {
                    let s1_idx = s1_idx + h_idx * k_w;
                    let s2_idx = s2_idx + h_idx;
                    let dst_idx = h_idx * k_w;
                    for w_idx in 0..k_w {
                        let s1_idx = s1_idx + w_idx;
                        let s3_idx = s3_idx + w_idx;
                        let dst_idx = dst_idx + w_idx;
                        dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
                    }
                }
            });
        let dst = candle::WithDType::to_cpu_storage_owned(dst);
        Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
    }
}
#[derive(Debug)]
struct Attention {
    qkv: super::Linear,
    proj: super::Linear,
    num_heads: usize,
    scale: f64,
    rel_pos_hw: Option<(Tensor, Tensor)>,
    span: tracing::Span,
    span_matmul: tracing::Span,
    span_rel_pos: tracing::Span,
    span_softmax: tracing::Span,
}
impl Attention {
    fn new(
        dim: usize,
        num_heads: usize,
        qkv_bias: bool,
        use_rel_pos: bool,
        input_size: (usize, usize),
        vb: VarBuilder,
    ) -> Result<Self> {
        let span = tracing::span!(tracing::Level::TRACE, "attention");
        let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
        let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
        let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
        let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
        let proj = super::linear(vb.pp("proj"), dim, dim, true)?;
        let head_dim = dim / num_heads;
        let scale = 1. / (head_dim as f64).sqrt();
        let rel_pos_hw = if use_rel_pos {
            let h = vb.get((2 * input_size.0 - 1, head_dim), "rel_pos_h")?;
            let w = vb.get((2 * input_size.1 - 1, head_dim), "rel_pos_w")?;
            Some((h, w))
        } else {
            None
        };
        Ok(Self {
            qkv,
            proj,
            num_heads,
            scale,
            rel_pos_hw,
            span,
            span_matmul,
            span_rel_pos,
            span_softmax,
        })
    }
    fn add_decomposed_rel_pos(
        &self,
        attn: Tensor,
        q: &Tensor,
        (q_h, q_w): (usize, usize),
        (k_h, k_w): (usize, usize),
    ) -> Result<Tensor> {
        match &self.rel_pos_hw {
            Some((rel_pos_h, rel_pos_w)) => {
                let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
                let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
                let (b, _, dim) = q.dims3()?;
                let r_q = q.reshape((b, q_h, q_w, dim))?;
                let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
                let rel_w = r_q
                    .transpose(1, 2)? .contiguous()?
                    .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? .transpose(1, 2)?
                    .contiguous()?;
                if attn.device().is_cpu() {
                    let op = Add3(b, q_h, q_w, k_h, k_w);
                    attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
                } else {
                    (attn.reshape((b, q_h, q_w, k_h, k_w))?
                        + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
                    .reshape((b, q_h * q_w, k_h * k_w))
                }
            }
            None => Ok(attn),
        }
    }
}
fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result<Tensor> {
    let max_rel_dist = 2 * usize::max(q_size, k_size) - 1;
    let dev = rel_pos.device();
    let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist {
        todo!("interpolation")
    } else {
        rel_pos
    };
    let q_coords = Tensor::arange(0u32, q_size as u32, dev)?
        .reshape((q_size, 1))?
        .to_dtype(DType::F32)?;
    let k_coords = Tensor::arange(0u32, k_size as u32, dev)?
        .reshape((1, k_size))?
        .to_dtype(DType::F32)?;
    let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?;
    let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?;
    let relative_coords = (q_coords.broadcast_sub(&k_coords)?
        + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?;
    let (d1, d2) = relative_coords.dims2()?;
    let relative_coords = relative_coords.to_dtype(DType::U32)?;
    rel_pos_resized
        .index_select(&relative_coords.reshape(d1 * d2)?, 0)?
        .reshape((d1, d2, ()))
}
impl Module for Attention {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let (b, h, w, c) = xs.dims4()?;
        let qkv = self
            .qkv
            .forward(&xs.flatten_to(1)?)?
            .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))?
            .permute((2, 0, 3, 1, 4))?
            .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?;
        let q = qkv.i(0)?;
        let k = qkv.i(1)?;
        let v = qkv.i(2)?;
        let attn = {
            let _enter = self.span_matmul.enter();
            (&q * self.scale)?.matmul(&k.t()?)?
        };
        let attn = {
            let _enter = self.span_rel_pos.enter();
            self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
        };
        let attn = {
            let _enter = self.span_softmax.enter();
            candle_nn::ops::softmax_last_dim(&attn)?
        };
        let attn = {
            let _enter = self.span_matmul.enter();
            attn.matmul(&v)?
        };
        let attn = attn
            .reshape((b, self.num_heads, h, w, c / self.num_heads))?
            .permute((0, 2, 3, 1, 4))?
            .reshape((b, h * w, c))?;
        self.proj.forward(&attn)?.reshape((b, h, w, c))
    }
}
#[derive(Debug)]
struct Block {
    norm1: LayerNorm,
    attn: Attention,
    norm2: LayerNorm,
    mlp: super::MlpBlock,
    window_size: usize,
    span: tracing::Span,
}
impl Block {
    fn new(
        dim: usize,
        num_heads: usize,
        qkv_bias: bool,
        use_rel_pos: bool,
        window_size: usize,
        input_size: (usize, usize),
        vb: VarBuilder,
    ) -> Result<Self> {
        let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?;
        let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?;
        let input_size_attn = if window_size == 0 {
            input_size
        } else {
            (window_size, window_size)
        };
        let attn = Attention::new(
            dim,
            num_heads,
            qkv_bias,
            use_rel_pos,
            input_size_attn,
            vb.pp("attn"),
        )?;
        let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
        let span = tracing::span!(tracing::Level::TRACE, "ie-block");
        Ok(Self {
            norm1,
            attn,
            norm2,
            mlp,
            window_size,
            span,
        })
    }
}
fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> {
    let (b, h, w, c) = xs.dims4()?;
    let pad_h = (window_size - h % window_size) % window_size;
    let pad_w = (window_size - w % window_size) % window_size;
    let xs = if pad_h > 0 {
        xs.pad_with_zeros(1, 0, pad_h)?
    } else {
        xs
    };
    let xs = if pad_w > 0 {
        xs.pad_with_zeros(2, 0, pad_w)?
    } else {
        xs
    };
    let (h_p, w_p) = (h + pad_h, w + pad_w);
    let windows = xs
        .reshape((
            b,
            h_p / window_size,
            window_size,
            w_p / window_size,
            window_size,
            c,
        ))?
        .transpose(2, 3)?
        .contiguous()?
        .flatten_to(2)?;
    Ok((windows, (h_p, w_p)))
}
fn window_unpartition(
    windows: Tensor,
    window_size: usize,
    (h_p, w_p): (usize, usize),
    (h, w): (usize, usize),
) -> Result<Tensor> {
    let b = windows.dim(0)? / (h_p * w_p / window_size / window_size);
    let xs = windows
        .reshape((
            b,
            h_p / window_size,
            w_p / window_size,
            window_size,
            window_size,
            windows.elem_count() / b / h_p / w_p,
        ))?
        .transpose(2, 3)?
        .contiguous()?
        .reshape((b, h_p, w_p, ()))?;
    let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs };
    let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs };
    Ok(xs)
}
impl Module for Block {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let shortcut = xs;
        let xs = self.norm1.forward(xs)?;
        let hw = (xs.dim(1)?, xs.dim(2)?);
        let (xs, pad_hw) = if self.window_size > 0 {
            window_partition(xs, self.window_size)?
        } else {
            (xs, (0, 0))
        };
        let xs = self.attn.forward(&xs)?;
        let xs = if self.window_size > 0 {
            window_unpartition(xs, self.window_size, pad_hw, hw)?
        } else {
            xs
        };
        let xs = (xs + shortcut)?;
        &xs + xs.apply(&self.norm2)?.apply(&self.mlp)?
    }
}
#[derive(Debug)]
pub struct ImageEncoderViT {
    patch_embed: PatchEmbed,
    blocks: Vec<Block>,
    neck_conv1: candle_nn::Conv2d,
    neck_ln1: super::LayerNorm2d,
    neck_conv2: candle_nn::Conv2d,
    neck_ln2: super::LayerNorm2d,
    pos_embed: Option<Tensor>,
    span: tracing::Span,
}
impl ImageEncoderViT {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        img_size: usize,
        patch_size: usize,
        in_chans: usize,
        embed_dim: usize,
        depth: usize,
        num_heads: usize,
        out_chans: usize,
        qkv_bias: bool,
        use_rel_pos: bool,
        use_abs_pos: bool,
        window_size: usize,
        global_attn_indexes: &[usize],
        vb: VarBuilder,
    ) -> Result<Self> {
        let patch_embed = PatchEmbed::new(
            in_chans,
            embed_dim,
            patch_size,
            patch_size,
            0,
            vb.pp("patch_embed"),
        )?;
        let mut blocks = Vec::with_capacity(depth);
        let vb_b = vb.pp("blocks");
        for i in 0..depth {
            let window_size = if global_attn_indexes.contains(&i) {
                0
            } else {
                window_size
            };
            let block = Block::new(
                embed_dim,
                num_heads,
                qkv_bias,
                use_rel_pos,
                window_size,
                (img_size / patch_size, img_size / patch_size),
                vb_b.pp(i),
            )?;
            blocks.push(block)
        }
        let neck_conv1 = candle_nn::conv2d_no_bias(
            embed_dim,
            out_chans,
            1,
            Default::default(),
            vb.pp("neck.0"),
        )?;
        let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
        let cfg = candle_nn::Conv2dConfig {
            padding: 1,
            ..Default::default()
        };
        let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
        let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
        let pos_embed = if use_abs_pos {
            let p = vb.get(
                (1, img_size / patch_size, img_size / patch_size, embed_dim),
                "pos_embed",
            )?;
            Some(p)
        } else {
            None
        };
        let span = tracing::span!(tracing::Level::TRACE, "image-encoder-vit");
        Ok(Self {
            patch_embed,
            blocks,
            neck_conv1,
            neck_ln1,
            neck_conv2,
            neck_ln2,
            pos_embed,
            span,
        })
    }
}
impl Module for ImageEncoderViT {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let xs = self.patch_embed.forward(xs)?;
        let mut xs = match &self.pos_embed {
            Some(pos_embed) => (xs + pos_embed)?,
            None => xs,
        };
        for block in self.blocks.iter() {
            xs = block.forward(&xs)?
        }
        xs.permute((0, 3, 1, 2))?
            .apply(&self.neck_conv1)?
            .apply(&self.neck_ln1)?
            .apply(&self.neck_conv2)?
            .apply(&self.neck_ln2)
    }
}