diff --git a/Cargo.lock b/Cargo.lock index 6ed4ee7b..83e21a3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,7 +248,7 @@ dependencies = [ "jsonwebtoken", "lru-cache", "num_cpus", - "opentelemetry", + "opentelemetry 0.16.0", "opentelemetry-jaeger", "parking_lot", "pretty_env_logger", @@ -1466,16 +1466,46 @@ dependencies = [ ] [[package]] -name = "opentelemetry-jaeger" -version = "0.14.0" +name = "opentelemetry" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a9fc8192722e7daa0c56e59e2336b797122fb8598383dcb11c8852733b435c" +checksum = "e1cf9b1c4e9a6c4de793c632496fa490bdc0e1eea73f0c91394f7b6990935d22" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures", + "js-sys", + "lazy_static", + "percent-encoding", + "pin-project", + "rand 0.8.4", + "thiserror", + "tokio", + "tokio-stream", +] + +[[package]] +name = "opentelemetry-jaeger" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db22f492873ea037bc267b35a0e8e4fb846340058cb7c864efe3d0bf23684593" dependencies = [ "async-trait", "lazy_static", - "opentelemetry", + "opentelemetry 0.16.0", + "opentelemetry-semantic-conventions", "thiserror", "thrift", + "tokio", +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffeac823339e8b0f27b961f4385057bf9f97f2863bc745bd015fd6091f2270e9" +dependencies = [ + "opentelemetry 0.16.0", ] [[package]] @@ -2014,8 +2044,8 @@ dependencies = [ [[package]] name = "ruma" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "assign", "js_int", @@ -2035,8 +2065,8 @@ dependencies = [ [[package]] name = "ruma-api" -version = "0.17.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.18.3" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "bytes", "http", @@ -2051,8 +2081,8 @@ dependencies = [ [[package]] name = "ruma-api-macros" -version = "0.17.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.18.3" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2062,8 +2092,8 @@ dependencies = [ [[package]] name = "ruma-appservice-api" -version = "0.3.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.4.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "ruma-api", "ruma-common", @@ -2076,8 +2106,8 @@ dependencies = [ [[package]] name = "ruma-client-api" -version = "0.11.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.12.2" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "assign", "bytes", @@ -2096,8 +2126,8 @@ dependencies = [ [[package]] name = "ruma-common" -version = "0.5.4" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.6.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "indexmap", "js_int", @@ -2111,8 +2141,8 @@ dependencies = [ [[package]] name = "ruma-events" -version = "0.23.2" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.24.4" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "indoc", "js_int", @@ -2127,8 +2157,8 @@ dependencies = [ [[package]] name = "ruma-events-macros" -version = "0.23.2" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.24.4" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2138,8 +2168,8 @@ dependencies = [ [[package]] name = "ruma-federation-api" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "js_int", "ruma-api", @@ -2153,8 +2183,8 @@ dependencies = [ [[package]] name = "ruma-identifiers" -version = "0.19.4" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.20.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "paste", "rand 0.8.4", @@ -2167,8 +2197,8 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" -version = "0.19.4" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.20.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "quote", "ruma-identifiers-validation", @@ -2177,13 +2207,13 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" -version = "0.4.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.5.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" [[package]] name = "ruma-identity-service-api" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "js_int", "ruma-api", @@ -2195,8 +2225,8 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "js_int", "ruma-api", @@ -2210,8 +2240,8 @@ dependencies = [ [[package]] name = "ruma-serde" -version = "0.4.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.5.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "bytes", "form_urlencoded", @@ -2224,8 +2254,8 @@ dependencies = [ [[package]] name = "ruma-serde-macros" -version = "0.4.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.5.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2235,8 +2265,8 @@ dependencies = [ [[package]] name = "ruma-signatures" -version = "0.8.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.9.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "base64 0.13.0", "ed25519-dalek", @@ -2252,8 +2282,8 @@ dependencies = [ [[package]] name = "ruma-state-res" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "itertools 0.10.1", "js_int", @@ -3022,7 +3052,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c47440f2979c4cd3138922840eec122e3c0ba2148bc290f756bd7fd60fc97fff" dependencies = [ - "opentelemetry", + "opentelemetry 0.15.0", "tracing", "tracing-core", "tracing-log", diff --git a/Cargo.toml b/Cargo.toml index 47bbd2f7..d28e0b79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,8 +18,8 @@ edition = "2018" rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests # Used for matrix spec type definitions and helpers -#ruma = { git = "https://github.com/ruma/ruma", rev = "eb19b0e08a901b87d11b3be0890ec788cc760492", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } -ruma = { git = "https://github.com/timokoesters/ruma", rev = "a2d93500e1dbc87e7032a3c74f3b2479a7f84e93", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } +ruma = { git = "https://github.com/ruma/ruma", rev = "f5ab038e22421ed338396ece977b6b2844772ced", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } +#ruma = { git = "https://github.com/timokoesters/ruma", rev = "995ccea20f5f6d4a8fb22041749ed4de22fa1b6a", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } # Used for long polling and federation sender, should be the same as rocket::tokio @@ -66,11 +66,11 @@ regex = "1.5.4" jsonwebtoken = "7.2.0" # Performance measurements tracing = { version = "0.1.26", features = ["release_max_level_warn"] } -opentelemetry = "0.15.0" tracing-subscriber = "0.2.19" tracing-opentelemetry = "0.14.0" tracing-flame = "0.1.0" -opentelemetry-jaeger = "0.14.0" +opentelemetry = { version = "0.16.0", features = ["rt-tokio"] } +opentelemetry-jaeger = { version = "0.15.0", features = ["rt-tokio"] } pretty_env_logger = "0.4.0" lru-cache = "0.1.2" rusqlite = { version = "0.25.3", optional = true, features = ["bundled"] } diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 48159c98..e68c957f 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -249,6 +249,8 @@ pub async fn register_route( let room_id = RoomId::new(db.globals.server_name()); + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let mutex_state = Arc::clone( db.globals .roomid_mutex_state @@ -290,6 +292,7 @@ pub async fn register_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -455,6 +458,7 @@ pub async fn register_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -476,6 +480,7 @@ pub async fn register_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -681,6 +686,7 @@ pub async fn deactivate_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }; let mutex_state = Arc::clone( @@ -731,7 +737,7 @@ pub async fn deactivate_route( pub async fn third_party_route( body: Ruma, ) -> ConduitResult { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_contacts::Response::new(Vec::new()).into()) } diff --git a/src/client_server/context.rs b/src/client_server/context.rs index dbc121e3..701e584d 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -44,7 +44,7 @@ pub async fn get_context_route( let events_before = db .rooms - .pdus_until(&sender_user, &body.room_id, base_token) + .pdus_until(&sender_user, &body.room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") @@ -66,7 +66,7 @@ pub async fn get_context_route( let events_after = db .rooms - .pdus_after(&sender_user, &body.room_id, base_token) + .pdus_after(&sender_user, &body.room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 716a615f..222d204f 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -262,6 +262,7 @@ pub async fn ban_user_route( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&body.user_id)?, + reason: None, }), |event| { let mut event = serde_json::from_value::>( @@ -563,6 +564,7 @@ async fn join_room_by_id_helper( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }) .expect("event is valid, we just created it"), ); @@ -609,6 +611,8 @@ async fn join_room_by_id_helper( ) .await?; + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; @@ -693,6 +697,7 @@ async fn join_room_by_id_helper( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }; db.rooms.build_and_append_pdu( @@ -844,6 +849,7 @@ pub async fn invite_helper<'a>( membership: MembershipState::Invite, third_party_invite: None, blurhash: None, + reason: None, }) .expect("member event is valid value"); @@ -1038,6 +1044,7 @@ pub async fn invite_helper<'a>( is_direct: Some(is_direct), third_party_invite: None, blurhash: db.users.blurhash(&user_id)?, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 9cb6faab..70cc00f7 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -128,7 +128,7 @@ pub async fn get_message_events_route( get_message_events::Direction::Forward => { let events_after = db .rooms - .pdus_after(&sender_user, &body.room_id, from) + .pdus_after(&sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { @@ -158,7 +158,7 @@ pub async fn get_message_events_route( get_message_events::Direction::Backward => { let events_before = db .rooms - .pdus_until(&sender_user, &body.room_id, from) + .pdus_until(&sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { diff --git a/src/client_server/room.rs b/src/client_server/room.rs index 89241f5c..25412788 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -33,6 +33,8 @@ pub async fn create_room_route( let room_id = RoomId::new(db.globals.server_name()); + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let mutex_state = Arc::clone( db.globals .roomid_mutex_state @@ -105,6 +107,7 @@ pub async fn create_room_route( is_direct: Some(body.is_direct), third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -173,7 +176,6 @@ pub async fn create_room_route( )?; // 4. Canonical room alias - if let Some(room_alias_id) = &alias { db.rooms.build_and_append_pdu( PduBuilder { @@ -193,7 +195,7 @@ pub async fn create_room_route( &room_id, &db, &state_lock, - ); + )?; } // 5. Events set by preset @@ -516,6 +518,7 @@ pub async fn upgrade_room_route( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, diff --git a/src/client_server/session.rs b/src/client_server/session.rs index d4d3c033..dada2d50 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -3,7 +3,10 @@ use crate::{database::DatabaseGuard, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, - r0::session::{get_login_types, login, logout, logout_all}, + r0::{ + session::{get_login_types, login, logout, logout_all}, + uiaa::IncomingUserIdentifier, + }, }, UserId, }; @@ -60,7 +63,7 @@ pub async fn login_route( identifier, password, } => { - let username = if let login::IncomingUserIdentifier::MatrixId(matrix_id) = identifier { + let username = if let IncomingUserIdentifier::MatrixId(matrix_id) = identifier { matrix_id } else { return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 937a252e..c196b2a9 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -205,7 +205,7 @@ async fn sync_helper( let mut non_timeline_pdus = db .rooms - .pdus_until(&sender_user, &room_id, u64::MAX) + .pdus_until(&sender_user, &room_id, u64::MAX)? .filter_map(|r| { // Filter out buggy events if r.is_err() { @@ -248,13 +248,13 @@ async fn sync_helper( let first_pdu_before_since = db .rooms - .pdus_until(&sender_user, &room_id, since) + .pdus_until(&sender_user, &room_id, since)? .next() .transpose()?; let pdus_after_since = db .rooms - .pdus_after(&sender_user, &room_id, since) + .pdus_after(&sender_user, &room_id, since)? .next() .is_some(); @@ -286,7 +286,7 @@ async fn sync_helper( for hero in db .rooms - .all_pdus(&sender_user, &room_id) + .all_pdus(&sender_user, &room_id)? .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) .map(|(_, pdu)| { @@ -328,11 +328,11 @@ async fn sync_helper( } } - ( + Ok::<_, Error>(( Some(joined_member_count), Some(invited_member_count), heroes, - ) + )) }; let ( @@ -343,7 +343,7 @@ async fn sync_helper( state_events, ) = if since_shortstatehash.is_none() { // Probably since = 0, we will do an initial sync - let (joined_member_count, invited_member_count, heroes) = calculate_counts(); + let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; let state_events = current_state_ids @@ -510,7 +510,7 @@ async fn sync_helper( } let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts() + calculate_counts()? } else { (None, None, Vec::new()) }; diff --git a/src/database.rs b/src/database.rs index bdff386f..3d1324e6 100644 --- a/src/database.rs +++ b/src/database.rs @@ -24,13 +24,14 @@ use rocket::{ request::{FromRequest, Request}, Shutdown, State, }; -use ruma::{DeviceId, RoomId, ServerName, UserId}; +use ruma::{DeviceId, EventId, RoomId, ServerName, UserId}; use serde::{de::IgnoredAny, Deserialize}; use std::{ - collections::{BTreeMap, HashMap}, - convert::TryFrom, + collections::{BTreeMap, HashMap, HashSet}, + convert::{TryFrom, TryInto}, fs::{self, remove_dir_all}, io::Write, + mem::size_of, ops::Deref, path::Path, sync::{Arc, Mutex, RwLock}, @@ -107,7 +108,7 @@ fn default_db_cache_capacity_mb() -> f64 { } fn default_sqlite_wal_clean_second_interval() -> u32 { - 15 * 60 // every 15 minutes + 1 * 60 // every minute } fn default_max_request_size() -> u32 { @@ -261,7 +262,11 @@ impl Database { userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, - stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, + + shortroomid_roomid: builder.open_tree("shortroomid_roomid")?, + roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, + + shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, @@ -270,8 +275,12 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, referencedevents: builder.open_tree("referencedevents")?, - pdu_cache: Mutex::new(LruCache::new(1_000_000)), - auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), + pdu_cache: Mutex::new(LruCache::new(100_000)), + auth_chain_cache: Mutex::new(LruCache::new(100_000)), + shorteventid_cache: Mutex::new(LruCache::new(1_000_000)), + eventidshort_cache: Mutex::new(LruCache::new(1_000_000)), + statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)), + stateinfo_cache: Mutex::new(LruCache::new(50)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, @@ -424,7 +433,6 @@ impl Database { } if db.globals.database_version()? < 6 { - // TODO update to 6 // Set room member count for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { let room_id = @@ -437,6 +445,261 @@ impl Database { println!("Migration: 5 -> 6 finished"); } + + if db.globals.database_version()? < 7 { + // Upgrade state store + let mut last_roomstates: HashMap = HashMap::new(); + let mut current_sstatehash: Option = None; + let mut current_room = None; + let mut current_state = HashSet::new(); + let mut counter = 0; + + let mut handle_state = + |current_sstatehash: u64, + current_room: &RoomId, + current_state: HashSet<_>, + last_roomstates: &mut HashMap<_, _>| { + counter += 1; + println!("counter: {}", counter); + let last_roomsstatehash = last_roomstates.get(current_room); + + let states_parents = last_roomsstatehash.map_or_else( + || Ok(Vec::new()), + |&last_roomsstatehash| { + db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash)) + }, + )?; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = current_state + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); + + let statediffremoved = parent_stateinfo + .1 + .difference(¤t_state) + .cloned() + .collect::>(); + + (statediffnew, statediffremoved) + } else { + (current_state, HashSet::new()) + }; + + db.rooms.save_state_from_diff( + dbg!(current_sstatehash), + statediffnew, + statediffremoved, + 2, // every state change is 2 event changes on average + states_parents, + )?; + + /* + let mut tmp = db.rooms.load_shortstatehash_info(¤t_sstatehash, &db)?; + let state = tmp.pop().unwrap(); + println!( + "{}\t{}{:?}: {:?} + {:?} - {:?}", + current_room, + " ".repeat(tmp.len()), + utils::u64_from_bytes(¤t_sstatehash).unwrap(), + tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), + state + .2 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>(), + state + .3 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>() + ); + */ + + Ok::<_, Error>(()) + }; + + for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { + let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]) + .expect("number of bytes is correct"); + let sstatekey = k[size_of::()..].to_vec(); + if Some(sstatehash) != current_sstatehash { + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_ref().unwrap(), + current_state, + &mut last_roomstates, + )?; + last_roomstates + .insert(current_room.clone().unwrap(), current_sstatehash); + } + current_state = HashSet::new(); + current_sstatehash = Some(sstatehash); + + let event_id = db + .rooms + .shorteventid_eventid + .get(&seventid) + .unwrap() + .unwrap(); + let event_id = + EventId::try_from(utils::string_from_bytes(&event_id).unwrap()) + .unwrap(); + let pdu = db.rooms.get_pdu(&event_id).unwrap().unwrap(); + + if Some(&pdu.room_id) != current_room.as_ref() { + current_room = Some(pdu.room_id.clone()); + } + } + + let mut val = sstatekey; + val.extend_from_slice(&seventid); + current_state.insert(val.try_into().expect("size is correct")); + } + + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_ref().unwrap(), + current_state, + &mut last_roomstates, + )?; + } + + db.globals.bump_database_version(7)?; + + println!("Migration: 6 -> 7 finished"); + } + + if db.globals.database_version()? < 8 { + // Generate short room ids for all rooms + for (room_id, _) in db.rooms.roomid_shortstatehash.iter() { + let shortroomid = db.globals.next_count()?.to_be_bytes(); + db.rooms.roomid_shortroomid.insert(&room_id, &shortroomid)?; + db.rooms.shortroomid_roomid.insert(&shortroomid, &room_id)?; + println!("Migration: 8"); + } + // Update pduids db layout + let mut batch = db.rooms.pduid_pdu.iter().filter_map(|(key, v)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(2, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_key = short_room_id; + new_key.extend_from_slice(count); + + Some((new_key, v)) + }); + + db.rooms.pduid_pdu.insert_batch(&mut batch)?; + + let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| { + if !value.starts_with(b"!") { + return None; + } + let mut parts = value.splitn(2, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_value = short_room_id; + new_value.extend_from_slice(count); + + Some((k, new_value)) + }); + + db.rooms.eventid_pduid.insert_batch(&mut batch2)?; + + db.globals.bump_database_version(8)?; + + println!("Migration: 7 -> 8 finished"); + } + + if db.globals.database_version()? < 9 { + // Update tokenids db layout + let batch = db + .rooms + .tokenids + .iter() + .filter_map(|(key, _)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(4, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let word = parts.next().unwrap(); + let _pdu_id_room = parts.next().unwrap(); + let pdu_id_count = parts.next().unwrap(); + + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); + let mut new_key = short_room_id; + new_key.extend_from_slice(word); + new_key.push(0xff); + new_key.extend_from_slice(pdu_id_count); + println!("old {:?}", key); + println!("new {:?}", new_key); + Some((new_key, Vec::new())) + }) + .collect::>(); + + let mut iter = batch.into_iter().peekable(); + + while iter.peek().is_some() { + db.rooms + .tokenids + .insert_batch(&mut iter.by_ref().take(1000))?; + println!("smaller batch done"); + } + + println!("Deleting starts"); + + let batch2 = db + .rooms + .tokenids + .iter() + .filter_map(|(key, _)| { + if key.starts_with(b"!") { + println!("del {:?}", key); + Some(key) + } else { + None + } + }) + .collect::>(); + + for key in batch2 { + println!("del"); + db.rooms.tokenids.remove(&key)?; + } + + db.globals.bump_database_version(9)?; + + println!("Migration: 8 -> 9 finished"); + } } let guard = db.read().await; diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index f381ce9f..5b941fbe 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -35,6 +35,7 @@ pub trait Tree: Send + Sync { ) -> Box, Vec)> + 'a>; fn increment(&self, key: &[u8]) -> Result>; + fn increment_batch<'a>(&self, iter: &mut dyn Iterator>) -> Result<()>; fn scan_prefix<'a>( &'a self, diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 72fb5f7f..5b895c78 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -9,15 +9,13 @@ use std::{ path::{Path, PathBuf}, pin::Pin, sync::Arc, - time::{Duration, Instant}, }; use tokio::sync::oneshot::Sender; -use tracing::{debug, warn}; - -pub const MILLI: Duration = Duration::from_millis(1); +use tracing::debug; thread_local! { static READ_CONNECTION: RefCell> = RefCell::new(None); + static READ_CONNECTION_ITERATOR: RefCell> = RefCell::new(None); } struct PreparedStatementIterator<'a> { @@ -51,11 +49,11 @@ impl Engine { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { let conn = Connection::open(&path)?; - conn.pragma_update(Some(Main), "page_size", &32768)?; + conn.pragma_update(Some(Main), "page_size", &2048)?; conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; - conn.pragma_update(Some(Main), "wal_autocheckpoint", &0)?; + conn.pragma_update(Some(Main), "wal_autocheckpoint", &2000)?; Ok(conn) } @@ -79,9 +77,25 @@ impl Engine { }) } + fn read_lock_iterator(&self) -> &'static Connection { + READ_CONNECTION_ITERATOR.with(|cell| { + let connection = &mut cell.borrow_mut(); + + if (*connection).is_none() { + let c = Box::leak(Box::new( + Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(), + )); + **connection = Some(c); + } + + connection.unwrap() + }) + } + pub fn flush_wal(self: &Arc) -> Result<()> { - self.write_lock() - .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; + // We use autocheckpoints + //self.write_lock() + //.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; Ok(()) } } @@ -153,6 +167,34 @@ impl SqliteTable { )?; Ok(()) } + + pub fn iter_with_guard<'a>( + &'a self, + guard: &'a Connection, + ) -> Box + 'a> { + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} ORDER BY key ASC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) + } } impl Tree for SqliteTable { @@ -164,16 +206,7 @@ impl Tree for SqliteTable { #[tracing::instrument(skip(self, key, value))] fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); - - let start = Instant::now(); - self.insert_with_guard(&guard, key, value)?; - - let elapsed = start.elapsed(); - if elapsed > MILLI { - warn!("insert took {:?} : {}", elapsed, &self.name); - } - drop(guard); let watchers = self.watchers.read(); @@ -216,53 +249,41 @@ impl Tree for SqliteTable { Ok(()) } + #[tracing::instrument(skip(self, iter))] + fn increment_batch<'a>(&self, iter: &mut dyn Iterator>) -> Result<()> { + let guard = self.engine.write_lock(); + + guard.execute("BEGIN", [])?; + for key in iter { + let old = self.get_with_guard(&guard, &key)?; + let new = crate::utils::increment(old.as_deref()) + .expect("utils::increment always returns Some"); + self.insert_with_guard(&guard, &key, &new)?; + } + guard.execute("COMMIT", [])?; + + drop(guard); + + Ok(()) + } + #[tracing::instrument(skip(self, key))] fn remove(&self, key: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); - let start = Instant::now(); - guard.execute( format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key], )?; - let elapsed = start.elapsed(); - - if elapsed > MILLI { - debug!("remove: took {:012?} : {}", elapsed, &self.name); - } - // debug!("remove key: {:?}", &key); - Ok(()) } #[tracing::instrument(skip(self))] fn iter<'a>(&'a self) -> Box + 'a> { - let guard = self.engine.read_lock(); + let guard = self.engine.read_lock_iterator(); - let statement = Box::leak(Box::new( - guard - .prepare(&format!( - "SELECT key, value FROM {} ORDER BY key ASC", - &self.name - )) - .unwrap(), - )); - - let statement_ref = NonAliasingBox(statement); - - let iterator = Box::new( - statement - .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(|r| r.unwrap()), - ); - - Box::new(PreparedStatementIterator { - iterator, - statement_ref, - }) + self.iter_with_guard(&guard) } #[tracing::instrument(skip(self, from, backwards))] @@ -271,7 +292,7 @@ impl Tree for SqliteTable { from: &[u8], backwards: bool, ) -> Box + 'a> { - let guard = self.engine.read_lock(); + let guard = self.engine.read_lock_iterator(); let from = from.to_vec(); // TODO change interface? if backwards { @@ -326,8 +347,6 @@ impl Tree for SqliteTable { fn increment(&self, key: &[u8]) -> Result> { let guard = self.engine.write_lock(); - let start = Instant::now(); - let old = self.get_with_guard(&guard, key)?; let new = @@ -335,26 +354,11 @@ impl Tree for SqliteTable { self.insert_with_guard(&guard, key, &new)?; - let elapsed = start.elapsed(); - - if elapsed > MILLI { - debug!("increment: took {:012?} : {}", elapsed, &self.name); - } - // debug!("increment key: {:?}", &key); - Ok(new) } #[tracing::instrument(skip(self, prefix))] fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { - // let name = self.name.clone(); - // self.iter_from_thread( - // format!( - // "SELECT key, value FROM {} WHERE key BETWEEN ?1 AND ?1 || X'FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF' ORDER BY key ASC", - // name - // ) - // [prefix] - // ) Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 0f422350..d3600f18 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -24,7 +24,7 @@ use ruma::{ use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet}, convert::{TryFrom, TryInto}, - mem, + mem::size_of, sync::{Arc, Mutex}, }; use tokio::sync::MutexGuard; @@ -37,17 +37,18 @@ use super::{abstraction::Tree, admin::AdminCommand, pusher}; /// This is created when a state group is added to the database by /// hashing the entire state. pub type StateHashId = Vec; +pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Rooms { pub edus: edus::RoomEdus, - pub(super) pduid_pdu: Arc, // PduId = RoomId + Count + pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count pub(super) eventid_pduid: Arc, pub(super) roomid_pduleaves: Arc, pub(super) alias_roomid: Arc, pub(super) aliasid_alias: Arc, // AliasId = RoomId + Count pub(super) publicroomids: Arc, - pub(super) tokenids: Arc, // TokenId = RoomId + Token + PduId + pub(super) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount /// Participating servers in a room. pub(super) roomserverids: Arc, // RoomServerId = RoomId + ServerName @@ -71,14 +72,15 @@ pub struct Rooms { pub(super) shorteventid_shortstatehash: Arc, /// StateKey = EventType + StateKey, ShortStateKey = Count pub(super) statekey_shortstatekey: Arc, + + pub(super) shortroomid_roomid: Arc, + pub(super) roomid_shortroomid: Arc, + pub(super) shorteventid_eventid: Arc, - /// ShortEventId = Count pub(super) eventid_shorteventid: Arc, - /// ShortEventId = Count + pub(super) statehash_shortstatehash: Arc, - /// ShortStateHash = Count - /// StateId = ShortStateHash + ShortStateKey - pub(super) stateid_shorteventid: Arc, + pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) /// RoomId + EventId -> outlier PDU. /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. @@ -88,43 +90,52 @@ pub struct Rooms { pub(super) referencedevents: Arc, pub(super) pdu_cache: Mutex>>, - pub(super) auth_chain_cache: Mutex, HashSet>>, + pub(super) auth_chain_cache: Mutex>>, + pub(super) shorteventid_cache: Mutex>, + pub(super) eventidshort_cache: Mutex>, + pub(super) statekeyshort_cache: Mutex>, + pub(super) stateinfo_cache: Mutex< + LruCache< + u64, + Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + >, + >, } impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. + #[tracing::instrument(skip(self))] pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { - Ok(self - .stateid_shorteventid - .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) - .flatten() - .map(|bytes| { - EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in stateid_shorteventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in stateid_shorteventid is invalid.")) - }) - .filter_map(|r| r.ok()) - .collect()) + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + full_state + .into_iter() + .map(|compressed| self.parse_compressed_state_event(compressed)) + .collect() } + #[tracing::instrument(skip(self))] pub fn state_full( &self, shortstatehash: u64, ) -> Result>> { - let state = self - .stateid_shorteventid - .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) - .flatten() - .map(|bytes| { - EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in stateid_shorteventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in stateid_shorteventid is invalid.")) - }) + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .map(|compressed| self.parse_compressed_state_event(compressed)) .filter_map(|r| r.ok()) .map(|eventid| self.get_pdu(&eventid)) .filter_map(|r| r.ok().flatten()) @@ -141,9 +152,7 @@ impl Rooms { )) }) .filter_map(|r| r.ok()) - .collect(); - - Ok(state) + .collect()) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -154,32 +163,19 @@ impl Rooms { event_type: &EventType, state_key: &str, ) -> Result> { - let mut key = event_type.as_ref().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&state_key.as_bytes()); - - let shortstatekey = self.statekey_shortstatekey.get(&key)?; - - if let Some(shortstatekey) = shortstatekey { - let mut stateid = shortstatehash.to_be_bytes().to_vec(); - stateid.extend_from_slice(&shortstatekey); - - Ok(self - .stateid_shorteventid - .get(&stateid)? - .map(|bytes| self.shorteventid_eventid.get(&bytes).ok().flatten()) - .flatten() - .map(|bytes| { - EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in stateid_shorteventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in stateid_shorteventid is invalid.")) - }) - .map(|r| r.ok()) - .flatten()) - } else { - Ok(None) - } + let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { + Some(s) => s, + None => return Ok(None), + }; + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| self.parse_compressed_state_event(compressed).ok())) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -226,6 +222,7 @@ impl Rooms { } /// This fetches auth events from the current state. + #[tracing::instrument(skip(self))] pub fn get_auth_events( &self, room_id: &RoomId, @@ -267,9 +264,12 @@ impl Rooms { } /// Checks if a room exists. + #[tracing::instrument(skip(self))] pub fn exists(&self, room_id: &RoomId) -> Result { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = match self.get_shortroomid(room_id)? { + Some(b) => b.to_be_bytes().to_vec(), + None => return Ok(false), + }; // Look for PDUs in that room. Ok(self @@ -281,9 +281,13 @@ impl Rooms { } /// Checks if a room exists. + #[tracing::instrument(skip(self))] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); // Look for PDUs in that room. self.pduid_pdu @@ -300,97 +304,79 @@ impl Rooms { /// Force the creation of a new StateHash and insert it into the db. /// - /// Whatever `state` is supplied to `force_state` __is__ the current room state snapshot. + /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot. + #[tracing::instrument(skip(self, new_state, db))] pub fn force_state( &self, room_id: &RoomId, - state: HashMap<(EventType, String), EventId>, + new_state: HashMap<(EventType, String), EventId>, db: &Database, ) -> Result<()> { + let previous_shortstatehash = self.current_shortstatehash(&room_id)?; + + let new_state_ids_compressed = new_state + .iter() + .filter_map(|((event_type, state_key), event_id)| { + let shortstatekey = self + .get_or_create_shortstatekey(event_type, state_key, &db.globals) + .ok()?; + Some( + self.compress_state_event(shortstatekey, event_id, &db.globals) + .ok()?, + ) + }) + .collect::>(); + let state_hash = self.calculate_hash( - &state + &new_state .values() .map(|event_id| event_id.as_bytes()) .collect::>(), ); - let (shortstatehash, already_existed) = - match self.statehash_shortstatehash.get(&state_hash)? { - Some(shortstatehash) => ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ), - None => { - let shortstatehash = db.globals.next_count()?; - self.statehash_shortstatehash - .insert(&state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - } - }; + let (new_shortstatehash, already_existed) = + self.get_or_create_shortstatehash(&state_hash, &db.globals)?; - let new_state = if !already_existed { - let mut new_state = HashSet::new(); + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok(()); + } - let batch = state - .iter() - .filter_map(|((event_type, state_key), eventid)| { - new_state.insert(eventid.clone()); + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() + { + let statediffnew = new_state_ids_compressed + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = db.globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let statediffremoved = parent_stateinfo + .1 + .difference(&new_state_ids_compressed) + .cloned() + .collect::>(); - let shorteventid = - match self.eventid_shorteventid.get(eventid.as_bytes()).ok()? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = db.globals.next_count().ok()?; - self.eventid_shorteventid - .insert(eventid.as_bytes(), &shorteventid.to_be_bytes()) - .ok()?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), eventid.as_bytes()) - .ok()?; - shorteventid.to_be_bytes().to_vec() - } - }; - - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - - Some((state_id, shorteventid)) - }) - .collect::>(); - - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; - - new_state + (statediffnew, statediffremoved) } else { - self.state_full_ids(shortstatehash)?.into_iter().collect() + (new_state_ids_compressed, HashSet::new()) }; - let old_state = self - .current_shortstatehash(&room_id)? - .map(|s| self.state_full_ids(s)) - .transpose()? - .map(|vec| vec.into_iter().collect::>()) - .unwrap_or_default(); + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 2, // every state change is 2 event changes on average + states_parents, + )?; + }; - for event_id in new_state.difference(&old_state) { - if let Some(pdu) = self.get_pdu_json(event_id)? { + for event_id in statediffnew + .into_iter() + .filter_map(|new| self.parse_compressed_state_event(new).ok()) + { + if let Some(pdu) = self.get_pdu_json(&event_id)? { if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") { if let Ok(pdu) = serde_json::from_value::( serde_json::to_value(&pdu).expect("CanonicalJsonObj is a valid JsonValue"), @@ -414,6 +400,7 @@ impl Rooms { &pdu.sender, None, db, + false, )?; } } @@ -422,12 +409,443 @@ impl Rooms { } } + self.update_joined_count(room_id)?; + self.roomid_shortstatehash - .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; Ok(()) } + /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. + #[tracing::instrument(skip(self))] + pub fn load_shortstatehash_info( + &self, + shortstatehash: u64, + ) -> Result< + Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + > { + if let Some(r) = self + .stateinfo_cache + .lock() + .unwrap() + .get_mut(&shortstatehash) + { + return Ok(r.clone()); + } + + let value = self + .shortstatehash_statediff + .get(&shortstatehash.to_be_bytes())? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = + utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = size_of::(); + while let Some(v) = value.get(i..i + 2 * size_of::()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::(); + continue; + } + if add_mode { + added.insert(v.try_into().expect("we checked the size above")); + } else { + removed.insert(v.try_into().expect("we checked the size above")); + } + i += 2 * size_of::(); + } + + if parent != 0_u64 { + let mut response = self.load_shortstatehash_info(parent)?; + let mut state = response.last().unwrap().1.clone(); + state.extend(added.iter().cloned()); + for r in &removed { + state.remove(r); + } + + response.push((shortstatehash, state, added, removed)); + + Ok(response) + } else { + let mut response = Vec::new(); + response.push((shortstatehash, added.clone(), added, removed)); + self.stateinfo_cache + .lock() + .unwrap() + .insert(shortstatehash, response.clone()); + Ok(response) + } + } + + #[tracing::instrument(skip(self, globals))] + pub fn compress_state_event( + &self, + shortstatekey: u64, + event_id: &EventId, + globals: &super::globals::Globals, + ) -> Result { + let mut v = shortstatekey.to_be_bytes().to_vec(); + v.extend_from_slice( + &self + .get_or_create_shorteventid(event_id, globals)? + .to_be_bytes(), + ); + Ok(v.try_into().expect("we checked the size above")) + } + + #[tracing::instrument(skip(self, compressed_event))] + pub fn parse_compressed_state_event( + &self, + compressed_event: CompressedStateEvent, + ) -> Result { + self.get_eventid_from_short( + utils::u64_from_bytes(&compressed_event[size_of::()..]) + .expect("bytes have right length"), + ) + } + + /// Creates a new shortstatehash that often is just a diff to an already existing + /// shortstatehash and therefore very efficient. + /// + /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer + /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0 + /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's + /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too. + /// + /// * `shortstatehash` - Shortstatehash of this state + /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid + /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid + /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer + /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer + #[tracing::instrument(skip( + self, + statediffnew, + statediffremoved, + diff_to_sibling, + parent_states + ))] + pub fn save_state_from_diff( + &self, + shortstatehash: u64, + statediffnew: HashSet, + statediffremoved: HashSet, + diff_to_sibling: usize, + mut parent_states: Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + ) -> Result<()> { + let diffsum = statediffnew.len() + statediffremoved.len(); + + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper + let parent = parent_states.pop().unwrap(); + + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + // It was not added in the parent and we removed it + parent_removed.insert(removed); + } + // Else it was added in the parent and we removed it again. We can forget this change + } + + for new in statediffnew { + if !parent_removed.remove(&new) { + // It was not touched in the parent and we added it + parent_new.insert(new); + } + // Else it was removed in the parent and we added it again. We can forget this change + } + + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + + return Ok(()); + } + + if parent_states.len() == 0 { + // There is no parent layer, create a new state + let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent + for new in &statediffnew { + value.extend_from_slice(&new[..]); + } + + if !statediffremoved.is_empty() { + warn!("Tried to create new state with removals"); + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value)?; + + return Ok(()); + }; + + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above + + let parent = parent_states.pop().unwrap(); + let parent_diff = parent.2.len() + parent.3.len(); + + if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + // Diff too big, we replace above layer(s) + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + // It was not added in the parent and we removed it + parent_removed.insert(removed); + } + // Else it was added in the parent and we removed it again. We can forget this change + } + + for new in statediffnew { + if !parent_removed.remove(&new) { + // It was not touched in the parent and we added it + parent_new.insert(new); + } + // Else it was removed in the parent and we added it again. We can forget this change + } + + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + } else { + // Diff small enough, we add diff as layer on top of parent + let mut value = parent.0.to_be_bytes().to_vec(); + for new in &statediffnew { + value.extend_from_slice(&new[..]); + } + + if !statediffremoved.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in &statediffremoved { + value.extend_from_slice(&removed[..]); + } + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value)?; + } + + Ok(()) + } + + /// Returns (shortstatehash, already_existed) + #[tracing::instrument(skip(self, globals))] + fn get_or_create_shortstatehash( + &self, + state_hash: &StateHashId, + globals: &super::globals::Globals, + ) -> Result<(u64, bool)> { + Ok(match self.statehash_shortstatehash.get(&state_hash)? { + Some(shortstatehash) => ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ), + None => { + let shortstatehash = globals.next_count()?; + self.statehash_shortstatehash + .insert(&state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + } + }) + } + + #[tracing::instrument(skip(self, globals))] + pub fn get_or_create_shorteventid( + &self, + event_id: &EventId, + globals: &super::globals::Globals, + ) -> Result { + if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(&event_id) { + return Ok(*short); + } + + let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { + Some(shorteventid) => utils::u64_from_bytes(&shorteventid) + .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + None => { + let shorteventid = globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + } + }; + + self.eventidshort_cache + .lock() + .unwrap() + .insert(event_id.clone(), short); + + Ok(short) + } + + #[tracing::instrument(skip(self))] + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + self.roomid_shortroomid + .get(&room_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + pub fn get_shortstatekey( + &self, + event_type: &EventType, + state_key: &str, + ) -> Result> { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(Some(*short)); + } + + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); + + let short = self + .statekey_shortstatekey + .get(&statekey)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose()?; + + if let Some(s) = short { + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), s); + } + + Ok(short) + } + + #[tracing::instrument(skip(self, globals))] + pub fn get_or_create_shortroomid( + &self, + room_id: &RoomId, + globals: &super::globals::Globals, + ) -> Result { + Ok(match self.roomid_shortroomid.get(&room_id.as_bytes())? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, + None => { + let short = globals.next_count()?; + self.roomid_shortroomid + .insert(&room_id.as_bytes(), &short.to_be_bytes())?; + short + } + }) + } + + #[tracing::instrument(skip(self, globals))] + pub fn get_or_create_shortstatekey( + &self, + event_type: &EventType, + state_key: &str, + globals: &super::globals::Globals, + ) -> Result { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(*short); + } + + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); + + let short = match self.statekey_shortstatekey.get(&statekey)? { + Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, + None => { + let shortstatekey = globals.next_count()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes())?; + shortstatekey + } + }; + + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), short); + + Ok(short) + } + + #[tracing::instrument(skip(self))] + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { + if let Some(id) = self + .shorteventid_cache + .lock() + .unwrap() + .get_mut(&shorteventid) + { + return Ok(id.clone()); + } + + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + + let event_id = + EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?; + + self.shorteventid_cache + .lock() + .unwrap() + .insert(shorteventid, event_id.clone()); + + Ok(event_id) + } + /// Returns the full room state. #[tracing::instrument(skip(self))] pub fn room_state_full( @@ -475,21 +893,26 @@ impl Rooms { #[tracing::instrument(skip(self))] pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { Ok( - utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::()..pdu_id.len()]) + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?, ) } /// Returns the `count` of this pdu's id. + #[tracing::instrument(skip(self))] pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map_or(Ok(None), |pdu_id| self.pdu_count(&pdu_id).map(Some)) } + #[tracing::instrument(skip(self))] pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -504,6 +927,7 @@ impl Rooms { } /// Returns the json of a pdu. + #[tracing::instrument(skip(self))] pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -522,6 +946,18 @@ impl Rooms { } /// Returns the json of a pdu. + #[tracing::instrument(skip(self))] + pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + + /// Returns the json of a pdu. + #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu_json( &self, event_id: &EventId, @@ -543,6 +979,7 @@ impl Rooms { } /// Returns the pdu's id. + #[tracing::instrument(skip(self))] pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid .get(event_id.as_bytes())? @@ -552,6 +989,7 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -572,6 +1010,7 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + #[tracing::instrument(skip(self))] pub fn get_pdu(&self, event_id: &EventId) -> Result>> { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(&event_id) { return Ok(Some(Arc::clone(p))); @@ -611,6 +1050,7 @@ impl Rooms { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. + #[tracing::instrument(skip(self))] pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -621,6 +1061,7 @@ impl Rooms { } /// Returns the pdu as a `BTreeMap`. + #[tracing::instrument(skip(self))] pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -631,6 +1072,7 @@ impl Rooms { } /// Removes a pdu and creates a new one with the same id. + #[tracing::instrument(skip(self))] fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { if self.pduid_pdu.get(&pdu_id)?.is_some() { self.pduid_pdu.insert( @@ -719,6 +1161,8 @@ impl Rooms { /// /// By this point the incoming event should be fully authenticated, no auth happens /// in `append_pdu`. + /// + /// Returns pdu id #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] pub fn append_pdu( &self, @@ -727,7 +1171,8 @@ impl Rooms { leaves: &[EventId], db: &Database, ) -> Result> { - // returns pdu id + let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); + // Make unsigned fields correct. This is not properly documented in the spec, but state // events need to have previous content in the unsigned field, so clients can easily // interpret things like membership changes @@ -782,8 +1227,7 @@ impl Rooms { self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; let count2 = db.globals.next_count()?; - let mut pdu_id = pdu.room_id.as_bytes().to_vec(); - pdu_id.push(0xff); + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); // There's a brief moment of time here where the count is updated but the pdu does not @@ -796,10 +1240,9 @@ impl Rooms { &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), )?; - // This also replaces the eventid of any outliers with the correct - // pduid, removing the place holder. self.eventid_pduid .insert(pdu.event_id.as_bytes(), &pdu_id)?; + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; drop(insert_lock); @@ -816,6 +1259,9 @@ impl Rooms { let sync_pdu = pdu.to_sync_room_event(); + let mut notifies = Vec::new(); + let mut highlights = Vec::new(); + for user in db .rooms .room_members(&pdu.room_id) @@ -861,11 +1307,11 @@ impl Rooms { userroom_id.extend_from_slice(pdu.room_id.as_bytes()); if notify { - self.userroomid_notificationcount.increment(&userroom_id)?; + notifies.push(userroom_id.clone()); } if highlight { - self.userroomid_highlightcount.increment(&userroom_id)?; + highlights.push(userroom_id); } for senderkey in db.pusher.get_pusher_senderkeys(&user) { @@ -873,6 +1319,11 @@ impl Rooms { } } + self.userroomid_notificationcount + .increment_batch(&mut notifies.into_iter())?; + self.userroomid_highlightcount + .increment_batch(&mut highlights.into_iter())?; + match pdu.kind { EventType::RoomRedaction => { if let Some(redact_id) = &pdu.redacts { @@ -919,6 +1370,7 @@ impl Rooms { &pdu.sender, invite_state, db, + true, )?; } } @@ -926,11 +1378,11 @@ impl Rooms { if let Some(body) = pdu.content.get("body").and_then(|b| b.as_str()) { let mut batch = body .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) .filter(|word| word.len() <= 50) .map(str::to_lowercase) .map(|word| { - let mut key = pdu.room_id.as_bytes().to_vec(); - key.push(0xff); + let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); key.push(0xff); key.extend_from_slice(&pdu_id); @@ -1113,20 +1565,26 @@ impl Rooms { pub fn set_event_state( &self, event_id: &EventId, + room_id: &RoomId, state: &StateMap>, globals: &super::globals::Globals, ) -> Result<()> { - let shorteventid = match self.eventid_shorteventid.get(event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?; + + let previous_shortstatehash = self.current_shortstatehash(&room_id)?; + + let state_ids_compressed = state + .iter() + .filter_map(|((event_type, state_key), pdu)| { + let shortstatekey = self + .get_or_create_shortstatekey(event_type, state_key, globals) + .ok()?; + Some( + self.compress_state_event(shortstatekey, &pdu.event_id, globals) + .ok()?, + ) + }) + .collect::>(); let state_hash = self.calculate_hash( &state @@ -1135,69 +1593,41 @@ impl Rooms { .collect::>(), ); - let shortstatehash = match self.statehash_shortstatehash.get(&state_hash)? { - Some(shortstatehash) => { - // State already existed in db - self.shorteventid_shortstatehash - .insert(&shorteventid, &*shortstatehash)?; - return Ok(()); - } - None => { - let shortstatehash = globals.next_count()?; - self.statehash_shortstatehash - .insert(&state_hash, &shortstatehash.to_be_bytes())?; - shortstatehash.to_be_bytes().to_vec() - } - }; + let (shortstatehash, already_existed) = + self.get_or_create_shortstatehash(&state_hash, globals)?; - let batch = state - .iter() - .filter_map(|((event_type, state_key), pdu)| { - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + if !already_existed { + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = state_ids_compressed + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); + + let statediffremoved = parent_stateinfo + .1 + .difference(&state_ids_compressed) + .cloned() + .collect::>(); + + (statediffnew, statediffremoved) + } else { + (state_ids_compressed, HashSet::new()) }; - - let shorteventid = match self - .eventid_shorteventid - .get(pdu.event_id.as_bytes()) - .ok()? - { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count().ok()?; - self.eventid_shorteventid - .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes()) - .ok()?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes()) - .ok()?; - shorteventid.to_be_bytes().to_vec() - } - }; - - let mut state_id = shortstatehash.clone(); - state_id.extend_from_slice(&shortstatekey); - - Some((state_id, shorteventid)) - }) - .collect::>(); - - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; + self.save_state_from_diff( + shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 1_000_000, // high number because no state will be based on this one + states_parents, + )?; + } self.shorteventid_shortstatehash - .insert(&shorteventid, &*shortstatehash)?; + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } @@ -1212,106 +1642,59 @@ impl Rooms { new_pdu: &PduEvent, globals: &super::globals::Globals, ) -> Result { - let old_state = if let Some(old_shortstatehash) = - self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())? - { - // Store state for event. The state does not include the event itself. - // Instead it's the state before the pdu, so the room's old state. + let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; - let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; + let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; + if let Some(p) = previous_shortstatehash { self.shorteventid_shortstatehash - .insert(&shorteventid, &old_shortstatehash)?; - if new_pdu.state_key.is_none() { - return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") - }); - } - - self.stateid_shorteventid - .scan_prefix(old_shortstatehash.clone()) - // Chop the old_shortstatehash out leaving behind the short state key - .map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v)) - .collect::, Vec>>() - } else { - HashMap::new() - }; + .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; + } if let Some(state_key) = &new_pdu.state_key { - let mut new_state: HashMap, Vec> = old_state; + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); - new_state_key.push(0xff); - new_state_key.extend_from_slice(state_key.as_bytes()); + let shortstatekey = + self.get_or_create_shortstatekey(&new_pdu.kind, &state_key, globals)?; - let shortstatekey = match self.statekey_shortstatekey.get(&new_state_key)? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count()?; - self.statekey_shortstatekey - .insert(&new_state_key, &shortstatekey.to_be_bytes())?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; - let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; + let replaces = states_parents + .last() + .map(|info| { + info.1 + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + }) + .unwrap_or_default(); - new_state.insert(shortstatekey, shorteventid); + if Some(&new) == replaces { + return Ok(previous_shortstatehash.expect("must exist")); + } - let new_state_hash = self.calculate_hash( - &new_state - .values() - .map(|event_id| &**event_id) - .collect::>(), - ); + // TODO: statehash with deterministic inputs + let shortstatehash = globals.next_count()?; - let shortstatehash = match self.statehash_shortstatehash.get(&new_state_hash)? { - Some(shortstatehash) => { - warn!("state hash already existed?!"); - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))? - } - None => { - let shortstatehash = globals.next_count()?; - self.statehash_shortstatehash - .insert(&new_state_hash, &shortstatehash.to_be_bytes())?; - shortstatehash - } - }; + let mut statediffnew = HashSet::new(); + statediffnew.insert(new); - let mut batch = new_state.into_iter().map(|(shortstatekey, shorteventid)| { - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - (state_id, shorteventid) - }); + let mut statediffremoved = HashSet::new(); + if let Some(replaces) = replaces { + statediffremoved.insert(replaces.clone()); + } - self.stateid_shorteventid.insert_batch(&mut batch)?; + self.save_state_from_diff( + shortstatehash, + statediffnew, + statediffremoved, + 2, + states_parents, + )?; Ok(shortstatehash) } else { - Err(Error::bad_database( - "Tried to insert non-state event into room without a state.", - )) + Ok(previous_shortstatehash.expect("first event in room must be a state event")) } } @@ -1516,11 +1899,7 @@ impl Rooms { ); // Generate short event id - let shorteventid = db.globals.next_count()?; - self.eventid_shorteventid - .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes())?; + let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. @@ -1618,7 +1997,7 @@ impl Rooms { &'a self, user_id: &UserId, room_id: &RoomId, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { self.pdus_since(user_id, room_id, 0) } @@ -1630,16 +2009,21 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> impl Iterator, PduEvent)>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + ) -> Result, PduEvent)>> + 'a> { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); // Skip the first pdu if it's exactly at since, because we sent that last time let mut first_pdu_id = prefix.clone(); first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(&first_pdu_id, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1649,7 +2033,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Returns an iterator over all events and their tokens in a room that happened before the @@ -1660,10 +2044,13 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` @@ -1671,7 +2058,9 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(current, true) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1681,7 +2070,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Returns an iterator over all events and their token in a room that happened after the event @@ -1692,10 +2081,13 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event @@ -1703,7 +2095,9 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(current, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1713,7 +2107,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Replace a PDU with the redacted form. @@ -1744,6 +2138,7 @@ impl Rooms { sender: &UserId, last_state: Option>>, db: &Database, + update_joined_count: bool, ) -> Result<()> { // Keep track what remote users exist by adding them as "deactivated" users if user_id.server_name() != db.globals.server_name() { @@ -1861,8 +2256,10 @@ impl Rooms { } } - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; + if update_joined_count { + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -1887,8 +2284,10 @@ impl Rooms { return Ok(()); } - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; + if update_joined_count { + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } self.userroomid_invitestate.insert( &userroom_id, &serde_json::to_vec(&last_state.unwrap_or_default()) @@ -1902,14 +2301,16 @@ impl Rooms { self.roomuserid_leftcount.remove(&roomuser_id)?; } member::MembershipState::Leave | member::MembershipState::Ban => { - if self - .room_members(room_id) - .chain(self.room_members_invited(room_id)) - .filter_map(|r| r.ok()) - .all(|u| u.server_name() != user_id.server_name()) - { - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; + if update_joined_count { + if self + .room_members(room_id) + .chain(self.room_members_invited(room_id)) + .filter_map(|r| r.ok()) + .all(|u| u.server_name() != user_id.server_name()) + { + self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; + } } self.userroomid_leftstate.insert( &userroom_id, @@ -1925,18 +2326,64 @@ impl Rooms { _ => {} } - self.update_joined_count(room_id)?; + if update_joined_count { + self.update_joined_count(room_id)?; + } Ok(()) } + #[tracing::instrument(skip(self))] pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - self.roomid_joinedcount.insert( - room_id.as_bytes(), - &(self.room_members(&room_id).count() as u64).to_be_bytes(), - ) + let mut joinedcount = 0_u64; + let mut joined_servers = HashSet::new(); + + for joined in self.room_members(&room_id).filter_map(|r| r.ok()) { + joined_servers.insert(joined.server_name().to_owned()); + joinedcount += 1; + } + + for invited in self.room_members_invited(&room_id).filter_map(|r| r.ok()) { + joined_servers.insert(invited.server_name().to_owned()); + } + + self.roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; + + for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { + if !joined_servers.remove(&old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xff); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xff); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; + } + } + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xff); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xff); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } + + Ok(()) } + #[tracing::instrument(skip(self, db))] pub async fn leave_room( &self, user_id: &UserId, @@ -1962,6 +2409,7 @@ impl Rooms { user_id, last_state, db, + true, )?; } else { let mutex_state = Arc::clone( @@ -2008,6 +2456,7 @@ impl Rooms { Ok(()) } + #[tracing::instrument(skip(self, db))] async fn remote_leave_room( &self, user_id: &UserId, @@ -2239,16 +2688,22 @@ impl Rooms { }) } + #[tracing::instrument(skip(self))] pub fn search_pdus<'a>( &'a self, room_id: &RoomId, search_string: &str, ) -> Result<(impl Iterator> + 'a, Vec)> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + let prefix_clone = prefix.clone(); let words = search_string .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) .map(str::to_lowercase) .collect::>(); @@ -2264,16 +2719,7 @@ impl Rooms { .iter_from(&last_possible_id, true) // Newest pdus first .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(|(key, _)| { - let pduid_index = key - .iter() - .enumerate() - .filter(|(_, &b)| b == 0xff) - .nth(1) - .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? - .0 - + 1; // +1 because the pdu id starts AFTER the separator - - let pdu_id = key[pduid_index..].to_vec(); + let pdu_id = key[key.len() - size_of::()..].to_vec(); Ok::<_, Error>(pdu_id) }) @@ -2285,7 +2731,12 @@ impl Rooms { // We compare b with a because we reversed the iterator earlier b.cmp(a) }) - .unwrap(), + .unwrap() + .map(move |id| { + let mut pduid = prefix_clone.clone(); + pduid.extend_from_slice(&id); + pduid + }), words, )) } @@ -2398,6 +2849,7 @@ impl Rooms { }) } + #[tracing::instrument(skip(self))] pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { Ok(self .roomid_joinedcount @@ -2655,9 +3107,7 @@ impl Rooms { } #[tracing::instrument(skip(self))] - pub fn auth_chain_cache( - &self, - ) -> std::sync::MutexGuard<'_, LruCache, HashSet>> { + pub fn auth_chain_cache(&self) -> std::sync::MutexGuard<'_, LruCache>> { self.auth_chain_cache.lock().unwrap() } } diff --git a/src/database/uiaa.rs b/src/database/uiaa.rs index 1372fef5..8a3fe4f0 100644 --- a/src/database/uiaa.rs +++ b/src/database/uiaa.rs @@ -4,11 +4,14 @@ use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, - r0::uiaa::{IncomingAuthData, UiaaInfo}, + r0::uiaa::{ + IncomingAuthData, IncomingPassword, IncomingUserIdentifier::MatrixId, UiaaInfo, + }, }, signatures::CanonicalJsonValue, DeviceId, UserId, }; +use tracing::error; use super::abstraction::Tree; @@ -49,126 +52,91 @@ impl Uiaa { users: &super::users::Users, globals: &super::globals::Globals, ) -> Result<(bool, UiaaInfo)> { - if let IncomingAuthData::DirectRequest { - kind, - session, - auth_parameters, - } = &auth - { - let mut uiaainfo = session - .as_ref() - .map(|session| self.get_uiaa_session(&user_id, &device_id, session)) - .unwrap_or_else(|| Ok(uiaainfo.clone()))?; + let mut uiaainfo = auth + .session() + .map(|session| self.get_uiaa_session(&user_id, &device_id, session)) + .unwrap_or_else(|| Ok(uiaainfo.clone()))?; - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } + match auth { // Find out what the user completed - match &**kind { - "m.login.password" => { - let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "m.login.password needs identifier.", - ))?; - - let identifier_type = identifier.get("type").ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "Identifier needs a type.", - ))?; - - if identifier_type != "m.id.user" { + IncomingAuthData::Password(IncomingPassword { + identifier, + password, + .. + }) => { + let username = match identifier { + MatrixId(username) => username, + _ => { return Err(Error::BadRequest( ErrorKind::Unrecognized, "Identifier type not recognized.", - )); + )) } + }; - let username = identifier - .get("user") - .ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "Identifier needs user field.", - ))? - .as_str() - .ok_or(Error::BadRequest( - ErrorKind::BadJson, - "User is not a string.", - ))?; - - let user_id = UserId::parse_with_server_name(username, globals.server_name()) + let user_id = + UserId::parse_with_server_name(username.clone(), globals.server_name()) .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") - })?; + Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") + })?; - let password = auth_parameters - .get("password") - .ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "Password is missing.", - ))? - .as_str() - .ok_or(Error::BadRequest( - ErrorKind::BadJson, - "Password is not a string.", - ))?; + // Check if password is correct + if let Some(hash) = users.password_hash(&user_id)? { + let hash_matches = + argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); - // Check if password is correct - if let Some(hash) = users.password_hash(&user_id)? { - let hash_matches = - argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); - - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } - - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push("m.login.password".to_owned()); - } - "m.login.dummy" => { - uiaainfo.completed.push("m.login.dummy".to_owned()); - } - k => panic!("type not supported: {}", k), - } - - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; + if !hash_matches { + uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password.".to_owned(), + }); + return Ok((false, uiaainfo)); } } - // We didn't break, so this flow succeeded! - completed = true; - } - if !completed { - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push("m.login.password".to_owned()); } + IncomingAuthData::Dummy(_) => { + uiaainfo.completed.push("m.login.dummy".to_owned()); + } + k => error!("type not supported: {:?}", k), + } - // UIAA was successful! Remove this session and return true + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } + + if !completed { self.update_uiaa_session( user_id, device_id, uiaainfo.session.as_ref().expect("session is always set"), - None, + Some(&uiaainfo), )?; - Ok((true, uiaainfo)) - } else { - panic!("FallbackAcknowledgement is not supported yet"); + return Ok((false, uiaainfo)); } + + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + )?; + Ok((true, uiaainfo)) } fn set_uiaa_request( diff --git a/src/main.rs b/src/main.rs index 5a6f8c76..72f753fe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use database::Config; pub use database::Database; pub use error::{Error, Result}; -use opentelemetry::trace::Tracer; +use opentelemetry::trace::{FutureExt, Tracer}; pub use pdu::PduEvent; pub use rocket::State; use ruma::api::client::error::ErrorKind; @@ -220,14 +220,17 @@ async fn main() { }; if config.allow_jaeger { + opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); let tracer = opentelemetry_jaeger::new_pipeline() - .with_service_name("conduit") - .install_simple() + .install_batch(opentelemetry::runtime::Tokio) .unwrap(); let span = tracer.start("conduit"); - start.await; + start.with_current_context().await; drop(span); + + println!("exporting"); + opentelemetry::global::shutdown_tracer_provider(); } else { std::env::set_var("RUST_LOG", &config.log); diff --git a/src/pdu.rs b/src/pdu.rs index 00eda5b1..1016fe66 100644 --- a/src/pdu.rs +++ b/src/pdu.rs @@ -12,7 +12,7 @@ use ruma::{ use serde::{Deserialize, Serialize}; use serde_json::json; use std::{cmp::Ordering, collections::BTreeMap, convert::TryFrom}; -use tracing::error; +use tracing::warn; #[derive(Clone, Deserialize, Serialize, Debug)] pub struct PduEvent { @@ -322,7 +322,7 @@ pub(crate) fn gen_event_id_canonical_json( pdu: &Raw, ) -> crate::Result<(EventId, CanonicalJsonObject)> { let value = serde_json::from_str(pdu.json().get()).map_err(|e| { - error!("{:?}: {:?}", pdu, e); + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); Error::BadServerResponse("Invalid PDU in server response") })?; diff --git a/src/server_server.rs b/src/server_server.rs index bf5e4f34..5299e1f0 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -111,7 +111,7 @@ impl FedDest { } } -#[tracing::instrument(skip(globals))] +#[tracing::instrument(skip(globals, request))] pub async fn send_request( globals: &crate::database::globals::Globals, destination: &ServerName, @@ -254,7 +254,7 @@ where }); // TODO: handle timeout if status != 200 { - info!( + warn!( "{} {}: {}", url, status, @@ -272,14 +272,20 @@ where if status == 200 { let response = T::IncomingResponse::try_from_http_response(http_response); response.map_err(|e| { - warn!("Invalid 200 response from {}: {}", &destination, e); + warn!( + "Invalid 200 response from {} on: {} {}", + &destination, url, e + ); Error::BadServerResponse("Server returned bad 200 response.") }) } else { Err(Error::FederationError( destination.to_owned(), RumaError::try_from_http_response(http_response).map_err(|e| { - warn!("Server returned bad error response: {}", e); + warn!( + "Invalid {} response from {} on: {} {}", + status, &destination, url, e + ); Error::BadServerResponse("Server returned bad error response.") })?, )) @@ -495,7 +501,7 @@ pub fn get_server_keys_route(db: DatabaseGuard) -> Json { ) .unwrap(); - Json(ruma::serde::to_canonical_json_string(&response).expect("JSON is canonical")) + Json(serde_json::to_string(&response).expect("JSON is canonical")) } #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] @@ -668,7 +674,7 @@ pub async fn send_transaction_message_route( let elapsed = start_time.elapsed(); warn!( - "Handling event {} took {}m{}s", + "Handling transaction of event {} took {}m{}s", event_id, elapsed.as_secs() / 60, elapsed.as_secs() % 60 @@ -721,7 +727,8 @@ pub async fn send_transaction_message_route( &db.globals, )?; } else { - warn!("No known event ids in read receipt: {:?}", user_updates); + // TODO fetch missing events + debug!("No known event ids in read receipt: {:?}", user_updates); } } } @@ -839,7 +846,7 @@ type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; /// 14. Use state resolution to find new room state // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively #[tracing::instrument(skip(value, is_timeline_event, db, pub_key_map))] -pub fn handle_incoming_pdu<'a>( +pub async fn handle_incoming_pdu<'a>( origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, @@ -847,20 +854,174 @@ pub fn handle_incoming_pdu<'a>( is_timeline_event: bool, db: &'a Database, pub_key_map: &'a RwLock>>, -) -> AsyncRecursiveType<'a, StdResult>, String>> { +) -> StdResult>, String> { + match db.rooms.exists(&room_id) { + Ok(true) => {} + _ => { + return Err("Room is unknown to this server.".to_string()); + } + } + + // 1. Skip the PDU if we already have it as a timeline event + if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { + return Ok(Some(pdu_id.to_vec())); + } + + let create_event = db + .rooms + .room_state_get(&room_id, &EventType::RoomCreate, "") + .map_err(|_| "Failed to ask database for event.".to_owned())? + .ok_or_else(|| "Failed to find create event in db.".to_owned())?; + + let (incoming_pdu, val) = handle_outlier_pdu( + origin, + &create_event, + event_id, + room_id, + value, + db, + pub_key_map, + ) + .await?; + + // 8. if not timeline event: stop + if !is_timeline_event { + return Ok(None); + } + + // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + let mut graph = HashMap::new(); + let mut eventid_info = HashMap::new(); + let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); + + let mut amount = 0; + + while let Some(prev_event_id) = todo_outlier_stack.pop() { + if let Some((pdu, json_opt)) = fetch_and_handle_outliers( + db, + origin, + &[prev_event_id.clone()], + &create_event, + &room_id, + pub_key_map, + ) + .await + .pop() + { + if amount > 100 { + // Max limit reached + warn!("Max prev event limit reached!"); + graph.insert(prev_event_id.clone(), HashSet::new()); + continue; + } + + if let Some(json) = + json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) + { + if pdu.origin_server_ts + > db.rooms + .first_pdu_in_room(&room_id) + .map_err(|_| "Error loading first room event.".to_owned())? + .expect("Room exists") + .origin_server_ts + { + amount += 1; + for prev_prev in &pdu.prev_events { + if !graph.contains_key(prev_prev) { + todo_outlier_stack.push(dbg!(prev_prev.clone())); + } + } + + graph.insert( + prev_event_id.clone(), + pdu.prev_events.iter().cloned().collect(), + ); + eventid_info.insert(prev_event_id.clone(), (pdu, json)); + } else { + // Time based check failed + graph.insert(prev_event_id.clone(), HashSet::new()); + eventid_info.insert(prev_event_id.clone(), (pdu, json)); + } + } else { + // Get json failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } else { + // Fetch and handle failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } + + let sorted = + state_res::StateResolution::lexicographical_topological_sort(dbg!(&graph), |event_id| { + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + println!("{}", event_id); + Ok(( + 0, + MilliSecondsSinceUnixEpoch( + eventid_info + .get(event_id) + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts.clone()), + ), + ruma::event_id!("$notimportant"), + )) + }) + .map_err(|_| "Error sorting prev events".to_owned())?; + + for prev_id in dbg!(sorted) { + if let Some((pdu, json)) = eventid_info.remove(&prev_id) { + let start_time = Instant::now(); + let event_id = pdu.event_id.clone(); + if let Err(e) = upgrade_outlier_to_timeline_pdu( + pdu, + json, + &create_event, + origin, + db, + room_id, + pub_key_map, + ) + .await + { + warn!("Prev event {} failed: {}", event_id, e); + } + let elapsed = start_time.elapsed(); + warn!( + "Handling prev event {} took {}m{}s", + event_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } + } + + upgrade_outlier_to_timeline_pdu( + incoming_pdu, + val, + &create_event, + origin, + db, + room_id, + pub_key_map, + ) + .await +} + +#[tracing::instrument(skip(origin, create_event, event_id, room_id, value, db, pub_key_map))] +fn handle_outlier_pdu<'a>( + origin: &'a ServerName, + create_event: &'a PduEvent, + event_id: &'a EventId, + room_id: &'a RoomId, + value: BTreeMap, + db: &'a Database, + pub_key_map: &'a RwLock>>, +) -> AsyncRecursiveType<'a, StdResult<(Arc, BTreeMap), String>> +{ Box::pin(async move { // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - match db.rooms.exists(&room_id) { - Ok(true) => {} - _ => { - return Err("Room is unknown to this server.".to_string()); - } - } - - // 1. Skip the PDU if we already have it as a timeline event - if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { - return Ok(Some(pdu_id.to_vec())); - } // We go through all the signatures we see on the value and fetch the corresponding signing // keys @@ -870,11 +1031,6 @@ pub fn handle_incoming_pdu<'a>( // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match - let create_event = db - .rooms - .room_state_get(&room_id, &EventType::RoomCreate, "") - .map_err(|_| "Failed to ask database for event.".to_owned())? - .ok_or_else(|| "Failed to find create event in db.".to_owned())?; let create_event_content = serde_json::from_value::>(create_event.content.clone()) @@ -921,13 +1077,13 @@ pub fn handle_incoming_pdu<'a>( // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // EDIT: Step 5 is not applied anymore because it failed too often debug!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_events( + fetch_and_handle_outliers( db, origin, &incoming_pdu.auth_events, + &create_event, &room_id, pub_key_map, - false, ) .await; @@ -1010,200 +1166,234 @@ pub fn handle_incoming_pdu<'a>( .map_err(|_| "Failed to add pdu as outlier.".to_owned())?; debug!("Added pdu as outlier."); - // 8. if not timeline event: stop - if !is_timeline_event - || incoming_pdu.origin_server_ts - < db.rooms - .first_pdu_in_room(&room_id) - .map_err(|_| "Error loading first room event.".to_owned())? - .expect("Room exists") - .origin_server_ts - { - return Ok(None); - } + Ok((incoming_pdu, val)) + }) +} - // Load missing prev events first - fetch_and_handle_events( - db, - origin, - &incoming_pdu.prev_events, - &room_id, - pub_key_map, - true, - ) - .await; +#[tracing::instrument(skip(incoming_pdu, val, create_event, origin, db, room_id, pub_key_map))] +async fn upgrade_outlier_to_timeline_pdu( + incoming_pdu: Arc, + val: BTreeMap, + create_event: &PduEvent, + origin: &ServerName, + db: &Database, + room_id: &RoomId, + pub_key_map: &RwLock>>, +) -> StdResult>, String> { + if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { + return Ok(Some(pduid)); + } + // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities + // doing all the checks in this list starting at 1. These are not timeline events. - // TODO: 9. fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + // TODO: if we know the prev_events of the incoming event we can avoid the request and build + // the state from a known point and resolve if > 1 prev_event - // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities - // doing all the checks in this list starting at 1. These are not timeline events. + debug!("Requesting state at event."); + let mut state_at_incoming_event = None; - // TODO: if we know the prev_events of the incoming event we can avoid the request and build - // the state from a known point and resolve if > 1 prev_event + if incoming_pdu.prev_events.len() == 1 { + let prev_event = &incoming_pdu.prev_events[0]; + let prev_event_sstatehash = db + .rooms + .pdu_shortstatehash(prev_event) + .map_err(|_| "Failed talking to db".to_owned())?; - debug!("Requesting state at event."); - let mut state_at_incoming_event = None; + let state = + prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash)); - if incoming_pdu.prev_events.len() == 1 { - let prev_event = &incoming_pdu.prev_events[0]; - let state = db - .rooms - .pdu_shortstatehash(prev_event) - .map_err(|_| "Failed talking to db".to_owned())? - .map(|shortstatehash| db.rooms.state_full_ids(shortstatehash).ok()) - .flatten(); - if let Some(state) = state { - let mut state = fetch_and_handle_events( - db, - origin, - &state.into_iter().collect::>(), - &room_id, - pub_key_map, - false, - ) - .await - .into_iter() - .map(|pdu| { + if let Some(Ok(state)) = state { + warn!("Using cached state"); + let mut state = fetch_and_handle_outliers( + db, + origin, + &state.into_iter().collect::>(), + &create_event, + &room_id, + pub_key_map, + ) + .await + .into_iter() + .map(|(pdu, _)| { + ( ( - ( - pdu.kind.clone(), - pdu.state_key - .clone() - .expect("events from state_full_ids are state events"), - ), - pdu, - ) - }) - .collect::>(); + pdu.kind.clone(), + pdu.state_key + .clone() + .expect("events from state_full_ids are state events"), + ), + pdu, + ) + }) + .collect::>(); - let prev_pdu = db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { + let prev_pdu = + db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { "Could not find prev event, but we know the state.".to_owned() })?; - if let Some(state_key) = &prev_pdu.state_key { - state.insert((prev_pdu.kind.clone(), state_key.clone()), prev_pdu); + if let Some(state_key) = &prev_pdu.state_key { + state.insert((prev_pdu.kind.clone(), state_key.clone()), prev_pdu); + } + + state_at_incoming_event = Some(state); + } + // TODO: set incoming_auth_events? + } + + if state_at_incoming_event.is_none() { + warn!("Calling /state_ids"); + // Call /state_ids to find out what the state at this pdu is. We trust the server's + // response to some extend, but we still do a lot of checks on the events + match db + .sending + .send_federation_request( + &db.globals, + origin, + get_room_state_ids::v1::Request { + room_id: &room_id, + event_id: &incoming_pdu.event_id, + }, + ) + .await + { + Ok(res) => { + debug!("Fetching state events at event."); + let state_vec = fetch_and_handle_outliers( + &db, + origin, + &res.pdu_ids, + &create_event, + &room_id, + pub_key_map, + ) + .await; + + let mut state = HashMap::new(); + for (pdu, _) in state_vec { + match state.entry(( + pdu.kind.clone(), + pdu.state_key + .clone() + .ok_or_else(|| "Found non-state pdu in state events.".to_owned())?, + )) { + Entry::Vacant(v) => { + v.insert(pdu); + } + Entry::Occupied(_) => return Err( + "State event's type and state_key combination exists multiple times." + .to_owned(), + ), + } } + // The original create event must still be in the state + if state + .get(&(EventType::RoomCreate, "".to_owned())) + .map(|a| a.as_ref()) + != Some(&create_event) + { + return Err("Incoming event refers to wrong create event.".to_owned()); + } + + debug!("Fetching auth chain events at event."); + fetch_and_handle_outliers( + &db, + origin, + &res.auth_chain_ids, + &create_event, + &room_id, + pub_key_map, + ) + .await; + state_at_incoming_event = Some(state); } - // TODO: set incoming_auth_events? - } - - if state_at_incoming_event.is_none() { - // Call /state_ids to find out what the state at this pdu is. We trust the server's - // response to some extend, but we still do a lot of checks on the events - match db - .sending - .send_federation_request( - &db.globals, - origin, - get_room_state_ids::v1::Request { - room_id: &room_id, - event_id: &incoming_pdu.event_id, - }, - ) - .await - { - Ok(res) => { - debug!("Fetching state events at event."); - let state_vec = fetch_and_handle_events( - &db, - origin, - &res.pdu_ids, - &room_id, - pub_key_map, - false, - ) - .await; - - let mut state = HashMap::new(); - for pdu in state_vec { - match state.entry((pdu.kind.clone(), pdu.state_key.clone().ok_or_else(|| "Found non-state pdu in state events.".to_owned())?)) { - Entry::Vacant(v) => { - v.insert(pdu); - } - Entry::Occupied(_) => { - return Err( - "State event's type and state_key combination exists multiple times.".to_owned(), - ) - } - } - } - - // The original create event must still be in the state - if state - .get(&(EventType::RoomCreate, "".to_owned())) - .map(|a| a.as_ref()) - != Some(&create_event) - { - return Err("Incoming event refers to wrong create event.".to_owned()); - } - - debug!("Fetching auth chain events at event."); - fetch_and_handle_events( - &db, - origin, - &res.auth_chain_ids, - &room_id, - pub_key_map, - false, - ) - .await; - - state_at_incoming_event = Some(state); - } - Err(_) => { - return Err("Fetching state for event failed".into()); - } - }; - } - - let state_at_incoming_event = - state_at_incoming_event.expect("we always set this to some above"); - - // 11. Check the auth of the event passes based on the state of the event - if !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - previous_create.clone(), - &state_at_incoming_event, - None, // TODO: third party invite - ) - .map_err(|_e| "Auth check failed.".to_owned())? - { - return Err("Event has failed auth check with state at the event.".into()); - } - debug!("Auth check succeeded."); - - // We start looking at current room state now, so lets lock the room - - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Now we calculate the set of extremities this room has after the incoming event has been - // applied. We start with the previous extremities (aka leaves) - let mut extremities = db - .rooms - .get_pdu_leaves(&room_id) - .map_err(|_| "Failed to load room leaves".to_owned())?; - - // Remove any forward extremities that are referenced by this incoming event's prev_events - for prev_event in &incoming_pdu.prev_events { - if extremities.contains(prev_event) { - extremities.remove(prev_event); + Err(_) => { + return Err("Fetching state for event failed".into()); } + }; + } + + let state_at_incoming_event = + state_at_incoming_event.expect("we always set this to some above"); + + // 11. Check the auth of the event passes based on the state of the event + let create_event_content = + serde_json::from_value::>(create_event.content.clone()) + .expect("Raw::from_value always works.") + .deserialize() + .map_err(|_| "Invalid PowerLevels event in db.".to_owned())?; + + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + + // If the previous event was the create event special rules apply + let previous_create = if incoming_pdu.auth_events.len() == 1 + && incoming_pdu.prev_events == incoming_pdu.auth_events + { + db.rooms + .get_pdu(&incoming_pdu.auth_events[0]) + .map_err(|e| e.to_string())? + .filter(|maybe_create| **maybe_create == *create_event) + } else { + None + }; + + if !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + previous_create.clone(), + &state_at_incoming_event, + None, // TODO: third party invite + ) + .map_err(|_e| "Auth check failed.".to_owned())? + { + return Err("Event has failed auth check with state at the event.".into()); + } + debug!("Auth check succeeded."); + + // We start looking at current room state now, so lets lock the room + + let mutex_state = Arc::clone( + db.globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Now we calculate the set of extremities this room has after the incoming event has been + // applied. We start with the previous extremities (aka leaves) + let mut extremities = db + .rooms + .get_pdu_leaves(&room_id) + .map_err(|_| "Failed to load room leaves".to_owned())?; + + // Remove any forward extremities that are referenced by this incoming event's prev_events + for prev_event in &incoming_pdu.prev_events { + if extremities.contains(prev_event) { + extremities.remove(prev_event); } + } - // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(db.rooms.is_event_referenced(&room_id, id), Ok(true))); + // Only keep those extremities were not referenced yet + extremities.retain(|id| !matches!(db.rooms.is_event_referenced(&room_id, id), Ok(true))); + let current_statehash = db + .rooms + .current_shortstatehash(&room_id) + .map_err(|_| "Failed to load current state hash.".to_owned())? + .expect("every room has state"); + + let current_state = db + .rooms + .state_full(current_statehash) + .map_err(|_| "Failed to load room state.")?; + + if incoming_pdu.state_key.is_some() { let mut extremity_statehashes = Vec::new(); for id in &extremities { @@ -1239,16 +1429,6 @@ pub fn handle_incoming_pdu<'a>( // don't just trust a set of state we got from a remote). // We do this by adding the current state to the list of fork states - let current_statehash = db - .rooms - .current_shortstatehash(&room_id) - .map_err(|_| "Failed to load current state hash.".to_owned())? - .expect("every room has state"); - - let current_state = db - .rooms - .state_full(current_statehash) - .map_err(|_| "Failed to load room state.")?; extremity_statehashes.push((current_statehash.clone(), None)); @@ -1271,7 +1451,6 @@ pub fn handle_incoming_pdu<'a>( } // We also add state after incoming event to the fork states - extremities.insert(incoming_pdu.event_id.clone()); let mut state_after = state_at_incoming_event.clone(); if let Some(state_key) = &incoming_pdu.state_key { state_after.insert( @@ -1309,7 +1488,8 @@ pub fn handle_incoming_pdu<'a>( for state in fork_states { auth_chain_sets.push( get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) - .map_err(|_| "Failed to load auth chain.".to_owned())?, + .map_err(|_| "Failed to load auth chain.".to_owned())? + .collect(), ); } @@ -1335,38 +1515,6 @@ pub fn handle_incoming_pdu<'a>( state }; - debug!("starting soft fail auth check"); - // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it - let soft_fail = !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - previous_create, - ¤t_state, - None, - ) - .map_err(|_e| "Auth check failed.".to_owned())?; - - let mut pdu_id = None; - if !soft_fail { - // Now that the event has passed all auth it is added into the timeline. - // We use the `state_at_event` instead of `state_after` so we accurately - // represent the state for this event. - pdu_id = Some( - append_incoming_pdu( - &db, - &incoming_pdu, - val, - extremities, - &state_at_incoming_event, - &state_lock, - ) - .map_err(|_| "Failed to add pdu to db.".to_owned())?, - ); - debug!("Appended incoming pdu."); - } else { - warn!("Event was soft failed: {:?}", incoming_pdu); - } - // Set the new room state to the resolved state if update_state { db.rooms @@ -1374,34 +1522,70 @@ pub fn handle_incoming_pdu<'a>( .map_err(|_| "Failed to set new room state.".to_owned())?; } debug!("Updated resolved state"); + } - if soft_fail { - // Soft fail, we leave the event as an outlier but don't add it to the timeline - return Err("Event has been soft failed".into()); - } + extremities.insert(incoming_pdu.event_id.clone()); - // Event has passed all auth/stateres checks - drop(state_lock); - Ok(pdu_id) - }) + debug!("starting soft fail auth check"); + // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it + let soft_fail = !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + previous_create, + ¤t_state, + None, + ) + .map_err(|_e| "Auth check failed.".to_owned())?; + + let mut pdu_id = None; + if !soft_fail { + // Now that the event has passed all auth it is added into the timeline. + // We use the `state_at_event` instead of `state_after` so we accurately + // represent the state for this event. + pdu_id = Some( + append_incoming_pdu( + &db, + &incoming_pdu, + val, + extremities, + &state_at_incoming_event, + &state_lock, + ) + .map_err(|_| "Failed to add pdu to db.".to_owned())?, + ); + debug!("Appended incoming pdu."); + } else { + warn!("Event was soft failed: {:?}", incoming_pdu); + } + + if soft_fail { + // Soft fail, we leave the event as an outlier but don't add it to the timeline + return Err("Event has been soft failed".into()); + } + + // Event has passed all auth/stateres checks + drop(state_lock); + Ok(pdu_id) } /// Find the event and auth it. Once the event is validated (steps 1 - 8) /// it is appended to the outliers Tree. /// +/// Returns pdu and if we fetched it over federation the raw json. +/// /// a. Look in the main timeline (pduid_pdu tree) /// b. Look at outlier pdu tree /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? -//#[tracing::instrument(skip(db, key_map, auth_cache))] -pub(crate) fn fetch_and_handle_events<'a>( +#[tracing::instrument(skip(db, origin, events, create_event, room_id, pub_key_map))] +pub(crate) fn fetch_and_handle_outliers<'a>( db: &'a Database, origin: &'a ServerName, events: &'a [EventId], + create_event: &'a PduEvent, room_id: &'a RoomId, pub_key_map: &'a RwLock>>, - are_timeline_events: bool, -) -> AsyncRecursiveType<'a, Vec>> { +) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> { Box::pin(async move { let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { Entry::Vacant(e) => { @@ -1412,35 +1596,32 @@ pub(crate) fn fetch_and_handle_events<'a>( let mut pdus = vec![]; for id in events { + info!("loading {}", id); if let Some((time, tries)) = db.globals.bad_event_ratelimiter.read().unwrap().get(&id) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", id); + info!("Backing off from {}", id); continue; } } // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree - // (get_pdu checks both) - let local_pdu = if are_timeline_events { - db.rooms.get_non_outlier_pdu(&id).map(|o| o.map(Arc::new)) - } else { - db.rooms.get_pdu(&id) - }; + // (get_pdu_json checks both) + let local_pdu = db.rooms.get_pdu(&id); let pdu = match local_pdu { Ok(Some(pdu)) => { trace!("Found {} in db", id); - pdu + (pdu, None) } Ok(None) => { // c. Ask origin server over federation - debug!("Fetching {} over federation.", id); + info!("Fetching {} over federation.", id); match db .sending .send_federation_request( @@ -1451,41 +1632,29 @@ pub(crate) fn fetch_and_handle_events<'a>( .await { Ok(res) => { - debug!("Got {} over federation", id); - let (event_id, mut value) = + info!("Got {} over federation", id); + let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&res.pdu) { Ok(t) => t, - Err(_) => continue, + Err(_) => { + back_off(id.clone()); + continue; + } }; // This will also fetch the auth chain - match handle_incoming_pdu( + match handle_outlier_pdu( origin, + create_event, &event_id, &room_id, value.clone(), - are_timeline_events, db, pub_key_map, ) .await { - Ok(_) => { - value.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.into()), - ); - - Arc::new( - serde_json::from_value( - serde_json::to_value(value) - .expect("canonicaljsonobject is valid value"), - ) - .expect( - "This is possible because handle_incoming_pdu worked", - ), - ) - } + Ok((pdu, json)) => (pdu, Some(json)), Err(e) => { warn!("Authentication of event {} failed: {:?}", id, e); back_off(id.clone()); @@ -1501,7 +1670,7 @@ pub(crate) fn fetch_and_handle_events<'a>( } } Err(e) => { - debug!("Error loading {}: {}", id, e); + warn!("Error loading {}: {}", id, e); continue; } }; @@ -1513,7 +1682,7 @@ pub(crate) fn fetch_and_handle_events<'a>( /// Search the DB for the signing keys of the given server, if we don't have them /// fetch them from the server and save to our DB. -#[tracing::instrument(skip(db))] +#[tracing::instrument(skip(db, origin, signature_ids))] pub(crate) async fn fetch_signing_keys( db: &Database, origin: &ServerName, @@ -1684,7 +1853,7 @@ fn append_incoming_pdu( // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. db.rooms - .set_event_state(&pdu.event_id, state, &db.globals)?; + .set_event_state(&pdu.event_id, &pdu.room_id, state, &db.globals)?; let pdu_id = db.rooms.append_pdu( pdu, @@ -1754,51 +1923,72 @@ fn append_incoming_pdu( Ok(pdu_id) } -fn get_auth_chain(starting_events: Vec, db: &Database) -> Result> { +#[tracing::instrument(skip(starting_events, db))] +fn get_auth_chain( + starting_events: Vec, + db: &Database, +) -> Result + '_> { let mut full_auth_chain = HashSet::new(); + let starting_events = starting_events + .iter() + .map(|id| { + db.rooms + .get_or_create_shorteventid(id, &db.globals) + .map(|s| (s, id)) + }) + .collect::>>()?; + let mut cache = db.rooms.auth_chain_cache(); - for event_id in &starting_events { - if let Some(cached) = cache.get_mut(&[event_id.clone()][..]) { + for (sevent_id, event_id) in starting_events { + if let Some(cached) = cache.get_mut(&sevent_id) { full_auth_chain.extend(cached.iter().cloned()); } else { drop(cache); - let mut auth_chain = HashSet::new(); - get_auth_chain_recursive(&event_id, &mut auth_chain, db)?; + let auth_chain = get_auth_chain_inner(&event_id, db)?; cache = db.rooms.auth_chain_cache(); - cache.insert(vec![event_id.clone()], auth_chain.clone()); + cache.insert(sevent_id, auth_chain.clone()); full_auth_chain.extend(auth_chain); }; } - Ok(full_auth_chain) + drop(cache); + + Ok(full_auth_chain + .into_iter() + .filter_map(move |sid| db.rooms.get_eventid_from_short(sid).ok())) } -fn get_auth_chain_recursive( - event_id: &EventId, - found: &mut HashSet, - db: &Database, -) -> Result<()> { - let r = db.rooms.get_pdu(&event_id); - match r { - Ok(Some(pdu)) => { - for auth_event in &pdu.auth_events { - if !found.contains(auth_event) { - found.insert(auth_event.clone()); - get_auth_chain_recursive(&auth_event, found, db)?; +#[tracing::instrument(skip(event_id, db))] +fn get_auth_chain_inner(event_id: &EventId, db: &Database) -> Result> { + let mut todo = vec![event_id.clone()]; + let mut found = HashSet::new(); + + while let Some(event_id) = todo.pop() { + match db.rooms.get_pdu(&event_id) { + Ok(Some(pdu)) => { + for auth_event in &pdu.auth_events { + let sauthevent = db + .rooms + .get_or_create_shorteventid(auth_event, &db.globals)?; + + if !found.contains(&sauthevent) { + found.insert(sauthevent); + todo.push(auth_event.clone()); + } } } - } - Ok(None) => { - warn!("Could not find pdu mentioned in auth events."); - } - Err(e) => { - warn!("Could not load event in auth chain: {}", e); + Ok(None) => { + warn!("Could not find pdu mentioned in auth events: {}", event_id); + } + Err(e) => { + warn!("Could not load event in auth chain: {} {}", event_id, e); + } } } - Ok(()) + Ok(found) } #[cfg_attr( @@ -1892,7 +2082,6 @@ pub fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids - .into_iter() .filter_map(|id| Some(db.rooms.get_pdu_json(&id).ok()??)) .map(|event| PduEvent::convert_to_outgoing_federation_event(event)) .collect(), @@ -1936,7 +2125,6 @@ pub fn get_room_state_route( Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids - .into_iter() .map(|id| { Ok::<_, Error>(PduEvent::convert_to_outgoing_federation_event( db.rooms.get_pdu_json(&id)?.unwrap(), @@ -1979,7 +2167,7 @@ pub fn get_room_state_ids_route( let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.into_iter().collect(), + auth_chain_ids: auth_chain_ids.collect(), pdu_ids, } .into()) @@ -2056,6 +2244,7 @@ pub fn create_join_event_template_route( is_direct: None, membership: MembershipState::Join, third_party_invite: None, + reason: None, }) .expect("member event is valid value"); @@ -2248,7 +2437,6 @@ pub async fn create_join_event_route( Ok(create_join_event::v2::Response { room_state: RoomState { auth_chain: auth_chain_ids - .iter() .filter_map(|id| db.rooms.get_pdu_json(&id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), @@ -2359,6 +2547,7 @@ pub async fn create_invite_route( &sender, Some(invite_state), &db, + true, )?; } @@ -2532,6 +2721,7 @@ pub async fn claim_keys_route( .into()) } +#[tracing::instrument(skip(event, pub_key_map, db))] pub async fn fetch_required_signing_keys( event: &BTreeMap, pub_key_map: &RwLock>>,