use super::embeddings::{TimestepEmbedding, Timesteps};
use super::unet_2d_blocks::*;
use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub struct BlockConfig {
pub out_channels: usize,
pub use_cross_attn: Option<usize>,
pub attention_head_dim: usize,
}
#[derive(Debug, Clone)]
pub struct UNet2DConditionModelConfig {
pub center_input_sample: bool,
pub flip_sin_to_cos: bool,
pub freq_shift: f64,
pub blocks: Vec<BlockConfig>,
pub layers_per_block: usize,
pub downsample_padding: usize,
pub mid_block_scale_factor: f64,
pub norm_num_groups: usize,
pub norm_eps: f64,
pub cross_attention_dim: usize,
pub sliced_attention_size: Option<usize>,
pub use_linear_projection: bool,
}
impl Default for UNet2DConditionModelConfig {
fn default() -> Self {
Self {
center_input_sample: false,
flip_sin_to_cos: true,
freq_shift: 0.,
blocks: vec![
BlockConfig {
out_channels: 320,
use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 640,
use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: Some(1),
attention_head_dim: 8,
},
BlockConfig {
out_channels: 1280,
use_cross_attn: None,
attention_head_dim: 8,
},
],
layers_per_block: 2,
downsample_padding: 1,
mid_block_scale_factor: 1.,
norm_num_groups: 32,
norm_eps: 1e-5,
cross_attention_dim: 1280,
sliced_attention_size: None,
use_linear_projection: false,
}
}
}
#[derive(Debug)]
pub(crate) enum UNetDownBlock {
Basic(DownBlock2D),
CrossAttn(CrossAttnDownBlock2D),
}
#[derive(Debug)]
enum UNetUpBlock {
Basic(UpBlock2D),
CrossAttn(CrossAttnUpBlock2D),
}
#[derive(Debug)]
pub struct UNet2DConditionModel {
conv_in: Conv2d,
time_proj: Timesteps,
time_embedding: TimestepEmbedding,
down_blocks: Vec<UNetDownBlock>,
mid_block: UNetMidBlock2DCrossAttn,
up_blocks: Vec<UNetUpBlock>,
conv_norm_out: nn::GroupNorm,
conv_out: Conv2d,
span: tracing::Span,
config: UNet2DConditionModelConfig,
}
impl UNet2DConditionModel {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
use_flash_attn: bool,
config: UNet2DConditionModelConfig,
) -> Result<Self> {
let n_blocks = config.blocks.len();
let b_channels = config.blocks[0].out_channels;
let bl_channels = config.blocks.last().unwrap().out_channels;
let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim;
let time_embed_dim = b_channels * 4;
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?;
let time_proj = Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift);
let time_embedding =
TimestepEmbedding::new(vs.pp("time_embedding"), b_channels, time_embed_dim)?;
let vs_db = vs.pp("down_blocks");
let down_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[i];
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let in_channels = if i > 0 {
config.blocks[i - 1].out_channels
} else {
b_channels
};
let db_cfg = DownBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_downsample: i < n_blocks - 1,
downsample_padding: config.downsample_padding,
..Default::default()
};
if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnDownBlock2DConfig {
downblock: db_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
transformer_layers_per_block,
};
let block = CrossAttnDownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetDownBlock::CrossAttn(block))
} else {
let block = DownBlock2D::new(
vs_db.pp(&i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
db_cfg,
)?;
Ok(UNetDownBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
let mid_transformer_layers_per_block = match config.blocks.last() {
None => 1,
Some(block) => block.use_cross_attn.unwrap_or(1),
};
let mid_cfg = UNetMidBlock2DCrossAttnConfig {
resnet_eps: config.norm_eps,
output_scale_factor: config.mid_block_scale_factor,
cross_attn_dim: config.cross_attention_dim,
attn_num_head_channels: bl_attention_head_dim,
resnet_groups: Some(config.norm_num_groups),
use_linear_projection: config.use_linear_projection,
transformer_layers_per_block: mid_transformer_layers_per_block,
..Default::default()
};
let mid_block = UNetMidBlock2DCrossAttn::new(
vs.pp("mid_block"),
bl_channels,
Some(time_embed_dim),
use_flash_attn,
mid_cfg,
)?;
let vs_ub = vs.pp("up_blocks");
let up_blocks = (0..n_blocks)
.map(|i| {
let BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
} = config.blocks[n_blocks - 1 - i];
let sliced_attention_size = match config.sliced_attention_size {
Some(0) => Some(attention_head_dim / 2),
_ => config.sliced_attention_size,
};
let prev_out_channels = if i > 0 {
config.blocks[n_blocks - i].out_channels
} else {
bl_channels
};
let in_channels = {
let index = if i == n_blocks - 1 {
0
} else {
n_blocks - i - 2
};
config.blocks[index].out_channels
};
let ub_cfg = UpBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: config.norm_eps,
resnet_groups: config.norm_num_groups,
add_upsample: i < n_blocks - 1,
..Default::default()
};
if let Some(transformer_layers_per_block) = use_cross_attn {
let config = CrossAttnUpBlock2DConfig {
upblock: ub_cfg,
attn_num_head_channels: attention_head_dim,
cross_attention_dim: config.cross_attention_dim,
sliced_attention_size,
use_linear_projection: config.use_linear_projection,
transformer_layers_per_block,
};
let block = CrossAttnUpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
use_flash_attn,
config,
)?;
Ok(UNetUpBlock::CrossAttn(block))
} else {
let block = UpBlock2D::new(
vs_ub.pp(&i.to_string()),
in_channels,
prev_out_channels,
out_channels,
Some(time_embed_dim),
ub_cfg,
)?;
Ok(UNetUpBlock::Basic(block))
}
})
.collect::<Result<Vec<_>>>()?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
b_channels,
config.norm_eps,
vs.pp("conv_norm_out"),
)?;
let conv_out = conv2d(b_channels, out_channels, 3, conv_cfg, vs.pp("conv_out"))?;
let span = tracing::span!(tracing::Level::TRACE, "unet2d");
Ok(Self {
conv_in,
time_proj,
time_embedding,
down_blocks,
mid_block,
up_blocks,
conv_norm_out,
conv_out,
span,
config,
})
}
pub fn forward(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
) -> Result<Tensor> {
let _enter = self.span.enter();
self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None)
}
pub fn forward_with_additional_residuals(
&self,
xs: &Tensor,
timestep: f64,
encoder_hidden_states: &Tensor,
down_block_additional_residuals: Option<&[Tensor]>,
mid_block_additional_residual: Option<&Tensor>,
) -> Result<Tensor> {
let (bsize, _channels, height, width) = xs.dims4()?;
let device = xs.device();
let n_blocks = self.config.blocks.len();
let num_upsamplers = n_blocks - 1;
let default_overall_up_factor = 2usize.pow(num_upsamplers as u32);
let forward_upsample_size =
height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0;
let xs = if self.config.center_input_sample {
((xs * 2.0)? - 1.0)?
} else {
xs.clone()
};
let emb = (Tensor::ones(bsize, xs.dtype(), device)? * timestep)?;
let emb = self.time_proj.forward(&emb)?;
let emb = self.time_embedding.forward(&emb)?;
let xs = self.conv_in.forward(&xs)?;
let mut down_block_res_xs = vec![xs.clone()];
let mut xs = xs;
for down_block in self.down_blocks.iter() {
let (_xs, res_xs) = match down_block {
UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb))?,
UNetDownBlock::CrossAttn(b) => {
b.forward(&xs, Some(&emb), Some(encoder_hidden_states))?
}
};
down_block_res_xs.extend(res_xs);
xs = _xs;
}
let new_down_block_res_xs =
if let Some(down_block_additional_residuals) = down_block_additional_residuals {
let mut v = vec![];
for (i, residuals) in down_block_additional_residuals.iter().enumerate() {
v.push((&down_block_res_xs[i] + residuals)?)
}
v
} else {
down_block_res_xs
};
let mut down_block_res_xs = new_down_block_res_xs;
let xs = self
.mid_block
.forward(&xs, Some(&emb), Some(encoder_hidden_states))?;
let xs = match mid_block_additional_residual {
None => xs,
Some(m) => (m + xs)?,
};
let mut xs = xs;
let mut upsample_size = None;
for (i, up_block) in self.up_blocks.iter().enumerate() {
let n_resnets = match up_block {
UNetUpBlock::Basic(b) => b.resnets.len(),
UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(),
};
let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets);
if i < n_blocks - 1 && forward_upsample_size {
let (_, _, h, w) = down_block_res_xs.last().unwrap().dims4()?;
upsample_size = Some((h, w))
}
xs = match up_block {
UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size)?,
UNetUpBlock::CrossAttn(b) => b.forward(
&xs,
&res_xs,
Some(&emb),
upsample_size,
Some(encoder_hidden_states),
)?,
};
}
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}