use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
    let img = img.unsqueeze(dim + 1)?;
    let mut dims = img.dims().to_vec();
    dims[dim + 1] = repeats;
    img.broadcast_as(dims)?.flatten(dim, dim + 1)
}
pub mod speaker_encoder {
    use super::*;
    #[derive(Debug, Clone, serde::Deserialize)]
    pub struct Config {
        pub sampling_rate: usize,
        pub partial_n_frames: usize,
        pub model_hidden_size: usize,
        pub model_embedding_size: usize,
        pub model_num_layers: usize,
        pub mel_window_length: usize,
        pub mel_window_step: usize,
        pub mel_n_channels: usize,
    }
    impl Config {
        pub fn cfg() -> Self {
            Self {
                sampling_rate: 16_000,
                partial_n_frames: 160,
                model_hidden_size: 256,
                model_embedding_size: 256,
                model_num_layers: 3,
                mel_window_length: 25,
                mel_window_step: 10,
                mel_n_channels: 40,
            }
        }
    }
    pub struct Model {
        lstms: Vec<candle_nn::LSTM>,
        linear: Linear,
        cfg: Config,
    }
    type Slice = (usize, usize);
    impl Model {
        pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
            let mut lstms = Vec::with_capacity(cfg.model_num_layers);
            let vb_l = vb.pp("lstm");
            for layer_idx in 0..cfg.model_num_layers {
                let c = candle_nn::LSTMConfig {
                    layer_idx,
                    ..Default::default()
                };
                let lstm = candle_nn::lstm(
                    cfg.mel_n_channels,
                    cfg.model_hidden_size,
                    c,
                    vb_l.pp(layer_idx),
                )?;
                lstms.push(lstm)
            }
            let linear = linear_b(
                cfg.model_hidden_size,
                cfg.model_embedding_size,
                true,
                vb.pp("linear"),
            )?;
            Ok(Self { lstms, linear, cfg })
        }
        fn compute_partial_slices(
            &self,
            n_samples: usize,
            rate: f64,
            min_coverage: f64,
        ) -> (Vec<Slice>, Vec<Slice>) {
            let c = &self.cfg;
            let samples_per_frame = c.sampling_rate * c.mel_window_step / 1000;
            let n_frames = n_samples / samples_per_frame + 1;
            let frame_step =
                (c.sampling_rate as f64 / rate / samples_per_frame as f64).round() as usize;
            let steps = (n_frames + frame_step).saturating_sub(c.partial_n_frames) + 1;
            let mut wav_slices = vec![];
            let mut mel_slices = vec![];
            for i in (0..steps).step_by(frame_step) {
                let mel_range = (i, i + c.partial_n_frames);
                let wav_range = (
                    i * samples_per_frame,
                    (i + c.partial_n_frames) * samples_per_frame,
                );
                mel_slices.push(mel_range);
                wav_slices.push(wav_range);
            }
            let last_wav_range = match wav_slices.last() {
                None => return (wav_slices, mel_slices),
                Some(l) => *l,
            };
            let coverage = (n_samples - last_wav_range.0) as f64
                / (last_wav_range.1 - last_wav_range.0) as f64;
            if coverage > min_coverage && mel_slices.len() > 1 {
                mel_slices.pop();
                wav_slices.pop();
            }
            (wav_slices, mel_slices)
        }
        pub fn embed_utterance(
            &self,
            wav: &[f32],
            mel_filters: &[f32],
            rate: f64,
            min_c: f64,
            device: &Device,
        ) -> Result<Tensor> {
            let (wav_slices, mel_slices) = self.compute_partial_slices(wav.len(), rate, min_c);
            let max_wave_length = match wav_slices.last() {
                Some(v) => v.1,
                None => candle::bail!("empty wav slices"),
            };
            let wav = if max_wave_length > wav.len() {
                let mut wav = wav.to_vec();
                wav.resize(max_wave_length - wav.len(), 0.0);
                std::borrow::Cow::Owned(wav)
            } else {
                std::borrow::Cow::Borrowed(wav)
            };
            let mel = crate::models::whisper::audio::log_mel_spectrogram_(
                wav.as_ref(),
                mel_filters,
                self.cfg.mel_window_length,
                self.cfg.mel_window_step,
                self.cfg.mel_n_channels,
                false,
            );
            let mels = mel_slices
                .iter()
                .flat_map(|s| [mel[s.0], mel[s.1]])
                .collect::<Vec<_>>();
            let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
            let partial_embeds = self.forward(&mels)?;
            let raw_embed = partial_embeds.mean(0)?;
            let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
            raw_embed.broadcast_div(&norm)
        }
    }
    impl Module for Model {
        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
            use candle_nn::RNN;
            let xs = xs.t()?;
            let mut xs = xs.clone();
            for layer in self.lstms.iter() {
                let states = layer.seq(&xs)?;
                xs = layer.states_to_tensor(&states)?;
            }
            let xs = xs.t()?;
            let embeds_raw = xs.apply(&self.linear)?.relu()?;
            let norm = embeds_raw.sqr()?.sum_keepdim(1)?.sqrt()?;
            embeds_raw.broadcast_div(&norm)
        }
    }
}
type Rank = u32;
pub mod tokenizers {
    use super::*;
    use std::collections::HashMap;
    pub struct BPE {
        pub re: fancy_regex::Regex,
        pub end_of_text: usize,
        pub offset: usize,
        pub ranks: HashMap<Vec<u8>, Rank>,
        span: tracing::Span,
    }
    impl BPE {
        pub fn from_json(json: &serde_json::Value, end_of_text: usize) -> Result<Self> {
            let json = match json.as_object() {
                None => candle::bail!("json value is not an object"),
                Some(json) => json,
            };
            let re = match json.get("pat_str") {
                None => candle::bail!("json object has no pat_str field"),
                Some(pat_str) => match pat_str.as_str() {
                    None => candle::bail!("pat_str field is not a string"),
                    Some(pat_str) => fancy_regex::Regex::new(pat_str).map_err(E::wrap)?,
                },
            };
            let offset = match json.get("offset") {
                None => candle::bail!("json object has no offset field"),
                Some(offset) => match offset.as_u64() {
                    None => candle::bail!("offset field is not a positive int"),
                    Some(offset) => offset as usize,
                },
            };
            let mut ranks = HashMap::new();
            for id in 0u8..=255 {
                ranks.insert(vec![id], id as u32);
            }
            let mergeable_ranks = match json.get("mergeable_ranks") {
                None => candle::bail!("json object has no mergeable_ranks field"),
                Some(mr) => match mr.as_object() {
                    None => candle::bail!("mergeable_ranks is not an object"),
                    Some(mr) => mr,
                },
            };
            for (key, value) in mergeable_ranks.iter() {
                let value = match value.as_u64() {
                    None => candle::bail!("mergeable_ranks '{key}' is not a u64"),
                    Some(value) => value as u32,
                };
                if value < 256 {
                    continue;
                }
                let key = key.as_bytes().to_vec();
                ranks.insert(key, value);
            }
            Ok(Self {
                re,
                end_of_text,
                offset,
                ranks,
                span: tracing::span!(tracing::Level::TRACE, "bpe"),
            })
        }
        fn _byte_pair_merge(&self, piece: &[u8]) -> Vec<(usize, Rank)> {
            let mut parts = Vec::with_capacity(piece.len() + 1);
            let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
            for i in 0..piece.len() - 1 {
                let rank = *self.ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX);
                if rank < min_rank.0 {
                    min_rank = (rank, i);
                }
                parts.push((i, rank));
            }
            parts.push((piece.len() - 1, Rank::MAX));
            parts.push((piece.len(), Rank::MAX));
            let get_rank = {
                #[inline(always)]
                |parts: &Vec<(usize, Rank)>, i: usize| {
                    if (i + 3) < parts.len() {
                        *self
                            .ranks
                            .get(&piece[parts[i].0..parts[i + 3].0])
                            .unwrap_or(&Rank::MAX)
                    } else {
                        Rank::MAX
                    }
                }
            };
            while min_rank.0 != Rank::MAX {
                let i = min_rank.1;
                if i > 0 {
                    parts[i - 1].1 = get_rank(&parts, i - 1);
                }
                parts[i].1 = get_rank(&parts, i);
                parts.remove(i + 1);
                min_rank = (Rank::MAX, usize::MAX);
                for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
                    if rank < min_rank.0 {
                        min_rank = (rank, i);
                    }
                }
            }
            parts
        }
        pub fn byte_pair_encode(&self, piece: &[u8]) -> Vec<Rank> {
            if piece.is_empty() {
                return Vec::new();
            }
            if piece.len() == 1 {
                return vec![self.ranks[piece]];
            }
            assert!(piece.len() > 1);
            self._byte_pair_merge(piece)
                .windows(2)
                .map(|part| self.ranks[&piece[part[0].0..part[1].0]])
                .collect()
        }
        pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
            let _enter = self.span.enter();
            let mut bpe_tokens: Vec<u32> = Vec::new();
            for word in self.re.find_iter(text) {
                let word = word.map_err(E::wrap)?;
                let word_tokens = self.byte_pair_encode(word.as_str().as_bytes());
                for &token in word_tokens.iter() {
                    bpe_tokens.push(token + self.offset as u32)
                }
            }
            bpe_tokens.push((self.end_of_text + self.offset) as u32);
            Ok(bpe_tokens)
        }
    }
}
pub mod gpt {
    use super::*;
    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
    pub enum NormType {
        LayerNorm,
        RMSNorm,
    }
    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
    pub enum AttnKernelType {
        Fa2,
        TorchAttn,
        Hand,
    }
    #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
    pub enum NonLinearityType {
        Gelu,
        Swiglu,
    }
    enum Norm {
        RMSNorm(candle_nn::RmsNorm),
        LayerNorm(candle_nn::LayerNorm),
    }
    #[derive(Debug, Clone)]
    pub struct Config {
        pub block_size: usize,
        pub vocab_sizes: Vec<usize>,
        pub target_vocab_sizes: Vec<usize>,
        pub n_layer: usize,
        pub n_head: usize,
        pub n_embd: usize,
        pub bias: bool,
        pub causal: bool,
        pub spk_emb_on_text: bool,
        pub norm_type: NormType,
        pub rmsnorm_eps: f64,
        pub nonlinearity_type: NonLinearityType,
        pub swiglu_multiple_of: Option<usize>,
        pub attn_kernel_type: AttnKernelType,
        pub kv_cache_enabled: bool,
    }
    impl Config {
        pub fn cfg1b_v0_1() -> Self {
            Self {
                n_layer: 6,
                n_head: 6,
                n_embd: 384,
                block_size: 1024,
                bias: false,
                vocab_sizes: vec![1538, 1025],
                causal: false,
                target_vocab_sizes: vec![1025, 1025, 1025, 1025, 1025, 1025],
                swiglu_multiple_of: Some(256),
                norm_type: NormType::LayerNorm,
                kv_cache_enabled: false,
                attn_kernel_type: AttnKernelType::TorchAttn,
                spk_emb_on_text: true,
                nonlinearity_type: NonLinearityType::Gelu,
                rmsnorm_eps: 1e-5,
            }
        }
    }
    impl Norm {
        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            match cfg.norm_type {
                NormType::RMSNorm => {
                    let rms_norm = candle_nn::rms_norm(cfg.n_embd, cfg.rmsnorm_eps, vb)?;
                    Ok(Self::RMSNorm(rms_norm))
                }
                NormType::LayerNorm => {
                    let ln_cfg = candle_nn::LayerNormConfig {
                        affine: cfg.bias,
                        ..Default::default()
                    };
                    let layer_norm = candle_nn::layer_norm(cfg.n_embd, ln_cfg, vb)?;
                    Ok(Self::LayerNorm(layer_norm))
                }
            }
        }
    }
    impl Module for Norm {
        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
            match self {
                Self::RMSNorm(m) => m.forward(xs),
                Self::LayerNorm(m) => m.forward(xs),
            }
        }
    }
    struct SelfAttention {
        c_attn: Linear,
        c_proj: Linear,
        n_head: usize,
        span: tracing::Span,
    }
    impl SelfAttention {
        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            if cfg.attn_kernel_type != AttnKernelType::TorchAttn {
                candle::bail!("only TorchAttn is supported")
            }
            if cfg.kv_cache_enabled {
                candle::bail!("kv_cache_enabled=true is not supported")
            }
            let c_attn = linear_b(cfg.n_embd, cfg.n_embd * 3, cfg.bias, vb.pp("c_attn"))?;
            let c_proj = linear_b(cfg.n_embd, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
            Ok(Self {
                c_attn,
                c_proj,
                n_head: cfg.n_head,
                span: tracing::span!(tracing::Level::TRACE, "self-attn"),
            })
        }
    }
    impl Module for SelfAttention {
        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
            let _enter = self.span.enter();
            let (b, t, c) = xs.dims3()?;
            let c_x = xs
                .apply(&self.c_attn)?
                .reshape((b, t, 3, self.n_head, c / self.n_head))?;
            let q = c_x.i((.., .., 0))?;
            let k = c_x.i((.., .., 1))?;
            let v = c_x.i((.., .., 2))?;
            let q = q.transpose(1, 2)?.contiguous()?;
            let k = k.transpose(1, 2)?.contiguous()?;
            let v = v.transpose(1, 2)?.contiguous()?;
            let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
            let att = candle_nn::ops::softmax_last_dim(&att)?;
            let att = att.matmul(&v)?.transpose(1, 2)?;
            att.reshape((b, t, c))?.apply(&self.c_proj)
        }
    }
    #[allow(clippy::upper_case_acronyms)]
    enum MLP {
        Gelu {
            c_fc: Linear,
            c_proj: Linear,
            span: tracing::Span,
        },
        Swiglu {
            w1: Linear,
            w3: Linear,
            c_proj: Linear,
            span: tracing::Span,
        },
    }
    impl MLP {
        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            let hidden_dim = 4 * cfg.n_embd;
            let slf = match cfg.nonlinearity_type {
                NonLinearityType::Gelu => {
                    let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
                    let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
                    Self::Gelu {
                        c_fc,
                        c_proj,
                        span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"),
                    }
                }
                NonLinearityType::Swiglu => {
                    let hidden_dim = (2 * hidden_dim) / 3;
                    let swiglu_multiple_of = match cfg.swiglu_multiple_of {
                        None => candle::bail!("swiglu-multiple-of has to be set"),
                        Some(smo) => smo,
                    };
                    let hidden_dim = swiglu_multiple_of * (hidden_dim + swiglu_multiple_of - 1)
                        / swiglu_multiple_of;
                    let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
                    let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
                    let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
                    Self::Swiglu {
                        w1,
                        w3,
                        c_proj,
                        span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"),
                    }
                }
            };
            Ok(slf)
        }
    }
    impl Module for MLP {
        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
            match self {
                Self::Gelu { c_fc, c_proj, span } => {
                    let _enter = span.enter();
                    xs.apply(c_fc)?.gelu()?.apply(c_proj)
                }
                Self::Swiglu {
                    w1,
                    w3,
                    c_proj,
                    span,
                } => {
                    let _enter = span.enter();
                    let w1 = xs.apply(w1)?;
                    let w3 = xs.apply(w3)?;
                    (w1.silu()? * w3)?.apply(c_proj)
                }
            }
        }
    }
    struct Block {
        ln_1: Norm,
        ln_2: Norm,
        attn: SelfAttention,
        mlp: MLP,
        span: tracing::Span,
    }
    impl Block {
        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            let ln_1 = Norm::new(cfg, vb.pp("ln_1"))?;
            let ln_2 = Norm::new(cfg, vb.pp("ln_2"))?;
            let attn = SelfAttention::new(cfg, vb.pp("attn"))?;
            let mlp = MLP::new(cfg, vb.pp("mlp"))?;
            Ok(Block {
                ln_1,
                ln_2,
                attn,
                mlp,
                span: tracing::span!(tracing::Level::TRACE, "gpt-block"),
            })
        }
    }
    impl Module for Block {
        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
            let _enter = self.span.enter();
            let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
            let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
            Ok(xs)
        }
    }
    #[allow(clippy::upper_case_acronyms)]
    pub struct Model {
        wtes: Vec<candle_nn::Embedding>,
        wpe: candle_nn::Embedding,
        h: Vec<Block>,
        ln_f: Norm,
        lm_heads: Vec<Linear>,
        cfg: Config,
        dtype: DType,
        span: tracing::Span,
    }
    impl Model {
        pub fn new(cfg: Config, vb: VarBuilder) -> Result<Self> {
            let vb_t = vb.pp("transformer");
            let ln_f = Norm::new(&cfg, vb_t.pp("ln_f"))?;
            let mut wtes = Vec::with_capacity(cfg.vocab_sizes.len());
            let vb_w = vb_t.pp("wtes");
            for (idx, vocab_size) in cfg.vocab_sizes.iter().enumerate() {
                let wte = candle_nn::embedding(*vocab_size, cfg.n_embd, vb_w.pp(idx))?;
                wtes.push(wte)
            }
            let wpe = candle_nn::embedding(cfg.block_size, cfg.n_embd, vb_t.pp("wpe"))?;
            let mut h = Vec::with_capacity(cfg.n_layer);
            let vb_h = vb_t.pp("h");
            for idx in 0..cfg.n_layer {
                let block = Block::new(&cfg, vb_h.pp(idx))?;
                h.push(block)
            }
            let mut lm_heads = Vec::with_capacity(cfg.target_vocab_sizes.len());
            let vb_l = vb.pp("lm_heads");
            for (idx, vocab_size) in cfg.target_vocab_sizes.iter().enumerate() {
                let head = linear_b(cfg.n_embd, *vocab_size, false, vb_l.pp(idx))?;
                lm_heads.push(head)
            }
            Ok(Self {
                wtes,
                wpe,
                h,
                ln_f,
                lm_heads,
                cfg,
                dtype: vb.dtype(),
                span: tracing::span!(tracing::Level::TRACE, "gpt"),
            })
        }
        pub fn config(&self) -> &Config {
            &self.cfg
        }
        pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
            let _enter = self.span.enter();
            let device = idx.device();
            let (b, _num_hierarchies, t) = idx.dims3()?;
            let pos = Tensor::arange(0u32, t as u32, device)?;
            let pos_emb = pos.apply(&self.wpe)?;
            let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
            for (wte_idx, wte) in self.wtes.iter().enumerate() {
                let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
                tok_emb = (tok_emb + emb)?;
            }
            let spk_emb = 0f64;
            let mut xs = (pos_emb.broadcast_add(&tok_emb)? + spk_emb)?;
            for block in self.h.iter() {
                xs = xs.apply(block)?
            }
            let xs = xs.apply(&self.ln_f)?;
            let mut logits = Vec::with_capacity(self.lm_heads.len());
            for lm_head in self.lm_heads.iter() {
                let ys = xs.apply(lm_head)?;
                logits.push(ys)
            }
            Ok(logits)
        }
    }
}
pub mod transformer {
    use super::*;
    #[derive(Debug, Clone, serde::Deserialize)]
    pub struct Config {
        pub block_size: usize,
        pub vocab_size: usize,
        pub n_layer: usize,
        pub n_head: usize,
        pub dim: usize,
        pub speaker_emb_dim: usize,
        pub intermediate_size: Option<usize>,
        pub n_local_heads: Option<usize>,
        pub norm_eps: f64,
    }
    impl Config {
        pub fn cfg1b_v0_1() -> Self {
            Self {
                n_layer: 24,
                n_head: 16,
                dim: 2048,
                vocab_size: 2562,
                speaker_emb_dim: 256,
                block_size: 2048,
                intermediate_size: None,
                n_local_heads: None,
                norm_eps: 1e-5,
            }
        }
        pub(crate) fn n_local_heads(&self) -> usize {
            self.n_local_heads.unwrap_or(self.n_head)
        }
        pub(crate) fn head_dim(&self) -> usize {
            self.dim / self.n_head
        }
        pub(crate) fn intermediate_size(&self) -> usize {
            match self.intermediate_size {
                Some(intermediate_size) => intermediate_size,
                None => {
                    let hidden_dim = self.dim * 4;
                    let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
                    (n_hidden + 255) / 256 * 256
                }
            }
        }
    }
    #[derive(Debug, Clone)]
    struct FeedForward {
        w1: Linear,
        w2: Linear,
        w3: Linear,
        span: tracing::Span,
    }
    impl FeedForward {
        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            let i_size = cfg.intermediate_size();
            let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
            let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
            let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
            Ok(Self {
                w1,
                w2,
                w3,
                span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
            })
        }
    }
    impl Module for FeedForward {
        fn forward(&self, xs: &Tensor) -> Result<Tensor> {
            let _enter = self.span.enter();
            let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
            swiglu.apply(&self.w2)
        }
    }
    #[derive(Debug, Clone)]
    struct Attention {
        wqkv: Linear,
        wo: Linear,
        dim: usize,
        kv_size: usize,
        n_local_heads: usize,
        head_dim: usize,
        n_head: usize,
        kv_cache: Option<(Tensor, Tensor)>,
        span: tracing::Span,
    }
    impl Attention {
        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            let n_local_heads = cfg.n_local_heads();
            let head_dim = cfg.head_dim();
            let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
            let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
            let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
            Ok(Self {
                wqkv,
                wo,
                dim: cfg.dim,
                kv_size: n_local_heads * head_dim,
                n_local_heads,
                head_dim,
                n_head: cfg.n_head,
                kv_cache: None,
                span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
            })
        }
        fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
            let _enter = self.span.enter();
            let (b_sz, seqlen, _) = xs.dims3()?;
            let qkv = xs.apply(&self.wqkv)?;
            let q = qkv.narrow(D::Minus1, 0, self.dim)?;
            let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
            let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
            let q = q
                .reshape((b_sz, seqlen, self.n_head, self.head_dim))?
                .transpose(1, 2)?
                .contiguous()?;
            let k = k
                .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
                .transpose(1, 2)?;
            let v = v
                .reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
                .transpose(1, 2)?;
            let (k, v) = match &self.kv_cache {
                None => (k, v),
                Some((prev_k, prev_v)) => {
                    let k = Tensor::cat(&[prev_k, &k], 2)?;
                    let v = Tensor::cat(&[prev_v, &v], 2)?;
                    (k, v)
                }
            };
            self.kv_cache = Some((k.clone(), v.clone()));
            let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
            let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
            let scale = 1f64 / f64::sqrt(self.head_dim as f64);
            let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
            let attn_weights = attn_weights.broadcast_add(mask)?;
            let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
            let attn_output = attn_weights.matmul(&v)?;
            attn_output
                .transpose(1, 2)?
                .reshape((b_sz, seqlen, self.dim))?
                .apply(&self.wo)
        }
        fn clear_kv_cache(&mut self) {
            self.kv_cache = None
        }
    }
    #[derive(Debug, Clone)]
    struct Block {
        attention: Attention,
        feed_forward: FeedForward,
        ffn_norm: RmsNorm,
        attention_norm: RmsNorm,
        span: tracing::Span,
    }
    impl Block {
        fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            let attention = Attention::new(cfg, vb.pp("attention"))?;
            let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
            let ffn_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
            let attention_norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
            Ok(Self {
                attention,
                feed_forward,
                ffn_norm,
                attention_norm,
                span: tracing::span!(tracing::Level::TRACE, "block"),
            })
        }
        fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
            let _enter = self.span.enter();
            let hs = xs.apply(&self.attention_norm)?;
            let hs = (xs + self.attention.forward(&hs, pos, mask))?;
            &hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
        }
        fn clear_kv_cache(&mut self) {
            self.attention.clear_kv_cache()
        }
    }
    #[derive(Debug, Clone)]
    pub struct Model {
        tok_embeddings: Embedding,
        pos_embeddings: Embedding,
        speaker_cond_pos: Linear,
        layers: Vec<Block>,
        norm: RmsNorm,
        output: Linear,
        spk_cond_mask: Tensor,
        span: tracing::Span,
    }
    impl Model {
        pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
            let tok_embeddings = embedding(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
            let pos_embeddings = embedding(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
            let speaker_cond_pos = linear_b(
                cfg.speaker_emb_dim,
                cfg.dim,
                false,
                vb.pp("speaker_cond_pos"),
            )?;
            let mut layers = Vec::with_capacity(cfg.n_layer);
            let vb_l = vb.pp("layers");
            for layer_idx in 0..cfg.n_layer {
                let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
                layers.push(layer)
            }
            let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
            let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
            let dtype = vb.dtype();
            let spk_cond_mask = Tensor::cat(
                &[
                    Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
                    Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
                ],
                0,
            )?;
            Ok(Self {
                tok_embeddings,
                pos_embeddings,
                speaker_cond_pos,
                layers,
                norm,
                output,
                spk_cond_mask,
                span: tracing::span!(tracing::Level::TRACE, "transformer"),
            })
        }
        pub fn clear_kv_cache(&mut self) {
            for layer in self.layers.iter_mut() {
                layer.clear_kv_cache()
            }
        }
        pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
            let _enter = self.span.enter();
            let (_b_sz, seqlen) = xs.dims2()?;
            let mask: Vec<_> = (0..seqlen)
                .flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
                .collect();
            let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
            let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
            let tok_embeddings = xs.apply(&self.tok_embeddings)?;
            let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
            let mut xs = tok_embeddings
                .broadcast_add(&pos_embeddings)?
                .broadcast_add(
                    &spk_emb
                        .apply(&self.speaker_cond_pos)?
                        .broadcast_mul(&self.spk_cond_mask)?,
                )?;
            let mask = mask.to_dtype(xs.dtype())?;
            for layer in self.layers.iter_mut() {
                xs = layer.forward(&xs, pos, &mask)?
            }
            xs.narrow(1, seqlen - 1, 1)?
                .apply(&self.norm)?
                .apply(&self.output)
        }
    }
}
pub mod adapters {
    pub struct TiltedEncodec {
        end_of_audio_token: u32,
        span: tracing::Span,
    }
    impl TiltedEncodec {
        pub fn new(end_of_audio_token: u32) -> Self {
            Self {
                end_of_audio_token,
                span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"),
            }
        }
        pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
            let _enter = self.span.enter();
            let mut text_ids = vec![];
            let mut extracted_audio_ids = vec![];
            let mut min_audio_ids_len = usize::MAX;
            for (book_id, tokens) in tokens.iter().enumerate() {
                let mut audio_ids = vec![];
                for &t in tokens.iter() {
                    #[allow(clippy::comparison_chain)]
                    if t > self.end_of_audio_token {
                        if book_id == 0 {
                            text_ids.push(t)
                        }
                    } else if t < self.end_of_audio_token {
                        audio_ids.push(t)
                    }
                }
                min_audio_ids_len = usize::min(min_audio_ids_len, audio_ids.len());
                extracted_audio_ids.push(audio_ids)
            }
            for audio_ids in extracted_audio_ids.iter_mut() {
                audio_ids.truncate(min_audio_ids_len)
            }
            (text_ids, extracted_audio_ids)
        }
    }
    pub struct FlattenedInterleavedEncodec2Codebook {
        end_of_audio_token: u32,
        span: tracing::Span,
    }
    impl FlattenedInterleavedEncodec2Codebook {
        pub fn new(end_of_audio_token: u32) -> Self {
            Self {
                end_of_audio_token,
                span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"),
            }
        }
        pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
            let _enter = self.span.enter();
            let mut text_ids = vec![];
            let mut audio_ids1 = vec![];
            let mut audio_ids2 = vec![];
            for &t in tokens.iter() {
                #[allow(clippy::comparison_chain)]
                if t < self.end_of_audio_token {
                    audio_ids1.push(t)
                } else if t < 2 * self.end_of_audio_token {
                    audio_ids2.push(t - self.end_of_audio_token)
                } else {
                    text_ids.push(t)
                }
            }
            (text_ids, audio_ids1, audio_ids2)
        }
    }
}