diff --git a/Cargo.lock b/Cargo.lock index 4f58ef36..6ed4ee7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -144,43 +144,12 @@ dependencies = [ "serde", ] -[[package]] -name = "bindgen" -version = "0.59.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "453c49e5950bb0eb63bb3df640e31618846c89d5b7faa54040d76e98e0134375" -dependencies = [ - "bitflags", - "cexpr", - "clang-sys", - "lazy_static", - "lazycell", - "peeking_take_while", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", -] - [[package]] name = "bitflags" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" -[[package]] -name = "bitvec" -version = "0.19.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8942c8d352ae1838c9dda0b0ca2ab657696ef2232a20147cf1b30ae1a9cb4321" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] - [[package]] name = "blake2b_simd" version = "0.5.11" @@ -234,15 +203,6 @@ dependencies = [ "jobserver", ] -[[package]] -name = "cexpr" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db507a7679252d2276ed0dd8113c6875ec56d3089f9225b2b42c30cc1f8e5c89" -dependencies = [ - "nom", -] - [[package]] name = "cfg-if" version = "0.1.10" @@ -268,17 +228,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "clang-sys" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "853eda514c284c2287f4bf20ae614f8781f40a81d32ecda6e91449304dfe077c" -dependencies = [ - "glob", - "libc", - "libloading", -] - [[package]] name = "color_quant" version = "1.1.0" @@ -308,7 +257,6 @@ dependencies = [ "reqwest", "ring", "rocket", - "rocksdb", "ruma", "rusqlite", "rust-argon2", @@ -725,12 +673,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "funty" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed34cd105917e91daa4da6b3728c47b068749d6a62c59811f06ed2ac71d9da7" - [[package]] name = "futures" version = "0.3.16" @@ -1243,40 +1185,12 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "libc" version = "0.2.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" -[[package]] -name = "libloading" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f84d96438c15fcd6c3f244c8fce01d1e2b9c6b5623e9c711dc9286d8fc92d6a" -dependencies = [ - "cfg-if 1.0.0", - "winapi", -] - -[[package]] -name = "librocksdb-sys" -version = "6.20.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c309a9d2470844aceb9a4a098cf5286154d20596868b75a6b36357d2bb9ca25d" -dependencies = [ - "bindgen", - "cc", - "glob", - "libc", -] - [[package]] name = "libsqlite3-sys" version = "0.22.2" @@ -1445,18 +1359,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "nom" -version = "6.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7413f999671bd4745a7b624bd370a569fb6bc574b23c83a3c5ed2e453f3d5e2" -dependencies = [ - "bitvec", - "funty", - "memchr", - "version_check", -] - [[package]] name = "ntapi" version = "0.3.6" @@ -1649,12 +1551,6 @@ dependencies = [ "syn", ] -[[package]] -name = "peeking_take_while" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" - [[package]] name = "pem" version = "0.8.3" @@ -1817,12 +1713,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "radium" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "941ba9d78d8e2f7ce474c015eea4d9c6d25b6a3327f9832ee29a4de27f91bbb8" - [[package]] name = "rand" version = "0.7.3" @@ -2122,16 +2012,6 @@ dependencies = [ "uncased", ] -[[package]] -name = "rocksdb" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c749134fda8bfc90d0de643d59bfc841dcb3ac8a1062e12b6754bd60235c48b3" -dependencies = [ - "libc", - "librocksdb-sys", -] - [[package]] name = "ruma" version = "0.2.0" @@ -2415,12 +2295,6 @@ dependencies = [ "crossbeam-utils 0.8.5", ] -[[package]] -name = "rustc-hash" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" - [[package]] name = "rustc_version" version = "0.2.3" @@ -2647,12 +2521,6 @@ dependencies = [ "lazy_static", ] -[[package]] -name = "shlex" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42a568c8f2cd051a4d283bd6eb0343ac214c1b0f1ac19f93e1175b2dee38c73d" - [[package]] name = "signal-hook-registry" version = "1.4.0" @@ -2864,12 +2732,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - [[package]] name = "tempfile" version = "3.2.0" @@ -3540,12 +3402,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "wyz" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85e60b0d1b5f99db2556934e21937020776a5d31520bf169e851ac44e6420214" - [[package]] name = "yaml-rust" version = "0.4.5" diff --git a/Cargo.toml b/Cargo.toml index 19ce6b10..3d18bfb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,6 @@ ruma = { git = "https://github.com/timokoesters/ruma", rev = "a2d93500e1dbc87e70 tokio = "1.8.2" # Used for storing data permanently sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } -rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } # Used for the http request / response body type for Ruma endpoints used with reqwest @@ -84,7 +83,6 @@ heed = { git = "https://github.com/timokoesters/heed.git", rev = "f6f825da7fb2c7 [features] default = ["conduit_bin", "backend_sqlite"] backend_sled = ["sled"] -backend_rocksdb = ["rocksdb"] backend_sqlite = ["sqlite"] backend_heed = ["heed", "crossbeam"] sqlite = ["rusqlite", "parking_lot", "crossbeam", "tokio/signal"] diff --git a/src/client_server/account.rs b/src/client_server/account.rs index c00cc871..87e3731f 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -243,15 +243,15 @@ pub async fn register_route( let room_id = RoomId::new(db.globals.server_name()); - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; let mut content = ruma::events::room::create::CreateEventContent::new(conduit_user.clone()); content.federate = true; @@ -270,7 +270,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 2. Make conduit bot join @@ -293,7 +293,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 3. Power levels @@ -318,7 +318,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 4.1 Join Rules @@ -336,7 +336,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 4.2 History Visibility @@ -356,7 +356,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 4.3 Guest Access @@ -374,7 +374,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 6. Events implied by name and topic @@ -393,7 +393,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; db.rooms.build_and_append_pdu( @@ -410,7 +410,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // Room alias @@ -433,7 +433,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; @@ -458,7 +458,7 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; db.rooms.build_and_append_pdu( PduBuilder { @@ -479,7 +479,7 @@ pub async fn register_route( &user_id, &room_id, &db, - &mutex_lock, + &state_lock, )?; // Send welcome message @@ -498,13 +498,13 @@ pub async fn register_route( &conduit_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; } info!("{} registered on this server", user_id); - db.flush().await?; + db.flush()?; Ok(register::Response { access_token: Some(token), @@ -580,7 +580,7 @@ pub async fn change_password_route( } } - db.flush().await?; + db.flush()?; Ok(change_password::Response {}.into()) } @@ -656,11 +656,17 @@ pub async fn deactivate_route( } // Leave all joined rooms and reject all invitations - for room_id in db.rooms.rooms_joined(&sender_user).chain( - db.rooms - .rooms_invited(&sender_user) - .map(|t| t.map(|(r, _)| r)), - ) { + let all_rooms = db + .rooms + .rooms_joined(&sender_user) + .chain( + db.rooms + .rooms_invited(&sender_user) + .map(|t| t.map(|(r, _)| r)), + ) + .collect::>(); + + for room_id in all_rooms { let room_id = room_id?; let event = member::MemberEventContent { membership: member::MembershipState::Leave, @@ -671,15 +677,15 @@ pub async fn deactivate_route( blurhash: None, }; - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; db.rooms.build_and_append_pdu( PduBuilder { @@ -692,7 +698,7 @@ pub async fn deactivate_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; } @@ -701,7 +707,7 @@ pub async fn deactivate_route( info!("{} deactivated their account", sender_user); - db.flush().await?; + db.flush()?; Ok(deactivate::Response { id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index f5d9f64b..143e6071 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -31,7 +31,7 @@ pub async fn create_alias_route( db.rooms .set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; - db.flush().await?; + db.flush()?; Ok(create_alias::Response::new().into()) } @@ -47,7 +47,7 @@ pub async fn delete_alias_route( ) -> ConduitResult { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; - db.flush().await?; + db.flush()?; Ok(delete_alias::Response::new().into()) } @@ -85,8 +85,7 @@ pub async fn get_alias_helper( match db.rooms.id_from_alias(&room_alias)? { Some(r) => room_id = Some(r), None => { - let iter = db.appservice.iter_all()?; - for (_id, registration) in iter.filter_map(|r| r.ok()) { + for (_id, registration) in db.appservice.all()? { let aliases = registration .get("namespaces") .and_then(|ns| ns.get("aliases")) diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index 6d540cb1..06f9818d 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -26,7 +26,7 @@ pub async fn create_backup_route( .key_backups .create_backup(&sender_user, &body.algorithm, &db.globals)?; - db.flush().await?; + db.flush()?; Ok(create_backup::Response { version }.into()) } @@ -44,7 +44,7 @@ pub async fn update_backup_route( db.key_backups .update_backup(&sender_user, &body.version, &body.algorithm, &db.globals)?; - db.flush().await?; + db.flush()?; Ok(update_backup::Response {}.into()) } @@ -117,7 +117,7 @@ pub async fn delete_backup_route( db.key_backups.delete_backup(&sender_user, &body.version)?; - db.flush().await?; + db.flush()?; Ok(delete_backup::Response {}.into()) } @@ -147,7 +147,7 @@ pub async fn add_backup_keys_route( } } - db.flush().await?; + db.flush()?; Ok(add_backup_keys::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -179,7 +179,7 @@ pub async fn add_backup_key_sessions_route( )? } - db.flush().await?; + db.flush()?; Ok(add_backup_key_sessions::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -209,7 +209,7 @@ pub async fn add_backup_key_session_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(add_backup_key_session::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -288,7 +288,7 @@ pub async fn delete_backup_keys_route( db.key_backups .delete_all_keys(&sender_user, &body.version)?; - db.flush().await?; + db.flush()?; Ok(delete_backup_keys::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -311,7 +311,7 @@ pub async fn delete_backup_key_sessions_route( db.key_backups .delete_room_keys(&sender_user, &body.version, &body.room_id)?; - db.flush().await?; + db.flush()?; Ok(delete_backup_key_sessions::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), @@ -334,7 +334,7 @@ pub async fn delete_backup_key_session_route( db.key_backups .delete_room_key(&sender_user, &body.version, &body.room_id, &body.session_id)?; - db.flush().await?; + db.flush()?; Ok(delete_backup_key_session::Response { count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), diff --git a/src/client_server/config.rs b/src/client_server/config.rs index b9826bfb..b6927493 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -43,7 +43,7 @@ pub async fn set_global_account_data_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_global_account_data::Response {}.into()) } @@ -78,7 +78,7 @@ pub async fn set_room_account_data_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_room_account_data::Response {}.into()) } @@ -98,7 +98,7 @@ pub async fn get_global_account_data_route( .account_data .get::>(None, sender_user, body.event_type.clone().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - db.flush().await?; + db.flush()?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? @@ -129,7 +129,7 @@ pub async fn get_room_account_data_route( body.event_type.clone().into(), )? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - db.flush().await?; + db.flush()?; let account_data = serde_json::from_str::(event.get()) .map_err(|_| Error::bad_database("Invalid account data event in db."))? diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 085d034f..52104673 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -71,7 +71,7 @@ pub async fn update_device_route( db.users .update_device_metadata(&sender_user, &body.device_id, &device)?; - db.flush().await?; + db.flush()?; Ok(update_device::Response {}.into()) } @@ -123,7 +123,7 @@ pub async fn delete_device_route( db.users.remove_device(&sender_user, &body.device_id)?; - db.flush().await?; + db.flush()?; Ok(delete_device::Response {}.into()) } @@ -177,7 +177,7 @@ pub async fn delete_devices_route( db.users.remove_device(&sender_user, &device_id)? } - db.flush().await?; + db.flush()?; Ok(delete_devices::Response {}.into()) } diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index f1ec4b86..589aacda 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,3 +1,5 @@ +use std::convert::TryInto; + use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::{ @@ -21,7 +23,7 @@ use ruma::{ serde::Raw, ServerName, UInt, }; -use tracing::info; +use tracing::{info, warn}; #[cfg(feature = "conduit_bin")] use rocket::{get, post, put}; @@ -100,7 +102,7 @@ pub async fn set_room_visibility_route( } } - db.flush().await?; + db.flush()?; Ok(set_room_visibility::Response {}.into()) } @@ -234,7 +236,15 @@ pub async fn get_public_rooms_filtered_helper( .name .map(|n| n.to_owned().into())) })?, - num_joined_members: (db.rooms.room_members(&room_id).count() as u32).into(), + num_joined_members: db + .rooms + .room_joined_count(&room_id)? + .unwrap_or_else(|| { + warn!("Room {} has no member count", room_id); + 0 + }) + .try_into() + .expect("user count should not be that big"), topic: db .rooms .room_state_get(&room_id, &EventType::RoomTopic, "")? diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 418e41af..8db7688d 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -64,7 +64,7 @@ pub async fn upload_keys_route( } } - db.flush().await?; + db.flush()?; Ok(upload_keys::Response { one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?, @@ -105,7 +105,7 @@ pub async fn claim_keys_route( ) -> ConduitResult { let response = claim_keys_helper(&body.one_time_keys, &db).await?; - db.flush().await?; + db.flush()?; Ok(response.into()) } @@ -166,7 +166,7 @@ pub async fn upload_signing_keys_route( )?; } - db.flush().await?; + db.flush()?; Ok(upload_signing_keys::Response {}.into()) } @@ -227,7 +227,7 @@ pub async fn upload_signatures_route( } } - db.flush().await?; + db.flush()?; Ok(upload_signatures::Response {}.into()) } diff --git a/src/client_server/media.rs b/src/client_server/media.rs index eaaf9399..2bd189a4 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -52,7 +52,7 @@ pub async fn create_content_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(create_content::Response { content_uri: mxc.try_into().expect("Invalid mxc:// URI"), diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index ea7fdab5..716a615f 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -74,7 +74,7 @@ pub async fn join_room_by_id_route( ) .await; - db.flush().await?; + db.flush()?; ret } @@ -125,7 +125,7 @@ pub async fn join_room_by_id_or_alias_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(join_room_by_id_or_alias::Response { room_id: join_room_response.0.room_id, @@ -146,7 +146,7 @@ pub async fn leave_room_route( db.rooms.leave_room(sender_user, &body.room_id, &db).await?; - db.flush().await?; + db.flush()?; Ok(leave_room::Response::new().into()) } @@ -164,7 +164,7 @@ pub async fn invite_user_route( if let invite_user::IncomingInvitationRecipient::UserId { user_id } = &body.recipient { invite_helper(sender_user, user_id, &body.room_id, &db, false).await?; - db.flush().await?; + db.flush()?; Ok(invite_user::Response {}.into()) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) @@ -203,15 +203,15 @@ pub async fn kick_user_route( event.membership = ruma::events::room::member::MembershipState::Leave; // TODO: reason - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(body.room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; db.rooms.build_and_append_pdu( PduBuilder { @@ -224,12 +224,12 @@ pub async fn kick_user_route( &sender_user, &body.room_id, &db, - &mutex_lock, + &state_lock, )?; - drop(mutex_lock); + drop(state_lock); - db.flush().await?; + db.flush()?; Ok(kick_user::Response::new().into()) } @@ -275,15 +275,15 @@ pub async fn ban_user_route( }, )?; - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(body.room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; db.rooms.build_and_append_pdu( PduBuilder { @@ -296,12 +296,12 @@ pub async fn ban_user_route( &sender_user, &body.room_id, &db, - &mutex_lock, + &state_lock, )?; - drop(mutex_lock); + drop(state_lock); - db.flush().await?; + db.flush()?; Ok(ban_user::Response::new().into()) } @@ -337,15 +337,15 @@ pub async fn unban_user_route( event.membership = ruma::events::room::member::MembershipState::Leave; - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(body.room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; db.rooms.build_and_append_pdu( PduBuilder { @@ -358,12 +358,12 @@ pub async fn unban_user_route( &sender_user, &body.room_id, &db, - &mutex_lock, + &state_lock, )?; - drop(mutex_lock); + drop(state_lock); - db.flush().await?; + db.flush()?; Ok(unban_user::Response::new().into()) } @@ -381,7 +381,7 @@ pub async fn forget_room_route( db.rooms.forget(&body.room_id, &sender_user)?; - db.flush().await?; + db.flush()?; Ok(forget_room::Response::new().into()) } @@ -486,15 +486,15 @@ async fn join_room_by_id_helper( ) -> ConduitResult { let sender_user = sender_user.expect("user is authenticated"); - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; // Ask a remote server if we don't have this room if !db.rooms.exists(&room_id)? && room_id.server_name() != db.globals.server_name() { @@ -706,13 +706,13 @@ async fn join_room_by_id_helper( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; } - drop(mutex_lock); + drop(state_lock); - db.flush().await?; + db.flush()?; Ok(join_room_by_id::Response::new(room_id.clone()).into()) } @@ -788,155 +788,165 @@ pub async fn invite_helper<'a>( db: &Database, is_direct: bool, ) -> Result<()> { - let mutex = Arc::clone( - db.globals - .roomid_mutex - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let mutex_lock = mutex.lock().await; - if user_id.server_name() != db.globals.server_name() { - let prev_events = db - .rooms - .get_pdu_leaves(room_id)? - .into_iter() - .take(20) - .collect::>(); - - let create_event = db - .rooms - .room_state_get(room_id, &EventType::RoomCreate, "")?; - - let create_event_content = create_event - .as_ref() - .map(|create_event| { - serde_json::from_value::>(create_event.content.clone()) - .expect("Raw::from_value always works.") - .deserialize() - .map_err(|_| Error::bad_database("Invalid PowerLevels event in db.")) - }) - .transpose()?; - - let create_prev_event = if prev_events.len() == 1 - && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) - { - create_event - } else { - None - }; - - // If there was no create event yet, assume we are creating a version 6 room right now - let room_version_id = create_event_content - .map_or(RoomVersionId::Version6, |create_event| { - create_event.room_version - }); - let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - - let content = serde_json::to_value(MemberEventContent { - avatar_url: None, - displayname: None, - is_direct: Some(is_direct), - membership: MembershipState::Invite, - third_party_invite: None, - blurhash: None, - }) - .expect("member event is valid value"); - - let state_key = user_id.to_string(); - let kind = EventType::RoomMember; - - let auth_events = - db.rooms - .get_auth_events(room_id, &kind, &sender_user, Some(&state_key), &content)?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = BTreeMap::new(); - - if let Some(prev_pdu) = db.rooms.room_state_get(room_id, &kind, &state_key)? { - unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); - unsigned.insert( - "prev_sender".to_owned(), - serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + let (room_version_id, pdu_json, invite_room_state) = { + let mutex_state = Arc::clone( + db.globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), ); - } + let state_lock = mutex_state.lock().await; - let pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater"), - room_id: room_id.clone(), - sender: sender_user.clone(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind, - content, - state_key: Some(state_key), - prev_events, - depth, - auth_events: auth_events + let prev_events = db + .rooms + .get_pdu_leaves(room_id)? + .into_iter() + .take(20) + .collect::>(); + + let create_event = db + .rooms + .room_state_get(room_id, &EventType::RoomCreate, "")?; + + let create_event_content = create_event + .as_ref() + .map(|create_event| { + serde_json::from_value::>(create_event.content.clone()) + .expect("Raw::from_value always works.") + .deserialize() + .map_err(|_| Error::bad_database("Invalid PowerLevels event in db.")) + }) + .transpose()?; + + let create_prev_event = if prev_events.len() == 1 + && Some(&prev_events[0]) == create_event.as_ref().map(|c| &c.event_id) + { + create_event + } else { + None + }; + + // If there was no create event yet, assume we are creating a version 6 room right now + let room_version_id = create_event_content + .map_or(RoomVersionId::Version6, |create_event| { + create_event.room_version + }); + let room_version = + RoomVersion::new(&room_version_id).expect("room version is supported"); + + let content = serde_json::to_value(MemberEventContent { + avatar_url: None, + displayname: None, + is_direct: Some(is_direct), + membership: MembershipState::Invite, + third_party_invite: None, + blurhash: None, + }) + .expect("member event is valid value"); + + let state_key = user_id.to_string(); + let kind = EventType::RoomMember; + + let auth_events = db.rooms.get_auth_events( + room_id, + &kind, + &sender_user, + Some(&state_key), + &content, + )?; + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events .iter() - .map(|(_, pdu)| pdu.event_id.clone()) - .collect(), - redacts: None, - unsigned, - hashes: ruma::events::pdu::EventHash { - sha256: "aaa".to_owned(), - }, - signatures: BTreeMap::new(), + .filter_map(|event_id| Some(db.rooms.get_pdu(event_id).ok()??.depth)) + .max() + .unwrap_or_else(|| uint!(0)) + + uint!(1); + + let mut unsigned = BTreeMap::new(); + + if let Some(prev_pdu) = db.rooms.room_state_get(room_id, &kind, &state_key)? { + unsigned.insert("prev_content".to_owned(), prev_pdu.content.clone()); + unsigned.insert( + "prev_sender".to_owned(), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + ); + } + + let pdu = PduEvent { + event_id: ruma::event_id!("$thiswillbefilledinlater"), + room_id: room_id.clone(), + sender: sender_user.clone(), + origin_server_ts: utils::millis_since_unix_epoch() + .try_into() + .expect("time is valid"), + kind, + content, + state_key: Some(state_key), + prev_events, + depth, + auth_events: auth_events + .iter() + .map(|(_, pdu)| pdu.event_id.clone()) + .collect(), + redacts: None, + unsigned, + hashes: ruma::events::pdu::EventHash { + sha256: "aaa".to_owned(), + }, + signatures: BTreeMap::new(), + }; + + let auth_check = state_res::auth_check( + &room_version, + &Arc::new(pdu.clone()), + create_prev_event, + &auth_events, + None, // TODO: third_party_invite + ) + .map_err(|e| { + error!("{:?}", e); + Error::bad_database("Auth check failed.") + })?; + + if !auth_check { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Event is not authorized.", + )); + } + + // Hash and sign + let mut pdu_json = + utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); + + pdu_json.remove("event_id"); + + // Add origin because synapse likes that (and it's required in the spec) + pdu_json.insert( + "origin".to_owned(), + to_canonical_value(db.globals.server_name()) + .expect("server name is a valid CanonicalJsonValue"), + ); + + ruma::signatures::hash_and_sign_event( + db.globals.server_name().as_str(), + db.globals.keypair(), + &mut pdu_json, + &room_version_id, + ) + .expect("event is valid, we just created it"); + + let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; + + drop(state_lock); + + (room_version_id, pdu_json, invite_room_state) }; - let auth_check = state_res::auth_check( - &room_version, - &Arc::new(pdu.clone()), - create_prev_event, - &auth_events, - None, // TODO: third_party_invite - ) - .map_err(|e| { - error!("{:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = - utils::to_canonical_object(&pdu).expect("event is valid, we just created it"); - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - to_canonical_value(db.globals.server_name()) - .expect("server name is a valid CanonicalJsonValue"), - ); - - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut pdu_json, - &room_version_id, - ) - .expect("event is valid, we just created it"); - - drop(mutex_lock); - - let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; let response = db .sending .send_federation_request( @@ -1008,6 +1018,16 @@ pub async fn invite_helper<'a>( return Ok(()); } + let mutex_state = Arc::clone( + db.globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + db.rooms.build_and_append_pdu( PduBuilder { event_type: EventType::RoomMember, @@ -1027,8 +1047,10 @@ pub async fn invite_helper<'a>( &sender_user, room_id, &db, - &mutex_lock, + &state_lock, )?; + drop(state_lock); + Ok(()) } diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 3d8218c6..9cb6faab 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -28,15 +28,15 @@ pub async fn send_message_event_route( let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(body.room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; // Check if this is a new transaction id if let Some(response) = @@ -75,7 +75,7 @@ pub async fn send_message_event_route( &sender_user, &body.room_id, &db, - &mutex_lock, + &state_lock, )?; db.transaction_ids.add_txnid( @@ -85,9 +85,9 @@ pub async fn send_message_event_route( event_id.as_bytes(), )?; - drop(mutex_lock); + drop(state_lock); - db.flush().await?; + db.flush()?; Ok(send_message_event::Response::new(event_id).into()) } diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index ca78a88b..7312cb32 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -41,7 +41,7 @@ pub async fn set_presence_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_presence::Response {}.into()) } diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 693254fe..de1babab 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -32,9 +32,10 @@ pub async fn set_displayname_route( .set_displayname(&sender_user, body.displayname.clone())?; // Send a new membership event and presence update into all joined rooms - for (pdu_builder, room_id) in db - .rooms - .rooms_joined(&sender_user) + let all_rooms_joined = db.rooms.rooms_joined(&sender_user).collect::>(); + + for (pdu_builder, room_id) in all_rooms_joined + .into_iter() .filter_map(|r| r.ok()) .map(|room_id| { Ok::<_, Error>(( @@ -72,19 +73,19 @@ pub async fn set_displayname_route( }) .filter_map(|r| r.ok()) { - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; let _ = db.rooms - .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &mutex_lock); + .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &state_lock); // Presence update db.rooms.edus.update_presence( @@ -109,7 +110,7 @@ pub async fn set_displayname_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_display_name::Response {}.into()) } @@ -165,9 +166,10 @@ pub async fn set_avatar_url_route( db.users.set_blurhash(&sender_user, body.blurhash.clone())?; // Send a new membership event and presence update into all joined rooms - for (pdu_builder, room_id) in db - .rooms - .rooms_joined(&sender_user) + let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::>(); + + for (pdu_builder, room_id) in all_joined_rooms + .into_iter() .filter_map(|r| r.ok()) .map(|room_id| { Ok::<_, Error>(( @@ -205,19 +207,19 @@ pub async fn set_avatar_url_route( }) .filter_map(|r| r.ok()) { - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; let _ = db.rooms - .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &mutex_lock); + .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &state_lock); // Presence update db.rooms.edus.update_presence( @@ -242,7 +244,7 @@ pub async fn set_avatar_url_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_avatar_url::Response {}.into()) } diff --git a/src/client_server/push.rs b/src/client_server/push.rs index 867b4525..9489f070 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -192,7 +192,7 @@ pub async fn set_pushrule_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_pushrule::Response {}.into()) } @@ -248,7 +248,7 @@ pub async fn get_pushrule_actions_route( _ => None, }; - db.flush().await?; + db.flush()?; Ok(get_pushrule_actions::Response { actions: actions.unwrap_or_default(), @@ -325,7 +325,7 @@ pub async fn set_pushrule_actions_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_pushrule_actions::Response {}.into()) } @@ -386,7 +386,7 @@ pub async fn get_pushrule_enabled_route( _ => false, }; - db.flush().await?; + db.flush()?; Ok(get_pushrule_enabled::Response { enabled }.into()) } @@ -465,7 +465,7 @@ pub async fn set_pushrule_enabled_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(set_pushrule_enabled::Response {}.into()) } @@ -534,7 +534,7 @@ pub async fn delete_pushrule_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(delete_pushrule::Response {}.into()) } @@ -570,7 +570,7 @@ pub async fn set_pushers_route( db.pusher.set_pusher(sender_user, pusher)?; - db.flush().await?; + db.flush()?; Ok(set_pusher::Response::default().into()) } diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index f5e2924e..85b0bf67 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -75,7 +75,7 @@ pub async fn set_read_marker_route( )?; } - db.flush().await?; + db.flush()?; Ok(set_read_marker::Response {}.into()) } @@ -128,7 +128,7 @@ pub async fn create_receipt_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(create_receipt::Response {}.into()) } diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index 2e4c6519..63bf103b 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -20,15 +20,15 @@ pub async fn redact_event_route( ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(body.room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; let event_id = db.rooms.build_and_append_pdu( PduBuilder { @@ -44,12 +44,12 @@ pub async fn redact_event_route( &sender_user, &body.room_id, &db, - &mutex_lock, + &state_lock, )?; - drop(mutex_lock); + drop(state_lock); - db.flush().await?; + db.flush()?; Ok(redact_event::Response { event_id }.into()) } diff --git a/src/client_server/room.rs b/src/client_server/room.rs index d5188e8b..f73d5445 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -33,15 +33,15 @@ pub async fn create_room_route( let room_id = RoomId::new(db.globals.server_name()); - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; let alias = body .room_alias_name @@ -79,7 +79,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 2. Let the room creator join @@ -102,7 +102,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 3. Power levels @@ -157,7 +157,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 4. Events set by preset @@ -184,7 +184,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 4.2 History Visibility @@ -202,7 +202,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 4.3 Guest Access @@ -228,7 +228,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; // 5. Events listed in initial_state @@ -244,7 +244,7 @@ pub async fn create_room_route( } db.rooms - .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &mutex_lock)?; + .build_and_append_pdu(pdu_builder, &sender_user, &room_id, &db, &state_lock)?; } // 6. Events implied by name and topic @@ -261,7 +261,7 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; } @@ -280,12 +280,12 @@ pub async fn create_room_route( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; } // 7. Events implied by invite (and TODO: invite_3pid) - drop(mutex_lock); + drop(state_lock); for user_id in &body.invite { let _ = invite_helper(sender_user, user_id, &room_id, &db, body.is_direct).await; } @@ -301,7 +301,7 @@ pub async fn create_room_route( info!("{} created a room", sender_user); - db.flush().await?; + db.flush()?; Ok(create_room::Response::new(room_id).into()) } @@ -364,13 +364,12 @@ pub async fn get_room_aliases_route( #[cfg_attr( feature = "conduit_bin", - post("/_matrix/client/r0/rooms/<_room_id>/upgrade", data = "") + post("/_matrix/client/r0/rooms/<_>/upgrade", data = "") )] #[tracing::instrument(skip(db, body))] pub async fn upgrade_room_route( db: DatabaseGuard, body: Ruma>, - _room_id: String, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -387,15 +386,15 @@ pub async fn upgrade_room_route( // Create a replacement room let replacement_room = RoomId::new(db.globals.server_name()); - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(body.room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further // Fail if the sender does not have the required permissions @@ -414,9 +413,21 @@ pub async fn upgrade_room_route( sender_user, &body.room_id, &db, - &mutex_lock, + &state_lock, )?; + // Change lock to replacement room + drop(state_lock); + let mutex_state = Arc::clone( + db.globals + .roomid_mutex_state + .write() + .unwrap() + .entry(replacement_room.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + // Get the old room federations status let federate = serde_json::from_value::>( db.rooms @@ -455,7 +466,7 @@ pub async fn upgrade_room_route( sender_user, &replacement_room, &db, - &mutex_lock, + &state_lock, )?; // Join the new room @@ -478,7 +489,7 @@ pub async fn upgrade_room_route( sender_user, &replacement_room, &db, - &mutex_lock, + &state_lock, )?; // Recommended transferable state events list from the specs @@ -512,7 +523,7 @@ pub async fn upgrade_room_route( sender_user, &replacement_room, &db, - &mutex_lock, + &state_lock, )?; } @@ -556,12 +567,12 @@ pub async fn upgrade_room_route( sender_user, &body.room_id, &db, - &mutex_lock, + &state_lock, )?; - drop(mutex_lock); + drop(state_lock); - db.flush().await?; + db.flush()?; // Return the replacement room id Ok(upgrade_room::Response { replacement_room }.into()) diff --git a/src/client_server/session.rs b/src/client_server/session.rs index f8452e0d..d4d3c033 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -143,7 +143,7 @@ pub async fn login_route( info!("{} logged in", user_id); - db.flush().await?; + db.flush()?; Ok(login::Response { user_id, @@ -175,7 +175,7 @@ pub async fn logout_route( db.users.remove_device(&sender_user, sender_device)?; - db.flush().await?; + db.flush()?; Ok(logout::Response::new().into()) } @@ -204,7 +204,7 @@ pub async fn logout_all_route( db.users.remove_device(&sender_user, &device_id)?; } - db.flush().await?; + db.flush()?; Ok(logout_all::Response::new().into()) } diff --git a/src/client_server/state.rs b/src/client_server/state.rs index e0e5d29a..aa020b50 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -43,7 +43,7 @@ pub async fn send_state_event_for_key_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(send_state_event::Response { event_id }.into()) } @@ -69,7 +69,7 @@ pub async fn send_state_event_for_empty_key_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(send_state_event::Response { event_id }.into()) } @@ -259,15 +259,15 @@ pub async fn send_state_event_for_key_helper( } } - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; let event_id = db.rooms.build_and_append_pdu( PduBuilder { @@ -280,7 +280,7 @@ pub async fn send_state_event_for_key_helper( &sender_user, &room_id, &db, - &mutex_lock, + &state_lock, )?; Ok(event_id) diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 541045ec..937a252e 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -186,20 +186,22 @@ async fn sync_helper( .filter_map(|r| r.ok()), ); - for room_id in db.rooms.rooms_joined(&sender_user) { + let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::>(); + for room_id in all_joined_rooms { let room_id = room_id?; // Get and drop the lock to wait for remaining operations to finish - let mutex = Arc::clone( + // This will make sure the we have all events until next_batch + let mutex_insert = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_insert .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; - drop(mutex_lock); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); let mut non_timeline_pdus = db .rooms @@ -658,20 +660,21 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - for result in db.rooms.rooms_left(&sender_user) { + let all_left_rooms = db.rooms.rooms_left(&sender_user).collect::>(); + for result in all_left_rooms { let (room_id, left_state_events) = result?; // Get and drop the lock to wait for remaining operations to finish - let mutex = Arc::clone( + let mutex_insert = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_insert .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; - drop(mutex_lock); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; @@ -697,20 +700,21 @@ async fn sync_helper( } let mut invited_rooms = BTreeMap::new(); - for result in db.rooms.rooms_invited(&sender_user) { + let all_invited_rooms = db.rooms.rooms_invited(&sender_user).collect::>(); + for result in all_invited_rooms { let (room_id, invite_state_events) = result?; // Get and drop the lock to wait for remaining operations to finish - let mutex = Arc::clone( + let mutex_insert = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_insert .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; - drop(mutex_lock); + let insert_lock = mutex_insert.lock().unwrap(); + drop(insert_lock); let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index 223d122c..5582bcdd 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -40,7 +40,7 @@ pub async fn update_tag_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(create_tag::Response {}.into()) } @@ -74,7 +74,7 @@ pub async fn delete_tag_route( &db.globals, )?; - db.flush().await?; + db.flush()?; Ok(delete_tag::Response {}.into()) } diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index d3f7d255..69147c9a 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -95,7 +95,7 @@ pub async fn send_event_to_device_route( db.transaction_ids .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; - db.flush().await?; + db.flush()?; Ok(send_event_to_device::Response {}.into()) } diff --git a/src/database.rs b/src/database.rs index 5e9e025d..4b7c7fe8 100644 --- a/src/database.rs +++ b/src/database.rs @@ -24,10 +24,11 @@ use rocket::{ request::{FromRequest, Request}, Shutdown, State, }; -use ruma::{DeviceId, ServerName, UserId}; +use ruma::{DeviceId, RoomId, ServerName, UserId}; use serde::{de::IgnoredAny, Deserialize}; use std::{ collections::{BTreeMap, HashMap}, + convert::TryFrom, fs::{self, remove_dir_all}, io::Write, ops::Deref, @@ -45,18 +46,8 @@ pub struct Config { database_path: String, #[serde(default = "default_db_cache_capacity_mb")] db_cache_capacity_mb: f64, - #[serde(default = "default_sqlite_read_pool_size")] - sqlite_read_pool_size: usize, - #[serde(default = "true_fn")] - sqlite_wal_clean_timer: bool, #[serde(default = "default_sqlite_wal_clean_second_interval")] sqlite_wal_clean_second_interval: u32, - #[serde(default = "default_sqlite_wal_clean_second_timeout")] - sqlite_wal_clean_second_timeout: u32, - #[serde(default = "default_sqlite_spillover_reap_fraction")] - sqlite_spillover_reap_fraction: f64, - #[serde(default = "default_sqlite_spillover_reap_interval_secs")] - sqlite_spillover_reap_interval_secs: u32, #[serde(default = "default_max_request_size")] max_request_size: u32, #[serde(default = "default_max_concurrent_requests")] @@ -115,24 +106,8 @@ fn default_db_cache_capacity_mb() -> f64 { 200.0 } -fn default_sqlite_read_pool_size() -> usize { - num_cpus::get().max(1) -} - fn default_sqlite_wal_clean_second_interval() -> u32 { - 60 * 60 -} - -fn default_sqlite_wal_clean_second_timeout() -> u32 { - 2 -} - -fn default_sqlite_spillover_reap_fraction() -> f64 { - 0.5 -} - -fn default_sqlite_spillover_reap_interval_secs() -> u32 { - 60 + 15 * 60 // every 15 minutes } fn default_max_request_size() -> u32 { @@ -150,9 +125,6 @@ fn default_log() -> String { #[cfg(feature = "sled")] pub type Engine = abstraction::sled::Engine; -#[cfg(feature = "rocksdb")] -pub type Engine = abstraction::rocksdb::Engine; - #[cfg(feature = "sqlite")] pub type Engine = abstraction::sqlite::Engine; @@ -278,6 +250,7 @@ impl Database { serverroomids: builder.open_tree("serverroomids")?, userroomid_joined: builder.open_tree("userroomid_joined")?, roomuserid_joined: builder.open_tree("roomuserid_joined")?, + roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, @@ -297,8 +270,8 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, referencedevents: builder.open_tree("referencedevents")?, - pdu_cache: Mutex::new(LruCache::new(100_000)), - auth_chain_cache: Mutex::new(LruCache::new(100_000)), + pdu_cache: Mutex::new(LruCache::new(0)), + auth_chain_cache: Mutex::new(LruCache::new(0)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, @@ -449,6 +422,21 @@ impl Database { println!("Migration: 4 -> 5 finished"); } + + if db.globals.database_version()? < 6 { + // TODO update to 6 + // Set room member count + for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { + let room_id = + RoomId::try_from(utils::string_from_bytes(&roomid).unwrap()).unwrap(); + + db.rooms.update_joined_count(&room_id)?; + } + + db.globals.bump_database_version(6)?; + + println!("Migration: 5 -> 6 finished"); + } } let guard = db.read().await; @@ -465,8 +453,7 @@ impl Database { #[cfg(feature = "sqlite")] { - Self::start_wal_clean_task(&db, &config).await; - Self::start_spillover_reap_task(builder, &config).await; + Self::start_wal_clean_task(Arc::clone(&db), &config).await; } Ok(db) @@ -511,6 +498,16 @@ impl Database { .watch_prefix(&userid_prefix), ); futures.push(self.rooms.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push( + self.rooms + .userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); + futures.push( + self.rooms + .userroomid_highlightcount + .watch_prefix(&userid_prefix), + ); // Events for rooms we are in for room_id in self.rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { @@ -576,7 +573,7 @@ impl Database { } #[tracing::instrument(skip(self))] - pub async fn flush(&self) -> Result<()> { + pub fn flush(&self) -> Result<()> { let start = std::time::Instant::now(); let res = self._db.flush(); @@ -593,51 +590,17 @@ impl Database { } #[cfg(feature = "sqlite")] - #[tracing::instrument(skip(engine, config))] - pub async fn start_spillover_reap_task(engine: Arc, config: &Config) { - let fraction = config.sqlite_spillover_reap_fraction.clamp(0.01, 1.0); - let interval_secs = config.sqlite_spillover_reap_interval_secs as u64; - - let weak = Arc::downgrade(&engine); - - tokio::spawn(async move { - use tokio::time::interval; - - use std::{sync::Weak, time::Duration}; - - let mut i = interval(Duration::from_secs(interval_secs)); - - loop { - i.tick().await; - - if let Some(arc) = Weak::upgrade(&weak) { - arc.reap_spillover_by_fraction(fraction); - } else { - break; - } - } - }); - } - - #[cfg(feature = "sqlite")] - #[tracing::instrument(skip(lock, config))] - pub async fn start_wal_clean_task(lock: &Arc>, config: &Config) { - use tokio::time::{interval, timeout}; + #[tracing::instrument(skip(db, config))] + pub async fn start_wal_clean_task(db: Arc>, config: &Config) { + use tokio::time::interval; #[cfg(unix)] use tokio::signal::unix::{signal, SignalKind}; use tracing::info; - use std::{ - sync::Weak, - time::{Duration, Instant}, - }; + use std::time::{Duration, Instant}; - let weak: Weak> = Arc::downgrade(&lock); - - let lock_timeout = Duration::from_secs(config.sqlite_wal_clean_second_timeout as u64); let timer_interval = Duration::from_secs(config.sqlite_wal_clean_second_interval as u64); - let do_timer = config.sqlite_wal_clean_timer; tokio::spawn(async move { let mut i = interval(timer_interval); @@ -647,45 +610,24 @@ impl Database { loop { #[cfg(unix)] tokio::select! { - _ = i.tick(), if do_timer => { - info!(target: "wal-trunc", "Timer ticked") + _ = i.tick() => { + info!("wal-trunc: Timer ticked"); } _ = s.recv() => { - info!(target: "wal-trunc", "Received SIGHUP") + info!("wal-trunc: Received SIGHUP"); } }; #[cfg(not(unix))] - if do_timer { + { i.tick().await; - info!(target: "wal-trunc", "Timer ticked") - } else { - // timer disabled, and there's no concept of signals on windows, bailing... - return; + info!("wal-trunc: Timer ticked") } - if let Some(arc) = Weak::upgrade(&weak) { - info!(target: "wal-trunc", "Rotating sync helpers..."); - // This actually creates a very small race condition between firing this and trying to acquire the subsequent write lock. - // Though it is not a huge deal if the write lock doesn't "catch", as it'll harmlessly time out. - arc.read().await.globals.rotate.fire(); - info!(target: "wal-trunc", "Locking..."); - let guard = { - if let Ok(guard) = timeout(lock_timeout, arc.write()).await { - guard - } else { - info!(target: "wal-trunc", "Lock failed in timeout, canceled."); - continue; - } - }; - info!(target: "wal-trunc", "Locked, flushing..."); - let start = Instant::now(); - if let Err(e) = guard.flush_wal() { - error!(target: "wal-trunc", "Errored: {}", e); - } else { - info!(target: "wal-trunc", "Flushed in {:?}", start.elapsed()); - } + let start = Instant::now(); + if let Err(e) = db.read().await.flush_wal() { + error!("wal-trunc: Errored: {}", e); } else { - break; + info!("wal-trunc: Flushed in {:?}", start.elapsed()); } } }); diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 8ccac787..f381ce9f 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -3,9 +3,6 @@ use crate::Result; use std::{future::Future, pin::Pin, sync::Arc}; -#[cfg(feature = "rocksdb")] -pub mod rocksdb; - #[cfg(feature = "sled")] pub mod sled; @@ -25,23 +22,24 @@ pub trait Tree: Send + Sync { fn get(&self, key: &[u8]) -> Result>>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; + fn insert_batch<'a>(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()>; fn remove(&self, key: &[u8]) -> Result<()>; - fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a>; + fn iter<'a>(&'a self) -> Box, Vec)> + 'a>; fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + Send + 'a>; + ) -> Box, Vec)> + 'a>; fn increment(&self, key: &[u8]) -> Result>; fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + Send + 'a>; + ) -> Box, Vec)> + 'a>; fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; diff --git a/src/database/abstraction/heed.rs b/src/database/abstraction/heed.rs index 0421b140..e767e22b 100644 --- a/src/database/abstraction/heed.rs +++ b/src/database/abstraction/heed.rs @@ -81,7 +81,7 @@ impl EngineTree { let (s, r) = bounded::(100); let engine = Arc::clone(&self.engine); - let lock = self.engine.iter_pool.lock().unwrap(); + let lock = self.engine.iter_pool.lock().await; if lock.active_count() < lock.max_count() { lock.execute(move || { iter_from_thread_work(tree, &engine.env.read_txn().unwrap(), from, backwards, &s); diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs deleted file mode 100644 index 4699b2d5..00000000 --- a/src/database/abstraction/rocksdb.rs +++ /dev/null @@ -1,176 +0,0 @@ -use super::super::Config; -use crate::{utils, Result}; - -use std::{future::Future, pin::Pin, sync::Arc}; - -use super::{DatabaseEngine, Tree}; - -use std::{collections::HashMap, sync::RwLock}; - -pub struct Engine(rocksdb::DBWithThreadMode); - -pub struct RocksDbEngineTree<'a> { - db: Arc, - name: &'a str, - watchers: RwLock, Vec>>>, -} - -impl DatabaseEngine for Engine { - fn open(config: &Config) -> Result> { - let mut db_opts = rocksdb::Options::default(); - db_opts.create_if_missing(true); - db_opts.set_max_open_files(16); - db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); - db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); - db_opts.set_target_file_size_base(256 << 20); - db_opts.set_write_buffer_size(256 << 20); - - let mut block_based_options = rocksdb::BlockBasedOptions::default(); - block_based_options.set_block_size(512 << 10); - db_opts.set_block_based_table_factory(&block_based_options); - - let cfs = rocksdb::DBWithThreadMode::::list_cf( - &db_opts, - &config.database_path, - ) - .unwrap_or_default(); - - let mut options = rocksdb::Options::default(); - options.set_merge_operator_associative("increment", utils::increment_rocksdb); - - let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( - &db_opts, - &config.database_path, - cfs.iter() - .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), - )?; - - Ok(Arc::new(Engine(db))) - } - - fn open_tree(self: &Arc, name: &'static str) -> Result> { - let mut options = rocksdb::Options::default(); - options.set_merge_operator_associative("increment", utils::increment_rocksdb); - - // Create if it doesn't exist - let _ = self.0.create_cf(name, &options); - - Ok(Arc::new(RocksDbEngineTree { - name, - db: Arc::clone(self), - watchers: RwLock::new(HashMap::new()), - })) - } -} - -impl RocksDbEngineTree<'_> { - fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { - self.db.0.cf_handle(self.name).unwrap() - } -} - -impl Tree for RocksDbEngineTree<'_> { - fn get(&self, key: &[u8]) -> Result>> { - Ok(self.db.0.get_cf(self.cf(), key)?) - } - - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - let watchers = self.watchers.read().unwrap(); - let mut triggered = Vec::new(); - - for length in 0..=key.len() { - if watchers.contains_key(&key[..length]) { - triggered.push(&key[..length]); - } - } - - drop(watchers); - - if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); - for prefix in triggered { - if let Some(txs) = watchers.remove(prefix) { - for tx in txs { - let _ = tx.send(()); - } - } - } - } - - Ok(self.db.0.put_cf(self.cf(), key, value)?) - } - - fn remove(&self, key: &[u8]) -> Result<()> { - Ok(self.db.0.delete_cf(self.cf(), key)?) - } - - fn iter<'a>(&'a self) -> Box, Vec)> + Send + Sync + 'a> { - Box::new( - self.db - .0 - .iterator_cf(self.cf(), rocksdb::IteratorMode::Start), - ) - } - - fn iter_from<'a>( - &'a self, - from: &[u8], - backwards: bool, - ) -> Box, Vec)> + 'a> { - Box::new(self.db.0.iterator_cf( - self.cf(), - rocksdb::IteratorMode::From( - from, - if backwards { - rocksdb::Direction::Reverse - } else { - rocksdb::Direction::Forward - }, - ), - )) - } - - fn increment(&self, key: &[u8]) -> Result> { - let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); - dbg!(stats.mem_table_total); - dbg!(stats.mem_table_unflushed); - dbg!(stats.mem_table_readers_total); - dbg!(stats.cache_total); - // TODO: atomic? - let old = self.get(key)?; - let new = utils::increment(old.as_deref()).unwrap(); - self.insert(key, &new)?; - Ok(new) - } - - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box, Vec)> + Send + 'a> { - Box::new( - self.db - .0 - .iterator_cf( - self.cf(), - rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), - ) - .take_while(move |(k, _)| k.starts_with(&prefix)), - ) - } - - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { - let (tx, rx) = tokio::sync::oneshot::channel(); - - self.watchers - .write() - .unwrap() - .entry(prefix.to_vec()) - .or_default() - .push(tx); - - Box::pin(async move { - // Tx is never destroyed - rx.await.unwrap(); - }) - } -} diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index d99ce264..35ba1b29 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -39,12 +39,21 @@ impl Tree for SledEngineTree { Ok(()) } + #[tracing::instrument(skip(self, iter))] + fn insert_batch<'a>(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { + for (key, value) in iter { + self.0.insert(key, value)?; + } + + Ok(()) + } + fn remove(&self, key: &[u8]) -> Result<()> { self.0.remove(key)?; Ok(()) } - fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a> { + fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { Box::new( self.0 .iter() @@ -62,7 +71,7 @@ impl Tree for SledEngineTree { &self, from: &[u8], backwards: bool, - ) -> Box, Vec)> + Send> { + ) -> Box, Vec)>> { let iter = if backwards { self.0.range(..=from) } else { @@ -95,7 +104,7 @@ impl Tree for SledEngineTree { fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Vec)> + Send + 'a> { + ) -> Box, Vec)> + 'a> { let iter = self .0 .scan_prefix(prefix) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index a46d3ada..0dbb2615 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,138 +1,60 @@ use super::{DatabaseEngine, Tree}; use crate::{database::Config, Result}; -use crossbeam::channel::{ - bounded, unbounded, Receiver as ChannelReceiver, Sender as ChannelSender, TryRecvError, -}; use parking_lot::{Mutex, MutexGuard, RwLock}; -use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension, Params}; +use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use std::{ + cell::RefCell, collections::HashMap, future::Future, - ops::Deref, path::{Path, PathBuf}, pin::Pin, sync::Arc, time::{Duration, Instant}, }; -use threadpool::ThreadPool; use tokio::sync::oneshot::Sender; use tracing::{debug, warn}; -struct Pool { - writer: Mutex, - readers: Vec>, - spills: ConnectionRecycler, - spill_tracker: Arc<()>, - path: PathBuf, -} - pub const MILLI: Duration = Duration::from_millis(1); -enum HoldingConn<'a> { - FromGuard(MutexGuard<'a, Connection>), - FromRecycled(RecycledConn, Arc<()>), +thread_local! { + static READ_CONNECTION: RefCell> = RefCell::new(None); } -impl<'a> Deref for HoldingConn<'a> { - type Target = Connection; +struct PreparedStatementIterator<'a> { + pub iterator: Box + 'a>, + pub statement_ref: NonAliasingBox>, +} - fn deref(&self) -> &Self::Target { - match self { - HoldingConn::FromGuard(guard) => guard.deref(), - HoldingConn::FromRecycled(conn, _) => conn.deref(), - } +impl Iterator for PreparedStatementIterator<'_> { + type Item = TupleOfBytes; + + fn next(&mut self) -> Option { + self.iterator.next() } } -struct ConnectionRecycler(ChannelSender, ChannelReceiver); - -impl ConnectionRecycler { - fn new() -> Self { - let (s, r) = unbounded(); - Self(s, r) - } - - fn recycle(&self, conn: Connection) -> RecycledConn { - let sender = self.0.clone(); - - RecycledConn(Some(conn), sender) - } - - fn try_take(&self) -> Option { - match self.1.try_recv() { - Ok(conn) => Some(conn), - Err(TryRecvError::Empty) => None, - // as this is pretty impossible, a panic is warranted if it ever occurs - Err(TryRecvError::Disconnected) => panic!("Receiving channel was disconnected. A a sender is owned by the current struct, this should never happen(!!!)") - } - } -} - -struct RecycledConn( - Option, // To allow moving out of the struct when `Drop` is called. - ChannelSender, -); - -impl Deref for RecycledConn { - type Target = Connection; - - fn deref(&self) -> &Self::Target { - self.0 - .as_ref() - .expect("RecycledConn does not have a connection in Option<>") - } -} - -impl Drop for RecycledConn { +struct NonAliasingBox(*mut T); +impl Drop for NonAliasingBox { fn drop(&mut self) { - if let Some(conn) = self.0.take() { - debug!("Recycled connection"); - if let Err(e) = self.1.send(conn) { - warn!("Recycling a connection led to the following error: {:?}", e) - } - } + unsafe { Box::from_raw(self.0) }; } } -impl Pool { - fn new>(path: P, num_readers: usize, total_cache_size_mb: f64) -> Result { - // calculates cache-size per permanent connection - // 1. convert MB to KiB - // 2. divide by permanent connections - // 3. round down to nearest integer - let cache_size: u32 = ((total_cache_size_mb * 1024.0) / (num_readers + 1) as f64) as u32; +pub struct Engine { + writer: Mutex, - let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); + path: PathBuf, + cache_size_per_thread: u32, +} - let mut readers = Vec::new(); +impl Engine { + fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { + let conn = Connection::open(&path)?; - for _ in 0..num_readers { - readers.push(Mutex::new(Self::prepare_conn(&path, Some(cache_size))?)) - } - - Ok(Self { - writer, - readers, - spills: ConnectionRecycler::new(), - spill_tracker: Arc::new(()), - path: path.as_ref().to_path_buf(), - }) - } - - fn prepare_conn>(path: P, cache_size: Option) -> Result { - let conn = Connection::open(path)?; - - conn.pragma_update(Some(Main), "journal_mode", &"WAL".to_owned())?; - - // conn.pragma_update(Some(Main), "wal_autocheckpoint", &250)?; - - // conn.pragma_update(Some(Main), "wal_checkpoint", &"FULL".to_owned())?; - - conn.pragma_update(Some(Main), "synchronous", &"OFF".to_owned())?; - - if let Some(cache_kib) = cache_size { - conn.pragma_update(Some(Main), "cache_size", &(-Into::::into(cache_kib)))?; - } + conn.pragma_update(Some(Main), "page_size", &32768)?; + conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; + conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; + conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; Ok(conn) } @@ -141,71 +63,52 @@ impl Pool { self.writer.lock() } - fn read_lock(&self) -> HoldingConn<'_> { - // First try to get a connection from the permanent pool - for r in &self.readers { - if let Some(reader) = r.try_lock() { - return HoldingConn::FromGuard(reader); + fn read_lock(&self) -> &'static Connection { + READ_CONNECTION.with(|cell| { + let connection = &mut cell.borrow_mut(); + + if (*connection).is_none() { + let c = Box::leak(Box::new( + Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(), + )); + **connection = Some(c); } - } - debug!("read_lock: All permanent readers locked, obtaining spillover reader..."); - - // We didn't get a connection from the permanent pool, so we'll dumpster-dive for recycled connections. - // Either we have a connection or we dont, if we don't, we make a new one. - let conn = match self.spills.try_take() { - Some(conn) => conn, - None => { - debug!("read_lock: No recycled connections left, creating new one..."); - Self::prepare_conn(&self.path, None).unwrap() - } - }; - - // Clone the spill Arc to mark how many spilled connections actually exist. - let spill_arc = Arc::clone(&self.spill_tracker); - - // Get a sense of how many connections exist now. - let now_count = Arc::strong_count(&spill_arc) - 1 /* because one is held by the pool */; - - // If the spillover readers are more than the number of total readers, there might be a problem. - if now_count > self.readers.len() { - warn!( - "Database is under high load. Consider increasing sqlite_read_pool_size ({} spillover readers exist)", - now_count - ); - } - - // Return the recyclable connection. - HoldingConn::FromRecycled(self.spills.recycle(conn), spill_arc) + connection.unwrap() + }) } -} -pub struct Engine { - pool: Pool, - iter_pool: Mutex, + pub fn flush_wal(self: &Arc) -> Result<()> { + self.write_lock() + .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; + Ok(()) + } } impl DatabaseEngine for Engine { fn open(config: &Config) -> Result> { - let pool = Pool::new( - Path::new(&config.database_path).join("conduit.db"), - config.sqlite_read_pool_size, - config.db_cache_capacity_mb, - )?; + let path = Path::new(&config.database_path).join("conduit.db"); - pool.write_lock() - .execute("CREATE TABLE IF NOT EXISTS _noop (\"key\" INT)", params![])?; + // calculates cache-size per permanent connection + // 1. convert MB to KiB + // 2. divide by permanent connections + // 3. round down to nearest integer + let cache_size_per_thread: u32 = + ((config.db_cache_capacity_mb * 1024.0) / (num_cpus::get().max(1) + 1) as f64) as u32; + + let writer = Mutex::new(Self::prepare_conn(&path, cache_size_per_thread)?); let arc = Arc::new(Engine { - pool, - iter_pool: Mutex::new(ThreadPool::new(10)), + writer, + path, + cache_size_per_thread, }); Ok(arc) } fn open_tree(self: &Arc, name: &str) -> Result> { - self.pool.write_lock().execute(format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name).as_str(), [])?; + self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; Ok(Arc::new(SqliteTable { engine: Arc::clone(self), @@ -215,55 +118,8 @@ impl DatabaseEngine for Engine { } fn flush(self: &Arc) -> Result<()> { - self.pool - .write_lock() - .execute_batch( - " - PRAGMA synchronous=FULL; - BEGIN; - DELETE FROM _noop; - INSERT INTO _noop VALUES (1); - COMMIT; - PRAGMA synchronous=OFF; - ", - ) - .map_err(Into::into) - } -} - -impl Engine { - pub fn flush_wal(self: &Arc) -> Result<()> { - self.pool - .write_lock() - .execute_batch( - " - PRAGMA synchronous=FULL; PRAGMA wal_checkpoint=TRUNCATE; - BEGIN; - DELETE FROM _noop; - INSERT INTO _noop VALUES (1); - COMMIT; - PRAGMA wal_checkpoint=PASSIVE; PRAGMA synchronous=OFF; - ", - ) - .map_err(Into::into) - } - - // Reaps (at most) (.len() * `fraction`) (rounded down, min 1) connections. - pub fn reap_spillover_by_fraction(&self, fraction: f64) { - let mut reaped = 0; - - let spill_amount = self.pool.spills.1.len() as f64; - let fraction = fraction.clamp(0.01, 1.0); - - let amount = (spill_amount * fraction).max(1.0) as u32; - - for _ in 0..amount { - if self.pool.spills.try_take().is_some() { - reaped += 1; - } - } - - debug!("Reaped {} connections", reaped); + // we enabled PRAGMA synchronous=normal, so this should not be necessary + Ok(()) } } @@ -288,7 +144,7 @@ impl SqliteTable { fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { guard.execute( format!( - "INSERT INTO {} (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", + "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name ) .as_str(), @@ -296,70 +152,17 @@ impl SqliteTable { )?; Ok(()) } - - #[tracing::instrument(skip(self, sql, param))] - fn iter_from_thread( - &self, - sql: String, - param: Option>, - ) -> Box + Send + Sync> { - let (s, r) = bounded::(5); - - let engine = Arc::clone(&self.engine); - - let lock = self.engine.iter_pool.lock(); - if lock.active_count() < lock.max_count() { - lock.execute(move || { - if let Some(param) = param { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, [param]); - } else { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, []); - } - }); - } else { - std::thread::spawn(move || { - if let Some(param) = param { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, [param]); - } else { - iter_from_thread_work(&engine.pool.read_lock(), &s, &sql, []); - } - }); - } - - Box::new(r.into_iter()) - } -} - -fn iter_from_thread_work

( - guard: &HoldingConn<'_>, - s: &ChannelSender<(Vec, Vec)>, - sql: &str, - params: P, -) where - P: Params, -{ - for bob in guard - .prepare(sql) - .unwrap() - .query_map(params, |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(|r| r.unwrap()) - { - if s.send(bob).is_err() { - return; - } - } } impl Tree for SqliteTable { #[tracing::instrument(skip(self, key))] fn get(&self, key: &[u8]) -> Result>> { - self.get_with_guard(&self.engine.pool.read_lock(), key) + self.get_with_guard(&self.engine.read_lock(), key) } #[tracing::instrument(skip(self, key, value))] fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - let guard = self.engine.pool.write_lock(); + let guard = self.engine.write_lock(); let start = Instant::now(); @@ -367,7 +170,7 @@ impl Tree for SqliteTable { let elapsed = start.elapsed(); if elapsed > MILLI { - debug!("insert: took {:012?} : {}", elapsed, &self.name); + warn!("insert took {:?} : {}", elapsed, &self.name); } drop(guard); @@ -397,9 +200,24 @@ impl Tree for SqliteTable { Ok(()) } + #[tracing::instrument(skip(self, iter))] + fn insert_batch<'a>(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { + let guard = self.engine.write_lock(); + + guard.execute("BEGIN", [])?; + for (key, value) in iter { + self.insert_with_guard(&guard, &key, &value)?; + } + guard.execute("COMMIT", [])?; + + drop(guard); + + Ok(()) + } + #[tracing::instrument(skip(self, key))] fn remove(&self, key: &[u8]) -> Result<()> { - let guard = self.engine.pool.write_lock(); + let guard = self.engine.write_lock(); let start = Instant::now(); @@ -419,9 +237,31 @@ impl Tree for SqliteTable { } #[tracing::instrument(skip(self))] - fn iter<'a>(&'a self) -> Box + Send + 'a> { - let name = self.name.clone(); - self.iter_from_thread(format!("SELECT key, value FROM {}", name), None) + fn iter<'a>(&'a self) -> Box + 'a> { + let guard = self.engine.read_lock(); + + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} ORDER BY key ASC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) } #[tracing::instrument(skip(self, from, backwards))] @@ -429,31 +269,61 @@ impl Tree for SqliteTable { &'a self, from: &[u8], backwards: bool, - ) -> Box + Send + 'a> { - let name = self.name.clone(); + ) -> Box + 'a> { + let guard = self.engine.read_lock(); let from = from.to_vec(); // TODO change interface? + if backwards { - self.iter_from_thread( - format!( - "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", - name - ), - Some(from), - ) + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) } else { - self.iter_from_thread( - format!( - "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", - name - ), - Some(from), - ) + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) } } #[tracing::instrument(skip(self, key))] fn increment(&self, key: &[u8]) -> Result> { - let guard = self.engine.pool.write_lock(); + let guard = self.engine.write_lock(); let start = Instant::now(); @@ -475,10 +345,7 @@ impl Tree for SqliteTable { } #[tracing::instrument(skip(self, prefix))] - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box + Send + 'a> { + fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { // let name = self.name.clone(); // self.iter_from_thread( // format!( @@ -513,25 +380,9 @@ impl Tree for SqliteTable { fn clear(&self) -> Result<()> { debug!("clear: running"); self.engine - .pool .write_lock() .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; debug!("clear: ran"); Ok(()) } } - -// TODO -// struct Pool { -// writer: Mutex, -// readers: [Mutex; NUM_READERS], -// } - -// // then, to pick a reader: -// for r in &pool.readers { -// if let Ok(reader) = r.try_lock() { -// // use reader -// } -// } -// // none unlocked, pick the next reader -// pool.readers[pool.counter.fetch_add(1, Relaxed) % NUM_READERS].lock() diff --git a/src/database/admin.rs b/src/database/admin.rs index e1b24d08..424e6746 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -84,15 +84,15 @@ impl Admin { tokio::select! { Some(event) = receiver.next() => { let guard = db.read().await; - let mutex = Arc::clone( + let mutex_state = Arc::clone( guard.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(conduit_room.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; match event { AdminCommand::RegisterAppservice(yaml) => { @@ -106,17 +106,17 @@ impl Admin { count, appservices.into_iter().filter_map(|r| r.ok()).collect::>().join(", ") ); - send_message(message::MessageEventContent::text_plain(output), guard, &mutex_lock); + send_message(message::MessageEventContent::text_plain(output), guard, &state_lock); } else { - send_message(message::MessageEventContent::text_plain("Failed to get appservices."), guard, &mutex_lock); + send_message(message::MessageEventContent::text_plain("Failed to get appservices."), guard, &state_lock); } } AdminCommand::SendMessage(message) => { - send_message(message, guard, &mutex_lock); + send_message(message, guard, &state_lock); } } - drop(mutex_lock); + drop(state_lock); } } } diff --git a/src/database/appservice.rs b/src/database/appservice.rs index f39520c7..7cc91372 100644 --- a/src/database/appservice.rs +++ b/src/database/appservice.rs @@ -49,22 +49,23 @@ impl Appservice { ) } - pub fn iter_ids(&self) -> Result> + Send + '_> { + pub fn iter_ids(&self) -> Result> + '_> { Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) })) } - pub fn iter_all( - &self, - ) -> Result> + '_ + Send> { - Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| { - Ok(( - id.clone(), - self.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - })) + pub fn all(&self) -> Result> { + self.iter_ids()? + .filter_map(|id| id.ok()) + .map(move |id| { + Ok(( + id.clone(), + self.get_registration(&id)? + .expect("iter_ids only returns appservices that exist"), + )) + }) + .collect() } } diff --git a/src/database/globals.rs b/src/database/globals.rs index 0edb9ca2..823ce349 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -12,10 +12,10 @@ use std::{ fs, future::Future, path::PathBuf, - sync::{Arc, RwLock}, + sync::{Arc, Mutex, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::{broadcast, watch::Receiver, Mutex, Semaphore}; +use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; use tracing::{error, info}; use trust_dns_resolver::TokioAsyncResolver; @@ -45,8 +45,9 @@ pub struct Globals { pub bad_signature_ratelimiter: Arc, RateLimitState>>>, pub servername_ratelimiter: Arc, Arc>>>, pub sync_receivers: RwLock), SyncHandle>>, - pub roomid_mutex: RwLock>>>, - pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer + pub roomid_mutex_insert: RwLock>>>, + pub roomid_mutex_state: RwLock>>>, + pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer pub rotate: RotationHandler, } @@ -200,7 +201,8 @@ impl Globals { bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - roomid_mutex: RwLock::new(HashMap::new()), + roomid_mutex_state: RwLock::new(HashMap::new()), + roomid_mutex_insert: RwLock::new(HashMap::new()), roomid_mutex_federation: RwLock::new(HashMap::new()), sync_receivers: RwLock::new(HashMap::new()), rotate: RotationHandler::new(), diff --git a/src/database/media.rs b/src/database/media.rs index f576ca4a..a9bb42b0 100644 --- a/src/database/media.rs +++ b/src/database/media.rs @@ -101,8 +101,8 @@ impl Media { prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail prefix.push(0xff); - let mut iter = self.mediaid_file.scan_prefix(prefix); - if let Some((key, _)) = iter.next() { + let first = self.mediaid_file.scan_prefix(prefix).next(); + if let Some((key, _)) = first { let path = globals.get_media_file(&key); let mut file = Vec::new(); File::open(path).await?.read_to_end(&mut file).await?; @@ -190,7 +190,9 @@ impl Media { original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail original_prefix.push(0xff); - if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() { + let first_thumbnailprefix = self.mediaid_file.scan_prefix(thumbnail_prefix).next(); + let first_originalprefix = self.mediaid_file.scan_prefix(original_prefix).next(); + if let Some((key, _)) = first_thumbnailprefix { // Using saved thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); @@ -225,7 +227,7 @@ impl Media { content_type, file: file.to_vec(), })) - } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() { + } else if let Some((key, _)) = first_originalprefix { // Generate a thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 79bb059d..10a6215d 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -2,7 +2,6 @@ mod edus; pub use edus::RoomEdus; use member::MembershipState; -use tokio::sync::MutexGuard; use crate::{pdu::PduBuilder, utils, Database, Error, PduEvent, Result}; use lru_cache::LruCache; @@ -28,6 +27,7 @@ use std::{ mem, sync::{Arc, Mutex}, }; +use tokio::sync::MutexGuard; use tracing::{debug, error, warn}; use super::{abstraction::Tree, admin::AdminCommand, pusher}; @@ -55,6 +55,7 @@ pub struct Rooms { pub(super) userroomid_joined: Arc, pub(super) roomuserid_joined: Arc, + pub(super) roomid_joinedcount: Arc, pub(super) roomuseroncejoinedids: Arc, pub(super) userroomid_invitestate: Arc, // InviteState = Vec> pub(super) roomuserid_invitecount: Arc, // InviteCount = Count @@ -87,7 +88,7 @@ pub struct Rooms { pub(super) referencedevents: Arc, pub(super) pdu_cache: Mutex>>, - pub(super) auth_chain_cache: Mutex>>, + pub(super) auth_chain_cache: Mutex, HashSet>>, } impl Rooms { @@ -313,41 +314,50 @@ impl Rooms { let new_state = if !already_existed { let mut new_state = HashSet::new(); - for ((event_type, state_key), eventid) in state { - new_state.insert(eventid.clone()); + let batch = state + .iter() + .filter_map(|((event_type, state_key), eventid)| { + new_state.insert(eventid.clone()); - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); - let shortstatekey = match self.statekey_shortstatekey.get(&statekey)? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = db.globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes())?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { + Some(shortstatekey) => shortstatekey.to_vec(), + None => { + let shortstatekey = db.globals.next_count().ok()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes()) + .ok()?; + shortstatekey.to_be_bytes().to_vec() + } + }; - let shorteventid = match self.eventid_shorteventid.get(eventid.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = db.globals.next_count()?; - self.eventid_shorteventid - .insert(eventid.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), eventid.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = + match self.eventid_shorteventid.get(eventid.as_bytes()).ok()? { + Some(shorteventid) => shorteventid.to_vec(), + None => { + let shorteventid = db.globals.next_count().ok()?; + self.eventid_shorteventid + .insert(eventid.as_bytes(), &shorteventid.to_be_bytes()) + .ok()?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), eventid.as_bytes()) + .ok()?; + shorteventid.to_be_bytes().to_vec() + } + }; - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); + let mut state_id = shortstatehash.to_be_bytes().to_vec(); + state_id.extend_from_slice(&shortstatekey); - self.stateid_shorteventid - .insert(&state_id, &*shorteventid)?; - } + Some((state_id, shorteventid)) + }) + .collect::>(); + + self.stateid_shorteventid + .insert_batch(&mut batch.into_iter())?; new_state } else { @@ -736,6 +746,16 @@ impl Rooms { self.replace_pdu_leaves(&pdu.room_id, leaves)?; + let mutex_insert = Arc::clone( + db.globals + .roomid_mutex_insert + .write() + .unwrap() + .entry(pdu.room_id.clone()) + .or_default(), + ); + let insert_lock = mutex_insert.lock().unwrap(); + let count1 = db.globals.next_count()?; // Mark as read first so the sending client doesn't get a notification even if appending // fails @@ -750,6 +770,8 @@ impl Rooms { // There's a brief moment of time here where the count is updated but the pdu does not // exist. This could theoretically lead to dropped pdus, but it's extremely rare + // + // Update: We fixed this using insert_lock self.pduid_pdu.insert( &pdu_id, @@ -761,6 +783,8 @@ impl Rooms { self.eventid_pduid .insert(pdu.event_id.as_bytes(), &pdu_id)?; + drop(insert_lock); + // See if the event matches any known pushers let power_levels: PowerLevelsEventContent = db .rooms @@ -779,7 +803,7 @@ impl Rooms { .room_members(&pdu.room_id) .filter_map(|r| r.ok()) .filter(|user_id| user_id.server_name() == db.globals.server_name()) - .filter(|user_id| !db.users.is_deactivated(user_id).unwrap_or(false)) + .filter(|user_id| !db.users.is_deactivated(user_id).unwrap_or(true)) { // Don't notify the user of their own events if user == pdu.sender { @@ -882,18 +906,20 @@ impl Rooms { } EventType::RoomMessage => { if let Some(body) = pdu.content.get("body").and_then(|b| b.as_str()) { - for word in body + let mut batch = body .split_terminator(|c: char| !c.is_alphanumeric()) .filter(|word| word.len() <= 50) .map(str::to_lowercase) - { - let mut key = pdu.room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(word.as_bytes()); - key.push(0xff); - key.extend_from_slice(&pdu_id); - self.tokenids.insert(&key, &[])?; - } + .map(|word| { + let mut key = pdu.room_id.as_bytes().to_vec(); + key.push(0xff); + key.extend_from_slice(word.as_bytes()); + key.push(0xff); + key.extend_from_slice(&pdu_id); + (key, Vec::new()) + }); + + self.tokenids.insert_batch(&mut batch)?; if body.starts_with(&format!("@conduit:{}: ", db.globals.server_name())) && self @@ -1106,39 +1132,51 @@ impl Rooms { } }; - for ((event_type, state_key), pdu) in state { - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + let batch = state + .iter() + .filter_map(|((event_type, state_key), pdu)| { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); - let shortstatekey = match self.statekey_shortstatekey.get(&statekey)? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes())?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { + Some(shortstatekey) => shortstatekey.to_vec(), + None => { + let shortstatekey = globals.next_count().ok()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes()) + .ok()?; + shortstatekey.to_be_bytes().to_vec() + } + }; - let shorteventid = match self.eventid_shorteventid.get(pdu.event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = match self + .eventid_shorteventid + .get(pdu.event_id.as_bytes()) + .ok()? + { + Some(shorteventid) => shorteventid.to_vec(), + None => { + let shorteventid = globals.next_count().ok()?; + self.eventid_shorteventid + .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes()) + .ok()?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes()) + .ok()?; + shorteventid.to_be_bytes().to_vec() + } + }; - let mut state_id = shortstatehash.clone(); - state_id.extend_from_slice(&shortstatekey); + let mut state_id = shortstatehash.clone(); + state_id.extend_from_slice(&shortstatekey); - self.stateid_shorteventid - .insert(&*state_id, &*shorteventid)?; - } + Some((state_id, shorteventid)) + }) + .collect::>(); + + self.stateid_shorteventid + .insert_batch(&mut batch.into_iter())?; self.shorteventid_shortstatehash .insert(&shorteventid, &*shortstatehash)?; @@ -1243,11 +1281,13 @@ impl Rooms { } }; - for (shortstatekey, shorteventid) in new_state { + let mut batch = new_state.into_iter().map(|(shortstatekey, shorteventid)| { let mut state_id = shortstatehash.to_be_bytes().to_vec(); state_id.extend_from_slice(&shortstatekey); - self.stateid_shorteventid.insert(&state_id, &shorteventid)?; - } + (state_id, shorteventid) + }); + + self.stateid_shorteventid.insert_batch(&mut batch)?; Ok(shortstatehash) } else { @@ -1464,13 +1504,6 @@ impl Rooms { self.shorteventid_eventid .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes())?; - // Increment the last index and use that - // This is also the next_batch/since value - let count = db.globals.next_count()?; - let mut pdu_id = room_id.as_bytes().to_vec(); - pdu_id.push(0xff); - pdu_id.extend_from_slice(&count.to_be_bytes()); - // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. let statehashid = self.append_to_state(&pdu, &db.globals)?; @@ -1496,7 +1529,7 @@ impl Rooms { db.sending.send_pdu(&server, &pdu_id)?; } - for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { + for appservice in db.appservice.all()? { if let Some(namespaces) = appservice.1.get("namespaces") { let users = namespaces .get("users") @@ -1874,9 +1907,18 @@ impl Rooms { _ => {} } + self.update_joined_count(room_id)?; + Ok(()) } + pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { + self.roomid_joinedcount.insert( + room_id.as_bytes(), + &(self.room_members(&room_id).count() as u64).to_be_bytes(), + ) + } + pub async fn leave_room( &self, user_id: &UserId, @@ -1904,15 +1946,15 @@ impl Rooms { db, )?; } else { - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; let mut event = serde_json::from_value::>( self.room_state_get(room_id, &EventType::RoomMember, &user_id.to_string())? @@ -1941,7 +1983,7 @@ impl Rooms { user_id, room_id, db, - &mutex_lock, + &state_lock, )?; } @@ -2338,6 +2380,17 @@ impl Rooms { }) } + pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { + Ok(self + .roomid_joinedcount + .get(room_id.as_bytes())? + .map(|b| { + utils::u64_from_bytes(&b) + .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) + }) + .transpose()?) + } + /// Returns an iterator over all User IDs who ever joined a room. #[tracing::instrument(skip(self))] pub fn room_useroncejoined<'a>( @@ -2586,7 +2639,7 @@ impl Rooms { #[tracing::instrument(skip(self))] pub fn auth_chain_cache( &self, - ) -> std::sync::MutexGuard<'_, LruCache>> { + ) -> std::sync::MutexGuard<'_, LruCache, HashSet>> { self.auth_chain_cache.lock().unwrap() } } diff --git a/src/database/rooms/edus.rs b/src/database/rooms/edus.rs index 664c1710..ff28436b 100644 --- a/src/database/rooms/edus.rs +++ b/src/database/rooms/edus.rs @@ -422,7 +422,7 @@ impl RoomEdus { } /// Sets all users to offline who have been quiet for too long. - pub fn presence_maintain( + fn presence_maintain( &self, rooms: &super::Rooms, globals: &super::super::globals::Globals, @@ -497,7 +497,7 @@ impl RoomEdus { rooms: &super::Rooms, globals: &super::super::globals::Globals, ) -> Result> { - self.presence_maintain(rooms, globals)?; + //self.presence_maintain(rooms, globals)?; let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/error.rs b/src/error.rs index 24e52ecc..1ecef3aa 100644 --- a/src/error.rs +++ b/src/error.rs @@ -30,12 +30,6 @@ pub enum Error { #[from] source: sled::Error, }, - #[cfg(feature = "rocksdb")] - #[error("There was a problem with the connection to the rocksdb database: {source}")] - RocksDbError { - #[from] - source: rocksdb::Error, - }, #[cfg(feature = "sqlite")] #[error("There was a problem with the connection to the sqlite database: {source}")] SqliteError { diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 21214399..56811942 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -75,9 +75,9 @@ where registration, )) = db .appservice - .iter_all() + .all() .unwrap() - .filter_map(|r| r.ok()) + .iter() .find(|(_id, registration)| { registration .get("as_token") diff --git a/src/server_server.rs b/src/server_server.rs index 232c5d46..45d90226 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -806,7 +806,7 @@ pub async fn send_transaction_message_route( } } - db.flush().await?; + db.flush()?; Ok(send_transaction_message::v1::Response { pdus: resolved_map }.into()) } @@ -1159,15 +1159,15 @@ pub fn handle_incoming_pdu<'a>( // We start looking at current room state now, so lets lock the room - let mutex = Arc::clone( + let mutex_state = Arc::clone( db.globals - .roomid_mutex + .roomid_mutex_state .write() .unwrap() .entry(room_id.clone()) .or_default(), ); - let mutex_lock = mutex.lock().await; + let state_lock = mutex_state.lock().await; // Now we calculate the set of extremities this room has after the incoming event has been // applied. We start with the previous extremities (aka leaves) @@ -1267,10 +1267,10 @@ pub fn handle_incoming_pdu<'a>( // 14. Use state resolution to find new room state let new_room_state = if fork_states.is_empty() { return Err("State is empty.".to_owned()); - } else if fork_states.len() == 1 { + } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { // There was only one state, so it has to be the room's current state (because that is // always included) - debug!("Skipping stateres because there is no new state."); + warn!("Skipping stateres because there is no new state."); fork_states[0] .iter() .map(|(k, pdu)| (k.clone(), pdu.event_id.clone())) @@ -1341,9 +1341,8 @@ pub fn handle_incoming_pdu<'a>( val, extremities, &state_at_incoming_event, - &mutex_lock, + &state_lock, ) - .await .map_err(|_| "Failed to add pdu to db.".to_owned())?, ); debug!("Appended incoming pdu."); @@ -1365,7 +1364,7 @@ pub fn handle_incoming_pdu<'a>( } // Event has passed all auth/stateres checks - drop(mutex_lock); + drop(state_lock); Ok(pdu_id) }) } @@ -1643,7 +1642,7 @@ pub(crate) async fn fetch_signing_keys( /// Append the incoming event setting the state snapshot to the state from the /// server that sent the event. #[tracing::instrument(skip(db, pdu, pdu_json, new_room_leaves, state, _mutex_lock))] -async fn append_incoming_pdu( +fn append_incoming_pdu( db: &Database, pdu: &PduEvent, pdu_json: CanonicalJsonObject, @@ -1663,7 +1662,7 @@ async fn append_incoming_pdu( &db, )?; - for appservice in db.appservice.iter_all()?.filter_map(|r| r.ok()) { + for appservice in db.appservice.all()? { if let Some(namespaces) = appservice.1.get("namespaces") { let users = namespaces .get("users") @@ -1728,39 +1727,44 @@ fn get_auth_chain(starting_events: Vec, db: &Database) -> Result Result> { - let mut auth_chain = HashSet::new(); - +fn get_auth_chain_recursive( + event_id: &EventId, + mut found: HashSet, + db: &Database, +) -> Result> { if let Some(pdu) = db.rooms.get_pdu(&event_id)? { - auth_chain.extend(pdu.auth_events.iter().cloned()); for auth_event in &pdu.auth_events { - auth_chain.extend(get_auth_chain_recursive(&auth_event, db)?); + if !found.contains(auth_event) { + found.insert(auth_event.clone()); + found = get_auth_chain_recursive(&auth_event, found, db)?; + } } } else { warn!("Could not find pdu mentioned in auth events."); } - Ok(auth_chain) + Ok(found) } #[cfg_attr( @@ -2208,7 +2212,7 @@ pub async fn create_join_event_route( db.sending.send_pdu(&server, &pdu_id)?; } - db.flush().await?; + db.flush()?; Ok(create_join_event::v2::Response { room_state: RoomState { @@ -2327,7 +2331,7 @@ pub async fn create_invite_route( )?; } - db.flush().await?; + db.flush()?; Ok(create_invite::v2::Response { event: PduEvent::convert_to_outgoing_federation_event(signed_event), @@ -2464,7 +2468,7 @@ pub async fn get_keys_route( ) .await?; - db.flush().await?; + db.flush()?; Ok(get_keys::v1::Response { device_keys: result.device_keys, @@ -2489,7 +2493,7 @@ pub async fn claim_keys_route( let result = claim_keys_helper(&body.one_time_keys, &db).await?; - db.flush().await?; + db.flush()?; Ok(claim_keys::v1::Response { one_time_keys: result.one_time_keys, diff --git a/src/utils.rs b/src/utils.rs index 60a4e0cf..d21395e1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -17,15 +17,6 @@ pub fn millis_since_unix_epoch() -> u64 { .as_millis() as u64 } -#[cfg(feature = "rocksdb")] -pub fn increment_rocksdb( - _new_key: &[u8], - old: Option<&[u8]>, - _operands: &mut rocksdb::MergeOperands, -) -> Option> { - increment(old) -} - pub fn increment(old: Option<&[u8]>) -> Option> { let number = match old.map(|bytes| bytes.try_into()) { Some(Ok(bytes)) => {