#![allow(dead_code)]
use super::unet_2d_blocks::{
DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig,
UpDecoderBlock2D, UpDecoderBlock2DConfig,
};
use candle::{Result, Tensor};
use candle_nn as nn;
use candle_nn::Module;
#[derive(Debug, Clone)]
struct EncoderConfig {
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
double_z: bool,
}
impl Default for EncoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
double_z: true,
}
}
}
#[derive(Debug)]
struct Encoder {
conv_in: nn::Conv2d,
down_blocks: Vec<DownEncoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: EncoderConfig,
}
impl Encoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: EncoderConfig,
) -> Result<Self> {
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
config.block_out_channels[0],
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mut down_blocks = vec![];
let vs_down_blocks = vs.pp("down_blocks");
for index in 0..config.block_out_channels.len() {
let out_channels = config.block_out_channels[index];
let in_channels = if index > 0 {
config.block_out_channels[index - 1]
} else {
config.block_out_channels[0]
};
let is_final = index + 1 == config.block_out_channels.len();
let cfg = DownEncoderBlock2DConfig {
num_layers: config.layers_per_block,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_downsample: !is_final,
downsample_padding: 0,
..Default::default()
};
let down_block = DownEncoderBlock2D::new(
vs_down_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
down_blocks.push(down_block)
}
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
last_block_out_channels,
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_out_channels = if config.double_z {
2 * out_channels
} else {
out_channels
};
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
last_block_out_channels,
conv_out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
down_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.apply(&self.conv_in)?;
for down_block in self.down_blocks.iter() {
xs = xs.apply(down_block)?
}
let xs = self
.mid_block
.forward(&xs, None)?
.apply(&self.conv_norm_out)?;
nn::ops::silu(&xs)?.apply(&self.conv_out)
}
}
#[derive(Debug, Clone)]
struct DecoderConfig {
block_out_channels: Vec<usize>,
layers_per_block: usize,
norm_num_groups: usize,
}
impl Default for DecoderConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 2,
norm_num_groups: 32,
}
}
}
#[derive(Debug)]
struct Decoder {
conv_in: nn::Conv2d,
up_blocks: Vec<UpDecoderBlock2D>,
mid_block: UNetMidBlock2D,
conv_norm_out: nn::GroupNorm,
conv_out: nn::Conv2d,
#[allow(dead_code)]
config: DecoderConfig,
}
impl Decoder {
fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: DecoderConfig,
) -> Result<Self> {
let n_block_out_channels = config.block_out_channels.len();
let last_block_out_channels = *config.block_out_channels.last().unwrap();
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_in = nn::conv2d(
in_channels,
last_block_out_channels,
3,
conv_cfg,
vs.pp("conv_in"),
)?;
let mid_cfg = UNetMidBlock2DConfig {
resnet_eps: 1e-6,
output_scale_factor: 1.,
attn_num_head_channels: None,
resnet_groups: Some(config.norm_num_groups),
..Default::default()
};
let mid_block =
UNetMidBlock2D::new(vs.pp("mid_block"), last_block_out_channels, None, mid_cfg)?;
let mut up_blocks = vec![];
let vs_up_blocks = vs.pp("up_blocks");
let reversed_block_out_channels: Vec<_> =
config.block_out_channels.iter().copied().rev().collect();
for index in 0..n_block_out_channels {
let out_channels = reversed_block_out_channels[index];
let in_channels = if index > 0 {
reversed_block_out_channels[index - 1]
} else {
reversed_block_out_channels[0]
};
let is_final = index + 1 == n_block_out_channels;
let cfg = UpDecoderBlock2DConfig {
num_layers: config.layers_per_block + 1,
resnet_eps: 1e-6,
resnet_groups: config.norm_num_groups,
add_upsample: !is_final,
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
vs_up_blocks.pp(&index.to_string()),
in_channels,
out_channels,
cfg,
)?;
up_blocks.push(up_block)
}
let conv_norm_out = nn::group_norm(
config.norm_num_groups,
config.block_out_channels[0],
1e-6,
vs.pp("conv_norm_out"),
)?;
let conv_cfg = nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let conv_out = nn::conv2d(
config.block_out_channels[0],
out_channels,
3,
conv_cfg,
vs.pp("conv_out"),
)?;
Ok(Self {
conv_in,
up_blocks,
mid_block,
conv_norm_out,
conv_out,
config,
})
}
}
impl Decoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.mid_block.forward(&self.conv_in.forward(xs)?, None)?;
for up_block in self.up_blocks.iter() {
xs = up_block.forward(&xs)?
}
let xs = self.conv_norm_out.forward(&xs)?;
let xs = nn::ops::silu(&xs)?;
self.conv_out.forward(&xs)
}
}
#[derive(Debug, Clone)]
pub struct AutoEncoderKLConfig {
pub block_out_channels: Vec<usize>,
pub layers_per_block: usize,
pub latent_channels: usize,
pub norm_num_groups: usize,
}
impl Default for AutoEncoderKLConfig {
fn default() -> Self {
Self {
block_out_channels: vec![64],
layers_per_block: 1,
latent_channels: 4,
norm_num_groups: 32,
}
}
}
pub struct DiagonalGaussianDistribution {
mean: Tensor,
std: Tensor,
}
impl DiagonalGaussianDistribution {
pub fn new(parameters: &Tensor) -> Result<Self> {
let mut parameters = parameters.chunk(2, 1)?.into_iter();
let mean = parameters.next().unwrap();
let logvar = parameters.next().unwrap();
let std = (logvar * 0.5)?.exp()?;
Ok(DiagonalGaussianDistribution { mean, std })
}
pub fn sample(&self) -> Result<Tensor> {
let sample = self.mean.randn_like(0., 1.);
&self.mean + &self.std * sample
}
}
#[derive(Debug)]
pub struct AutoEncoderKL {
encoder: Encoder,
decoder: Decoder,
quant_conv: nn::Conv2d,
post_quant_conv: nn::Conv2d,
pub config: AutoEncoderKLConfig,
}
impl AutoEncoderKL {
pub fn new(
vs: nn::VarBuilder,
in_channels: usize,
out_channels: usize,
config: AutoEncoderKLConfig,
) -> Result<Self> {
let latent_channels = config.latent_channels;
let encoder_cfg = EncoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
double_z: true,
};
let encoder = Encoder::new(vs.pp("encoder"), in_channels, latent_channels, encoder_cfg)?;
let decoder_cfg = DecoderConfig {
block_out_channels: config.block_out_channels.clone(),
layers_per_block: config.layers_per_block,
norm_num_groups: config.norm_num_groups,
};
let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?;
let conv_cfg = Default::default();
let quant_conv = nn::conv2d(
2 * latent_channels,
2 * latent_channels,
1,
conv_cfg,
vs.pp("quant_conv"),
)?;
let post_quant_conv = nn::conv2d(
latent_channels,
latent_channels,
1,
conv_cfg,
vs.pp("post_quant_conv"),
)?;
Ok(Self {
encoder,
decoder,
quant_conv,
post_quant_conv,
config,
})
}
pub fn encode(&self, xs: &Tensor) -> Result<DiagonalGaussianDistribution> {
let xs = self.encoder.forward(xs)?;
let parameters = self.quant_conv.forward(&xs)?;
DiagonalGaussianDistribution::new(¶meters)
}
pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.post_quant_conv.forward(xs)?;
self.decoder.forward(&xs)
}
}