diff --git a/src/main.rs b/src/main.rs index bfc80ff..2a16d0e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use error::AppError; use exif::Exif; use image::{imageops::FilterType, DynamicImage}; use serde::{Serialize, Deserialize}; -use tokio::task; +use tokio::{task, sync::Semaphore}; use tower_http::{trace::{self, TraceLayer}, compression::CompressionLayer}; use tracing::Level; use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; @@ -34,6 +34,7 @@ type SecretKey = [u8; 64]; type ImageListCache = Arc>>; type ImageCache = Arc, Vec)>>>; type Config = Arc>; +type CpuTaskLimit = Arc; #[derive(Clone, extract::FromRef)] struct ApplicationState { @@ -43,6 +44,7 @@ struct ApplicationState { image_dir: ImageDir, image_list_cache: ImageListCache, secret_key: SecretKey, + cpu_task_limit: CpuTaskLimit, } #[tokio::main] @@ -102,6 +104,7 @@ async fn main() { image_dir, image_list_cache, secret_key, + cpu_task_limit: Arc::new(Semaphore::new(4)), }); let addr = SocketAddr::from(([0, 0, 0, 0], 3000)); @@ -265,6 +268,7 @@ async fn converted_image( State(image_dir): State, State(secret_key): State, session: ReadableSession, + State(cpu_task_limit): State, ) -> Result { session.get::<()>("logged_in") .ok_or(anyhow!("Trying to load image while not logged in!")) @@ -308,6 +312,7 @@ async fn converted_image( image_buffer } None => { + let _cpu_task_permit = cpu_task_limit.clone().acquire_owned().await?; let image_buffer = task::spawn_blocking(move || { convert_image(&image_path) .with_context(|| format!("Could not convert image {:?}", image_path))