1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
//! This module provides functionality for loading and processing models used for image captioning.
//! It supports different model variants including BLIP and quantized BLIP models.
#![allow(unused)]
pub mod model_loader;
pub mod token_output_stream;
pub mod utils;
use std::collections::HashMap;
use tokenizers::Tokenizer;
use image::{ImageBuffer, Rgb};
use candle_core::{Result, Tensor, DType, Device, Error, Module};
use candle_nn::var_builder::{VarBuilder, VarBuilderArgs, SimpleBackend};
use candle_transformers::models::{blip, quantized_blip};
use candle_transformers::generation::{Sampling, LogitsProcessor};
use crate::proto::ModelType;
use crate::image_captioning::model_loader::{Models, Model};
/// The separator token ID used for ending generated sequences.
const SEP_TOKEN_ID: u32 = 102;
/// Represents different variants of image captioning models.
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum ModelVariant {
Blip(blip::BlipForConditionalGeneration),
QuantizedBlip(quantized_blip::BlipForConditionalGeneration),
}
impl Module for ModelVariant {
/// Performs a forward pass for the vision model.
///
/// This function takes an input tensor, passes it through the vision model, and returns the
/// resulting tensor.
///
/// # Arguments
///
/// * `xs` - A reference to the input tensor to be processed by the vision model.
///
/// # Returns
///
/// A [`Result`] containing the output tensor if the forward pass is successful, or an error if
/// the forward pass fails.
///
/// # Errors
///
/// Returns an error if the vision model's forward pass encounters any issues.
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Blip(m) => m.vision_model().forward(xs),
Self::QuantizedBlip(m) => m.vision_model().forward(xs),
}
}
}
impl ModelVariant {
/// Performs a forward pass for the text decoder model.
///
/// This function takes an input tensor and an image embeddings tensor, passes them through the
/// text decoder model, and returns the resulting tensor.
///
/// # Arguments
///
/// * `xs` - A reference to the input tensor for the text decoder.
/// * `img_xs` - A reference to the input tensor containing the image embeddings.
///
/// # Returns
///
/// A [`Result`] containing the output tensor if the forward pass is successful, or an error if
/// the forward pass fails.
///
/// # Errors
///
/// Returns an error if the text decoder model's forward pass encounters any issues.
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
match self {
Self::Blip(m) => m.text_decoder().forward(xs, img_xs),
Self::QuantizedBlip(m) => m.text_decoder().forward(xs, img_xs),
}
}
/// Resets the key-value cache of the model.
///
/// This function resets any cached key-value pairs in the model to ensure that new predictions
/// do not rely on previous state.
fn reset_kv_cache(&mut self) {
match self {
Self::Blip(m) => m.reset_kv_cache(),
Self::QuantizedBlip(m) => m.reset_kv_cache(),
}
}
}
/// Struct for processing images and generating captions.
#[derive(Clone)]
pub struct ImageProcessor {
models: HashMap<ModelType, ModelVariant>,
device: Device,
tokenizer: Tokenizer,
sampling: Sampling,
}
impl ImageProcessor {
/// Creates a new instance of [`ImageProcessor`].
///
/// This function initializes the [`ImageProcessor`] with the provided models and device. It loads
/// the BLIP and quantized BLIP models, sets up the tokenizer, and prepares the processor for
/// image captioning tasks.
///
/// # Arguments
///
/// * `models` - A reference to a `Models` struct containing model configurations.
/// * `device` - The device on which the models will be loaded (e.g., CPU or GPU).
///
/// # Returns
///
/// A [`Result`] containing the new [`ImageProcessor`] instance or an error if initialization fails.
///
/// # Errors
///
/// Returns an error if any of the required models cannot be found or initialized.
pub fn new(models: &Models, device: Device) -> Result<Self> {
let blip_cfg: &Model = models
.get("Salesforce/blip-image-captioning-large")
.ok_or_else(|| Error::Msg("BLIP Model not found".into()))?;
let blip_quantized_cfg: &Model = models
.get("lmz/candle-blip")
.ok_or_else(|| Error::Msg("Quantized BLIP Model not found".into()))?;
let config = blip::Config::image_captioning_large();
let mut model_map: HashMap<ModelType, ModelVariant> = HashMap::new();
let vb: VarBuilderArgs<Box<dyn SimpleBackend>> = unsafe {
VarBuilder::from_mmaped_safetensors(&[blip_cfg.model_path()], DType::F32, &device)?
};
model_map.insert(
ModelType::Blip,
ModelVariant::Blip(blip::BlipForConditionalGeneration::new(&config, vb)?),
);
let vb = quantized_blip::VarBuilder::from_gguf(blip_quantized_cfg.model_path(), &device)?;
model_map.insert(
ModelType::BlipQuantized,
ModelVariant::QuantizedBlip(quantized_blip::BlipForConditionalGeneration::new(&config, vb)?),
);
let tokenizer = Tokenizer::from_file(blip_cfg.tokenizer_path()).unwrap();
Ok(Self {
models: model_map,
device,
tokenizer,
sampling: Sampling::ArgMax,
})
}
/// Processes an image and generates a caption.
///
/// This function processes the input image using the specified model and generates a textual
/// description of the image. It involves preprocessing the image, converting it into a tensor,
/// passing it through the model to get image embeddings, and then generating text based on
/// these embeddings.
///
/// # Arguments
///
/// * `model` - The type of model to use for processing the image.
/// * `image` - A byte slice containing the image data.
///
/// # Returns
///
/// A [`Result`] containing the generated caption as a [`String`] or an error if processing fails.
///
/// # Errors
///
/// Returns an error if image processing or caption generation fails.
pub fn process_image(&self, model: ModelType, image: &[u8]) -> Result<String> {
let model_var: &ModelVariant = self.models.get(&model).unwrap(); // TODO: Handle error
let image: ImageBuffer<Rgb<u8>, Vec<u8>> = utils::process_image(image).map_err(Error::wrap)?;
let tensor: Tensor = utils::create_tensor(&image.into_raw(), &Device::Cpu)?.to_device(&self.device)?;
tracing::debug!("Image tensor: {:?}", tensor);
let image_embeddings: Tensor = tensor.unsqueeze(0)?.apply(model_var)?;
self.generate_text(model, &image_embeddings)
}
/// Generates text from image embeddings.
///
/// This function generates a caption by running the image embeddings through the text decoder
/// model and using a logits processor to sample tokens until the end of sequence token is encountered.
/// It uses a sampling strategy (e.g., argmax) to decide the next token at each step.
///
/// # Arguments
///
/// * `model` - The type of model to use for generating text.
/// * `image_embeds` - A reference to the tensor containing image embeddings.
///
/// # Returns
///
/// A [`Result`] containing the generated text as a [`String`] or an error if generation fails.
///
/// # Errors
///
/// Returns an error if text generation fails.
fn generate_text(&self, model: ModelType, image_embeds: &Tensor) -> Result<String> {
let mut model: ModelVariant = self.models.get(&model)
.unwrap()
.clone();
let mut logits_processor: LogitsProcessor = LogitsProcessor::from_sampling(1337, self.sampling.clone());
let mut token_ids: Vec<u32> = vec![30522];
for index in 0..1000 {
let context_size: usize = if index > 0 { 1 } else { token_ids.len() };
let start_pos: usize = token_ids.len().saturating_sub(context_size);
let input_ids: Tensor = Tensor::new(&token_ids[start_pos..], &self.device)?.unsqueeze(0)?;
let logits: Tensor = model.text_decoder_forward(&input_ids, image_embeds)?.squeeze(0)?;
let logits: Tensor = logits.get(logits.dim(0)? - 1)?;
let token: u32 = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
}
self.tokenizer.decode(&token_ids, true).map_err(Error::Wrapped)
}
}