use crate::models::with_tracing::{linear, Embedding as E, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
    pub(crate) vocab_size: usize,
    pub(crate) n_positions: usize,
    pub(crate) n_embd: usize,
    pub(crate) n_layer: usize,
    pub(crate) n_inner: Option<usize>,
    pub(crate) n_head: usize,
    pub(crate) rotary_dim: usize,
    pub(crate) activation_function: Activation,
    pub(crate) layer_norm_epsilon: f64,
    pub(crate) tie_word_embeddings: bool,
    pub(crate) pad_vocab_size_multiple: usize,
}
impl Config {
    pub fn v1() -> Self {
        Self {
            vocab_size: 50304,
            n_positions: 2048,
            n_embd: 1024,
            n_layer: 20,
            n_inner: None,
            n_head: 16,
            rotary_dim: usize::min(32, 1024 / 16),
            activation_function: Activation::Gelu,
            layer_norm_epsilon: 1e-5,
            tie_word_embeddings: false,
            pad_vocab_size_multiple: 64,
        }
    }
    pub fn v1_5() -> Self {
        Self {
            vocab_size: 51200,
            n_positions: 2048,
            n_embd: 2048,
            n_layer: 24,
            n_inner: None,
            n_head: 32,
            rotary_dim: usize::min(32, 2048 / 32),
            activation_function: Activation::Gelu,
            layer_norm_epsilon: 1e-5,
            tie_word_embeddings: false,
            pad_vocab_size_multiple: 64,
        }
    }
    pub fn v2() -> Self {
        Self {
            vocab_size: 51200,
            n_positions: 2048,
            n_embd: 2560,
            n_layer: 32,
            n_inner: None,
            n_head: 32,
            rotary_dim: usize::min(32, 2560 / 32),
            activation_function: Activation::Gelu,
            layer_norm_epsilon: 1e-5,
            tie_word_embeddings: false,
            pad_vocab_size_multiple: 64,
        }
    }
    pub fn puffin_phi_v2() -> Self {
        Self {
            vocab_size: 50304,
            n_positions: 2048,
            n_embd: 2048,
            n_layer: 24,
            n_inner: None,
            n_head: 32,
            rotary_dim: usize::min(32, 2048 / 32),
            activation_function: Activation::Gelu,
            layer_norm_epsilon: 1e-5,
            tie_word_embeddings: false,
            pad_vocab_size_multiple: 64,
        }
    }
    pub fn phi_hermes_1_3b() -> Self {
        Self {
            vocab_size: 50304,
            n_positions: 2048,
            n_embd: 2048,
            n_layer: 24,
            n_inner: None,
            n_head: 32,
            rotary_dim: usize::min(32, 2048 / 32),
            activation_function: Activation::NewGelu,
            layer_norm_epsilon: 1e-5,
            tie_word_embeddings: false,
            pad_vocab_size_multiple: 64,
        }
    }
}
#[derive(Debug, Clone)]
struct Embedding {
    wte: E,
}
impl Embedding {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
        Ok(Self { wte })
    }
}
impl Module for Embedding {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        self.wte.forward(xs)
    }
}
fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
    let mask: Vec<_> = (0..size)
        .flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
        .collect();
    Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
    sin: Tensor,
    cos: Tensor,
}
impl RotaryEmbedding {
    fn new(dim: usize, max_seq_len: usize, dtype: DType, dev: &Device) -> Result<Self> {
        let inv_freq: Vec<_> = (0..dim)
            .step_by(2)
            .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
            .collect();
        let inv_freq_len = inv_freq.len();
        let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
        let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
            .to_dtype(DType::F32)?
            .reshape((max_seq_len, 1))?;
        let freqs = t.matmul(&inv_freq)?;
        Ok(Self {
            sin: freqs.sin()?.to_dtype(dtype)?,
            cos: freqs.cos()?.to_dtype(dtype)?,
        })
    }
    fn apply_rotary_emb_qkv(
        &self,
        qkv: &Tensor,
        seqlen_offset: usize,
    ) -> Result<(Tensor, Tensor, Tensor)> {
        let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?;
        if three != 3 {
            candle::bail!("unexpected shape for qkv {:?}", qkv.shape())
        }
        let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
        let rotary_dim = rotary_dim * 2;
        let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?.contiguous()?;
        let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
        let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?.contiguous()?;
        let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
        let c = self.cos.narrow(0, seqlen_offset, seqlen)?;
        let s = self.sin.narrow(0, seqlen_offset, seqlen)?;
        let q_rot = candle_nn::rotary_emb::rope_thd(&q_rot, &c, &s)?;
        let k_rot = candle_nn::rotary_emb::rope_thd(&k_rot, &c, &s)?;
        let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
        let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
        let v = qkv.i((.., .., 2))?;
        Ok((q, k, v))
    }
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
    fc1: Linear,
    fc2: Linear,
    act: Activation,
    span: tracing::Span,
}
impl MLP {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
        let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
        let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
        Ok(Self {
            fc1,
            fc2,
            act: cfg.activation_function,
            span: tracing::span!(tracing::Level::TRACE, "mlp"),
        })
    }
}
impl Module for MLP {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
    }
}
#[derive(Debug, Clone)]
struct CausalLMHead {
    ln: candle_nn::LayerNorm,
    linear: Linear,
}
impl CausalLMHead {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
        let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
        Ok(Self { ln, linear })
    }
}
impl Module for CausalLMHead {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.ln)?
            .apply(&self.linear)?
            .to_dtype(DType::F32)
    }
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MHA {
    wqkv: Linear,
    out_proj: Linear,
    rotary_emb: RotaryEmbedding,
    kv_cache: Option<(Tensor, Tensor)>,
    head_dim: usize,
    softmax_scale: f64,
    span: tracing::Span,
    span_rope: tracing::Span,
    span_mask: tracing::Span,
    span_softmax: tracing::Span,
}
impl MHA {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let head_dim = cfg.n_embd / cfg.n_head;
        let op_size = cfg.n_embd;
        let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
        let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
        let rotary_emb =
            RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.dtype(), vb.device())?;
        let softmax_scale = 1f64 / (head_dim as f64).sqrt();
        Ok(Self {
            wqkv,
            out_proj,
            head_dim,
            kv_cache: None,
            rotary_emb,
            softmax_scale,
            span: tracing::span!(tracing::Level::TRACE, "mha"),
            span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
            span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
            span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
        })
    }
    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
        let _enter = self.span.enter();
        let (b_size, seq_len, _n_embd) = xs.dims3()?;
        let qkv = self
            .wqkv
            .forward(xs)?
            .reshape((b_size, seq_len, 3, (), self.head_dim))?;
        let seqlen_offset = match &self.kv_cache {
            None => 0,
            Some((prev_k, _)) => prev_k.dim(1)?,
        };
        let (q, k, v) = {
            let _enter = self.span_rope.enter();
            self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
        };
        let (k, v) = match &self.kv_cache {
            None => (k, v),
            Some((prev_k, prev_v)) => {
                let k = Tensor::cat(&[prev_k, &k], 1)?;
                let v = Tensor::cat(&[prev_v, &v], 1)?;
                (k, v)
            }
        };
        self.kv_cache = Some((k.clone(), v.clone()));
        let q = q.transpose(1, 2)?.flatten_to(1)?; let k = k.transpose(1, 2)?.flatten_to(1)?; let v = v.transpose(1, 2)?.flatten_to(1)?; let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; let attn_weights = match mask {
            None => attn_weights,
            Some(mask) => {
                let _enter = self.span_mask.enter();
                attn_weights.broadcast_add(mask)?
            }
        };
        let attn_weights = {
            let _enter = self.span_softmax.enter();
            candle_nn::ops::softmax_last_dim(&attn_weights)?
        };
        let attn_output = attn_weights.matmul(&v)?;
        let attn_output = attn_output
            .reshape((b_size, (), seq_len, self.head_dim))?
            .transpose(1, 2)?
            .flatten_from(D::Minus2)?;
        attn_output.apply(&self.out_proj)
    }
    fn clear_kv_cache(&mut self) {
        self.kv_cache = None
    }
}
#[derive(Debug, Clone)]
struct ParallelBlock {
    ln: candle_nn::LayerNorm,
    mixer: MHA,
    mlp: MLP,
    span: tracing::Span,
}
impl ParallelBlock {
    fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
        let mixer = MHA::new(cfg, vb.pp("mixer"))?;
        let mlp = MLP::new(cfg, vb.pp("mlp"))?;
        Ok(Self {
            ln,
            mixer,
            mlp,
            span: tracing::span!(tracing::Level::TRACE, "block"),
        })
    }
    fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
        let _enter = self.span.enter();
        let residual = xs;
        let xs = xs.apply(&self.ln)?;
        let attn_outputs = self.mixer.forward(&xs, mask)?;
        let feed_forward_hidden_states = self.mlp.forward(&xs)?;
        attn_outputs + feed_forward_hidden_states + residual
    }
    fn clear_kv_cache(&mut self) {
        self.mixer.clear_kv_cache()
    }
}
#[derive(Debug, Clone)]
pub struct MixFormerSequentialForCausalLM {
    embedding: Embedding,
    blocks: Vec<ParallelBlock>,
    head: CausalLMHead,
    span: tracing::Span,
}
impl MixFormerSequentialForCausalLM {
    pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let vb_head = vb.pp("lm_head");
        let vb = vb.pp("transformer");
        let embedding = Embedding::new(cfg, vb.pp("embd"))?;
        let mut blocks = Vec::new();
        for i in 0..cfg.n_layer {
            let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
            blocks.push(block)
        }
        let head = CausalLMHead::new(cfg, vb_head)?;
        Ok(Self {
            embedding,
            blocks,
            head,
            span: tracing::span!(tracing::Level::TRACE, "mixformer"),
        })
    }
    pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
        let vb = vb.pp("layers");
        let embedding = Embedding::new(cfg, vb.pp(0))?;
        let mut blocks = Vec::new();
        for i in 0..cfg.n_layer {
            let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
            blocks.push(block)
        }
        let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
        Ok(Self {
            embedding,
            blocks,
            head,
            span: tracing::span!(tracing::Level::TRACE, "mixformer"),
        })
    }
    pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let (_b_size, seq_len) = xs.dims2()?;
        let mut xs = xs.apply(&self.embedding)?;
        let mask = if seq_len <= 1 {
            None
        } else {
            Some(get_mask(seq_len, xs.dtype(), xs.device())?)
        };
        for block in self.blocks.iter_mut() {
            xs = block.forward(&xs, mask.as_ref())?
        }
        xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
    }
    pub fn forward_with_img(
        &mut self,
        bos_token: &Tensor,
        xs: &Tensor,
        img_embeds: &Tensor,
    ) -> Result<Tensor> {
        let _enter = self.span.enter();
        let xs = xs.apply(&self.embedding)?;
        let bos_token = bos_token.apply(&self.embedding)?;
        let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
        let (_b_size, seq_len, _embds) = xs.dims3()?;
        let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
        for block in self.blocks.iter_mut() {
            xs = block.forward(&xs, mask.as_ref())?
        }
        let xs = xs
            .narrow(1, seq_len - 1, 1)?
            .apply(&self.head)?
            .squeeze(1)?;
        Ok(xs)
    }
    pub fn clear_kv_cache(&mut self) {
        self.blocks.iter_mut().for_each(|b| b.clear_kv_cache())
    }
}