diff --git a/src/client_server.rs b/src/client_server.rs index a3f47608..eede5fd2 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -2977,11 +2977,11 @@ pub fn send_event_to_device_route( } #[get("/_matrix/media/r0/config")] -pub fn get_media_config_route() -> ConduitResult { - Ok(get_media_config::Response { - upload_size: (20_u32 * 1024 * 1024).into(), // 20 MB - } - .into()) +pub fn get_media_config_route( + db: State<'_, Database>, +) -> ConduitResult { + let upload_size = db.globals.max_request_size().into(); + Ok(get_media_config::Response { upload_size }.into()) } #[post("/_matrix/media/r0/upload", data = "")] diff --git a/src/database/globals.rs b/src/database/globals.rs index 3a257a54..5db28069 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -1,7 +1,7 @@ -use std::convert::TryInto; - use crate::{utils, Error, Result}; use ruma::ServerName; +use std::convert::TryInto; + pub const COUNTER: &str = "c"; pub struct Globals { @@ -9,6 +9,7 @@ pub struct Globals { keypair: ruma::signatures::Ed25519KeyPair, reqwest_client: reqwest::Client, server_name: Box, + max_request_size: u32, registration_disabled: bool, encryption_disabled: bool, } @@ -32,7 +33,12 @@ impl Globals { .unwrap_or("localhost") .to_string() .try_into() - .map_err(|_| Error::BadConfig("Invalid server name found."))?, + .map_err(|_| Error::BadConfig("Invalid server_name."))?, + max_request_size: config + .get_int("max_request_size") + .unwrap_or(20 * 1024 * 1024) // Default to 20 MB + .try_into() + .map_err(|_| Error::BadConfig("Invalid max_request_size."))?, registration_disabled: config.get_bool("registration_disabled").unwrap_or(false), encryption_disabled: config.get_bool("encryption_disabled").unwrap_or(false), }) @@ -69,6 +75,10 @@ impl Globals { self.server_name.as_ref() } + pub fn max_request_size(&self) -> u32 { + self.max_request_size + } + pub fn registration_disabled(&self) -> bool { self.registration_disabled } diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 15e50ba3..5b380b37 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -11,8 +11,6 @@ use ruma::{api::Endpoint, DeviceId, UserId}; use std::{convert::TryInto, io::Cursor, ops::Deref}; use tokio::io::AsyncReadExt; -const MESSAGE_LIMIT: u64 = 20 * 1024 * 1024; // 20 MB - /// This struct converts rocket requests into ruma structs by converting them into http requests /// first. pub struct Ruma { @@ -40,13 +38,12 @@ impl<'a, T: Endpoint> FromTransformedData<'a> for Ruma { ) -> FromDataFuture<'a, Self, Self::Error> { Box::pin(async move { let data = rocket::try_outcome!(outcome.owned()); + let db = request + .guard::>() + .await + .expect("database was loaded"); let (user_id, device_id) = if T::METADATA.requires_authentication { - let db = request - .guard::>() - .await - .expect("database was loaded"); - // Get token from header or query value let token = match request .headers() @@ -76,7 +73,8 @@ impl<'a, T: Endpoint> FromTransformedData<'a> for Ruma { http_request = http_request.header(header.name.as_str(), &*header.value); } - let mut handle = data.open().take(MESSAGE_LIMIT); + let limit = db.globals.max_request_size(); + let mut handle = data.open().take(limit.into()); let mut body = Vec::new(); handle.read_to_end(&mut body).await.unwrap();