use super::attention::{
    AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
};
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
use crate::models::with_tracing::{conv2d, Conv2d};
use candle::{Module, Result, Tensor, D};
use candle_nn as nn;
#[derive(Debug)]
struct Downsample2D {
    conv: Option<Conv2d>,
    padding: usize,
    span: tracing::Span,
}
impl Downsample2D {
    fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        use_conv: bool,
        out_channels: usize,
        padding: usize,
    ) -> Result<Self> {
        let conv = if use_conv {
            let config = nn::Conv2dConfig {
                stride: 2,
                padding,
                ..Default::default()
            };
            let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
            Some(conv)
        } else {
            None
        };
        let span = tracing::span!(tracing::Level::TRACE, "downsample2d");
        Ok(Self {
            conv,
            padding,
            span,
        })
    }
}
impl Module for Downsample2D {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        match &self.conv {
            None => xs.avg_pool2d(2),
            Some(conv) => {
                if self.padding == 0 {
                    let xs = xs
                        .pad_with_zeros(D::Minus1, 0, 1)?
                        .pad_with_zeros(D::Minus2, 0, 1)?;
                    conv.forward(&xs)
                } else {
                    conv.forward(xs)
                }
            }
        }
    }
}
#[derive(Debug)]
struct Upsample2D {
    conv: Conv2d,
    span: tracing::Span,
}
impl Upsample2D {
    fn new(vs: nn::VarBuilder, in_channels: usize, out_channels: usize) -> Result<Self> {
        let config = nn::Conv2dConfig {
            padding: 1,
            ..Default::default()
        };
        let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?;
        let span = tracing::span!(tracing::Level::TRACE, "upsample2d");
        Ok(Self { conv, span })
    }
}
impl Upsample2D {
    fn forward(&self, xs: &Tensor, size: Option<(usize, usize)>) -> Result<Tensor> {
        let _enter = self.span.enter();
        let xs = match size {
            None => {
                let (_bsize, _channels, h, w) = xs.dims4()?;
                xs.upsample_nearest2d(2 * h, 2 * w)?
            }
            Some((h, w)) => xs.upsample_nearest2d(h, w)?,
        };
        self.conv.forward(&xs)
    }
}
#[derive(Debug, Clone, Copy)]
pub struct DownEncoderBlock2DConfig {
    pub num_layers: usize,
    pub resnet_eps: f64,
    pub resnet_groups: usize,
    pub output_scale_factor: f64,
    pub add_downsample: bool,
    pub downsample_padding: usize,
}
impl Default for DownEncoderBlock2DConfig {
    fn default() -> Self {
        Self {
            num_layers: 1,
            resnet_eps: 1e-6,
            resnet_groups: 32,
            output_scale_factor: 1.,
            add_downsample: true,
            downsample_padding: 1,
        }
    }
}
#[derive(Debug)]
pub struct DownEncoderBlock2D {
    resnets: Vec<ResnetBlock2D>,
    downsampler: Option<Downsample2D>,
    span: tracing::Span,
    pub config: DownEncoderBlock2DConfig,
}
impl DownEncoderBlock2D {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        out_channels: usize,
        config: DownEncoderBlock2DConfig,
    ) -> Result<Self> {
        let resnets: Vec<_> = {
            let vs = vs.pp("resnets");
            let conv_cfg = ResnetBlock2DConfig {
                eps: config.resnet_eps,
                out_channels: Some(out_channels),
                groups: config.resnet_groups,
                output_scale_factor: config.output_scale_factor,
                temb_channels: None,
                ..Default::default()
            };
            (0..(config.num_layers))
                .map(|i| {
                    let in_channels = if i == 0 { in_channels } else { out_channels };
                    ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
                })
                .collect::<Result<Vec<_>>>()?
        };
        let downsampler = if config.add_downsample {
            let downsample = Downsample2D::new(
                vs.pp("downsamplers").pp("0"),
                out_channels,
                true,
                out_channels,
                config.downsample_padding,
            )?;
            Some(downsample)
        } else {
            None
        };
        let span = tracing::span!(tracing::Level::TRACE, "down-enc2d");
        Ok(Self {
            resnets,
            downsampler,
            span,
            config,
        })
    }
}
impl Module for DownEncoderBlock2D {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let mut xs = xs.clone();
        for resnet in self.resnets.iter() {
            xs = resnet.forward(&xs, None)?
        }
        match &self.downsampler {
            Some(downsampler) => downsampler.forward(&xs),
            None => Ok(xs),
        }
    }
}
#[derive(Debug, Clone, Copy)]
pub struct UpDecoderBlock2DConfig {
    pub num_layers: usize,
    pub resnet_eps: f64,
    pub resnet_groups: usize,
    pub output_scale_factor: f64,
    pub add_upsample: bool,
}
impl Default for UpDecoderBlock2DConfig {
    fn default() -> Self {
        Self {
            num_layers: 1,
            resnet_eps: 1e-6,
            resnet_groups: 32,
            output_scale_factor: 1.,
            add_upsample: true,
        }
    }
}
#[derive(Debug)]
pub struct UpDecoderBlock2D {
    resnets: Vec<ResnetBlock2D>,
    upsampler: Option<Upsample2D>,
    span: tracing::Span,
    pub config: UpDecoderBlock2DConfig,
}
impl UpDecoderBlock2D {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        out_channels: usize,
        config: UpDecoderBlock2DConfig,
    ) -> Result<Self> {
        let resnets: Vec<_> = {
            let vs = vs.pp("resnets");
            let conv_cfg = ResnetBlock2DConfig {
                out_channels: Some(out_channels),
                eps: config.resnet_eps,
                groups: config.resnet_groups,
                output_scale_factor: config.output_scale_factor,
                temb_channels: None,
                ..Default::default()
            };
            (0..(config.num_layers))
                .map(|i| {
                    let in_channels = if i == 0 { in_channels } else { out_channels };
                    ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
                })
                .collect::<Result<Vec<_>>>()?
        };
        let upsampler = if config.add_upsample {
            let upsample =
                Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
            Some(upsample)
        } else {
            None
        };
        let span = tracing::span!(tracing::Level::TRACE, "up-dec2d");
        Ok(Self {
            resnets,
            upsampler,
            span,
            config,
        })
    }
}
impl Module for UpDecoderBlock2D {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let _enter = self.span.enter();
        let mut xs = xs.clone();
        for resnet in self.resnets.iter() {
            xs = resnet.forward(&xs, None)?
        }
        match &self.upsampler {
            Some(upsampler) => upsampler.forward(&xs, None),
            None => Ok(xs),
        }
    }
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DConfig {
    pub num_layers: usize,
    pub resnet_eps: f64,
    pub resnet_groups: Option<usize>,
    pub attn_num_head_channels: Option<usize>,
    pub output_scale_factor: f64,
}
impl Default for UNetMidBlock2DConfig {
    fn default() -> Self {
        Self {
            num_layers: 1,
            resnet_eps: 1e-6,
            resnet_groups: Some(32),
            attn_num_head_channels: Some(1),
            output_scale_factor: 1.,
        }
    }
}
#[derive(Debug)]
pub struct UNetMidBlock2D {
    resnet: ResnetBlock2D,
    attn_resnets: Vec<(AttentionBlock, ResnetBlock2D)>,
    span: tracing::Span,
    pub config: UNetMidBlock2DConfig,
}
impl UNetMidBlock2D {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        temb_channels: Option<usize>,
        config: UNetMidBlock2DConfig,
    ) -> Result<Self> {
        let vs_resnets = vs.pp("resnets");
        let vs_attns = vs.pp("attentions");
        let resnet_groups = config
            .resnet_groups
            .unwrap_or_else(|| usize::min(in_channels / 4, 32));
        let resnet_cfg = ResnetBlock2DConfig {
            eps: config.resnet_eps,
            groups: resnet_groups,
            output_scale_factor: config.output_scale_factor,
            temb_channels,
            ..Default::default()
        };
        let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
        let attn_cfg = AttentionBlockConfig {
            num_head_channels: config.attn_num_head_channels,
            num_groups: resnet_groups,
            rescale_output_factor: config.output_scale_factor,
            eps: config.resnet_eps,
        };
        let mut attn_resnets = vec![];
        for index in 0..config.num_layers {
            let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
            let resnet = ResnetBlock2D::new(
                vs_resnets.pp(&(index + 1).to_string()),
                in_channels,
                resnet_cfg,
            )?;
            attn_resnets.push((attn, resnet))
        }
        let span = tracing::span!(tracing::Level::TRACE, "mid2d");
        Ok(Self {
            resnet,
            attn_resnets,
            span,
            config,
        })
    }
    pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<Tensor> {
        let _enter = self.span.enter();
        let mut xs = self.resnet.forward(xs, temb)?;
        for (attn, resnet) in self.attn_resnets.iter() {
            xs = resnet.forward(&attn.forward(&xs)?, temb)?
        }
        Ok(xs)
    }
}
#[derive(Debug, Clone, Copy)]
pub struct UNetMidBlock2DCrossAttnConfig {
    pub num_layers: usize,
    pub resnet_eps: f64,
    pub resnet_groups: Option<usize>,
    pub attn_num_head_channels: usize,
    pub output_scale_factor: f64,
    pub cross_attn_dim: usize,
    pub sliced_attention_size: Option<usize>,
    pub use_linear_projection: bool,
    pub transformer_layers_per_block: usize,
}
impl Default for UNetMidBlock2DCrossAttnConfig {
    fn default() -> Self {
        Self {
            num_layers: 1,
            resnet_eps: 1e-6,
            resnet_groups: Some(32),
            attn_num_head_channels: 1,
            output_scale_factor: 1.,
            cross_attn_dim: 1280,
            sliced_attention_size: None, use_linear_projection: false,
            transformer_layers_per_block: 1,
        }
    }
}
#[derive(Debug)]
pub struct UNetMidBlock2DCrossAttn {
    resnet: ResnetBlock2D,
    attn_resnets: Vec<(SpatialTransformer, ResnetBlock2D)>,
    span: tracing::Span,
    pub config: UNetMidBlock2DCrossAttnConfig,
}
impl UNetMidBlock2DCrossAttn {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        temb_channels: Option<usize>,
        use_flash_attn: bool,
        config: UNetMidBlock2DCrossAttnConfig,
    ) -> Result<Self> {
        let vs_resnets = vs.pp("resnets");
        let vs_attns = vs.pp("attentions");
        let resnet_groups = config
            .resnet_groups
            .unwrap_or_else(|| usize::min(in_channels / 4, 32));
        let resnet_cfg = ResnetBlock2DConfig {
            eps: config.resnet_eps,
            groups: resnet_groups,
            output_scale_factor: config.output_scale_factor,
            temb_channels,
            ..Default::default()
        };
        let resnet = ResnetBlock2D::new(vs_resnets.pp("0"), in_channels, resnet_cfg)?;
        let n_heads = config.attn_num_head_channels;
        let attn_cfg = SpatialTransformerConfig {
            depth: config.transformer_layers_per_block,
            num_groups: resnet_groups,
            context_dim: Some(config.cross_attn_dim),
            sliced_attention_size: config.sliced_attention_size,
            use_linear_projection: config.use_linear_projection,
        };
        let mut attn_resnets = vec![];
        for index in 0..config.num_layers {
            let attn = SpatialTransformer::new(
                vs_attns.pp(&index.to_string()),
                in_channels,
                n_heads,
                in_channels / n_heads,
                use_flash_attn,
                attn_cfg,
            )?;
            let resnet = ResnetBlock2D::new(
                vs_resnets.pp(&(index + 1).to_string()),
                in_channels,
                resnet_cfg,
            )?;
            attn_resnets.push((attn, resnet))
        }
        let span = tracing::span!(tracing::Level::TRACE, "xa-mid2d");
        Ok(Self {
            resnet,
            attn_resnets,
            span,
            config,
        })
    }
    pub fn forward(
        &self,
        xs: &Tensor,
        temb: Option<&Tensor>,
        encoder_hidden_states: Option<&Tensor>,
    ) -> Result<Tensor> {
        let _enter = self.span.enter();
        let mut xs = self.resnet.forward(xs, temb)?;
        for (attn, resnet) in self.attn_resnets.iter() {
            xs = resnet.forward(&attn.forward(&xs, encoder_hidden_states)?, temb)?
        }
        Ok(xs)
    }
}
#[derive(Debug, Clone, Copy)]
pub struct DownBlock2DConfig {
    pub num_layers: usize,
    pub resnet_eps: f64,
    pub resnet_groups: usize,
    pub output_scale_factor: f64,
    pub add_downsample: bool,
    pub downsample_padding: usize,
}
impl Default for DownBlock2DConfig {
    fn default() -> Self {
        Self {
            num_layers: 1,
            resnet_eps: 1e-6,
            resnet_groups: 32,
            output_scale_factor: 1.,
            add_downsample: true,
            downsample_padding: 1,
        }
    }
}
#[derive(Debug)]
pub struct DownBlock2D {
    resnets: Vec<ResnetBlock2D>,
    downsampler: Option<Downsample2D>,
    span: tracing::Span,
    pub config: DownBlock2DConfig,
}
impl DownBlock2D {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        out_channels: usize,
        temb_channels: Option<usize>,
        config: DownBlock2DConfig,
    ) -> Result<Self> {
        let vs_resnets = vs.pp("resnets");
        let resnet_cfg = ResnetBlock2DConfig {
            out_channels: Some(out_channels),
            eps: config.resnet_eps,
            output_scale_factor: config.output_scale_factor,
            temb_channels,
            ..Default::default()
        };
        let resnets = (0..config.num_layers)
            .map(|i| {
                let in_channels = if i == 0 { in_channels } else { out_channels };
                ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
            })
            .collect::<Result<Vec<_>>>()?;
        let downsampler = if config.add_downsample {
            let downsampler = Downsample2D::new(
                vs.pp("downsamplers").pp("0"),
                out_channels,
                true,
                out_channels,
                config.downsample_padding,
            )?;
            Some(downsampler)
        } else {
            None
        };
        let span = tracing::span!(tracing::Level::TRACE, "down2d");
        Ok(Self {
            resnets,
            downsampler,
            span,
            config,
        })
    }
    pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Result<(Tensor, Vec<Tensor>)> {
        let _enter = self.span.enter();
        let mut xs = xs.clone();
        let mut output_states = vec![];
        for resnet in self.resnets.iter() {
            xs = resnet.forward(&xs, temb)?;
            output_states.push(xs.clone());
        }
        let xs = match &self.downsampler {
            Some(downsampler) => {
                let xs = downsampler.forward(&xs)?;
                output_states.push(xs.clone());
                xs
            }
            None => xs,
        };
        Ok((xs, output_states))
    }
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnDownBlock2DConfig {
    pub downblock: DownBlock2DConfig,
    pub attn_num_head_channels: usize,
    pub cross_attention_dim: usize,
    pub sliced_attention_size: Option<usize>,
    pub use_linear_projection: bool,
    pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnDownBlock2DConfig {
    fn default() -> Self {
        Self {
            downblock: Default::default(),
            attn_num_head_channels: 1,
            cross_attention_dim: 1280,
            sliced_attention_size: None,
            use_linear_projection: false,
            transformer_layers_per_block: 1,
        }
    }
}
#[derive(Debug)]
pub struct CrossAttnDownBlock2D {
    downblock: DownBlock2D,
    attentions: Vec<SpatialTransformer>,
    span: tracing::Span,
    pub config: CrossAttnDownBlock2DConfig,
}
impl CrossAttnDownBlock2D {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        out_channels: usize,
        temb_channels: Option<usize>,
        use_flash_attn: bool,
        config: CrossAttnDownBlock2DConfig,
    ) -> Result<Self> {
        let downblock = DownBlock2D::new(
            vs.clone(),
            in_channels,
            out_channels,
            temb_channels,
            config.downblock,
        )?;
        let n_heads = config.attn_num_head_channels;
        let cfg = SpatialTransformerConfig {
            depth: config.transformer_layers_per_block,
            context_dim: Some(config.cross_attention_dim),
            num_groups: config.downblock.resnet_groups,
            sliced_attention_size: config.sliced_attention_size,
            use_linear_projection: config.use_linear_projection,
        };
        let vs_attn = vs.pp("attentions");
        let attentions = (0..config.downblock.num_layers)
            .map(|i| {
                SpatialTransformer::new(
                    vs_attn.pp(&i.to_string()),
                    out_channels,
                    n_heads,
                    out_channels / n_heads,
                    use_flash_attn,
                    cfg,
                )
            })
            .collect::<Result<Vec<_>>>()?;
        let span = tracing::span!(tracing::Level::TRACE, "xa-down2d");
        Ok(Self {
            downblock,
            attentions,
            span,
            config,
        })
    }
    pub fn forward(
        &self,
        xs: &Tensor,
        temb: Option<&Tensor>,
        encoder_hidden_states: Option<&Tensor>,
    ) -> Result<(Tensor, Vec<Tensor>)> {
        let _enter = self.span.enter();
        let mut output_states = vec![];
        let mut xs = xs.clone();
        for (resnet, attn) in self.downblock.resnets.iter().zip(self.attentions.iter()) {
            xs = resnet.forward(&xs, temb)?;
            xs = attn.forward(&xs, encoder_hidden_states)?;
            output_states.push(xs.clone());
        }
        let xs = match &self.downblock.downsampler {
            Some(downsampler) => {
                let xs = downsampler.forward(&xs)?;
                output_states.push(xs.clone());
                xs
            }
            None => xs,
        };
        Ok((xs, output_states))
    }
}
#[derive(Debug, Clone, Copy)]
pub struct UpBlock2DConfig {
    pub num_layers: usize,
    pub resnet_eps: f64,
    pub resnet_groups: usize,
    pub output_scale_factor: f64,
    pub add_upsample: bool,
}
impl Default for UpBlock2DConfig {
    fn default() -> Self {
        Self {
            num_layers: 1,
            resnet_eps: 1e-6,
            resnet_groups: 32,
            output_scale_factor: 1.,
            add_upsample: true,
        }
    }
}
#[derive(Debug)]
pub struct UpBlock2D {
    pub resnets: Vec<ResnetBlock2D>,
    upsampler: Option<Upsample2D>,
    span: tracing::Span,
    pub config: UpBlock2DConfig,
}
impl UpBlock2D {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        prev_output_channels: usize,
        out_channels: usize,
        temb_channels: Option<usize>,
        config: UpBlock2DConfig,
    ) -> Result<Self> {
        let vs_resnets = vs.pp("resnets");
        let resnet_cfg = ResnetBlock2DConfig {
            out_channels: Some(out_channels),
            temb_channels,
            eps: config.resnet_eps,
            output_scale_factor: config.output_scale_factor,
            ..Default::default()
        };
        let resnets = (0..config.num_layers)
            .map(|i| {
                let res_skip_channels = if i == config.num_layers - 1 {
                    in_channels
                } else {
                    out_channels
                };
                let resnet_in_channels = if i == 0 {
                    prev_output_channels
                } else {
                    out_channels
                };
                let in_channels = resnet_in_channels + res_skip_channels;
                ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
            })
            .collect::<Result<Vec<_>>>()?;
        let upsampler = if config.add_upsample {
            let upsampler =
                Upsample2D::new(vs.pp("upsamplers").pp("0"), out_channels, out_channels)?;
            Some(upsampler)
        } else {
            None
        };
        let span = tracing::span!(tracing::Level::TRACE, "up2d");
        Ok(Self {
            resnets,
            upsampler,
            span,
            config,
        })
    }
    pub fn forward(
        &self,
        xs: &Tensor,
        res_xs: &[Tensor],
        temb: Option<&Tensor>,
        upsample_size: Option<(usize, usize)>,
    ) -> Result<Tensor> {
        let _enter = self.span.enter();
        let mut xs = xs.clone();
        for (index, resnet) in self.resnets.iter().enumerate() {
            xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
            xs = xs.contiguous()?;
            xs = resnet.forward(&xs, temb)?;
        }
        match &self.upsampler {
            Some(upsampler) => upsampler.forward(&xs, upsample_size),
            None => Ok(xs),
        }
    }
}
#[derive(Debug, Clone, Copy)]
pub struct CrossAttnUpBlock2DConfig {
    pub upblock: UpBlock2DConfig,
    pub attn_num_head_channels: usize,
    pub cross_attention_dim: usize,
    pub sliced_attention_size: Option<usize>,
    pub use_linear_projection: bool,
    pub transformer_layers_per_block: usize,
}
impl Default for CrossAttnUpBlock2DConfig {
    fn default() -> Self {
        Self {
            upblock: Default::default(),
            attn_num_head_channels: 1,
            cross_attention_dim: 1280,
            sliced_attention_size: None,
            use_linear_projection: false,
            transformer_layers_per_block: 1,
        }
    }
}
#[derive(Debug)]
pub struct CrossAttnUpBlock2D {
    pub upblock: UpBlock2D,
    pub attentions: Vec<SpatialTransformer>,
    span: tracing::Span,
    pub config: CrossAttnUpBlock2DConfig,
}
impl CrossAttnUpBlock2D {
    pub fn new(
        vs: nn::VarBuilder,
        in_channels: usize,
        prev_output_channels: usize,
        out_channels: usize,
        temb_channels: Option<usize>,
        use_flash_attn: bool,
        config: CrossAttnUpBlock2DConfig,
    ) -> Result<Self> {
        let upblock = UpBlock2D::new(
            vs.clone(),
            in_channels,
            prev_output_channels,
            out_channels,
            temb_channels,
            config.upblock,
        )?;
        let n_heads = config.attn_num_head_channels;
        let cfg = SpatialTransformerConfig {
            depth: config.transformer_layers_per_block,
            context_dim: Some(config.cross_attention_dim),
            num_groups: config.upblock.resnet_groups,
            sliced_attention_size: config.sliced_attention_size,
            use_linear_projection: config.use_linear_projection,
        };
        let vs_attn = vs.pp("attentions");
        let attentions = (0..config.upblock.num_layers)
            .map(|i| {
                SpatialTransformer::new(
                    vs_attn.pp(&i.to_string()),
                    out_channels,
                    n_heads,
                    out_channels / n_heads,
                    use_flash_attn,
                    cfg,
                )
            })
            .collect::<Result<Vec<_>>>()?;
        let span = tracing::span!(tracing::Level::TRACE, "xa-up2d");
        Ok(Self {
            upblock,
            attentions,
            span,
            config,
        })
    }
    pub fn forward(
        &self,
        xs: &Tensor,
        res_xs: &[Tensor],
        temb: Option<&Tensor>,
        upsample_size: Option<(usize, usize)>,
        encoder_hidden_states: Option<&Tensor>,
    ) -> Result<Tensor> {
        let _enter = self.span.enter();
        let mut xs = xs.clone();
        for (index, resnet) in self.upblock.resnets.iter().enumerate() {
            xs = Tensor::cat(&[&xs, &res_xs[res_xs.len() - index - 1]], 1)?;
            xs = xs.contiguous()?;
            xs = resnet.forward(&xs, temb)?;
            xs = self.attentions[index].forward(&xs, encoder_hidden_states)?;
        }
        match &self.upblock.upsampler {
            Some(upsampler) => upsampler.forward(&xs, upsample_size),
            None => Ok(xs),
        }
    }
}