use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{Conv2dConfig, Module, VarBuilder};
const MBCONV_EXPAND_RATIO: usize = 4;
const MLP_RATIO: usize = 4;
const LOCAL_CONV_SIZE: usize = 3;
const IMG_SIZE: usize = 1024;
const IN_CHANNELS: usize = 3;
#[derive(Debug)]
struct Conv2dBN {
c: candle_nn::Conv2d,
bn: candle_nn::BatchNorm,
span: tracing::Span,
}
impl Conv2dBN {
fn new(in_: usize, out: usize, ks: usize, cfg: Conv2dConfig, vb: VarBuilder) -> Result<Self> {
let c = candle_nn::conv2d_no_bias(in_, out, ks, cfg, vb.pp("c"))?;
let bn = candle_nn::batch_norm(out, 1e-5, vb.pp("bn"))?;
let span = tracing::span!(tracing::Level::TRACE, "conv2d-bn");
Ok(Self { c, bn, span })
}
}
impl Module for Conv2dBN {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.c)?.apply_t(&self.bn, false)
}
}
#[derive(Debug)]
struct PatchEmbed {
conv1: Conv2dBN,
conv2: Conv2dBN,
span: tracing::Span,
}
impl PatchEmbed {
fn new(in_chans: usize, embed_dim: usize, vb: VarBuilder) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
stride: 2,
padding: 1,
..Default::default()
};
let conv1 = Conv2dBN::new(in_chans, embed_dim / 2, 3, cfg, vb.pp("seq.0"))?;
let conv2 = Conv2dBN::new(embed_dim / 2, embed_dim, 3, cfg, vb.pp("seq.2"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-embed");
Ok(Self { conv1, conv2, span })
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.conv1)?.gelu()?.apply(&self.conv2)
}
}
#[derive(Debug)]
struct MBConv {
conv1: Conv2dBN,
conv2: Conv2dBN,
conv3: Conv2dBN,
span: tracing::Span,
}
impl MBConv {
fn new(in_: usize, out: usize, expand_ratio: usize, vb: VarBuilder) -> Result<Self> {
let hidden = in_ * expand_ratio;
let cfg2 = candle_nn::Conv2dConfig {
padding: 1,
groups: hidden,
..Default::default()
};
let conv1 = Conv2dBN::new(in_, hidden, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(hidden, hidden, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(hidden, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "mb-conv");
Ok(Self {
conv1,
conv2,
conv3,
span,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut = xs;
let xs = xs
.apply(&self.conv1)?
.gelu()?
.apply(&self.conv2)?
.gelu()?
.apply(&self.conv3)?;
(xs + shortcut)?.gelu()
}
}
#[derive(Debug)]
struct PatchMerging {
conv1: Conv2dBN,
conv2: Conv2dBN,
conv3: Conv2dBN,
input_resolution: (usize, usize),
span: tracing::Span,
}
impl PatchMerging {
fn new(
input_resolution: (usize, usize),
dim: usize,
out: usize,
vb: VarBuilder,
) -> Result<Self> {
let stride = if [320, 448, 576].contains(&out) { 1 } else { 2 };
let cfg2 = candle_nn::Conv2dConfig {
padding: 1,
stride,
groups: out,
..Default::default()
};
let conv1 = Conv2dBN::new(dim, out, 1, Default::default(), vb.pp("conv1"))?;
let conv2 = Conv2dBN::new(out, out, 3, cfg2, vb.pp("conv2"))?;
let conv3 = Conv2dBN::new(out, out, 1, Default::default(), vb.pp("conv3"))?;
let span = tracing::span!(tracing::Level::TRACE, "patch-merging");
Ok(Self {
conv1,
conv2,
conv3,
input_resolution,
span,
})
}
}
impl Module for PatchMerging {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = if xs.rank() == 3 {
let (h, w) = self.input_resolution;
let b = xs.dim(0)?;
xs.reshape((b, h, w, ()))?.permute((0, 3, 1, 2))?
} else {
xs.clone()
};
xs.apply(&self.conv1)?
.gelu()?
.apply(&self.conv2)?
.gelu()?
.apply(&self.conv3)?
.flatten_from(2)?
.transpose(1, 2)
}
}
#[derive(Debug)]
struct ConvLayer {
blocks: Vec<MBConv>,
downsample: Option<PatchMerging>,
span: tracing::Span,
}
impl ConvLayer {
fn new(
dim: usize,
out: usize,
input_resolution: (usize, usize),
depth: usize,
downsample: bool,
conv_expand_ratio: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb_b = vb.pp("blocks");
let mut blocks = Vec::with_capacity(depth);
for index in 0..depth {
let block = MBConv::new(dim, dim, conv_expand_ratio, vb_b.pp(index))?;
blocks.push(block)
}
let downsample = if downsample {
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
Some(downsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "conv-layer");
Ok(Self {
blocks,
downsample,
span,
})
}
}
impl Module for ConvLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
match &self.downsample {
None => Ok(xs),
Some(downsample) => downsample.forward(&xs),
}
}
}
#[derive(Debug)]
struct Mlp {
norm: candle_nn::LayerNorm,
fc1: super::Linear,
fc2: super::Linear,
span: tracing::Span,
}
impl Mlp {
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?;
let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
Ok(Self {
norm,
fc1,
fc2,
span,
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.norm)?
.apply(&self.fc1)?
.gelu()?
.apply(&self.fc2)
}
}
#[derive(Debug)]
struct Attention {
norm: candle_nn::LayerNorm,
qkv: super::Linear,
proj: super::Linear,
ab: Tensor,
key_dim: usize,
num_heads: usize,
d: usize,
dh: usize,
scale: f64,
span: tracing::Span,
span_matmul: tracing::Span,
span_softmax: tracing::Span,
}
impl Attention {
fn new(
dim: usize,
key_dim: usize,
num_heads: usize,
attn_ratio: usize,
resolution: (usize, usize),
vb: VarBuilder,
) -> Result<Self> {
let d = attn_ratio * key_dim;
let dh = d * num_heads;
let nh_kd = key_dim * num_heads;
let h = dh + nh_kd * 2;
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
let qkv = super::linear(vb.pp("qkv"), dim, h, true)?;
let proj = super::linear(vb.pp("proj"), dh, dim, true)?;
let points = (0..resolution.0)
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
.collect::<Vec<_>>();
let mut idxs = Vec::with_capacity(points.len() * points.len());
let mut attention_offsets = std::collections::HashMap::new();
for &(x1, y1) in points.iter() {
for &(x2, y2) in points.iter() {
let offset = ((x2 - x1).abs(), (y2 - y1).abs());
let l = attention_offsets.len();
let idx = attention_offsets.entry(offset).or_insert(l);
idxs.push(*idx as u32)
}
}
let attention_biases = vb.get((num_heads, attention_offsets.len()), "attention_biases")?;
let idxs = Tensor::new(idxs, attention_biases.device())?;
let ab =
attention_biases
.index_select(&idxs, 1)?
.reshape(((), points.len(), points.len()))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
Ok(Self {
norm,
qkv,
proj,
ab,
key_dim,
num_heads,
d,
dh,
scale: 1f64 / (key_dim as f64).sqrt(),
span,
span_matmul,
span_softmax,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, n, _) = xs.dims3()?;
let xs = xs.apply(&self.norm)?;
let qkv = xs.apply(&self.qkv)?.reshape((b, n, self.num_heads, ()))?;
let q = qkv
.narrow(D::Minus1, 0, self.key_dim)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let k = qkv
.narrow(D::Minus1, self.key_dim, self.key_dim)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let v = qkv
.narrow(D::Minus1, 2 * self.key_dim, self.d)?
.permute((0, 2, 1, 3))?
.contiguous()?;
let attn = {
let _enter = self.span_matmul.enter();
(q.matmul(&k.t()?)? * self.scale)?
};
let attn = attn.broadcast_add(&self.ab)?;
let attn = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = {
let _enter = self.span_matmul.enter();
attn.matmul(&v)?
};
attn.transpose(1, 2)?
.reshape((b, n, self.dh))?
.apply(&self.proj)
}
}
#[derive(Debug)]
struct TinyViTBlock {
attn: Attention,
local_conv: Conv2dBN,
mlp: Mlp,
window_size: usize,
input_resolution: (usize, usize),
span: tracing::Span,
}
impl TinyViTBlock {
fn new(
dim: usize,
input_resolution: (usize, usize),
num_heads: usize,
window_size: usize,
vb: VarBuilder,
) -> Result<Self> {
let head_dim = dim / num_heads;
let attn = Attention::new(
dim,
head_dim,
num_heads,
1,
(window_size, window_size),
vb.pp("attn"),
)?;
let mlp = Mlp::new(dim, dim * MLP_RATIO, vb.pp("mlp"))?;
let cfg = candle_nn::Conv2dConfig {
padding: LOCAL_CONV_SIZE / 2,
groups: dim,
..Default::default()
};
let local_conv = Conv2dBN::new(dim, dim, LOCAL_CONV_SIZE, cfg, vb.pp("local_conv"))?;
let span = tracing::span!(tracing::Level::TRACE, "attention");
Ok(Self {
attn,
local_conv,
mlp,
window_size,
input_resolution,
span,
})
}
}
impl Module for TinyViTBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (h, w) = self.input_resolution;
let (b, l, c) = xs.dims3()?;
let res_x = xs;
let xs = if h == self.window_size && w == self.window_size {
self.attn.forward(xs)?
} else {
let xs = xs.reshape((b, h, w, c))?;
let pad_b = (self.window_size - h % self.window_size) % self.window_size;
let pad_r = (self.window_size - w % self.window_size) % self.window_size;
let xs = if pad_b > 0 {
xs.pad_with_zeros(1, 0, pad_b)?
} else {
xs
};
let xs = if pad_r > 0 {
xs.pad_with_zeros(2, 0, pad_r)?
} else {
xs
};
let (p_h, p_w) = (h + pad_b, w + pad_r);
let n_h = p_h / self.window_size;
let n_w = p_w / self.window_size;
let xs = xs
.reshape((b, n_h, self.window_size, n_w, self.window_size, c))?
.transpose(2, 3)?
.reshape((b * n_h * n_w, self.window_size * self.window_size, c))?;
let xs = self.attn.forward(&xs)?;
let xs = xs
.reshape((b, n_h, n_w, self.window_size, self.window_size, c))?
.transpose(2, 3)?
.reshape((b, p_h, p_w, c))?;
let xs = if pad_r > 0 {
xs.i((.., .., ..w))?.contiguous()?
} else {
xs
};
let xs = if pad_b > 0 {
xs.i((.., ..h, ..))?.contiguous()?
} else {
xs
};
xs.reshape((b, l, c))?
};
let xs = (xs + res_x)?;
let xs = xs
.transpose(1, 2)?
.reshape((b, c, h, w))?
.apply(&self.local_conv)?
.reshape((b, c, l))?
.transpose(1, 2)?;
&xs + self.mlp.forward(&xs)?
}
}
#[derive(Debug)]
struct BasicLayer {
blocks: Vec<TinyViTBlock>,
downsample: Option<PatchMerging>,
span: tracing::Span,
}
impl BasicLayer {
#[allow(clippy::too_many_arguments)]
fn new(
dim: usize,
input_resolution: (usize, usize),
depth: usize,
num_heads: usize,
window_size: usize,
downsample: bool,
out: usize,
vb: VarBuilder,
) -> Result<Self> {
let vb_b = vb.pp("blocks");
let mut blocks = Vec::with_capacity(depth);
for index in 0..depth {
let block = TinyViTBlock::new(
dim,
input_resolution,
num_heads,
window_size,
vb_b.pp(index),
)?;
blocks.push(block)
}
let downsample = if downsample {
let downsample = PatchMerging::new(input_resolution, dim, out, vb.pp("downsample"))?;
Some(downsample)
} else {
None
};
let span = tracing::span!(tracing::Level::TRACE, "basic-layer");
Ok(Self {
blocks,
downsample,
span,
})
}
}
impl Module for BasicLayer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = xs.clone();
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
match &self.downsample {
None => Ok(xs),
Some(downsample) => downsample.forward(&xs),
}
}
}
#[derive(Debug)]
pub struct TinyViT {
patch_embed: PatchEmbed,
layer0: ConvLayer,
layers: Vec<BasicLayer>,
neck_conv1: candle_nn::Conv2d,
neck_ln1: super::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
neck_ln2: super::LayerNorm2d,
span: tracing::Span,
span_neck: tracing::Span,
}
impl TinyViT {
pub fn new(
embed_dims: &[usize],
depths: &[usize],
num_heads: &[usize],
window_sizes: &[usize],
_num_classes: usize,
vb: VarBuilder,
) -> Result<Self> {
let patch_embed = PatchEmbed::new(IN_CHANNELS, embed_dims[0], vb.pp("patch_embed"))?;
let patches_resolution = IMG_SIZE / 4;
let vb_l = vb.pp("layers");
let layer0 = ConvLayer::new(
embed_dims[0],
embed_dims[1],
(patches_resolution, patches_resolution),
depths[0],
true,
MBCONV_EXPAND_RATIO,
vb_l.pp(0),
)?;
let num_layers = embed_dims.len();
let mut layers = Vec::with_capacity(num_layers - 1);
for i_layer in 1..num_layers {
let patches_resolution = patches_resolution / (1 << usize::min(i_layer, 2));
let layer = BasicLayer::new(
embed_dims[i_layer],
(patches_resolution, patches_resolution),
depths[i_layer],
num_heads[i_layer],
window_sizes[i_layer],
i_layer < num_layers - 1,
embed_dims[usize::min(i_layer + 1, num_layers - 1)],
vb_l.pp(i_layer),
)?;
layers.push(layer)
}
let last_embed_dim = embed_dims[embed_dims.len() - 1];
let neck_conv1 =
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");
Ok(Self {
patch_embed,
layer0,
layers,
neck_conv1,
neck_ln1,
neck_conv2,
neck_ln2,
span,
span_neck,
})
}
}
impl Module for TinyViT {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.patch_embed.forward(xs)?;
let mut xs = self.layer0.forward(&xs)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
let (b, _, c) = xs.dims3()?;
let _enter = self.span_neck.enter();
xs.reshape((b, 64, 64, c))?
.permute((0, 3, 1, 2))?
.apply(&self.neck_conv1)?
.apply(&self.neck_ln1)?
.apply(&self.neck_conv2)?
.apply(&self.neck_ln2)
}
}
pub fn tiny_vit_5m(vb: VarBuilder) -> Result<TinyViT> {
TinyViT::new(
&[64, 128, 160, 320],
&[2, 2, 6, 2],
&[2, 4, 5, 10],
&[7, 7, 14, 7],
1000,
vb,
)
}