use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm};
pub use crate::quantized_var_builder::VarBuilder;
use crate::models::metavoice::repeat_interleave;
use candle::{Module, Result, Tensor, D};
pub mod transformer {
use super::*;
type Config = crate::models::metavoice::transformer::Config;
#[derive(Debug, Clone)]
struct FeedForward {
w1: Linear,
w2: Linear,
w3: Linear,
span: tracing::Span,
}
impl FeedForward {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let i_size = cfg.intermediate_size();
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
Ok(Self {
w1,
w2,
w3,
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
})
}
}
impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
swiglu.apply(&self.w2)
}
}
#[derive(Debug, Clone)]
struct Attention {
wqkv: Linear,
wo: Linear,
dim: usize,
kv_size: usize,
n_local_heads: usize,
head_dim: usize,
n_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
span: tracing::Span,
}
impl Attention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let n_local_heads = cfg.n_local_heads();
let head_dim = cfg.head_dim();
let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
Ok(Self {
wqkv,
wo,
dim: cfg.dim,
kv_size: n_local_heads * head_dim,
n_local_heads,
head_dim,
n_head: cfg.n_head,
kv_cache: None,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seqlen, _) = xs.dims3()?;
let qkv = xs.apply(&self.wqkv)?;
let q = qkv.narrow(D::Minus1, 0, self.dim)?;
let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
let q = q
.reshape((b_sz, seqlen, self.n_head, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let k = k
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
.transpose(1, 2)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &k], 2)?;
let v = Tensor::cat(&[prev_v, &v], 2)?;
(k, v)
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
let attn_weights = attn_weights.broadcast_add(mask)?;
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_weights.matmul(&v)?;
attn_output
.transpose(1, 2)?
.reshape((b_sz, seqlen, self.dim))?
.apply(&self.wo)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct Block {
attention: Attention,
feed_forward: FeedForward,
ffn_norm: RmsNorm,
attention_norm: RmsNorm,
span: tracing::Span,
}
impl Block {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention = Attention::new(cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
Ok(Self {
attention,
feed_forward,
ffn_norm,
attention_norm,
span: tracing::span!(tracing::Level::TRACE, "block"),
})
}
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hs = xs.apply(&self.attention_norm)?;
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
}
fn clear_kv_cache(&mut self) {
self.attention.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
tok_embeddings: Embedding,
pos_embeddings: Embedding,
speaker_cond_pos: Linear,
layers: Vec<Block>,
norm: RmsNorm,
output: Linear,
spk_cond_mask: Tensor,
span: tracing::Span,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
let speaker_cond_pos = linear_b(
cfg.speaker_emb_dim,
cfg.dim,
false,
vb.pp("speaker_cond_pos"),
)?;
let mut layers = Vec::with_capacity(cfg.n_layer);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.n_layer {
let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
let spk_cond_mask = Tensor::cat(
&[
Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
],
0,
)?;
Ok(Self {
tok_embeddings,
pos_embeddings,
speaker_cond_pos,
layers,
norm,
output,
spk_cond_mask,
span: tracing::span!(tracing::Level::TRACE, "qtransformer"),
})
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (_b_sz, seqlen) = xs.dims2()?;
let mask: Vec<_> = (0..seqlen)
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
.collect();
let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
let tok_embeddings = xs.apply(&self.tok_embeddings)?;
let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
let mut xs = tok_embeddings
.broadcast_add(&pos_embeddings)?
.broadcast_add(
&spk_emb
.apply(&self.speaker_cond_pos)?
.broadcast_mul(&self.spk_cond_mask)?,
)?;
let mask = mask.to_dtype(xs.dtype())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, pos, &mask)?
}
xs.narrow(1, seqlen - 1, 1)?
.contiguous()?
.apply(&self.norm)?
.apply(&self.output)
}
}
}