use candle::{DType, Device, Result, Tensor, D};
use candle_nn::{embedding, linear_b as linear, Embedding, LayerNorm, Linear, Module, VarBuilder};
use serde::Deserialize;
const MAX_SEQ_LEN: usize = 5000;
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err);
}
}
};
Ok(LayerNorm::new(weight, bias, eps))
}
#[derive(Clone, Debug, Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub layer_norm_epsilon: f64,
pub initializer_range: f64,
pub use_cache: bool,
pub bos_token_id: u32,
pub eos_token_id: u32,
pub hidden_dropout: f64,
pub attention_dropout: f64,
pub n_head_kv: Option<usize>,
pub alibi: bool,
pub new_decoder_architecture: bool,
pub multi_query: bool,
pub parallel_attn: bool,
pub bias: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 65024,
hidden_size: 4544,
num_hidden_layers: 32,
num_attention_heads: 71,
layer_norm_epsilon: 1e-5,
initializer_range: 0.02,
use_cache: true,
bos_token_id: 11,
eos_token_id: 11,
hidden_dropout: 0.0,
attention_dropout: 0.0,
n_head_kv: None,
alibi: false,
new_decoder_architecture: false,
multi_query: true,
parallel_attn: true,
bias: false,
}
}
}
impl Config {
pub fn validate(&self) -> Result<()> {
if self.alibi {
candle::bail!("alibi is not supported");
}
if self.new_decoder_architecture {
candle::bail!("new_decoder_architecture is not supported");
}
if self.n_head_kv.is_some() {
candle::bail!("n_head_kv is not supported");
}
Ok(())
}
pub fn falcon7b() -> Self {
Self {
vocab_size: 65024,
hidden_size: 4544,
num_hidden_layers: 32,
num_attention_heads: 71,
layer_norm_epsilon: 1e-5,
initializer_range: 0.02,
use_cache: true,
bos_token_id: 11,
eos_token_id: 11,
hidden_dropout: 0.,
attention_dropout: 0.,
n_head_kv: None,
alibi: false,
new_decoder_architecture: false,
multi_query: true,
parallel_attn: true,
bias: false,
}
}
fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
fn rotary(&self) -> bool {
!self.alibi
}
}
fn rotate_half(x: &Tensor) -> Result<Tensor> {
let l = x.dim(D::Minus1)?;
let x1 = x.narrow(D::Minus1, 0, l / 2)?;
let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?;
let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
Ok(x21)
}
#[derive(Debug, Clone)]
struct FalconRotaryEmbedding {
inv_freq: Tensor,
cache: Option<(usize, Tensor, Tensor)>,
}
impl FalconRotaryEmbedding {
fn load(device: &Device, cfg: &Config) -> Result<Self> {
let head_dim = cfg.head_dim();
let inv_freq: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
.collect();
Ok(Self {
inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
cache: None,
})
}
fn cos_sin(
&mut self,
seq_len: usize,
device: &Device,
dtype: DType,
) -> Result<(Tensor, Tensor)> {
match &self.cache {
Some((s, cos, sin)) if *s == seq_len => {
return Ok((cos.clone(), sin.clone()));
}
_ => {}
}
let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?;
let inv_freq = self.inv_freq.to_dtype(dtype)?;
let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
let cos = emb.cos()?;
let sin = emb.sin()?;
self.cache = Some((seq_len, cos.clone(), sin.clone()));
Ok((cos, sin))
}
fn forward(
&mut self,
query: &Tensor,
key: &Tensor,
past_kv_len: usize,
) -> Result<(Tensor, Tensor)> {
let (_batch, seq_len, _head_dim) = query.dims3()?;
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
let cos = cos.narrow(0, past_kv_len, seq_len)?;
let sin = sin.narrow(0, past_kv_len, seq_len)?;
let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?;
Ok((qs, ks))
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Debug, Clone)]
struct FalconAttention {
query_key_value: Linear,
dense: Linear,
maybe_rotary: Option<FalconRotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
inv_norm_factor: f64,
multi_query: bool,
use_cache: bool,
num_heads: usize,
head_dim: usize,
n_head_kv: usize,
}
impl FalconAttention {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let maybe_rotary = if cfg.rotary() {
let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
Some(rotary)
} else {
None
};
let head_dim = cfg.head_dim();
let hidden_size = cfg.hidden_size;
let qkv_out_dim = if cfg.multi_query {
hidden_size + 2 * head_dim
} else {
3 * hidden_size
};
let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
Ok(Self {
query_key_value,
dense,
maybe_rotary,
kv_cache: None,
inv_norm_factor: 1. / (head_dim as f64).sqrt(),
multi_query: cfg.multi_query,
use_cache: cfg.use_cache,
num_heads: cfg.num_attention_heads,
n_head_kv: cfg.n_head_kv.unwrap_or(1),
head_dim,
})
}
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
let (b_sz, seq_len, _) = fused_qkv.dims3()?;
if !self.multi_query {
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?;
let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?;
Ok((q, k, v))
} else {
let fused_qkv =
fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?;
let d = fused_qkv.dim(D::Minus2)?;
let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?;
let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?;
let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?;
Ok((q, k, v))
}
}
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
let fused_qkv = self.query_key_value.forward(x)?;
let head_dim = self.head_dim;
let (query, key, value) = self.split_heads(&fused_qkv)?;
let (b_sz, seq_len, _, _) = query.dims4()?;
let query = query
.transpose(1, 2)?
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
let key = key
.transpose(1, 2)?
.reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
let value = value
.transpose(1, 2)?
.reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
let (query, key) = if let Some(r) = &mut self.maybe_rotary {
r.forward(&query, &key, past_kv_len)?
} else {
(query, key)
};
let (mut key, mut value) = (key, value);
if self.use_cache {
if let Some((cache_k, cache_v)) = &self.kv_cache {
key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?;
value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?;
}
self.kv_cache = Some((key.clone(), value.clone()))
}
let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
let all_len = past_kv_len + seq_len;
let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
let (key, value) = if self.n_head_kv == 1 {
(
key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
)
} else {
(key, value)
};
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
let attention_scores = match mask {
None => attention_scores,
Some(mask) => {
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?
.to_dtype(query.dtype())?;
attention_scores.broadcast_add(&mask.squeeze(1)?)?
}
};
let attention_scores =
candle_nn::ops::softmax(&attention_scores.to_dtype(DType::F32)?, D::Minus1)?
.to_dtype(x.dtype())?;
let attn_output = attention_scores
.matmul(&value)?
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
.transpose(1, 2)?
.reshape((b_sz, seq_len, self.num_heads * head_dim))?;
let attn_output = self.dense.forward(&attn_output)?;
Ok(attn_output)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct FalconMlp {
dense_h_to_4h: Linear,
dense_4h_to_h: Linear,
}
impl FalconMlp {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h = cfg.hidden_size;
let b = cfg.bias;
let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
Ok(Self {
dense_h_to_4h,
dense_4h_to_h,
})
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = self.dense_h_to_4h.forward(x)?.gelu()?;
let x = self.dense_4h_to_h.forward(&x)?;
Ok(x)
}
}
#[derive(Debug, Clone)]
struct FalconDecoderLayer {
inp_layernorm: LayerNorm,
self_attention: FalconAttention,
post_attention_layernorm: Option<LayerNorm>,
mlp: FalconMlp,
parallel_attn: bool,
}
impl FalconDecoderLayer {
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
let inp_layernorm = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
vb.pp("input_layernorm"),
)?;
let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
let post_attention_layernorm = if cfg.parallel_attn {
None
} else {
let ln = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
vb.pp("post_attention_layernorm"),
)?;
Some(ln)
};
Ok(Self {
inp_layernorm,
self_attention,
post_attention_layernorm,
mlp,
parallel_attn: cfg.parallel_attn,
})
}
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, past_kv_len: usize) -> Result<Tensor> {
let residual = x.clone();
let ln_attn = self.inp_layernorm.forward(x)?;
let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
let (residual, ln_mlp) = match &self.post_attention_layernorm {
None => (residual, ln_attn),
Some(pal) => {
let residual = (&attn_output + &residual)?;
let ln_mlp = pal.forward(&residual)?;
(residual, ln_mlp)
}
};
let mlp_output = self.mlp.forward(&ln_mlp)?;
let mlp_output = if self.parallel_attn {
(mlp_output + attn_output)?
} else {
mlp_output
};
let output = (mlp_output + residual)?;
Ok(output)
}
pub fn clear_kv_cache(&mut self) {
self.self_attention.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Falcon {
word_embeddings: Embedding,
blocks: Vec<FalconDecoderLayer>,
ln_f: LayerNorm,
lm_head: Linear,
config: Config,
}
fn make_causal_mask(t: usize) -> Result<Tensor> {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
Ok(mask)
}
fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> {
let mask = make_causal_mask(seq_len)?;
let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?;
Ok(mask)
}
impl Falcon {
pub fn config(&self) -> &Config {
&self.config
}
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
let word_embeddings = embedding(
cfg.vocab_size,
cfg.hidden_size,
vb.pp("transformer.word_embeddings"),
)?;
let blocks = (0..cfg.num_hidden_layers)
.map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
.collect::<Result<Vec<_>>>()?;
let ln_f = layer_norm(
cfg.hidden_size,
cfg.layer_norm_epsilon,
vb.pp("transformer.ln_f"),
)?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
Ok(Self {
word_embeddings,
blocks,
ln_f,
lm_head,
config: cfg,
})
}
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.dims2()?;
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
Some((k, _)) => k.dim(1)?,
None => 0,
};
let causal_mask = if seq_len <= 1 {
None
} else {
Some(prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?)
};
for block in self.blocks.iter_mut() {
hidden_state = block.forward(&hidden_state, causal_mask.as_ref(), past_kv_len)?;
}
let hidden_state = self.ln_f.forward(&hidden_state)?;
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
Ok(logits)
}
pub fn clear_kv_cache(&mut self) {
for block in self.blocks.iter_mut() {
block.clear_kv_cache()
}
}
}