use super::common::{AttnBlock, GlobalResponseNorm, LayerNormNoWeights, TimestepBlock, WLayerNorm};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug)]
pub struct ResBlockStageB {
depthwise: candle_nn::Conv2d,
norm: WLayerNorm,
channelwise_lin1: candle_nn::Linear,
channelwise_grn: GlobalResponseNorm,
channelwise_lin2: candle_nn::Linear,
}
impl ResBlockStageB {
pub fn new(c: usize, c_skip: usize, ksize: usize, vb: VarBuilder) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
groups: c,
padding: ksize / 2,
..Default::default()
};
let depthwise = candle_nn::conv2d(c, c, ksize, cfg, vb.pp("depthwise"))?;
let norm = WLayerNorm::new(c)?;
let channelwise_lin1 = candle_nn::linear(c + c_skip, c * 4, vb.pp("channelwise.0"))?;
let channelwise_grn = GlobalResponseNorm::new(4 * c, vb.pp("channelwise.2"))?;
let channelwise_lin2 = candle_nn::linear(c * 4, c, vb.pp("channelwise.4"))?;
Ok(Self {
depthwise,
norm,
channelwise_lin1,
channelwise_grn,
channelwise_lin2,
})
}
pub fn forward(&self, xs: &Tensor, x_skip: Option<&Tensor>) -> Result<Tensor> {
let x_res = xs;
let xs = xs.apply(&self.depthwise)?.apply(&self.norm)?;
let xs = match x_skip {
None => xs.clone(),
Some(x_skip) => Tensor::cat(&[&xs, x_skip], 1)?,
};
let xs = xs
.permute((0, 2, 3, 1))?
.contiguous()?
.apply(&self.channelwise_lin1)?
.gelu()?
.apply(&self.channelwise_grn)?
.apply(&self.channelwise_lin2)?
.permute((0, 3, 1, 2))?;
xs + x_res
}
}
#[derive(Debug)]
struct SubBlock {
res_block: ResBlockStageB,
ts_block: TimestepBlock,
attn_block: Option<AttnBlock>,
}
#[derive(Debug)]
struct DownBlock {
layer_norm: Option<WLayerNorm>,
conv: Option<candle_nn::Conv2d>,
sub_blocks: Vec<SubBlock>,
}
#[derive(Debug)]
struct UpBlock {
sub_blocks: Vec<SubBlock>,
layer_norm: Option<WLayerNorm>,
conv: Option<candle_nn::ConvTranspose2d>,
}
#[derive(Debug)]
pub struct WDiffNeXt {
clip_mapper: candle_nn::Linear,
effnet_mappers: Vec<Option<candle_nn::Conv2d>>,
seq_norm: LayerNormNoWeights,
embedding_conv: candle_nn::Conv2d,
embedding_ln: WLayerNorm,
down_blocks: Vec<DownBlock>,
up_blocks: Vec<UpBlock>,
clf_ln: WLayerNorm,
clf_conv: candle_nn::Conv2d,
c_r: usize,
patch_size: usize,
}
impl WDiffNeXt {
#[allow(clippy::too_many_arguments)]
pub fn new(
c_in: usize,
c_out: usize,
c_r: usize,
c_cond: usize,
clip_embd: usize,
patch_size: usize,
use_flash_attn: bool,
vb: VarBuilder,
) -> Result<Self> {
const C_HIDDEN: [usize; 4] = [320, 640, 1280, 1280];
const BLOCKS: [usize; 4] = [4, 4, 14, 4];
const NHEAD: [usize; 4] = [1, 10, 20, 20];
const INJECT_EFFNET: [bool; 4] = [false, true, true, true];
const EFFNET_EMBD: usize = 16;
let clip_mapper = candle_nn::linear(clip_embd, c_cond, vb.pp("clip_mapper"))?;
let mut effnet_mappers = Vec::with_capacity(2 * INJECT_EFFNET.len());
let vb_e = vb.pp("effnet_mappers");
for (i, &inject) in INJECT_EFFNET.iter().enumerate() {
let c = if inject {
Some(candle_nn::conv2d(
EFFNET_EMBD,
c_cond,
1,
Default::default(),
vb_e.pp(i),
)?)
} else {
None
};
effnet_mappers.push(c)
}
for (i, &inject) in INJECT_EFFNET.iter().rev().enumerate() {
let c = if inject {
Some(candle_nn::conv2d(
EFFNET_EMBD,
c_cond,
1,
Default::default(),
vb_e.pp(i + INJECT_EFFNET.len()),
)?)
} else {
None
};
effnet_mappers.push(c)
}
let seq_norm = LayerNormNoWeights::new(c_cond)?;
let embedding_ln = WLayerNorm::new(C_HIDDEN[0])?;
let embedding_conv = candle_nn::conv2d(
c_in * patch_size * patch_size,
C_HIDDEN[0],
1,
Default::default(),
vb.pp("embedding.1"),
)?;
let mut down_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate() {
let vb = vb.pp("down_blocks").pp(i);
let (layer_norm, conv, start_layer_i) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
let cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let conv = candle_nn::conv2d(C_HIDDEN[i - 1], c_hidden, 2, cfg, vb.pp("0.1"))?;
(Some(layer_norm), Some(conv), 1)
} else {
(None, None, 0)
};
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = start_layer_i;
for _j in 0..BLOCKS[i] {
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
let res_block = ResBlockStageB::new(c_hidden, c_skip, 3, vb.pp(layer_i))?;
layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1;
let attn_block = if i == 0 {
None
} else {
let attn_block = AttnBlock::new(
c_hidden,
c_cond,
NHEAD[i],
true,
use_flash_attn,
vb.pp(layer_i),
)?;
layer_i += 1;
Some(attn_block)
};
let sub_block = SubBlock {
res_block,
ts_block,
attn_block,
};
sub_blocks.push(sub_block)
}
let down_block = DownBlock {
layer_norm,
conv,
sub_blocks,
};
down_blocks.push(down_block)
}
let mut up_blocks = Vec::with_capacity(C_HIDDEN.len());
for (i, &c_hidden) in C_HIDDEN.iter().enumerate().rev() {
let vb = vb.pp("up_blocks").pp(C_HIDDEN.len() - 1 - i);
let mut sub_blocks = Vec::with_capacity(BLOCKS[i]);
let mut layer_i = 0;
for j in 0..BLOCKS[i] {
let c_skip = if INJECT_EFFNET[i] { c_cond } else { 0 };
let c_skip_res = if i < BLOCKS.len() - 1 && j == 0 {
c_hidden + c_skip
} else {
c_skip
};
let res_block = ResBlockStageB::new(c_hidden, c_skip_res, 3, vb.pp(layer_i))?;
layer_i += 1;
let ts_block = TimestepBlock::new(c_hidden, c_r, vb.pp(layer_i))?;
layer_i += 1;
let attn_block = if i == 0 {
None
} else {
let attn_block = AttnBlock::new(
c_hidden,
c_cond,
NHEAD[i],
true,
use_flash_attn,
vb.pp(layer_i),
)?;
layer_i += 1;
Some(attn_block)
};
let sub_block = SubBlock {
res_block,
ts_block,
attn_block,
};
sub_blocks.push(sub_block)
}
let (layer_norm, conv) = if i > 0 {
let layer_norm = WLayerNorm::new(C_HIDDEN[i - 1])?;
let cfg = candle_nn::ConvTranspose2dConfig {
stride: 2,
..Default::default()
};
let conv = candle_nn::conv_transpose2d(
c_hidden,
C_HIDDEN[i - 1],
2,
cfg,
vb.pp(layer_i).pp(1),
)?;
(Some(layer_norm), Some(conv))
} else {
(None, None)
};
let up_block = UpBlock {
layer_norm,
conv,
sub_blocks,
};
up_blocks.push(up_block)
}
let clf_ln = WLayerNorm::new(C_HIDDEN[0])?;
let clf_conv = candle_nn::conv2d(
C_HIDDEN[0],
2 * c_out * patch_size * patch_size,
1,
Default::default(),
vb.pp("clf.1"),
)?;
Ok(Self {
clip_mapper,
effnet_mappers,
seq_norm,
embedding_conv,
embedding_ln,
down_blocks,
up_blocks,
clf_ln,
clf_conv,
c_r,
patch_size,
})
}
fn gen_r_embedding(&self, r: &Tensor) -> Result<Tensor> {
const MAX_POSITIONS: usize = 10000;
let r = (r * MAX_POSITIONS as f64)?;
let half_dim = self.c_r / 2;
let emb = (MAX_POSITIONS as f64).ln() / (half_dim - 1) as f64;
let emb = (Tensor::arange(0u32, half_dim as u32, r.device())?.to_dtype(DType::F32)?
* -emb)?
.exp()?;
let emb = r.unsqueeze(1)?.broadcast_mul(&emb.unsqueeze(0)?)?;
let emb = Tensor::cat(&[emb.sin()?, emb.cos()?], 1)?;
let emb = if self.c_r % 2 == 1 {
emb.pad_with_zeros(D::Minus1, 0, 1)?
} else {
emb
};
emb.to_dtype(r.dtype())
}
fn gen_c_embeddings(&self, clip: &Tensor) -> Result<Tensor> {
clip.apply(&self.clip_mapper)?.apply(&self.seq_norm)
}
pub fn forward(
&self,
xs: &Tensor,
r: &Tensor,
effnet: &Tensor,
clip: Option<&Tensor>,
) -> Result<Tensor> {
const EPS: f64 = 1e-3;
let r_embed = self.gen_r_embedding(r)?;
let clip = match clip {
None => None,
Some(clip) => Some(self.gen_c_embeddings(clip)?),
};
let x_in = xs;
let mut xs = xs
.apply(&|xs: &_| candle_nn::ops::pixel_unshuffle(xs, self.patch_size))?
.apply(&self.embedding_conv)?
.apply(&self.embedding_ln)?;
let mut level_outputs = Vec::new();
for (i, down_block) in self.down_blocks.iter().enumerate() {
if let Some(ln) = &down_block.layer_norm {
xs = xs.apply(ln)?
}
if let Some(conv) = &down_block.conv {
xs = xs.apply(conv)?
}
let skip = match &self.effnet_mappers[i] {
None => None,
Some(m) => {
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
Some(m.forward(&effnet)?)
}
};
for block in down_block.sub_blocks.iter() {
xs = block.res_block.forward(&xs, skip.as_ref())?;
xs = block.ts_block.forward(&xs, &r_embed)?;
if let Some(attn_block) = &block.attn_block {
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
}
}
level_outputs.push(xs.clone())
}
level_outputs.reverse();
let mut xs = level_outputs[0].clone();
for (i, up_block) in self.up_blocks.iter().enumerate() {
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
None => None,
Some(m) => {
let effnet = effnet.interpolate2d(xs.dim(D::Minus2)?, xs.dim(D::Minus1)?)?;
Some(m.forward(&effnet)?)
}
};
for (j, block) in up_block.sub_blocks.iter().enumerate() {
let skip = if j == 0 && i > 0 {
Some(&level_outputs[i])
} else {
None
};
let skip = match (skip, effnet_c.as_ref()) {
(Some(skip), Some(effnet_c)) => Some(Tensor::cat(&[skip, effnet_c], 1)?),
(None, Some(skip)) | (Some(skip), None) => Some(skip.clone()),
(None, None) => None,
};
xs = block.res_block.forward(&xs, skip.as_ref())?;
xs = block.ts_block.forward(&xs, &r_embed)?;
if let Some(attn_block) = &block.attn_block {
xs = attn_block.forward(&xs, clip.as_ref().unwrap())?;
}
}
if let Some(ln) = &up_block.layer_norm {
xs = xs.apply(ln)?
}
if let Some(conv) = &up_block.conv {
xs = xs.apply(conv)?
}
}
let ab = xs
.apply(&self.clf_ln)?
.apply(&self.clf_conv)?
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
.chunk(2, 1)?;
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
(x_in - &ab[0])? / b
}
}