1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
//! This module provides the [`ComputerVisionSvc`] struct and its associated methods for image processing.
//!
//! The primary functionality includes handling single and batch image processing requests using gRPC.
//! The [`ComputerVisionSvc`] utilizes an [`ImageProcessor`] to perform the actual processing of images
//! and a semaphore to limit the number of concurrent requests for efficient resource management.
use std::sync::Arc;
use tokio::task::{self, JoinError};
use tokio::sync::{mpsc, Semaphore, OwnedSemaphorePermit};
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status, Streaming};
use candle_core::{Device, Result as CandleResult};
use crate::image_captioning::ImageProcessor;
use crate::image_captioning::model_loader::Models;
use crate::proto::{ImgProcRequest, ImgProcResponse, ModelType};
use crate::proto::computer_vision_server::ComputerVision;
/// Maximum number of concurrent requests that can be processed.
const MAX_CONCURRENT_REQUESTS: usize = 16;
/// Type alias for a result that returns a gRPC [`Response`] or a [`Status`].
type ResponseResult<T> = Result<Response<T>, Status>;
/// The [`ComputerVisionSvc`] struct provides methods for processing images.
/// It holds an [`ImageProcessor`] instance and a semaphore for limiting concurrent requests.
pub struct ComputerVisionSvc {
processor: Arc<ImageProcessor>,
semaphore: Arc<Semaphore>,
}
impl ComputerVisionSvc {
/// Creates a new instance of [`ComputerVisionSvc`].
///
/// This method initializes the image processor and the semaphore for controlling
/// the number of concurrent requests.
///
/// # Arguments
///
/// * `models` - A reference to the [`Models`] struct containing the model configurations.
/// * `device` - The device on which the models will be loaded.
///
/// # Returns
///
/// A [`CandleResult`] containing the new [`ComputerVisionSvc`] instance or an error if
/// initialization fails.
pub fn new(models: &Models, device: Device) -> CandleResult<Self> {
Ok(Self {
processor: Arc::new(ImageProcessor::new(models, device)?),
semaphore: Arc::new(Semaphore::new(MAX_CONCURRENT_REQUESTS)),
})
}
/// Validates an [`ImgProcRequest`] to ensure it is well-formed.
///
/// This method checks if the request's image field is not empty and if the model type is valid.
///
/// # Arguments
///
/// * `request` - A reference to the [`ImgProcRequest`] to be validated.
///
/// # Returns
///
/// An `Ok(())` if the request is valid, otherwise an `Err(Status)` describing the problem.
///
/// # Errors
///
/// Returns a [`Status::invalid_argument`] if the image is empty or the model type is invalid.
fn validate_request(&self, request: &ImgProcRequest) -> Result<(), Status> {
if request.image.is_empty() {
return Err(Status::invalid_argument("Empty vector of bytes"));
}
ModelType::try_from(request.model)
.map_err(|_| Status::invalid_argument("Invalid model type"))?;
Ok(())
}
}
#[tonic::async_trait]
impl ComputerVision for ComputerVisionSvc {
/// The stream type for the `process_image_batch` method.
type ProcessImageBatchStream = ReceiverStream<Result<ImgProcResponse, Status>>;
/// Processes a single image and returns a description.
///
/// This method handles the processing of a single image request by validating the request,
/// acquiring a semaphore permit to limit concurrency, and then spawning a blocking task to
/// perform the actual image processing. The result is then sent back as a gRPC response.
///
/// # Arguments
///
/// * `request` - A gRPC [`Request`] containing the [`ImgProcRequest`].
///
/// # Returns
///
/// A [`ResponseResult`] containing an [`ImgProcResponse`] with the image description or a gRPC
/// `Status` on error.
///
/// # Errors
///
/// Returns a [`Status::invalid_argument`] if the request is invalid, [`Status::resource_exhausted`]
/// if too many concurrent requests are being processed, or [`Status::internal`] if an error occurs
/// during processing.
async fn process_image(&self, request: Request<ImgProcRequest>) -> ResponseResult<ImgProcResponse> {
tracing::info!(peer_addr = ?request.remote_addr(), "ProcessImage Invoked");
self.validate_request(request.get_ref())?;
let ImgProcRequest { model, image } = request.into_inner();
// Safely unwrap as validation ensures validity
let model = ModelType::try_from(model).unwrap();
let processor: Arc<ImageProcessor> = Arc::clone(&self.processor);
let semaphore: Arc<Semaphore> = Arc::clone(&self.semaphore);
let _permit: OwnedSemaphorePermit = semaphore
.acquire_owned()
.await
.map_err(|_| Status::resource_exhausted("Too many concurrent requests"))?;
let process_result: Result<CandleResult<String>, JoinError> =
task::spawn_blocking(move || processor.process_image(model, &image)).await;
drop(_permit);
match process_result {
Ok(Ok(description)) => {
let response = ImgProcResponse { description };
Ok(Response::new(response))
}
Ok(Err(e)) => {
tracing::error!("Error processing image: {:?}", e);
Err(Status::internal(format!("Error processing image: {}", e)))
}
Err(e) => {
tracing::error!("Error executing blocking task: {:?}", e);
Err(Status::internal(format!("Error executing blocking task: {}", e)))
}
}
}
/// Processes a stream of image requests and returns a stream of responses.
///
/// This method handles the processing of a batch of image requests received as a stream.
/// It validates each request, acquires a semaphore permit, and spawns a blocking task for each
/// image processing operation. The responses are sent back as a stream of [`ImgProcResponse`].
///
/// # Arguments
///
/// * `request` - A gRPC [`Request`] containing a [`Streaming<ImgProcRequest>`].
///
/// # Returns
///
/// A [`ResponseResult`] containing a stream of [`ImgProcResponse`] or a gRPC [`Status`] on error.
///
/// # Errors
///
/// Returns a [`Status::resource_exhausted`] if too many concurrent requests are being processed,
/// or [`Status::internal`] if an error occurs during processing.
async fn process_image_batch(&self, request: Request<Streaming<ImgProcRequest>>) -> ResponseResult<Self::ProcessImageBatchStream> {
tracing::info!(peer_addr = ?request.remote_addr(), "ProcessImageBatch Invoked");
let mut stream: Streaming<ImgProcRequest> = request.into_inner();
let (tx, rx): (mpsc::Sender<_>, mpsc::Receiver<_>) = mpsc::channel(128);
while let Some(request) = stream.message().await? {
let tx: mpsc::Sender<_> = tx.clone();
let semaphore: Arc<Semaphore> = Arc::clone(&self.semaphore);
let processor: Arc<ImageProcessor> = Arc::clone(&self.processor);
let _permit: OwnedSemaphorePermit = semaphore.acquire_owned().await
.map_err(|_| Status::resource_exhausted("Too many concurrent requests"))?;
tokio::spawn(async move {
// TODO: add request validation
let ImgProcRequest { model, image } = request;
let model = ModelType::try_from(model).unwrap();
let process_result: Result<CandleResult<String>, JoinError> =
task::spawn_blocking(move || processor.process_image(model, &image)).await;
let response: Result<ImgProcResponse, Status> = match process_result {
Ok(Ok(description)) => {
let response = ImgProcResponse { description };
Ok(response)
}
Ok(Err(e)) => {
tracing::error!("Error processing image: {:?}", e);
Err(Status::internal(format!("Error processing image: {}", e)))
}
Err(e) => {
tracing::error!("Error executing blocking task: {:?}", e);
Err(Status::internal(format!("Error executing blocking task: {}", e)))
}
};
if let Err(e) = tx.send(response).await {
tracing::error!("Error sending response: {:?}", e);
}
drop(_permit);
});
}
Ok(Response::new(ReceiverStream::new(rx)))
}
}