use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Result, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone, Copy)]
pub struct ResnetBlock2DConfig {
pub out_channels: Option<usize>,
pub temb_channels: Option<usize>,
pub groups: usize,
pub groups_out: Option<usize>,
pub eps: f64,
pub use_in_shortcut: Option<bool>,
pub output_scale_factor: f64,
}
impl Default for ResnetBlock2DConfig {
fn default() -> Self {
Self {
out_channels: None,
temb_channels: Some(512),
groups: 32,
groups_out: None,
eps: 1e-6,
use_in_shortcut: None,
output_scale_factor: 1.,
}
}
}
#[derive(Debug)]
pub struct ResnetBlock2D {
norm1: nn::GroupNorm,
conv1: Conv2d,
norm2: nn::GroupNorm,
conv2: Conv2d,
time_emb_proj: Option<nn::Linear>,
conv_shortcut: Option<Conv2d>,
span: tracing::Span,
config: ResnetBlock2DConfig,
}
impl ResnetBlock2D {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
config: ResnetBlock2DConfig,
) -> Result<Self> {
let out_channels = config.out_channels.unwrap_or(in_channels);
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 1,
groups: 1,
dilation: 1,
};
let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?;
let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?;
let groups_out = config.groups_out.unwrap_or(config.groups);
let norm2 = nn::group_norm(groups_out, out_channels, config.eps, vs.pp("norm2"))?;
let conv2 = conv2d(out_channels, out_channels, 3, conv_cfg, vs.pp("conv2"))?;
let use_in_shortcut = config
.use_in_shortcut
.unwrap_or(in_channels != out_channels);
let conv_shortcut = if use_in_shortcut {
let conv_cfg = nn::Conv2dConfig {
stride: 1,
padding: 0,
groups: 1,
dilation: 1,
};
Some(conv2d(
in_channels,
out_channels,
1,
conv_cfg,
vs.pp("conv_shortcut"),
)?)
} else {
None
};
let time_emb_proj = match config.temb_channels {
None => None,
Some(temb_channels) => Some(nn::linear(
temb_channels,
out_channels,
vs.pp("time_emb_proj"),
)?),
};
let span = tracing::span!(tracing::Level::TRACE, "resnet2d");
Ok(Self {
norm1,
conv1,
norm2,
conv2,
time_emb_proj,
span,
config,
conv_shortcut,
})
}
pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let shortcut_xs = match &self.conv_shortcut {
Some(conv_shortcut) => conv_shortcut.forward(xs)?,
None => xs.clone(),
};
let xs = self.norm1.forward(xs)?;
let xs = self.conv1.forward(&nn::ops::silu(&xs)?)?;
let xs = match (temb, &self.time_emb_proj) {
(Some(temb), Some(time_emb_proj)) => time_emb_proj
.forward(&nn::ops::silu(temb)?)?
.unsqueeze(D::Minus1)?
.unsqueeze(D::Minus1)?
.broadcast_add(&xs)?,
_ => xs,
};
let xs = self
.conv2
.forward(&nn::ops::silu(&self.norm2.forward(&xs)?)?)?;
(shortcut_xs + xs)? / self.config.output_scale_factor
}
}