From 096e0971f1da629adeec53bbac0fd775cc46fa79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Wed, 11 Aug 2021 10:24:16 +0200 Subject: [PATCH 01/25] improvement: smaller cache, better prev event fetching --- src/database.rs | 4 ++-- src/server_server.rs | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/database.rs b/src/database.rs index bdff386f..1bf9434f 100644 --- a/src/database.rs +++ b/src/database.rs @@ -270,8 +270,8 @@ impl Database { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, referencedevents: builder.open_tree("referencedevents")?, - pdu_cache: Mutex::new(LruCache::new(1_000_000)), - auth_chain_cache: Mutex::new(LruCache::new(1_000_000)), + pdu_cache: Mutex::new(LruCache::new(100_000)), + auth_chain_cache: Mutex::new(LruCache::new(100_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/server_server.rs b/src/server_server.rs index bf5e4f34..68adcd0a 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1022,7 +1022,7 @@ pub fn handle_incoming_pdu<'a>( return Ok(None); } - // Load missing prev events first + // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events fetch_and_handle_events( db, origin, @@ -1033,8 +1033,6 @@ pub fn handle_incoming_pdu<'a>( ) .await; - // TODO: 9. fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities // doing all the checks in this list starting at 1. These are not timeline events. From c2c6a8673e2a8a3fca079e20692c2da76e9622d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Wed, 11 Aug 2021 19:15:38 +0200 Subject: [PATCH 02/25] improvement: use u64s in auth chain cache --- src/database/rooms.rs | 230 ++++++++++++++++++------------------------ src/server_server.rs | 43 +++++--- 2 files changed, 131 insertions(+), 142 deletions(-) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 0f422350..c53fa9ea 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -88,7 +88,7 @@ pub struct Rooms { pub(super) referencedevents: Arc, pub(super) pdu_cache: Mutex>>, - pub(super) auth_chain_cache: Mutex, HashSet>>, + pub(super) auth_chain_cache: Mutex>>, } impl Rooms { @@ -315,19 +315,7 @@ impl Rooms { ); let (shortstatehash, already_existed) = - match self.statehash_shortstatehash.get(&state_hash)? { - Some(shortstatehash) => ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ), - None => { - let shortstatehash = db.globals.next_count()?; - self.statehash_shortstatehash - .insert(&state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - } - }; + self.get_or_create_shortstatehash(&state_hash, &db.globals)?; let new_state = if !already_existed { let mut new_state = HashSet::new(); @@ -352,25 +340,14 @@ impl Rooms { } }; - let shorteventid = - match self.eventid_shorteventid.get(eventid.as_bytes()).ok()? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = db.globals.next_count().ok()?; - self.eventid_shorteventid - .insert(eventid.as_bytes(), &shorteventid.to_be_bytes()) - .ok()?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), eventid.as_bytes()) - .ok()?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = self + .get_or_create_shorteventid(&eventid, &db.globals) + .ok()?; let mut state_id = shortstatehash.to_be_bytes().to_vec(); state_id.extend_from_slice(&shortstatekey); - Some((state_id, shorteventid)) + Some((state_id, shorteventid.to_be_bytes().to_vec())) }) .collect::>(); @@ -428,6 +405,61 @@ impl Rooms { Ok(()) } + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash( + &self, + state_hash: &StateHashId, + globals: &super::globals::Globals, + ) -> Result<(u64, bool)> { + Ok(match self.statehash_shortstatehash.get(&state_hash)? { + Some(shortstatehash) => ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ), + None => { + let shortstatehash = globals.next_count()?; + self.statehash_shortstatehash + .insert(&state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + } + }) + } + + /// Returns (shortstatehash, already_existed) + pub fn get_or_create_shorteventid( + &self, + event_id: &EventId, + globals: &super::globals::Globals, + ) -> Result { + Ok(match self.eventid_shorteventid.get(event_id.as_bytes())? { + Some(shorteventid) => utils::u64_from_bytes(&shorteventid) + .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + None => { + let shorteventid = globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + } + }) + } + + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + + EventId::try_from( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + } + /// Returns the full room state. #[tracing::instrument(skip(self))] pub fn room_state_full( @@ -1116,17 +1148,7 @@ impl Rooms { state: &StateMap>, globals: &super::globals::Globals, ) -> Result<()> { - let shorteventid = match self.eventid_shorteventid.get(event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?; let state_hash = self.calculate_hash( &state @@ -1135,69 +1157,45 @@ impl Rooms { .collect::>(), ); - let shortstatehash = match self.statehash_shortstatehash.get(&state_hash)? { - Some(shortstatehash) => { - // State already existed in db - self.shorteventid_shortstatehash - .insert(&shorteventid, &*shortstatehash)?; - return Ok(()); - } - None => { - let shortstatehash = globals.next_count()?; - self.statehash_shortstatehash - .insert(&state_hash, &shortstatehash.to_be_bytes())?; - shortstatehash.to_be_bytes().to_vec() - } - }; + let (shortstatehash, already_existed) = + self.get_or_create_shortstatehash(&state_hash, globals)?; - let batch = state - .iter() - .filter_map(|((event_type, state_key), pdu)| { - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + if !already_existed { + let batch = state + .iter() + .filter_map(|((event_type, state_key), pdu)| { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { + Some(shortstatekey) => shortstatekey.to_vec(), + None => { + let shortstatekey = globals.next_count().ok()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes()) + .ok()?; + shortstatekey.to_be_bytes().to_vec() + } + }; - let shorteventid = match self - .eventid_shorteventid - .get(pdu.event_id.as_bytes()) - .ok()? - { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count().ok()?; - self.eventid_shorteventid - .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes()) - .ok()?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes()) - .ok()?; - shorteventid.to_be_bytes().to_vec() - } - }; + let shorteventid = self + .get_or_create_shorteventid(&pdu.event_id, globals) + .ok()?; - let mut state_id = shortstatehash.clone(); - state_id.extend_from_slice(&shortstatekey); + let mut state_id = shortstatehash.to_be_bytes().to_vec(); + state_id.extend_from_slice(&shortstatekey); - Some((state_id, shorteventid)) - }) - .collect::>(); + Some((state_id, shorteventid.to_be_bytes().to_vec())) + }) + .collect::>(); - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; + self.stateid_shorteventid + .insert_batch(&mut batch.into_iter())?; + } self.shorteventid_shortstatehash - .insert(&shorteventid, &*shortstatehash)?; + .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; Ok(()) } @@ -1212,26 +1210,16 @@ impl Rooms { new_pdu: &PduEvent, globals: &super::globals::Globals, ) -> Result { + let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; + let old_state = if let Some(old_shortstatehash) = self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())? { // Store state for event. The state does not include the event itself. // Instead it's the state before the pdu, so the room's old state. - - let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; - self.shorteventid_shortstatehash - .insert(&shorteventid, &old_shortstatehash)?; + .insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?; + if new_pdu.state_key.is_none() { return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") @@ -1264,19 +1252,7 @@ impl Rooms { } }; - let shorteventid = match self.eventid_shorteventid.get(new_pdu.event_id.as_bytes())? { - Some(shorteventid) => shorteventid.to_vec(), - None => { - let shorteventid = globals.next_count()?; - self.eventid_shorteventid - .insert(new_pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), new_pdu.event_id.as_bytes())?; - shorteventid.to_be_bytes().to_vec() - } - }; - - new_state.insert(shortstatekey, shorteventid); + new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec()); let new_state_hash = self.calculate_hash( &new_state @@ -1516,11 +1492,7 @@ impl Rooms { ); // Generate short event id - let shorteventid = db.globals.next_count()?; - self.eventid_shorteventid - .insert(pdu.event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), pdu.event_id.as_bytes())?; + let _shorteventid = self.get_or_create_shorteventid(&pdu.event_id, &db.globals)?; // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. @@ -2655,9 +2627,7 @@ impl Rooms { } #[tracing::instrument(skip(self))] - pub fn auth_chain_cache( - &self, - ) -> std::sync::MutexGuard<'_, LruCache, HashSet>> { + pub fn auth_chain_cache(&self) -> std::sync::MutexGuard<'_, LruCache>> { self.auth_chain_cache.lock().unwrap() } } diff --git a/src/server_server.rs b/src/server_server.rs index 68adcd0a..23c80ee6 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1044,13 +1044,16 @@ pub fn handle_incoming_pdu<'a>( if incoming_pdu.prev_events.len() == 1 { let prev_event = &incoming_pdu.prev_events[0]; - let state = db + let prev_event_sstatehash = db .rooms .pdu_shortstatehash(prev_event) - .map_err(|_| "Failed talking to db".to_owned())? - .map(|shortstatehash| db.rooms.state_full_ids(shortstatehash).ok()) - .flatten(); - if let Some(state) = state { + .map_err(|_| "Failed talking to db".to_owned())?; + + let state = + prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash)); + + if let Some(Ok(state)) = state { + warn!("Using cached state"); let mut state = fetch_and_handle_events( db, origin, @@ -1088,6 +1091,7 @@ pub fn handle_incoming_pdu<'a>( } if state_at_incoming_event.is_none() { + warn!("Calling /state_ids"); // Call /state_ids to find out what the state at this pdu is. We trust the server's // response to some extend, but we still do a lot of checks on the events match db @@ -1755,35 +1759,50 @@ fn append_incoming_pdu( fn get_auth_chain(starting_events: Vec, db: &Database) -> Result> { let mut full_auth_chain = HashSet::new(); + let starting_events = starting_events + .iter() + .map(|id| { + (db.rooms + .get_or_create_shorteventid(id, &db.globals) + .map(|s| (s, id))) + }) + .collect::>>()?; + let mut cache = db.rooms.auth_chain_cache(); - for event_id in &starting_events { - if let Some(cached) = cache.get_mut(&[event_id.clone()][..]) { + for (sevent_id, event_id) in starting_events { + if let Some(cached) = cache.get_mut(&sevent_id) { full_auth_chain.extend(cached.iter().cloned()); } else { drop(cache); let mut auth_chain = HashSet::new(); get_auth_chain_recursive(&event_id, &mut auth_chain, db)?; cache = db.rooms.auth_chain_cache(); - cache.insert(vec![event_id.clone()], auth_chain.clone()); + cache.insert(sevent_id, auth_chain.clone()); full_auth_chain.extend(auth_chain); }; } - Ok(full_auth_chain) + full_auth_chain + .into_iter() + .map(|sid| db.rooms.get_eventid_from_short(sid)) + .collect() } fn get_auth_chain_recursive( event_id: &EventId, - found: &mut HashSet, + found: &mut HashSet, db: &Database, ) -> Result<()> { let r = db.rooms.get_pdu(&event_id); match r { Ok(Some(pdu)) => { for auth_event in &pdu.auth_events { - if !found.contains(auth_event) { - found.insert(auth_event.clone()); + let sauthevent = db + .rooms + .get_or_create_shorteventid(auth_event, &db.globals)?; + if !found.contains(&sauthevent) { + found.insert(sauthevent); get_auth_chain_recursive(&auth_event, found, db)?; } } From 5173d0deb59a4bdaf03026f6894610e396ddbec4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Wed, 11 Aug 2021 21:14:22 +0200 Subject: [PATCH 03/25] improvement: cache for short event ids --- src/database.rs | 1 + src/database/rooms.rs | 16 ++++++++++++++-- src/server_server.rs | 20 +++++++++----------- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/database.rs b/src/database.rs index 1bf9434f..7996057d 100644 --- a/src/database.rs +++ b/src/database.rs @@ -272,6 +272,7 @@ impl Database { referencedevents: builder.open_tree("referencedevents")?, pdu_cache: Mutex::new(LruCache::new(100_000)), auth_chain_cache: Mutex::new(LruCache::new(100_000)), + shorteventid_cache: Mutex::new(LruCache::new(100_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index c53fa9ea..246aa0ba 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -89,6 +89,7 @@ pub struct Rooms { pub(super) pdu_cache: Mutex>>, pub(super) auth_chain_cache: Mutex>>, + pub(super) shorteventid_cache: Mutex>, } impl Rooms { @@ -447,17 +448,28 @@ impl Rooms { } pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { + if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) { + return Ok(id.clone()); + } + let bytes = self .shorteventid_eventid .get(&shorteventid.to_be_bytes())? .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - EventId::try_from( + let event_id = EventId::try_from( utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") })?, ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?; + + self.shorteventid_cache + .lock() + .unwrap() + .insert(shorteventid, event_id.clone()); + + Ok(event_id) } /// Returns the full room state. diff --git a/src/server_server.rs b/src/server_server.rs index 23c80ee6..0b9a7e68 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1311,7 +1311,7 @@ pub fn handle_incoming_pdu<'a>( for state in fork_states { auth_chain_sets.push( get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) - .map_err(|_| "Failed to load auth chain.".to_owned())?, + .map_err(|_| "Failed to load auth chain.".to_owned())?.collect(), ); } @@ -1756,15 +1756,15 @@ fn append_incoming_pdu( Ok(pdu_id) } -fn get_auth_chain(starting_events: Vec, db: &Database) -> Result> { +fn get_auth_chain(starting_events: Vec, db: &Database) -> Result + '_> { let mut full_auth_chain = HashSet::new(); let starting_events = starting_events .iter() .map(|id| { - (db.rooms + db.rooms .get_or_create_shorteventid(id, &db.globals) - .map(|s| (s, id))) + .map(|s| (s, id)) }) .collect::>>()?; @@ -1783,10 +1783,11 @@ fn get_auth_chain(starting_events: Vec, db: &Database) -> Result(PduEvent::convert_to_outgoing_federation_event( db.rooms.get_pdu_json(&id)?.unwrap(), @@ -1996,7 +1995,7 @@ pub fn get_room_state_ids_route( let auth_chain_ids = get_auth_chain(vec![body.event_id.clone()], &db)?; Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.into_iter().collect(), + auth_chain_ids: auth_chain_ids.collect(), pdu_ids, } .into()) @@ -2265,7 +2264,6 @@ pub async fn create_join_event_route( Ok(create_join_event::v2::Response { room_state: RoomState { auth_chain: auth_chain_ids - .iter() .filter_map(|id| db.rooms.get_pdu_json(&id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), From 665aee11c04672a1c928d381b7b691e41785c7e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Wed, 11 Aug 2021 21:17:01 +0200 Subject: [PATCH 04/25] less warnings --- src/database/abstraction/sqlite.rs | 39 +----------------------------- src/server_server.rs | 3 ++- 2 files changed, 3 insertions(+), 39 deletions(-) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 72fb5f7f..f4200212 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -12,9 +12,7 @@ use std::{ time::{Duration, Instant}, }; use tokio::sync::oneshot::Sender; -use tracing::{debug, warn}; - -pub const MILLI: Duration = Duration::from_millis(1); +use tracing::debug; thread_local! { static READ_CONNECTION: RefCell> = RefCell::new(None); @@ -164,16 +162,7 @@ impl Tree for SqliteTable { #[tracing::instrument(skip(self, key, value))] fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); - - let start = Instant::now(); - self.insert_with_guard(&guard, key, value)?; - - let elapsed = start.elapsed(); - if elapsed > MILLI { - warn!("insert took {:?} : {}", elapsed, &self.name); - } - drop(guard); let watchers = self.watchers.read(); @@ -220,20 +209,11 @@ impl Tree for SqliteTable { fn remove(&self, key: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); - let start = Instant::now(); - guard.execute( format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key], )?; - let elapsed = start.elapsed(); - - if elapsed > MILLI { - debug!("remove: took {:012?} : {}", elapsed, &self.name); - } - // debug!("remove key: {:?}", &key); - Ok(()) } @@ -326,8 +306,6 @@ impl Tree for SqliteTable { fn increment(&self, key: &[u8]) -> Result> { let guard = self.engine.write_lock(); - let start = Instant::now(); - let old = self.get_with_guard(&guard, key)?; let new = @@ -335,26 +313,11 @@ impl Tree for SqliteTable { self.insert_with_guard(&guard, key, &new)?; - let elapsed = start.elapsed(); - - if elapsed > MILLI { - debug!("increment: took {:012?} : {}", elapsed, &self.name); - } - // debug!("increment key: {:?}", &key); - Ok(new) } #[tracing::instrument(skip(self, prefix))] fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { - // let name = self.name.clone(); - // self.iter_from_thread( - // format!( - // "SELECT key, value FROM {} WHERE key BETWEEN ?1 AND ?1 || X'FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF' ORDER BY key ASC", - // name - // ) - // [prefix] - // ) Box::new( self.iter_from(&prefix, false) .take_while(move |(key, _)| key.starts_with(&prefix)), diff --git a/src/server_server.rs b/src/server_server.rs index 0b9a7e68..a4c90a71 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -721,7 +721,8 @@ pub async fn send_transaction_message_route( &db.globals, )?; } else { - warn!("No known event ids in read receipt: {:?}", user_updates); + // TODO fetch missing events + debug!("No known event ids in read receipt: {:?}", user_updates); } } } From 9410d3ef9cc2a19096149ac69ba3422231c760b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 12 Aug 2021 17:55:16 +0200 Subject: [PATCH 05/25] fix: long prev event fetch times for huge rooms --- src/database.rs | 2 +- src/database/rooms.rs | 49 ++++++++++++++++++------------------------- src/server_server.rs | 36 ++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 38 deletions(-) diff --git a/src/database.rs b/src/database.rs index 7996057d..7a17e53e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -272,7 +272,7 @@ impl Database { referencedevents: builder.open_tree("referencedevents")?, pdu_cache: Mutex::new(LruCache::new(100_000)), auth_chain_cache: Mutex::new(LruCache::new(100_000)), - shorteventid_cache: Mutex::new(LruCache::new(100_000)), + shorteventid_cache: Mutex::new(LruCache::new(1_000_000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 246aa0ba..88878e9d 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -99,15 +99,11 @@ impl Rooms { Ok(self .stateid_shorteventid .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) - .flatten() - .map(|bytes| { - EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in stateid_shorteventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in stateid_shorteventid is invalid.")) + .map(|(_, bytes)| { + self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) + .ok() }) - .filter_map(|r| r.ok()) + .flatten() .collect()) } @@ -118,15 +114,11 @@ impl Rooms { let state = self .stateid_shorteventid .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| self.shorteventid_eventid.get(&bytes).ok().flatten()) - .flatten() - .map(|bytes| { - EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in stateid_shorteventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in stateid_shorteventid is invalid.")) + .map(|(_, bytes)| { + self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) + .ok() }) - .filter_map(|r| r.ok()) + .flatten() .map(|eventid| self.get_pdu(&eventid)) .filter_map(|r| r.ok().flatten()) .map(|pdu| { @@ -168,15 +160,10 @@ impl Rooms { Ok(self .stateid_shorteventid .get(&stateid)? - .map(|bytes| self.shorteventid_eventid.get(&bytes).ok().flatten()) - .flatten() .map(|bytes| { - EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in stateid_shorteventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in stateid_shorteventid is invalid.")) + self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) + .ok() }) - .map(|r| r.ok()) .flatten()) } else { Ok(None) @@ -448,7 +435,12 @@ impl Rooms { } pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { - if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) { + if let Some(id) = self + .shorteventid_cache + .lock() + .unwrap() + .get_mut(&shorteventid) + { return Ok(id.clone()); } @@ -457,12 +449,11 @@ impl Rooms { .get(&shorteventid.to_be_bytes())? .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - let event_id = EventId::try_from( - utils::string_from_bytes(&bytes).map_err(|_| { + let event_id = + EventId::try_from(utils::string_from_bytes(&bytes).map_err(|_| { Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?; + })?) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))?; self.shorteventid_cache .lock() diff --git a/src/server_server.rs b/src/server_server.rs index a4c90a71..b3f0353a 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -668,7 +668,7 @@ pub async fn send_transaction_message_route( let elapsed = start_time.elapsed(); warn!( - "Handling event {} took {}m{}s", + "Handling transaction of event {} took {}m{}s", event_id, elapsed.as_secs() / 60, elapsed.as_secs() % 60 @@ -850,6 +850,8 @@ pub fn handle_incoming_pdu<'a>( pub_key_map: &'a RwLock>>, ) -> AsyncRecursiveType<'a, StdResult>, String>> { Box::pin(async move { + let start_time = Instant::now(); + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json match db.rooms.exists(&room_id) { Ok(true) => {} @@ -1014,12 +1016,18 @@ pub fn handle_incoming_pdu<'a>( // 8. if not timeline event: stop if !is_timeline_event || incoming_pdu.origin_server_ts - < db.rooms - .first_pdu_in_room(&room_id) - .map_err(|_| "Error loading first room event.".to_owned())? - .expect("Room exists") - .origin_server_ts + < (utils::millis_since_unix_epoch() - 1000 * 60 * 20) + .try_into() + .expect("time is valid") + // Not older than 20 mins { + let elapsed = start_time.elapsed(); + warn!( + "Handling outlier event {} took {}m{}s", + event_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); return Ok(None); } @@ -1312,7 +1320,8 @@ pub fn handle_incoming_pdu<'a>( for state in fork_states { auth_chain_sets.push( get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) - .map_err(|_| "Failed to load auth chain.".to_owned())?.collect(), + .map_err(|_| "Failed to load auth chain.".to_owned())? + .collect(), ); } @@ -1385,6 +1394,14 @@ pub fn handle_incoming_pdu<'a>( // Event has passed all auth/stateres checks drop(state_lock); + + let elapsed = start_time.elapsed(); + warn!( + "Handling timeline event {} took {}m{}s", + event_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); Ok(pdu_id) }) } @@ -1757,7 +1774,10 @@ fn append_incoming_pdu( Ok(pdu_id) } -fn get_auth_chain(starting_events: Vec, db: &Database) -> Result + '_> { +fn get_auth_chain( + starting_events: Vec, + db: &Database, +) -> Result + '_> { let mut full_auth_chain = HashSet::new(); let starting_events = starting_events From 41dd620d74276a4619bdc878307b136a5710bbae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 1 Aug 2021 15:14:54 +0200 Subject: [PATCH 06/25] WIP improvement: much better state storage --- src/database.rs | 338 ++++++++++++++++++++++++++++- src/database/abstraction/sqlite.rs | 5 +- src/database/rooms.rs | 12 +- 3 files changed, 345 insertions(+), 10 deletions(-) diff --git a/src/database.rs b/src/database.rs index 7a17e53e..e0f9eec7 100644 --- a/src/database.rs +++ b/src/database.rs @@ -24,13 +24,14 @@ use rocket::{ request::{FromRequest, Request}, Shutdown, State, }; -use ruma::{DeviceId, RoomId, ServerName, UserId}; +use ruma::{DeviceId, EventId, RoomId, ServerName, UserId}; use serde::{de::IgnoredAny, Deserialize}; use std::{ - collections::{BTreeMap, HashMap}, + collections::{BTreeMap, HashMap, HashSet}, convert::TryFrom, fs::{self, remove_dir_all}, io::Write, + mem::size_of, ops::Deref, path::Path, sync::{Arc, Mutex, RwLock}, @@ -261,7 +262,12 @@ impl Database { userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, + + shortroomid_roomid: builder.open_tree("shortroomid_roomid")?, + roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, + stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, + shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, @@ -438,6 +444,334 @@ impl Database { println!("Migration: 5 -> 6 finished"); } + + fn load_shortstatehash_info( + shortstatehash: &[u8], + db: &Database, + lru: &mut LruCache< + Vec, + Vec<( + Vec, + HashSet>, + HashSet>, + HashSet>, + )>, + >, + ) -> Result< + Vec<( + Vec, // sstatehash + HashSet>, // full state + HashSet>, // added + HashSet>, // removed + )>, + > { + if let Some(result) = lru.get_mut(shortstatehash) { + return Ok(result.clone()); + } + + let value = db + .rooms + .shortstatehash_statediff + .get(shortstatehash)? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = value[0..size_of::()].to_vec(); + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = size_of::(); + while let Some(v) = value.get(i..i + 2 * size_of::()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::(); + continue; + } + if add_mode { + added.insert(v.to_vec()); + } else { + removed.insert(v.to_vec()); + } + i += 2 * size_of::(); + } + + if parent != 0_u64.to_be_bytes() { + let mut response = load_shortstatehash_info(&parent, db, lru)?; + let mut state = response.last().unwrap().1.clone(); + state.extend(added.iter().cloned()); + for r in &removed { + state.remove(r); + } + + response.push((shortstatehash.to_vec(), state, added, removed)); + + lru.insert(shortstatehash.to_vec(), response.clone()); + Ok(response) + } else { + let mut response = Vec::new(); + response.push((shortstatehash.to_vec(), added.clone(), added, removed)); + lru.insert(shortstatehash.to_vec(), response.clone()); + Ok(response) + } + } + + fn update_shortstatehash_level( + current_shortstatehash: &[u8], + statediffnew: HashSet>, + statediffremoved: HashSet>, + diff_to_sibling: usize, + mut parent_states: Vec<( + Vec, // sstatehash + HashSet>, // full state + HashSet>, // added + HashSet>, // removed + )>, + db: &Database, + ) -> Result<()> { + let diffsum = statediffnew.len() + statediffremoved.len(); + + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper + let parent = parent_states.pop().unwrap(); + + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + parent_removed.insert(removed); + } + } + parent_new.extend(statediffnew); + + update_shortstatehash_level( + current_shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + db, + )?; + + return Ok(()); + } + + if parent_states.len() == 0 { + // There is no parent layer, create a new state + let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent + for new in &statediffnew { + value.extend_from_slice(&new); + } + + if !statediffremoved.is_empty() { + warn!("Tried to create new state with removals"); + } + + db.rooms + .shortstatehash_statediff + .insert(¤t_shortstatehash, &value)?; + + return Ok(()); + }; + + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above + + let parent = parent_states.pop().unwrap(); + let parent_diff = parent.2.len() + parent.3.len(); + + if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + // Diff too big, we replace above layer(s) + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + parent_removed.insert(removed); + } + } + + parent_new.extend(statediffnew); + update_shortstatehash_level( + current_shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + db, + )?; + } else { + // Diff small enough, we add diff as layer on top of parent + let mut value = parent.0.clone(); + for new in &statediffnew { + value.extend_from_slice(&new); + } + + if !statediffremoved.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in &statediffremoved { + value.extend_from_slice(&removed); + } + } + + db.rooms + .shortstatehash_statediff + .insert(¤t_shortstatehash, &value)?; + } + + Ok(()) + } + + if db.globals.database_version()? < 7 { + // Upgrade state store + let mut lru = LruCache::new(1000); + let mut last_roomstates: HashMap> = HashMap::new(); + let mut current_sstatehash: Vec = Vec::new(); + let mut current_room = None; + let mut current_state = HashSet::new(); + let mut counter = 0; + for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { + let sstatehash = k[0..size_of::()].to_vec(); + let sstatekey = k[size_of::()..].to_vec(); + if sstatehash != current_sstatehash { + if !current_sstatehash.is_empty() { + counter += 1; + println!("counter: {}", counter); + let current_room = current_room.as_ref().unwrap(); + let last_roomsstatehash = last_roomstates.get(¤t_room); + + let states_parents = last_roomsstatehash.map_or_else( + || Ok(Vec::new()), + |last_roomsstatehash| { + load_shortstatehash_info(&last_roomsstatehash, &db, &mut lru) + }, + )?; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = current_state + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); + + let statediffremoved = parent_stateinfo + .1 + .difference(¤t_state) + .cloned() + .collect::>(); + + (statediffnew, statediffremoved) + } else { + (current_state, HashSet::new()) + }; + + update_shortstatehash_level( + ¤t_sstatehash, + statediffnew, + statediffremoved, + 2, // every state change is 2 event changes on average + states_parents, + &db, + )?; + + /* + let mut tmp = load_shortstatehash_info(¤t_sstatehash, &db)?; + let state = tmp.pop().unwrap(); + println!( + "{}\t{}{:?}: {:?} + {:?} - {:?}", + current_room, + " ".repeat(tmp.len()), + utils::u64_from_bytes(¤t_sstatehash).unwrap(), + tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), + state + .2 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>(), + state + .3 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>() + ); + */ + + last_roomstates.insert(current_room.clone(), current_sstatehash); + } + current_state = HashSet::new(); + current_sstatehash = sstatehash; + + let event_id = db + .rooms + .shorteventid_eventid + .get(&seventid) + .unwrap() + .unwrap(); + let event_id = + EventId::try_from(utils::string_from_bytes(&event_id).unwrap()) + .unwrap(); + let pdu = db.rooms.get_pdu(&event_id).unwrap().unwrap(); + + if Some(&pdu.room_id) != current_room.as_ref() { + current_room = Some(pdu.room_id.clone()); + } + } + + let mut val = sstatekey; + val.extend_from_slice(&seventid); + current_state.insert(val); + } + + db.globals.bump_database_version(7)?; + + println!("Migration: 6 -> 7 finished"); + } + + if db.globals.database_version()? < 8 { + // Generate short room ids for all rooms + for (room_id, _) in db.rooms.roomid_shortstatehash.iter() { + let shortroomid = db.globals.next_count()?.to_be_bytes(); + db.rooms.roomid_shortroomid.insert(&room_id, &shortroomid)?; + db.rooms.shortroomid_roomid.insert(&shortroomid, &room_id)?; + } + // Update pduids db layout + for (key, v) in db.rooms.pduid_pdu.iter() { + let mut parts = key.splitn(2, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db.rooms.roomid_shortroomid.get(&room_id)?.unwrap(); + + let mut new_key = short_room_id; + new_key.extend_from_slice(count); + + println!("{:?}", new_key); + } + + // Update tokenids db layout + for (key, _) in db.rooms.tokenids.iter() { + let mut parts = key.splitn(4, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let word = parts.next().unwrap(); + let _pdu_id_room = parts.next().unwrap(); + let pdu_id_count = parts.next().unwrap(); + + let short_room_id = db.rooms.roomid_shortroomid.get(&room_id)?.unwrap(); + let mut new_key = short_room_id; + new_key.extend_from_slice(word); + new_key.push(0xff); + new_key.extend_from_slice(pdu_id_count); + println!("{:?}", new_key); + } + + db.globals.bump_database_version(8)?; + + println!("Migration: 7 -> 8 finished"); + } + + panic!(); } let guard = db.read().await; diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index f4200212..f37dd9e7 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -223,10 +223,7 @@ impl Tree for SqliteTable { let statement = Box::leak(Box::new( guard - .prepare(&format!( - "SELECT key, value FROM {} ORDER BY key ASC", - &self.name - )) + .prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)) .unwrap(), )); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 88878e9d..a3a1c411 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -47,7 +47,7 @@ pub struct Rooms { pub(super) aliasid_alias: Arc, // AliasId = RoomId + Count pub(super) publicroomids: Arc, - pub(super) tokenids: Arc, // TokenId = RoomId + Token + PduId + pub(super) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount /// Participating servers in a room. pub(super) roomserverids: Arc, // RoomServerId = RoomId + ServerName @@ -71,14 +71,18 @@ pub struct Rooms { pub(super) shorteventid_shortstatehash: Arc, /// StateKey = EventType + StateKey, ShortStateKey = Count pub(super) statekey_shortstatekey: Arc, + + pub(super) shortroomid_roomid: Arc, + pub(super) roomid_shortroomid: Arc, + pub(super) shorteventid_eventid: Arc, - /// ShortEventId = Count pub(super) eventid_shorteventid: Arc, - /// ShortEventId = Count + pub(super) statehash_shortstatehash: Arc, /// ShortStateHash = Count - /// StateId = ShortStateHash + ShortStateKey + /// StateId = ShortStateHash pub(super) stateid_shorteventid: Arc, + pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) /// RoomId + EventId -> outlier PDU. /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. From 31f60ad6fd7feee087f5d340ac5858f70d675e48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Mon, 2 Aug 2021 22:32:28 +0200 Subject: [PATCH 07/25] improvement: migrations, batch inserts --- src/database.rs | 57 +++++++++++++++++++++++++----- src/database/abstraction/sqlite.rs | 5 ++- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/src/database.rs b/src/database.rs index e0f9eec7..2e471d83 100644 --- a/src/database.rs +++ b/src/database.rs @@ -735,40 +735,81 @@ impl Database { let shortroomid = db.globals.next_count()?.to_be_bytes(); db.rooms.roomid_shortroomid.insert(&room_id, &shortroomid)?; db.rooms.shortroomid_roomid.insert(&shortroomid, &room_id)?; + println!("Migration: 8"); } // Update pduids db layout - for (key, v) in db.rooms.pduid_pdu.iter() { + let mut batch = db.rooms.pduid_pdu.iter().filter_map(|(key, v)| { + if !key.starts_with(b"!") { + return None; + } let mut parts = key.splitn(2, |&b| b == 0xff); let room_id = parts.next().unwrap(); let count = parts.next().unwrap(); - let short_room_id = db.rooms.roomid_shortroomid.get(&room_id)?.unwrap(); + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); let mut new_key = short_room_id; new_key.extend_from_slice(count); - println!("{:?}", new_key); + Some((new_key, v)) + }); + + db.rooms.pduid_pdu.insert_batch(&mut batch)?; + + for (key, _) in db.rooms.pduid_pdu.iter() { + if key.starts_with(b"!") { + db.rooms.pduid_pdu.remove(&key); + } } + db.globals.bump_database_version(8)?; + + println!("Migration: 7 -> 8 finished"); + } + + if db.globals.database_version()? < 9 { // Update tokenids db layout - for (key, _) in db.rooms.tokenids.iter() { + let mut batch = db.rooms.tokenids.iter().filter_map(|(key, _)| { + if !key.starts_with(b"!") { + return None; + } let mut parts = key.splitn(4, |&b| b == 0xff); let room_id = parts.next().unwrap(); let word = parts.next().unwrap(); let _pdu_id_room = parts.next().unwrap(); let pdu_id_count = parts.next().unwrap(); - let short_room_id = db.rooms.roomid_shortroomid.get(&room_id)?.unwrap(); + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); let mut new_key = short_room_id; new_key.extend_from_slice(word); new_key.push(0xff); new_key.extend_from_slice(pdu_id_count); - println!("{:?}", new_key); + println!("old {:?}", key); + println!("new {:?}", new_key); + Some((new_key, Vec::new())) + }); + + db.rooms.tokenids.insert_batch(&mut batch)?; + + for (key, _) in db.rooms.tokenids.iter() { + if key.starts_with(b"!") { + db.rooms.pduid_pdu.remove(&key)?; + } } - db.globals.bump_database_version(8)?; + db.globals.bump_database_version(9)?; - println!("Migration: 7 -> 8 finished"); + println!("Migration: 8 -> 9 finished"); } panic!(); diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index f37dd9e7..f4200212 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -223,7 +223,10 @@ impl Tree for SqliteTable { let statement = Box::leak(Box::new( guard - .prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)) + .prepare(&format!( + "SELECT key, value FROM {} ORDER BY key ASC", + &self.name + )) .unwrap(), )); From 3eabaa2a95645378b130134220f264a23e0fd7ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 12 Aug 2021 23:04:00 +0200 Subject: [PATCH 08/25] finish implementing better state store --- src/client_server/account.rs | 2 + src/client_server/context.rs | 4 +- src/client_server/membership.rs | 2 + src/client_server/message.rs | 4 +- src/client_server/room.rs | 5 +- src/client_server/sync.rs | 16 +- src/database.rs | 373 +++++----------- src/database/abstraction/sqlite.rs | 72 +-- src/database/rooms.rs | 691 +++++++++++++++++++---------- src/server_server.rs | 2 +- 10 files changed, 645 insertions(+), 526 deletions(-) diff --git a/src/client_server/account.rs b/src/client_server/account.rs index 48159c98..d4f103c7 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -249,6 +249,8 @@ pub async fn register_route( let room_id = RoomId::new(db.globals.server_name()); + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let mutex_state = Arc::clone( db.globals .roomid_mutex_state diff --git a/src/client_server/context.rs b/src/client_server/context.rs index dbc121e3..701e584d 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -44,7 +44,7 @@ pub async fn get_context_route( let events_before = db .rooms - .pdus_until(&sender_user, &body.room_id, base_token) + .pdus_until(&sender_user, &body.room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") @@ -66,7 +66,7 @@ pub async fn get_context_route( let events_after = db .rooms - .pdus_after(&sender_user, &body.room_id, base_token) + .pdus_after(&sender_user, &body.room_id, base_token)? .take( u32::try_from(body.limit).map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid.") diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 716a615f..de6fa5a1 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -609,6 +609,8 @@ async fn join_room_by_id_helper( ) .await?; + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let pdu = PduEvent::from_id_val(&event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 9cb6faab..70cc00f7 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -128,7 +128,7 @@ pub async fn get_message_events_route( get_message_events::Direction::Forward => { let events_after = db .rooms - .pdus_after(&sender_user, &body.room_id, from) + .pdus_after(&sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { @@ -158,7 +158,7 @@ pub async fn get_message_events_route( get_message_events::Direction::Backward => { let events_before = db .rooms - .pdus_until(&sender_user, &body.room_id, from) + .pdus_until(&sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { diff --git a/src/client_server/room.rs b/src/client_server/room.rs index 89241f5c..c323be4a 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -33,6 +33,8 @@ pub async fn create_room_route( let room_id = RoomId::new(db.globals.server_name()); + db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + let mutex_state = Arc::clone( db.globals .roomid_mutex_state @@ -173,7 +175,6 @@ pub async fn create_room_route( )?; // 4. Canonical room alias - if let Some(room_alias_id) = &alias { db.rooms.build_and_append_pdu( PduBuilder { @@ -193,7 +194,7 @@ pub async fn create_room_route( &room_id, &db, &state_lock, - ); + )?; } // 5. Events set by preset diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 937a252e..c196b2a9 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -205,7 +205,7 @@ async fn sync_helper( let mut non_timeline_pdus = db .rooms - .pdus_until(&sender_user, &room_id, u64::MAX) + .pdus_until(&sender_user, &room_id, u64::MAX)? .filter_map(|r| { // Filter out buggy events if r.is_err() { @@ -248,13 +248,13 @@ async fn sync_helper( let first_pdu_before_since = db .rooms - .pdus_until(&sender_user, &room_id, since) + .pdus_until(&sender_user, &room_id, since)? .next() .transpose()?; let pdus_after_since = db .rooms - .pdus_after(&sender_user, &room_id, since) + .pdus_after(&sender_user, &room_id, since)? .next() .is_some(); @@ -286,7 +286,7 @@ async fn sync_helper( for hero in db .rooms - .all_pdus(&sender_user, &room_id) + .all_pdus(&sender_user, &room_id)? .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter(|(_, pdu)| pdu.kind == EventType::RoomMember) .map(|(_, pdu)| { @@ -328,11 +328,11 @@ async fn sync_helper( } } - ( + Ok::<_, Error>(( Some(joined_member_count), Some(invited_member_count), heroes, - ) + )) }; let ( @@ -343,7 +343,7 @@ async fn sync_helper( state_events, ) = if since_shortstatehash.is_none() { // Probably since = 0, we will do an initial sync - let (joined_member_count, invited_member_count, heroes) = calculate_counts(); + let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; let current_state_ids = db.rooms.state_full_ids(current_shortstatehash)?; let state_events = current_state_ids @@ -510,7 +510,7 @@ async fn sync_helper( } let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts() + calculate_counts()? } else { (None, None, Vec::new()) }; diff --git a/src/database.rs b/src/database.rs index 2e471d83..4e340195 100644 --- a/src/database.rs +++ b/src/database.rs @@ -28,7 +28,7 @@ use ruma::{DeviceId, EventId, RoomId, ServerName, UserId}; use serde::{de::IgnoredAny, Deserialize}; use std::{ collections::{BTreeMap, HashMap, HashSet}, - convert::TryFrom, + convert::{TryFrom, TryInto}, fs::{self, remove_dir_all}, io::Write, mem::size_of, @@ -266,7 +266,6 @@ impl Database { shortroomid_roomid: builder.open_tree("shortroomid_roomid")?, roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, - stateid_shorteventid: builder.open_tree("stateid_shorteventid")?, shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, @@ -431,7 +430,6 @@ impl Database { } if db.globals.database_version()? < 6 { - // TODO update to 6 // Set room member count for (roomid, _) in db.rooms.roomid_shortstatehash.iter() { let room_id = @@ -445,263 +443,98 @@ impl Database { println!("Migration: 5 -> 6 finished"); } - fn load_shortstatehash_info( - shortstatehash: &[u8], - db: &Database, - lru: &mut LruCache< - Vec, - Vec<( - Vec, - HashSet>, - HashSet>, - HashSet>, - )>, - >, - ) -> Result< - Vec<( - Vec, // sstatehash - HashSet>, // full state - HashSet>, // added - HashSet>, // removed - )>, - > { - if let Some(result) = lru.get_mut(shortstatehash) { - return Ok(result.clone()); - } - - let value = db - .rooms - .shortstatehash_statediff - .get(shortstatehash)? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = value[0..size_of::()].to_vec(); - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let mut i = size_of::(); - while let Some(v) = value.get(i..i + 2 * size_of::()) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i += size_of::(); - continue; - } - if add_mode { - added.insert(v.to_vec()); - } else { - removed.insert(v.to_vec()); - } - i += 2 * size_of::(); - } - - if parent != 0_u64.to_be_bytes() { - let mut response = load_shortstatehash_info(&parent, db, lru)?; - let mut state = response.last().unwrap().1.clone(); - state.extend(added.iter().cloned()); - for r in &removed { - state.remove(r); - } - - response.push((shortstatehash.to_vec(), state, added, removed)); - - lru.insert(shortstatehash.to_vec(), response.clone()); - Ok(response) - } else { - let mut response = Vec::new(); - response.push((shortstatehash.to_vec(), added.clone(), added, removed)); - lru.insert(shortstatehash.to_vec(), response.clone()); - Ok(response) - } - } - - fn update_shortstatehash_level( - current_shortstatehash: &[u8], - statediffnew: HashSet>, - statediffremoved: HashSet>, - diff_to_sibling: usize, - mut parent_states: Vec<( - Vec, // sstatehash - HashSet>, // full state - HashSet>, // added - HashSet>, // removed - )>, - db: &Database, - ) -> Result<()> { - let diffsum = statediffnew.len() + statediffremoved.len(); - - if parent_states.len() > 3 { - // Number of layers - // To many layers, we have to go deeper - let parent = parent_states.pop().unwrap(); - - let mut parent_new = parent.2; - let mut parent_removed = parent.3; - - for removed in statediffremoved { - if !parent_new.remove(&removed) { - parent_removed.insert(removed); - } - } - parent_new.extend(statediffnew); - - update_shortstatehash_level( - current_shortstatehash, - parent_new, - parent_removed, - diffsum, - parent_states, - db, - )?; - - return Ok(()); - } - - if parent_states.len() == 0 { - // There is no parent layer, create a new state - let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent - for new in &statediffnew { - value.extend_from_slice(&new); - } - - if !statediffremoved.is_empty() { - warn!("Tried to create new state with removals"); - } - - db.rooms - .shortstatehash_statediff - .insert(¤t_shortstatehash, &value)?; - - return Ok(()); - }; - - // Else we have two options. - // 1. We add the current diff on top of the parent layer. - // 2. We replace a layer above - - let parent = parent_states.pop().unwrap(); - let parent_diff = parent.2.len() + parent.3.len(); - - if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { - // Diff too big, we replace above layer(s) - let mut parent_new = parent.2; - let mut parent_removed = parent.3; - - for removed in statediffremoved { - if !parent_new.remove(&removed) { - parent_removed.insert(removed); - } - } - - parent_new.extend(statediffnew); - update_shortstatehash_level( - current_shortstatehash, - parent_new, - parent_removed, - diffsum, - parent_states, - db, - )?; - } else { - // Diff small enough, we add diff as layer on top of parent - let mut value = parent.0.clone(); - for new in &statediffnew { - value.extend_from_slice(&new); - } - - if !statediffremoved.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in &statediffremoved { - value.extend_from_slice(&removed); - } - } - - db.rooms - .shortstatehash_statediff - .insert(¤t_shortstatehash, &value)?; - } - - Ok(()) - } - if db.globals.database_version()? < 7 { // Upgrade state store - let mut lru = LruCache::new(1000); - let mut last_roomstates: HashMap> = HashMap::new(); - let mut current_sstatehash: Vec = Vec::new(); + let mut last_roomstates: HashMap = HashMap::new(); + let mut current_sstatehash: Option = None; let mut current_room = None; let mut current_state = HashSet::new(); let mut counter = 0; + + let mut handle_state = + |current_sstatehash: u64, + current_room: &RoomId, + current_state: HashSet<_>, + last_roomstates: &mut HashMap<_, _>| { + counter += 1; + println!("counter: {}", counter); + let last_roomsstatehash = last_roomstates.get(current_room); + + let states_parents = last_roomsstatehash.map_or_else( + || Ok(Vec::new()), + |&last_roomsstatehash| { + db.rooms.load_shortstatehash_info(dbg!(last_roomsstatehash)) + }, + )?; + + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = current_state + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); + + let statediffremoved = parent_stateinfo + .1 + .difference(¤t_state) + .cloned() + .collect::>(); + + (statediffnew, statediffremoved) + } else { + (current_state, HashSet::new()) + }; + + db.rooms.save_state_from_diff( + dbg!(current_sstatehash), + statediffnew, + statediffremoved, + 2, // every state change is 2 event changes on average + states_parents, + )?; + + /* + let mut tmp = db.rooms.load_shortstatehash_info(¤t_sstatehash, &db)?; + let state = tmp.pop().unwrap(); + println!( + "{}\t{}{:?}: {:?} + {:?} - {:?}", + current_room, + " ".repeat(tmp.len()), + utils::u64_from_bytes(¤t_sstatehash).unwrap(), + tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), + state + .2 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>(), + state + .3 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>() + ); + */ + + Ok::<_, Error>(()) + }; + for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() { - let sstatehash = k[0..size_of::()].to_vec(); + let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]) + .expect("number of bytes is correct"); let sstatekey = k[size_of::()..].to_vec(); - if sstatehash != current_sstatehash { - if !current_sstatehash.is_empty() { - counter += 1; - println!("counter: {}", counter); - let current_room = current_room.as_ref().unwrap(); - let last_roomsstatehash = last_roomstates.get(¤t_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |last_roomsstatehash| { - load_shortstatehash_info(&last_roomsstatehash, &db, &mut lru) - }, + if Some(sstatehash) != current_sstatehash { + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_ref().unwrap(), + current_state, + &mut last_roomstates, )?; - - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .cloned() - .collect::>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .cloned() - .collect::>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - update_shortstatehash_level( - ¤t_sstatehash, - statediffnew, - statediffremoved, - 2, // every state change is 2 event changes on average - states_parents, - &db, - )?; - - /* - let mut tmp = load_shortstatehash_info(¤t_sstatehash, &db)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>() - ); - */ - - last_roomstates.insert(current_room.clone(), current_sstatehash); + last_roomstates + .insert(current_room.clone().unwrap(), current_sstatehash); } current_state = HashSet::new(); - current_sstatehash = sstatehash; + current_sstatehash = Some(sstatehash); let event_id = db .rooms @@ -721,7 +554,16 @@ impl Database { let mut val = sstatekey; val.extend_from_slice(&seventid); - current_state.insert(val); + current_state.insert(val.try_into().expect("size is correct")); + } + + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_ref().unwrap(), + current_state, + &mut last_roomstates, + )?; } db.globals.bump_database_version(7)?; @@ -761,11 +603,28 @@ impl Database { db.rooms.pduid_pdu.insert_batch(&mut batch)?; - for (key, _) in db.rooms.pduid_pdu.iter() { - if key.starts_with(b"!") { - db.rooms.pduid_pdu.remove(&key); + let mut batch2 = db.rooms.eventid_pduid.iter().filter_map(|(k, value)| { + if !value.starts_with(b"!") { + return None; } - } + let mut parts = value.splitn(2, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); + + let mut new_value = short_room_id; + new_value.extend_from_slice(count); + + Some((k, new_value)) + }); + + db.rooms.eventid_pduid.insert_batch(&mut batch2)?; db.globals.bump_database_version(8)?; @@ -803,7 +662,7 @@ impl Database { for (key, _) in db.rooms.tokenids.iter() { if key.starts_with(b"!") { - db.rooms.pduid_pdu.remove(&key)?; + db.rooms.tokenids.remove(&key)?; } } @@ -811,8 +670,6 @@ impl Database { println!("Migration: 8 -> 9 finished"); } - - panic!(); } let guard = db.read().await; diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index f4200212..3c4ae9c6 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -9,13 +9,13 @@ use std::{ path::{Path, PathBuf}, pin::Pin, sync::Arc, - time::{Duration, Instant}, }; use tokio::sync::oneshot::Sender; use tracing::debug; thread_local! { static READ_CONNECTION: RefCell> = RefCell::new(None); + static READ_CONNECTION_ITERATOR: RefCell> = RefCell::new(None); } struct PreparedStatementIterator<'a> { @@ -77,6 +77,21 @@ impl Engine { }) } + fn read_lock_iterator(&self) -> &'static Connection { + READ_CONNECTION_ITERATOR.with(|cell| { + let connection = &mut cell.borrow_mut(); + + if (*connection).is_none() { + let c = Box::leak(Box::new( + Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap(), + )); + **connection = Some(c); + } + + connection.unwrap() + }) + } + pub fn flush_wal(self: &Arc) -> Result<()> { self.write_lock() .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; @@ -151,6 +166,34 @@ impl SqliteTable { )?; Ok(()) } + + pub fn iter_with_guard<'a>( + &'a self, + guard: &'a Connection, + ) -> Box + 'a> { + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} ORDER BY key ASC", + &self.name + )) + .unwrap(), + )); + + let statement_ref = NonAliasingBox(statement); + + let iterator = Box::new( + statement + .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()), + ); + + Box::new(PreparedStatementIterator { + iterator, + statement_ref, + }) + } } impl Tree for SqliteTable { @@ -219,30 +262,9 @@ impl Tree for SqliteTable { #[tracing::instrument(skip(self))] fn iter<'a>(&'a self) -> Box + 'a> { - let guard = self.engine.read_lock(); + let guard = self.engine.read_lock_iterator(); - let statement = Box::leak(Box::new( - guard - .prepare(&format!( - "SELECT key, value FROM {} ORDER BY key ASC", - &self.name - )) - .unwrap(), - )); - - let statement_ref = NonAliasingBox(statement); - - let iterator = Box::new( - statement - .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(|r| r.unwrap()), - ); - - Box::new(PreparedStatementIterator { - iterator, - statement_ref, - }) + self.iter_with_guard(&guard) } #[tracing::instrument(skip(self, from, backwards))] @@ -251,7 +273,7 @@ impl Tree for SqliteTable { from: &[u8], backwards: bool, ) -> Box + 'a> { - let guard = self.engine.read_lock(); + let guard = self.engine.read_lock_iterator(); let from = from.to_vec(); // TODO change interface? if backwards { diff --git a/src/database/rooms.rs b/src/database/rooms.rs index a3a1c411..fc01e8ae 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -24,7 +24,7 @@ use ruma::{ use std::{ collections::{BTreeMap, BTreeSet, HashMap, HashSet}, convert::{TryFrom, TryInto}, - mem, + mem::size_of, sync::{Arc, Mutex}, }; use tokio::sync::MutexGuard; @@ -37,10 +37,11 @@ use super::{abstraction::Tree, admin::AdminCommand, pusher}; /// This is created when a state group is added to the database by /// hashing the entire state. pub type StateHashId = Vec; +pub type CompressedStateEvent = [u8; 2 * size_of::()]; pub struct Rooms { pub edus: edus::RoomEdus, - pub(super) pduid_pdu: Arc, // PduId = RoomId + Count + pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count pub(super) eventid_pduid: Arc, pub(super) roomid_pduleaves: Arc, pub(super) alias_roomid: Arc, @@ -79,9 +80,6 @@ pub struct Rooms { pub(super) eventid_shorteventid: Arc, pub(super) statehash_shortstatehash: Arc, - /// ShortStateHash = Count - /// StateId = ShortStateHash - pub(super) stateid_shorteventid: Arc, pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) /// RoomId + EventId -> outlier PDU. @@ -100,29 +98,30 @@ impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { - Ok(self - .stateid_shorteventid - .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| { - self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) - .ok() - }) - .flatten() - .collect()) + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + full_state + .into_iter() + .map(|compressed| self.parse_compressed_state_event(compressed)) + .collect() } pub fn state_full( &self, shortstatehash: u64, ) -> Result>> { - let state = self - .stateid_shorteventid - .scan_prefix(shortstatehash.to_be_bytes().to_vec()) - .map(|(_, bytes)| { - self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) - .ok() - }) - .flatten() + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .map(|compressed| self.parse_compressed_state_event(compressed)) + .filter_map(|r| r.ok()) .map(|eventid| self.get_pdu(&eventid)) .filter_map(|r| r.ok().flatten()) .map(|pdu| { @@ -138,9 +137,7 @@ impl Rooms { )) }) .filter_map(|r| r.ok()) - .collect(); - - Ok(state) + .collect()) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -151,27 +148,19 @@ impl Rooms { event_type: &EventType, state_key: &str, ) -> Result> { - let mut key = event_type.as_ref().as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&state_key.as_bytes()); - - let shortstatekey = self.statekey_shortstatekey.get(&key)?; - - if let Some(shortstatekey) = shortstatekey { - let mut stateid = shortstatehash.to_be_bytes().to_vec(); - stateid.extend_from_slice(&shortstatekey); - - Ok(self - .stateid_shorteventid - .get(&stateid)? - .map(|bytes| { - self.get_eventid_from_short(utils::u64_from_bytes(&bytes).unwrap()) - .ok() - }) - .flatten()) - } else { - Ok(None) - } + let shortstatekey = match self.get_shortstatekey(event_type, state_key)? { + Some(s) => s, + None => return Ok(None), + }; + let full_state = self + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok(full_state + .into_iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + .and_then(|compressed| self.parse_compressed_state_event(compressed).ok())) } /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). @@ -260,8 +249,7 @@ impl Rooms { /// Checks if a room exists. pub fn exists(&self, room_id: &RoomId) -> Result { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); // Look for PDUs in that room. Ok(self @@ -274,8 +262,7 @@ impl Rooms { /// Checks if a room exists. pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); // Look for PDUs in that room. self.pduid_pdu @@ -292,74 +279,78 @@ impl Rooms { /// Force the creation of a new StateHash and insert it into the db. /// - /// Whatever `state` is supplied to `force_state` __is__ the current room state snapshot. + /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot. pub fn force_state( &self, room_id: &RoomId, - state: HashMap<(EventType, String), EventId>, + new_state: HashMap<(EventType, String), EventId>, db: &Database, ) -> Result<()> { + let previous_shortstatehash = self.current_shortstatehash(&room_id)?; + + let new_state_ids_compressed = new_state + .iter() + .filter_map(|((event_type, state_key), event_id)| { + let shortstatekey = self + .get_or_create_shortstatekey(event_type, state_key, &db.globals) + .ok()?; + Some( + self.compress_state_event(shortstatekey, event_id, &db.globals) + .ok()?, + ) + }) + .collect::>(); + let state_hash = self.calculate_hash( - &state + &new_state .values() .map(|event_id| event_id.as_bytes()) .collect::>(), ); - let (shortstatehash, already_existed) = + let (new_shortstatehash, already_existed) = self.get_or_create_shortstatehash(&state_hash, &db.globals)?; - let new_state = if !already_existed { - let mut new_state = HashSet::new(); + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok(()); + } - let batch = state - .iter() - .filter_map(|((event_type, state_key), eventid)| { - new_state.insert(eventid.clone()); + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() + { + let statediffnew = new_state_ids_compressed + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = db.globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let statediffremoved = parent_stateinfo + .1 + .difference(&new_state_ids_compressed) + .cloned() + .collect::>(); - let shorteventid = self - .get_or_create_shorteventid(&eventid, &db.globals) - .ok()?; - - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - - Some((state_id, shorteventid.to_be_bytes().to_vec())) - }) - .collect::>(); - - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; - - new_state + (statediffnew, statediffremoved) } else { - self.state_full_ids(shortstatehash)?.into_iter().collect() + (new_state_ids_compressed, HashSet::new()) }; - let old_state = self - .current_shortstatehash(&room_id)? - .map(|s| self.state_full_ids(s)) - .transpose()? - .map(|vec| vec.into_iter().collect::>()) - .unwrap_or_default(); + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 2, // every state change is 2 event changes on average + states_parents, + )?; + }; - for event_id in new_state.difference(&old_state) { - if let Some(pdu) = self.get_pdu_json(event_id)? { + for event_id in statediffnew + .into_iter() + .filter_map(|new| self.parse_compressed_state_event(new).ok()) + { + if let Some(pdu) = self.get_pdu_json(&event_id)? { if pdu.get("type").and_then(|val| val.as_str()) == Some("m.room.member") { if let Ok(pdu) = serde_json::from_value::( serde_json::to_value(&pdu).expect("CanonicalJsonObj is a valid JsonValue"), @@ -392,7 +383,206 @@ impl Rooms { } self.roomid_shortstatehash - .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?; + .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + + Ok(()) + } + + /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. + pub fn load_shortstatehash_info( + &self, + shortstatehash: u64, + ) -> Result< + Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + > { + let value = self + .shortstatehash_statediff + .get(&shortstatehash.to_be_bytes())? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = + utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = size_of::(); + while let Some(v) = value.get(i..i + 2 * size_of::()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::(); + continue; + } + if add_mode { + added.insert(v.try_into().expect("we checked the size above")); + } else { + removed.insert(v.try_into().expect("we checked the size above")); + } + i += 2 * size_of::(); + } + + if parent != 0_u64 { + let mut response = self.load_shortstatehash_info(parent)?; + let mut state = response.last().unwrap().1.clone(); + state.extend(added.iter().cloned()); + for r in &removed { + state.remove(r); + } + + response.push((shortstatehash, state, added, removed)); + + Ok(response) + } else { + let mut response = Vec::new(); + response.push((shortstatehash, added.clone(), added, removed)); + Ok(response) + } + } + + pub fn compress_state_event( + &self, + shortstatekey: u64, + event_id: &EventId, + globals: &super::globals::Globals, + ) -> Result { + let mut v = shortstatekey.to_be_bytes().to_vec(); + v.extend_from_slice( + &self + .get_or_create_shorteventid(event_id, globals)? + .to_be_bytes(), + ); + Ok(v.try_into().expect("we checked the size above")) + } + + pub fn parse_compressed_state_event( + &self, + compressed_event: CompressedStateEvent, + ) -> Result { + self.get_eventid_from_short( + utils::u64_from_bytes(&compressed_event[size_of::()..]) + .expect("bytes have right length"), + ) + } + + /// Creates a new shortstatehash that often is just a diff to an already existing + /// shortstatehash and therefore very efficient. + /// + /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer + /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0 + /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's + /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too. + /// + /// * `shortstatehash` - Shortstatehash of this state + /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid + /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid + /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer + /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer + pub fn save_state_from_diff( + &self, + shortstatehash: u64, + statediffnew: HashSet, + statediffremoved: HashSet, + diff_to_sibling: usize, + mut parent_states: Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + ) -> Result<()> { + let diffsum = statediffnew.len() + statediffremoved.len(); + + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper + let parent = parent_states.pop().unwrap(); + + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + parent_removed.insert(removed); + } + } + parent_new.extend(statediffnew); + + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + + return Ok(()); + } + + if parent_states.len() == 0 { + // There is no parent layer, create a new state + let mut value = 0_u64.to_be_bytes().to_vec(); // 0 means no parent + for new in &statediffnew { + value.extend_from_slice(&new[..]); + } + + if !statediffremoved.is_empty() { + warn!("Tried to create new state with removals"); + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value)?; + + return Ok(()); + }; + + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above + + let parent = parent_states.pop().unwrap(); + let parent_diff = parent.2.len() + parent.3.len(); + + if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + // Diff too big, we replace above layer(s) + let mut parent_new = parent.2; + let mut parent_removed = parent.3; + + for removed in statediffremoved { + if !parent_new.remove(&removed) { + parent_removed.insert(removed); + } + } + + parent_new.extend(statediffnew); + self.save_state_from_diff( + shortstatehash, + parent_new, + parent_removed, + diffsum, + parent_states, + )?; + } else { + // Diff small enough, we add diff as layer on top of parent + let mut value = parent.0.to_be_bytes().to_vec(); + for new in &statediffnew { + value.extend_from_slice(&new[..]); + } + + if !statediffremoved.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in &statediffremoved { + value.extend_from_slice(&removed[..]); + } + } + + self.shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value)?; + } Ok(()) } @@ -418,7 +608,6 @@ impl Rooms { }) } - /// Returns (shortstatehash, already_existed) pub fn get_or_create_shorteventid( &self, event_id: &EventId, @@ -438,6 +627,71 @@ impl Rooms { }) } + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result { + let bytes = self + .roomid_shortroomid + .get(&room_id.as_bytes())? + .expect("every room has a shortroomid"); + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + } + + pub fn get_shortstatekey( + &self, + event_type: &EventType, + state_key: &str, + ) -> Result> { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); + + self.statekey_shortstatekey + .get(&statekey)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose() + } + + pub fn get_or_create_shortroomid( + &self, + room_id: &RoomId, + globals: &super::globals::Globals, + ) -> Result { + Ok(match self.roomid_shortroomid.get(&room_id.as_bytes())? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, + None => { + let short = globals.next_count()?; + self.roomid_shortroomid + .insert(&room_id.as_bytes(), &short.to_be_bytes())?; + short + } + }) + } + + pub fn get_or_create_shortstatekey( + &self, + event_type: &EventType, + state_key: &str, + globals: &super::globals::Globals, + ) -> Result { + let mut statekey = event_type.as_ref().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(&state_key.as_bytes()); + + Ok(match self.statekey_shortstatekey.get(&statekey)? { + Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, + None => { + let shortstatekey = globals.next_count()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes())?; + shortstatekey + } + }) + } + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { if let Some(id) = self .shorteventid_cache @@ -514,7 +768,7 @@ impl Rooms { #[tracing::instrument(skip(self))] pub fn pdu_count(&self, pdu_id: &[u8]) -> Result { Ok( - utils::u64_from_bytes(&pdu_id[pdu_id.len() - mem::size_of::()..pdu_id.len()]) + utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?, ) } @@ -527,8 +781,7 @@ impl Rooms { } pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -758,6 +1011,8 @@ impl Rooms { /// /// By this point the incoming event should be fully authenticated, no auth happens /// in `append_pdu`. + /// + /// Returns pdu id #[tracing::instrument(skip(self, pdu, pdu_json, leaves, db))] pub fn append_pdu( &self, @@ -766,7 +1021,8 @@ impl Rooms { leaves: &[EventId], db: &Database, ) -> Result> { - // returns pdu id + let shortroomid = self.get_shortroomid(&pdu.room_id)?; + // Make unsigned fields correct. This is not properly documented in the spec, but state // events need to have previous content in the unsigned field, so clients can easily // interpret things like membership changes @@ -821,8 +1077,7 @@ impl Rooms { self.reset_notification_counts(&pdu.sender, &pdu.room_id)?; let count2 = db.globals.next_count()?; - let mut pdu_id = pdu.room_id.as_bytes().to_vec(); - pdu_id.push(0xff); + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); pdu_id.extend_from_slice(&count2.to_be_bytes()); // There's a brief moment of time here where the count is updated but the pdu does not @@ -968,8 +1223,7 @@ impl Rooms { .filter(|word| word.len() <= 50) .map(str::to_lowercase) .map(|word| { - let mut key = pdu.room_id.as_bytes().to_vec(); - key.push(0xff); + let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); key.push(0xff); key.extend_from_slice(&pdu_id); @@ -1152,11 +1406,27 @@ impl Rooms { pub fn set_event_state( &self, event_id: &EventId, + room_id: &RoomId, state: &StateMap>, globals: &super::globals::Globals, ) -> Result<()> { let shorteventid = self.get_or_create_shorteventid(&event_id, globals)?; + let previous_shortstatehash = self.current_shortstatehash(&room_id)?; + + let state_ids_compressed = state + .iter() + .filter_map(|((event_type, state_key), pdu)| { + let shortstatekey = self + .get_or_create_shortstatekey(event_type, state_key, globals) + .ok()?; + Some( + self.compress_state_event(shortstatekey, &pdu.event_id, globals) + .ok()?, + ) + }) + .collect::>(); + let state_hash = self.calculate_hash( &state .values() @@ -1168,37 +1438,33 @@ impl Rooms { self.get_or_create_shortstatehash(&state_hash, globals)?; if !already_existed { - let batch = state - .iter() - .filter_map(|((event_type, state_key), pdu)| { - let mut statekey = event_type.as_ref().as_bytes().to_vec(); - statekey.push(0xff); - statekey.extend_from_slice(&state_key.as_bytes()); + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let shortstatekey = match self.statekey_shortstatekey.get(&statekey).ok()? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count().ok()?; - self.statekey_shortstatekey - .insert(&statekey, &shortstatekey.to_be_bytes()) - .ok()?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let (statediffnew, statediffremoved) = + if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = state_ids_compressed + .difference(&parent_stateinfo.1) + .cloned() + .collect::>(); - let shorteventid = self - .get_or_create_shorteventid(&pdu.event_id, globals) - .ok()?; + let statediffremoved = parent_stateinfo + .1 + .difference(&state_ids_compressed) + .cloned() + .collect::>(); - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - - Some((state_id, shorteventid.to_be_bytes().to_vec())) - }) - .collect::>(); - - self.stateid_shorteventid - .insert_batch(&mut batch.into_iter())?; + (statediffnew, statediffremoved) + } else { + (state_ids_compressed, HashSet::new()) + }; + self.save_state_from_diff( + shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 1_000_000, // high number because no state will be based on this one + states_parents, + )?; } self.shorteventid_shortstatehash @@ -1219,82 +1485,52 @@ impl Rooms { ) -> Result { let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id, globals)?; - let old_state = if let Some(old_shortstatehash) = - self.roomid_shortstatehash.get(new_pdu.room_id.as_bytes())? - { - // Store state for event. The state does not include the event itself. - // Instead it's the state before the pdu, so the room's old state. + let previous_shortstatehash = self.current_shortstatehash(&new_pdu.room_id)?; + + if let Some(p) = previous_shortstatehash { self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &old_shortstatehash)?; - - if new_pdu.state_key.is_none() { - return utils::u64_from_bytes(&old_shortstatehash).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash.") - }); - } - - self.stateid_shorteventid - .scan_prefix(old_shortstatehash.clone()) - // Chop the old_shortstatehash out leaving behind the short state key - .map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v)) - .collect::, Vec>>() - } else { - HashMap::new() - }; + .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?; + } if let Some(state_key) = &new_pdu.state_key { - let mut new_state: HashMap, Vec> = old_state; + let states_parents = previous_shortstatehash + .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); - new_state_key.push(0xff); - new_state_key.extend_from_slice(state_key.as_bytes()); + let shortstatekey = + self.get_or_create_shortstatekey(&new_pdu.kind, &state_key, globals)?; - let shortstatekey = match self.statekey_shortstatekey.get(&new_state_key)? { - Some(shortstatekey) => shortstatekey.to_vec(), - None => { - let shortstatekey = globals.next_count()?; - self.statekey_shortstatekey - .insert(&new_state_key, &shortstatekey.to_be_bytes())?; - shortstatekey.to_be_bytes().to_vec() - } - }; + let replaces = states_parents + .last() + .map(|info| { + info.1 + .iter() + .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) + }) + .unwrap_or_default(); - new_state.insert(shortstatekey, shorteventid.to_be_bytes().to_vec()); + // TODO: statehash with deterministic inputs + let shortstatehash = globals.next_count()?; - let new_state_hash = self.calculate_hash( - &new_state - .values() - .map(|event_id| &**event_id) - .collect::>(), - ); + let mut statediffnew = HashSet::new(); + let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; + statediffnew.insert(new); - let shortstatehash = match self.statehash_shortstatehash.get(&new_state_hash)? { - Some(shortstatehash) => { - warn!("state hash already existed?!"); - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))? - } - None => { - let shortstatehash = globals.next_count()?; - self.statehash_shortstatehash - .insert(&new_state_hash, &shortstatehash.to_be_bytes())?; - shortstatehash - } - }; + let mut statediffremoved = HashSet::new(); + if let Some(replaces) = replaces { + statediffremoved.insert(replaces.clone()); + } - let mut batch = new_state.into_iter().map(|(shortstatekey, shorteventid)| { - let mut state_id = shortstatehash.to_be_bytes().to_vec(); - state_id.extend_from_slice(&shortstatekey); - (state_id, shorteventid) - }); - - self.stateid_shorteventid.insert_batch(&mut batch)?; + self.save_state_from_diff( + shortstatehash, + statediffnew, + statediffremoved, + 2, + states_parents, + )?; Ok(shortstatehash) } else { - Err(Error::bad_database( - "Tried to insert non-state event into room without a state.", - )) + Ok(previous_shortstatehash.expect("first event in room must be a state event")) } } @@ -1597,7 +1833,7 @@ impl Rooms { &'a self, user_id: &UserId, room_id: &RoomId, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { self.pdus_since(user_id, room_id, 0) } @@ -1609,16 +1845,17 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> impl Iterator, PduEvent)>> + 'a { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + ) -> Result, PduEvent)>> + 'a> { + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); // Skip the first pdu if it's exactly at since, because we sent that last time let mut first_pdu_id = prefix.clone(); first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes()); let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(&first_pdu_id, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1628,7 +1865,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Returns an iterator over all events and their tokens in a room that happened before the @@ -1639,10 +1876,9 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` @@ -1650,7 +1886,9 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(current, true) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1660,7 +1898,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Returns an iterator over all events and their token in a room that happened after the event @@ -1671,10 +1909,9 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event @@ -1682,7 +1919,9 @@ impl Rooms { let current: &[u8] = ¤t; let user_id = user_id.clone(); - self.pduid_pdu + + Ok(self + .pduid_pdu .iter_from(current, false) .take_while(move |(k, _)| k.starts_with(&prefix)) .map(move |(pdu_id, v)| { @@ -1692,7 +1931,7 @@ impl Rooms { pdu.unsigned.remove("transaction_id"); } Ok((pdu_id, pdu)) - }) + })) } /// Replace a PDU with the redacted form. @@ -2223,8 +2462,8 @@ impl Rooms { room_id: &RoomId, search_string: &str, ) -> Result<(impl Iterator> + 'a, Vec)> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix_clone = prefix.clone(); let words = search_string .split_terminator(|c: char| !c.is_alphanumeric()) @@ -2243,16 +2482,7 @@ impl Rooms { .iter_from(&last_possible_id, true) // Newest pdus first .take_while(move |(k, _)| k.starts_with(&prefix2)) .map(|(key, _)| { - let pduid_index = key - .iter() - .enumerate() - .filter(|(_, &b)| b == 0xff) - .nth(1) - .ok_or_else(|| Error::bad_database("Invalid tokenid in db."))? - .0 - + 1; // +1 because the pdu id starts AFTER the separator - - let pdu_id = key[pduid_index..].to_vec(); + let pdu_id = key[key.len() - size_of::()..].to_vec(); Ok::<_, Error>(pdu_id) }) @@ -2264,7 +2494,12 @@ impl Rooms { // We compare b with a because we reversed the iterator earlier b.cmp(a) }) - .unwrap(), + .unwrap() + .map(move |id| { + let mut pduid = prefix_clone.clone(); + pduid.extend_from_slice(&id); + pduid + }), words, )) } diff --git a/src/server_server.rs b/src/server_server.rs index b3f0353a..0c226aca 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1704,7 +1704,7 @@ fn append_incoming_pdu( // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. db.rooms - .set_event_state(&pdu.event_id, state, &db.globals)?; + .set_event_state(&pdu.event_id, &pdu.room_id, state, &db.globals)?; let pdu_id = db.rooms.append_pdu( pdu, From 3cf0145bc5b564e1230417bf98b9aeffde1f1085 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 08:26:45 +0200 Subject: [PATCH 09/25] fix: room exists panic --- src/database/rooms.rs | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index fc01e8ae..400ce386 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -249,7 +249,10 @@ impl Rooms { /// Checks if a room exists. pub fn exists(&self, room_id: &RoomId) -> Result { - let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix = match self.get_shortroomid(room_id)? { + Some(b) => b.to_be_bytes().to_vec(), + None => return Ok(false), + }; // Look for PDUs in that room. Ok(self @@ -262,7 +265,7 @@ impl Rooms { /// Checks if a room exists. pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); // Look for PDUs in that room. self.pduid_pdu @@ -627,12 +630,12 @@ impl Rooms { }) } - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result { - let bytes = self + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + self .roomid_shortroomid .get(&room_id.as_bytes())? - .expect("every room has a shortroomid"); - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + .map(|bytes| + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))).transpose() } pub fn get_shortstatekey( @@ -781,7 +784,7 @@ impl Rooms { } pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -1021,7 +1024,7 @@ impl Rooms { leaves: &[EventId], db: &Database, ) -> Result> { - let shortroomid = self.get_shortroomid(&pdu.room_id)?; + let shortroomid = self.get_shortroomid(&pdu.room_id)?.expect("room exists"); // Make unsigned fields correct. This is not properly documented in the spec, but state // events need to have previous content in the unsigned field, so clients can easily @@ -1846,7 +1849,7 @@ impl Rooms { room_id: &RoomId, since: u64, ) -> Result, PduEvent)>> + 'a> { - let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); // Skip the first pdu if it's exactly at since, because we sent that last time let mut first_pdu_id = prefix.clone(); @@ -1878,7 +1881,7 @@ impl Rooms { until: u64, ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` @@ -1911,7 +1914,7 @@ impl Rooms { from: u64, ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event @@ -2462,7 +2465,7 @@ impl Rooms { room_id: &RoomId, search_string: &str, ) -> Result<(impl Iterator> + 'a, Vec)> { - let prefix = self.get_shortroomid(room_id)?.to_be_bytes().to_vec(); + let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); let prefix_clone = prefix.clone(); let words = search_string From 38effda799570e1a0837673d671f364d9c7bd4b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 19:07:50 +0200 Subject: [PATCH 10/25] fix: delta calculation --- src/database.rs | 2 +- src/database/rooms.rs | 22 ++++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/database.rs b/src/database.rs index 4e340195..0bf2a440 100644 --- a/src/database.rs +++ b/src/database.rs @@ -108,7 +108,7 @@ fn default_db_cache_capacity_mb() -> f64 { } fn default_sqlite_wal_clean_second_interval() -> u32 { - 15 * 60 // every 15 minutes + 1 * 60 // every minute } fn default_max_request_size() -> u32 { diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 400ce386..fc2bd051 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -510,10 +510,19 @@ impl Rooms { for removed in statediffremoved { if !parent_new.remove(&removed) { + // It was not added in the parent and we removed it parent_removed.insert(removed); } + // Else it was added in the parent and we removed it again. We can forget this change + } + + for new in statediffnew { + if !parent_removed.remove(&new) { + // It was not touched in the parent and we added it + parent_new.insert(new); + } + // Else it was removed in the parent and we added it again. We can forget this change } - parent_new.extend(statediffnew); self.save_state_from_diff( shortstatehash, @@ -557,11 +566,20 @@ impl Rooms { for removed in statediffremoved { if !parent_new.remove(&removed) { + // It was not added in the parent and we removed it parent_removed.insert(removed); } + // Else it was added in the parent and we removed it again. We can forget this change + } + + for new in statediffnew { + if !parent_removed.remove(&new) { + // It was not touched in the parent and we added it + parent_new.insert(new); + } + // Else it was removed in the parent and we added it again. We can forget this change } - parent_new.extend(statediffnew); self.save_state_from_diff( shortstatehash, parent_new, From 0cb22996be02f38dff2e049cfeae351f607aeb38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 19:47:16 +0200 Subject: [PATCH 11/25] remove prev event fetch limit --- src/server_server.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/server_server.rs b/src/server_server.rs index 0c226aca..cbbb8507 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1016,10 +1016,11 @@ pub fn handle_incoming_pdu<'a>( // 8. if not timeline event: stop if !is_timeline_event || incoming_pdu.origin_server_ts - < (utils::millis_since_unix_epoch() - 1000 * 60 * 20) - .try_into() - .expect("time is valid") - // Not older than 20 mins + < db.rooms + .first_pdu_in_room(&room_id) + .map_err(|_| "Error loading first room event.".to_owned())? + .expect("Room exists") + .origin_server_ts { let elapsed = start_time.elapsed(); warn!( @@ -1031,6 +1032,7 @@ pub fn handle_incoming_pdu<'a>( return Ok(None); } + // TODO: make not recursive // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events fetch_and_handle_events( db, From 1e3a8ca35dff2287963da7cd4fae30539790fa76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 19:47:49 +0200 Subject: [PATCH 12/25] fmt --- src/database/rooms.rs | 46 +++++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index fc2bd051..99f0c83e 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -265,7 +265,11 @@ impl Rooms { /// Checks if a room exists. pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); // Look for PDUs in that room. self.pduid_pdu @@ -649,11 +653,13 @@ impl Rooms { } pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self - .roomid_shortroomid + self.roomid_shortroomid .get(&room_id.as_bytes())? - .map(|bytes| - utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))).transpose() + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + }) + .transpose() } pub fn get_shortstatekey( @@ -802,7 +808,11 @@ impl Rooms { } pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let mut last_possible_key = prefix.clone(); last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); @@ -1867,7 +1877,11 @@ impl Rooms { room_id: &RoomId, since: u64, ) -> Result, PduEvent)>> + 'a> { - let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); // Skip the first pdu if it's exactly at since, because we sent that last time let mut first_pdu_id = prefix.clone(); @@ -1899,7 +1913,11 @@ impl Rooms { until: u64, ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until` @@ -1932,7 +1950,11 @@ impl Rooms { from: u64, ) -> Result, PduEvent)>> + 'a> { // Create the first part of the full pdu id - let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let mut current = prefix.clone(); current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event @@ -2483,7 +2505,11 @@ impl Rooms { room_id: &RoomId, search_string: &str, ) -> Result<(impl Iterator> + 'a, Vec)> { - let prefix = self.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); let prefix_clone = prefix.clone(); let words = search_string From 1d465699296837fe147f855d95fcfef57775a3a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 21:30:14 +0200 Subject: [PATCH 13/25] fix: don't use recursion for prev events --- src/server_server.rs | 228 ++++++++++++++++++++++++------------------- 1 file changed, 129 insertions(+), 99 deletions(-) diff --git a/src/server_server.rs b/src/server_server.rs index cbbb8507..5b2acd00 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -840,7 +840,7 @@ type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; /// 14. Use state resolution to find new room state // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively #[tracing::instrument(skip(value, is_timeline_event, db, pub_key_map))] -pub fn handle_incoming_pdu<'a>( +pub async fn handle_incoming_pdu<'a>( origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, @@ -848,22 +848,76 @@ pub fn handle_incoming_pdu<'a>( is_timeline_event: bool, db: &'a Database, pub_key_map: &'a RwLock>>, -) -> AsyncRecursiveType<'a, StdResult>, String>> { +) -> StdResult>, String> { + match db.rooms.exists(&room_id) { + Ok(true) => {} + _ => { + return Err("Room is unknown to this server.".to_string()); + } + } + + // 1. Skip the PDU if we already have it as a timeline event + if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { + return Ok(Some(pdu_id.to_vec())); + } + + let create_event = db + .rooms + .room_state_get(&room_id, &EventType::RoomCreate, "") + .map_err(|_| "Failed to ask database for event.".to_owned())? + .ok_or_else(|| "Failed to find create event in db.".to_owned())?; + + let (incoming_pdu, val) = handle_outlier_pdu(origin, &create_event, event_id, room_id, value, db, pub_key_map).await?; + + // 8. if not timeline event: stop + if !is_timeline_event + || incoming_pdu.origin_server_ts + < db.rooms + .first_pdu_in_room(&room_id) + .map_err(|_| "Error loading first room event.".to_owned())? + .expect("Room exists") + .origin_server_ts + { + return Ok(None); + } + + // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); + let mut todo_timeline_stack = Vec::new(); + while let Some(prev_event_id) = todo_outlier_stack.pop() { + if let Some((pdu, Some(json))) = fetch_and_handle_outliers( + db, + origin, + &[prev_event_id], + &create_event, + &room_id, + pub_key_map, + ) + .await.pop() { + todo_timeline_stack.push((pdu, json)); + } + } + + while let Some(prev) = todo_timeline_stack.pop() { + upgrade_outlier_to_timeline_pdu(prev.0, prev.1, &create_event, origin, db, room_id, pub_key_map).await?; + } + + upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, db, room_id, pub_key_map).await +} + +fn handle_outlier_pdu<'a>( + origin: &'a ServerName, + create_event: &'a PduEvent, + event_id: &'a EventId, + room_id: &'a RoomId, + value: BTreeMap, + db: &'a Database, + pub_key_map: &'a RwLock>>, +) -> AsyncRecursiveType<'a, StdResult<(Arc, BTreeMap), String>> { Box::pin(async move { let start_time = Instant::now(); // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - match db.rooms.exists(&room_id) { - Ok(true) => {} - _ => { - return Err("Room is unknown to this server.".to_string()); - } - } - - // 1. Skip the PDU if we already have it as a timeline event - if let Ok(Some(pdu_id)) = db.rooms.get_pdu_id(&event_id) { - return Ok(Some(pdu_id.to_vec())); - } // We go through all the signatures we see on the value and fetch the corresponding signing // keys @@ -873,17 +927,12 @@ pub fn handle_incoming_pdu<'a>( // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match - let create_event = db - .rooms - .room_state_get(&room_id, &EventType::RoomCreate, "") - .map_err(|_| "Failed to ask database for event.".to_owned())? - .ok_or_else(|| "Failed to find create event in db.".to_owned())?; - let create_event_content = - serde_json::from_value::>(create_event.content.clone()) - .expect("Raw::from_value always works.") - .deserialize() - .map_err(|_| "Invalid PowerLevels event in db.".to_owned())?; + let create_event_content = + serde_json::from_value::>(create_event.content.clone()) + .expect("Raw::from_value always works.") + .deserialize() + .map_err(|_| "Invalid PowerLevels event in db.".to_owned())?; let room_version_id = &create_event_content.room_version; let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); @@ -924,13 +973,13 @@ pub fn handle_incoming_pdu<'a>( // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // EDIT: Step 5 is not applied anymore because it failed too often debug!("Fetching auth events for {}", incoming_pdu.event_id); - fetch_and_handle_events( + fetch_and_handle_outliers( db, origin, &incoming_pdu.auth_events, + &create_event, &room_id, pub_key_map, - false, ) .await; @@ -1013,37 +1062,20 @@ pub fn handle_incoming_pdu<'a>( .map_err(|_| "Failed to add pdu as outlier.".to_owned())?; debug!("Added pdu as outlier."); - // 8. if not timeline event: stop - if !is_timeline_event - || incoming_pdu.origin_server_ts - < db.rooms - .first_pdu_in_room(&room_id) - .map_err(|_| "Error loading first room event.".to_owned())? - .expect("Room exists") - .origin_server_ts - { - let elapsed = start_time.elapsed(); - warn!( - "Handling outlier event {} took {}m{}s", - event_id, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - return Ok(None); - } + Ok((incoming_pdu,val)) + }) - // TODO: make not recursive - // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - fetch_and_handle_events( - db, - origin, - &incoming_pdu.prev_events, - &room_id, - pub_key_map, - true, - ) - .await; +} +async fn upgrade_outlier_to_timeline_pdu( + incoming_pdu: Arc, + val: BTreeMap, + create_event: &PduEvent, + origin: &ServerName, + db: &Database, + room_id: &RoomId, + pub_key_map: &RwLock>>, +) -> StdResult>, String> { // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities // doing all the checks in this list starting at 1. These are not timeline events. @@ -1065,17 +1097,17 @@ pub fn handle_incoming_pdu<'a>( if let Some(Ok(state)) = state { warn!("Using cached state"); - let mut state = fetch_and_handle_events( + let mut state = fetch_and_handle_outliers( db, origin, &state.into_iter().collect::>(), + &create_event, &room_id, pub_key_map, - false, ) .await .into_iter() - .map(|pdu| { + .map(|(pdu,_)| { ( ( pdu.kind.clone(), @@ -1119,18 +1151,18 @@ pub fn handle_incoming_pdu<'a>( { Ok(res) => { debug!("Fetching state events at event."); - let state_vec = fetch_and_handle_events( + let state_vec = fetch_and_handle_outliers( &db, origin, &res.pdu_ids, + &create_event, &room_id, pub_key_map, - false, ) .await; let mut state = HashMap::new(); - for pdu in state_vec { + for (pdu, _) in state_vec { match state.entry((pdu.kind.clone(), pdu.state_key.clone().ok_or_else(|| "Found non-state pdu in state events.".to_owned())?)) { Entry::Vacant(v) => { v.insert(pdu); @@ -1153,13 +1185,13 @@ pub fn handle_incoming_pdu<'a>( } debug!("Fetching auth chain events at event."); - fetch_and_handle_events( + fetch_and_handle_outliers( &db, origin, &res.auth_chain_ids, + &create_event, &room_id, pub_key_map, - false, ) .await; @@ -1175,6 +1207,28 @@ pub fn handle_incoming_pdu<'a>( state_at_incoming_event.expect("we always set this to some above"); // 11. Check the auth of the event passes based on the state of the event + let create_event_content = + serde_json::from_value::>(create_event.content.clone()) + .expect("Raw::from_value always works.") + .deserialize() + .map_err(|_| "Invalid PowerLevels event in db.".to_owned())?; + + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + + // If the previous event was the create event special rules apply + let previous_create = if incoming_pdu.auth_events.len() == 1 + && incoming_pdu.prev_events == incoming_pdu.auth_events + { + db.rooms + .get_pdu(&incoming_pdu.auth_events[0]) + .map_err(|e| e.to_string())? + .filter(|maybe_create| **maybe_create == *create_event) + } else { + None + }; + + if !state_res::event_auth::auth_check( &room_version, &incoming_pdu, @@ -1396,34 +1450,27 @@ pub fn handle_incoming_pdu<'a>( // Event has passed all auth/stateres checks drop(state_lock); - - let elapsed = start_time.elapsed(); - warn!( - "Handling timeline event {} took {}m{}s", - event_id, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); Ok(pdu_id) - }) } /// Find the event and auth it. Once the event is validated (steps 1 - 8) /// it is appended to the outliers Tree. /// +/// Returns pdu and if we fetched it over federation the raw json. +/// /// a. Look in the main timeline (pduid_pdu tree) /// b. Look at outlier pdu tree /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? //#[tracing::instrument(skip(db, key_map, auth_cache))] -pub(crate) fn fetch_and_handle_events<'a>( +pub(crate) fn fetch_and_handle_outliers<'a>( db: &'a Database, origin: &'a ServerName, events: &'a [EventId], + create_event: &'a PduEvent, room_id: &'a RoomId, pub_key_map: &'a RwLock>>, - are_timeline_events: bool, -) -> AsyncRecursiveType<'a, Vec>> { +) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> { Box::pin(async move { let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { Entry::Vacant(e) => { @@ -1449,16 +1496,12 @@ pub(crate) fn fetch_and_handle_events<'a>( // a. Look in the main timeline (pduid_pdu tree) // b. Look at outlier pdu tree - // (get_pdu checks both) - let local_pdu = if are_timeline_events { - db.rooms.get_non_outlier_pdu(&id).map(|o| o.map(Arc::new)) - } else { - db.rooms.get_pdu(&id) - }; + // (get_pdu_json checks both) + let local_pdu = db.rooms.get_pdu(&id); let pdu = match local_pdu { Ok(Some(pdu)) => { trace!("Found {} in db", id); - pdu + (pdu, None) } Ok(None) => { // c. Ask origin server over federation @@ -1474,39 +1517,26 @@ pub(crate) fn fetch_and_handle_events<'a>( { Ok(res) => { debug!("Got {} over federation", id); - let (event_id, mut value) = + let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&res.pdu) { Ok(t) => t, Err(_) => continue, }; // This will also fetch the auth chain - match handle_incoming_pdu( + match handle_outlier_pdu( origin, + create_event, &event_id, &room_id, value.clone(), - are_timeline_events, db, pub_key_map, ) .await { - Ok(_) => { - value.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.into()), - ); - - Arc::new( - serde_json::from_value( - serde_json::to_value(value) - .expect("canonicaljsonobject is valid value"), - ) - .expect( - "This is possible because handle_incoming_pdu worked", - ), - ) + Ok((pdu, json)) => { + (pdu, Some(json)) } Err(e) => { warn!("Authentication of event {} failed: {:?}", id, e); From ecd1e45a449036e8d3166e1c59b51a161bf31f12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 21:56:15 +0200 Subject: [PATCH 14/25] fix: fetch more than one prev event --- src/server_server.rs | 745 ++++++++++++++++++++++--------------------- 1 file changed, 389 insertions(+), 356 deletions(-) diff --git a/src/server_server.rs b/src/server_server.rs index 5b2acd00..e7221261 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -867,17 +867,19 @@ pub async fn handle_incoming_pdu<'a>( .map_err(|_| "Failed to ask database for event.".to_owned())? .ok_or_else(|| "Failed to find create event in db.".to_owned())?; - let (incoming_pdu, val) = handle_outlier_pdu(origin, &create_event, event_id, room_id, value, db, pub_key_map).await?; + let (incoming_pdu, val) = handle_outlier_pdu( + origin, + &create_event, + event_id, + room_id, + value, + db, + pub_key_map, + ) + .await?; // 8. if not timeline event: stop - if !is_timeline_event - || incoming_pdu.origin_server_ts - < db.rooms - .first_pdu_in_room(&room_id) - .map_err(|_| "Error loading first room event.".to_owned())? - .expect("Room exists") - .origin_server_ts - { + if !is_timeline_event { return Ok(None); } @@ -893,16 +895,45 @@ pub async fn handle_incoming_pdu<'a>( &room_id, pub_key_map, ) - .await.pop() { - todo_timeline_stack.push((pdu, json)); + .await + .pop() + { + if incoming_pdu.origin_server_ts + > db.rooms + .first_pdu_in_room(&room_id) + .map_err(|_| "Error loading first room event.".to_owned())? + .expect("Room exists") + .origin_server_ts + { + todo_outlier_stack.extend(pdu.prev_events.iter().cloned()); + todo_timeline_stack.push((pdu, json)); + } } } while let Some(prev) = todo_timeline_stack.pop() { - upgrade_outlier_to_timeline_pdu(prev.0, prev.1, &create_event, origin, db, room_id, pub_key_map).await?; + upgrade_outlier_to_timeline_pdu( + prev.0, + prev.1, + &create_event, + origin, + db, + room_id, + pub_key_map, + ) + .await?; } - upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, db, room_id, pub_key_map).await + upgrade_outlier_to_timeline_pdu( + incoming_pdu, + val, + &create_event, + origin, + db, + room_id, + pub_key_map, + ) + .await } fn handle_outlier_pdu<'a>( @@ -913,7 +944,8 @@ fn handle_outlier_pdu<'a>( value: BTreeMap, db: &'a Database, pub_key_map: &'a RwLock>>, -) -> AsyncRecursiveType<'a, StdResult<(Arc, BTreeMap), String>> { +) -> AsyncRecursiveType<'a, StdResult<(Arc, BTreeMap), String>> +{ Box::pin(async move { let start_time = Instant::now(); @@ -928,11 +960,11 @@ fn handle_outlier_pdu<'a>( // 2. Check signatures, otherwise drop // 3. check content hash, redact if doesn't match - let create_event_content = - serde_json::from_value::>(create_event.content.clone()) - .expect("Raw::from_value always works.") - .deserialize() - .map_err(|_| "Invalid PowerLevels event in db.".to_owned())?; + let create_event_content = + serde_json::from_value::>(create_event.content.clone()) + .expect("Raw::from_value always works.") + .deserialize() + .map_err(|_| "Invalid PowerLevels event in db.".to_owned())?; let room_version_id = &create_event_content.room_version; let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); @@ -1062,9 +1094,8 @@ fn handle_outlier_pdu<'a>( .map_err(|_| "Failed to add pdu as outlier.".to_owned())?; debug!("Added pdu as outlier."); - Ok((incoming_pdu,val)) + Ok((incoming_pdu, val)) }) - } async fn upgrade_outlier_to_timeline_pdu( @@ -1076,381 +1107,385 @@ async fn upgrade_outlier_to_timeline_pdu( room_id: &RoomId, pub_key_map: &RwLock>>, ) -> StdResult>, String> { - // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities - // doing all the checks in this list starting at 1. These are not timeline events. + // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities + // doing all the checks in this list starting at 1. These are not timeline events. - // TODO: if we know the prev_events of the incoming event we can avoid the request and build - // the state from a known point and resolve if > 1 prev_event + // TODO: if we know the prev_events of the incoming event we can avoid the request and build + // the state from a known point and resolve if > 1 prev_event - debug!("Requesting state at event."); - let mut state_at_incoming_event = None; + debug!("Requesting state at event."); + let mut state_at_incoming_event = None; - if incoming_pdu.prev_events.len() == 1 { - let prev_event = &incoming_pdu.prev_events[0]; - let prev_event_sstatehash = db - .rooms - .pdu_shortstatehash(prev_event) - .map_err(|_| "Failed talking to db".to_owned())?; + if incoming_pdu.prev_events.len() == 1 { + let prev_event = &incoming_pdu.prev_events[0]; + let prev_event_sstatehash = db + .rooms + .pdu_shortstatehash(prev_event) + .map_err(|_| "Failed talking to db".to_owned())?; - let state = - prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash)); + let state = + prev_event_sstatehash.map(|shortstatehash| db.rooms.state_full_ids(shortstatehash)); - if let Some(Ok(state)) = state { - warn!("Using cached state"); - let mut state = fetch_and_handle_outliers( - db, + if let Some(Ok(state)) = state { + warn!("Using cached state"); + let mut state = fetch_and_handle_outliers( + db, + origin, + &state.into_iter().collect::>(), + &create_event, + &room_id, + pub_key_map, + ) + .await + .into_iter() + .map(|(pdu, _)| { + ( + ( + pdu.kind.clone(), + pdu.state_key + .clone() + .expect("events from state_full_ids are state events"), + ), + pdu, + ) + }) + .collect::>(); + + let prev_pdu = + db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { + "Could not find prev event, but we know the state.".to_owned() + })?; + + if let Some(state_key) = &prev_pdu.state_key { + state.insert((prev_pdu.kind.clone(), state_key.clone()), prev_pdu); + } + + state_at_incoming_event = Some(state); + } + // TODO: set incoming_auth_events? + } + + if state_at_incoming_event.is_none() { + warn!("Calling /state_ids"); + // Call /state_ids to find out what the state at this pdu is. We trust the server's + // response to some extend, but we still do a lot of checks on the events + match db + .sending + .send_federation_request( + &db.globals, + origin, + get_room_state_ids::v1::Request { + room_id: &room_id, + event_id: &incoming_pdu.event_id, + }, + ) + .await + { + Ok(res) => { + debug!("Fetching state events at event."); + let state_vec = fetch_and_handle_outliers( + &db, origin, - &state.into_iter().collect::>(), + &res.pdu_ids, &create_event, &room_id, pub_key_map, ) - .await - .into_iter() - .map(|(pdu,_)| { - ( - ( - pdu.kind.clone(), - pdu.state_key - .clone() - .expect("events from state_full_ids are state events"), + .await; + + let mut state = HashMap::new(); + for (pdu, _) in state_vec { + match state.entry(( + pdu.kind.clone(), + pdu.state_key + .clone() + .ok_or_else(|| "Found non-state pdu in state events.".to_owned())?, + )) { + Entry::Vacant(v) => { + v.insert(pdu); + } + Entry::Occupied(_) => return Err( + "State event's type and state_key combination exists multiple times." + .to_owned(), ), - pdu, - ) - }) - .collect::>(); - - let prev_pdu = db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { - "Could not find prev event, but we know the state.".to_owned() - })?; - - if let Some(state_key) = &prev_pdu.state_key { - state.insert((prev_pdu.kind.clone(), state_key.clone()), prev_pdu); + } } + // The original create event must still be in the state + if state + .get(&(EventType::RoomCreate, "".to_owned())) + .map(|a| a.as_ref()) + != Some(&create_event) + { + return Err("Incoming event refers to wrong create event.".to_owned()); + } + + debug!("Fetching auth chain events at event."); + fetch_and_handle_outliers( + &db, + origin, + &res.auth_chain_ids, + &create_event, + &room_id, + pub_key_map, + ) + .await; + state_at_incoming_event = Some(state); } - // TODO: set incoming_auth_events? - } + Err(_) => { + return Err("Fetching state for event failed".into()); + } + }; + } - if state_at_incoming_event.is_none() { - warn!("Calling /state_ids"); - // Call /state_ids to find out what the state at this pdu is. We trust the server's - // response to some extend, but we still do a lot of checks on the events - match db - .sending - .send_federation_request( - &db.globals, - origin, - get_room_state_ids::v1::Request { - room_id: &room_id, - event_id: &incoming_pdu.event_id, - }, - ) - .await - { - Ok(res) => { - debug!("Fetching state events at event."); - let state_vec = fetch_and_handle_outliers( - &db, - origin, - &res.pdu_ids, - &create_event, - &room_id, - pub_key_map, - ) - .await; + let state_at_incoming_event = + state_at_incoming_event.expect("we always set this to some above"); - let mut state = HashMap::new(); - for (pdu, _) in state_vec { - match state.entry((pdu.kind.clone(), pdu.state_key.clone().ok_or_else(|| "Found non-state pdu in state events.".to_owned())?)) { - Entry::Vacant(v) => { - v.insert(pdu); - } - Entry::Occupied(_) => { - return Err( - "State event's type and state_key combination exists multiple times.".to_owned(), - ) - } - } - } - - // The original create event must still be in the state - if state - .get(&(EventType::RoomCreate, "".to_owned())) - .map(|a| a.as_ref()) - != Some(&create_event) - { - return Err("Incoming event refers to wrong create event.".to_owned()); - } - - debug!("Fetching auth chain events at event."); - fetch_and_handle_outliers( - &db, - origin, - &res.auth_chain_ids, - &create_event, - &room_id, - pub_key_map, - ) - .await; - - state_at_incoming_event = Some(state); - } - Err(_) => { - return Err("Fetching state for event failed".into()); - } - }; - } - - let state_at_incoming_event = - state_at_incoming_event.expect("we always set this to some above"); - - // 11. Check the auth of the event passes based on the state of the event + // 11. Check the auth of the event passes based on the state of the event let create_event_content = serde_json::from_value::>(create_event.content.clone()) .expect("Raw::from_value always works.") .deserialize() .map_err(|_| "Invalid PowerLevels event in db.".to_owned())?; - let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); - // If the previous event was the create event special rules apply - let previous_create = if incoming_pdu.auth_events.len() == 1 - && incoming_pdu.prev_events == incoming_pdu.auth_events - { - db.rooms - .get_pdu(&incoming_pdu.auth_events[0]) - .map_err(|e| e.to_string())? - .filter(|maybe_create| **maybe_create == *create_event) - } else { - None - }; + // If the previous event was the create event special rules apply + let previous_create = if incoming_pdu.auth_events.len() == 1 + && incoming_pdu.prev_events == incoming_pdu.auth_events + { + db.rooms + .get_pdu(&incoming_pdu.auth_events[0]) + .map_err(|e| e.to_string())? + .filter(|maybe_create| **maybe_create == *create_event) + } else { + None + }; + if !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + previous_create.clone(), + &state_at_incoming_event, + None, // TODO: third party invite + ) + .map_err(|_e| "Auth check failed.".to_owned())? + { + return Err("Event has failed auth check with state at the event.".into()); + } + debug!("Auth check succeeded."); - if !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - previous_create.clone(), - &state_at_incoming_event, - None, // TODO: third party invite - ) - .map_err(|_e| "Auth check failed.".to_owned())? - { - return Err("Event has failed auth check with state at the event.".into()); + // We start looking at current room state now, so lets lock the room + + let mutex_state = Arc::clone( + db.globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.clone()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + // Now we calculate the set of extremities this room has after the incoming event has been + // applied. We start with the previous extremities (aka leaves) + let mut extremities = db + .rooms + .get_pdu_leaves(&room_id) + .map_err(|_| "Failed to load room leaves".to_owned())?; + + // Remove any forward extremities that are referenced by this incoming event's prev_events + for prev_event in &incoming_pdu.prev_events { + if extremities.contains(prev_event) { + extremities.remove(prev_event); } - debug!("Auth check succeeded."); + } - // We start looking at current room state now, so lets lock the room + // Only keep those extremities were not referenced yet + extremities.retain(|id| !matches!(db.rooms.is_event_referenced(&room_id, id), Ok(true))); - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), + let mut extremity_statehashes = Vec::new(); + + for id in &extremities { + match db + .rooms + .get_pdu(&id) + .map_err(|_| "Failed to ask db for pdu.".to_owned())? + { + Some(leaf_pdu) => { + extremity_statehashes.push(( + db.rooms + .pdu_shortstatehash(&leaf_pdu.event_id) + .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? + .ok_or_else(|| { + error!( + "Found extremity pdu with no statehash in db: {:?}", + leaf_pdu + ); + "Found pdu with no statehash in db.".to_owned() + })?, + Some(leaf_pdu), + )); + } + _ => { + error!("Missing state snapshot for {:?}", id); + return Err("Missing state snapshot.".to_owned()); + } + } + } + + // 12. Ensure that the state is derived from the previous current state (i.e. we calculated + // by doing state res where one of the inputs was a previously trusted set of state, + // don't just trust a set of state we got from a remote). + + // We do this by adding the current state to the list of fork states + let current_statehash = db + .rooms + .current_shortstatehash(&room_id) + .map_err(|_| "Failed to load current state hash.".to_owned())? + .expect("every room has state"); + + let current_state = db + .rooms + .state_full(current_statehash) + .map_err(|_| "Failed to load room state.")?; + + extremity_statehashes.push((current_statehash.clone(), None)); + + let mut fork_states = Vec::new(); + for (statehash, leaf_pdu) in extremity_statehashes { + let mut leaf_state = db + .rooms + .state_full(statehash) + .map_err(|_| "Failed to ask db for room state.".to_owned())?; + + if let Some(leaf_pdu) = leaf_pdu { + if let Some(state_key) = &leaf_pdu.state_key { + // Now it's the state after + let key = (leaf_pdu.kind.clone(), state_key.clone()); + leaf_state.insert(key, leaf_pdu); + } + } + + fork_states.push(leaf_state); + } + + // We also add state after incoming event to the fork states + extremities.insert(incoming_pdu.event_id.clone()); + let mut state_after = state_at_incoming_event.clone(); + if let Some(state_key) = &incoming_pdu.state_key { + state_after.insert( + (incoming_pdu.kind.clone(), state_key.clone()), + incoming_pdu.clone(), ); - let state_lock = mutex_state.lock().await; + } + fork_states.push(state_after.clone()); - // Now we calculate the set of extremities this room has after the incoming event has been - // applied. We start with the previous extremities (aka leaves) - let mut extremities = db - .rooms - .get_pdu_leaves(&room_id) - .map_err(|_| "Failed to load room leaves".to_owned())?; + let mut update_state = false; + // 14. Use state resolution to find new room state + let new_room_state = if fork_states.is_empty() { + return Err("State is empty.".to_owned()); + } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { + // There was only one state, so it has to be the room's current state (because that is + // always included) + fork_states[0] + .iter() + .map(|(k, pdu)| (k.clone(), pdu.event_id.clone())) + .collect() + } else { + // We do need to force an update to this room's state + update_state = true; - // Remove any forward extremities that are referenced by this incoming event's prev_events - for prev_event in &incoming_pdu.prev_events { - if extremities.contains(prev_event) { - extremities.remove(prev_event); - } - } + let fork_states = &fork_states + .into_iter() + .map(|map| { + map.into_iter() + .map(|(k, v)| (k, v.event_id.clone())) + .collect::>() + }) + .collect::>(); - // Only keep those extremities were not referenced yet - extremities.retain(|id| !matches!(db.rooms.is_event_referenced(&room_id, id), Ok(true))); - - let mut extremity_statehashes = Vec::new(); - - for id in &extremities { - match db - .rooms - .get_pdu(&id) - .map_err(|_| "Failed to ask db for pdu.".to_owned())? - { - Some(leaf_pdu) => { - extremity_statehashes.push(( - db.rooms - .pdu_shortstatehash(&leaf_pdu.event_id) - .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? - .ok_or_else(|| { - error!( - "Found extremity pdu with no statehash in db: {:?}", - leaf_pdu - ); - "Found pdu with no statehash in db.".to_owned() - })?, - Some(leaf_pdu), - )); - } - _ => { - error!("Missing state snapshot for {:?}", id); - return Err("Missing state snapshot.".to_owned()); - } - } - } - - // 12. Ensure that the state is derived from the previous current state (i.e. we calculated - // by doing state res where one of the inputs was a previously trusted set of state, - // don't just trust a set of state we got from a remote). - - // We do this by adding the current state to the list of fork states - let current_statehash = db - .rooms - .current_shortstatehash(&room_id) - .map_err(|_| "Failed to load current state hash.".to_owned())? - .expect("every room has state"); - - let current_state = db - .rooms - .state_full(current_statehash) - .map_err(|_| "Failed to load room state.")?; - - extremity_statehashes.push((current_statehash.clone(), None)); - - let mut fork_states = Vec::new(); - for (statehash, leaf_pdu) in extremity_statehashes { - let mut leaf_state = db - .rooms - .state_full(statehash) - .map_err(|_| "Failed to ask db for room state.".to_owned())?; - - if let Some(leaf_pdu) = leaf_pdu { - if let Some(state_key) = &leaf_pdu.state_key { - // Now it's the state after - let key = (leaf_pdu.kind.clone(), state_key.clone()); - leaf_state.insert(key, leaf_pdu); - } - } - - fork_states.push(leaf_state); - } - - // We also add state after incoming event to the fork states - extremities.insert(incoming_pdu.event_id.clone()); - let mut state_after = state_at_incoming_event.clone(); - if let Some(state_key) = &incoming_pdu.state_key { - state_after.insert( - (incoming_pdu.kind.clone(), state_key.clone()), - incoming_pdu.clone(), + let mut auth_chain_sets = Vec::new(); + for state in fork_states { + auth_chain_sets.push( + get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) + .map_err(|_| "Failed to load auth chain.".to_owned())? + .collect(), ); } - fork_states.push(state_after.clone()); - let mut update_state = false; - // 14. Use state resolution to find new room state - let new_room_state = if fork_states.is_empty() { - return Err("State is empty.".to_owned()); - } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { - // There was only one state, so it has to be the room's current state (because that is - // always included) - fork_states[0] - .iter() - .map(|(k, pdu)| (k.clone(), pdu.event_id.clone())) - .collect() - } else { - // We do need to force an update to this room's state - update_state = true; - - let fork_states = &fork_states - .into_iter() - .map(|map| { - map.into_iter() - .map(|(k, v)| (k, v.event_id.clone())) - .collect::>() - }) - .collect::>(); - - let mut auth_chain_sets = Vec::new(); - for state in fork_states { - auth_chain_sets.push( - get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) - .map_err(|_| "Failed to load auth chain.".to_owned())? - .collect(), - ); - } - - let state = match state_res::StateResolution::resolve( - &room_id, - room_version_id, - fork_states, - auth_chain_sets, - |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); - } - res.ok().flatten() - }, - ) { - Ok(new_state) => new_state, - Err(_) => { - return Err("State resolution failed, either an event could not be found or deserialization".into()); + let state = match state_res::StateResolution::resolve( + &room_id, + room_version_id, + fork_states, + auth_chain_sets, + |id| { + let res = db.rooms.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); } - }; - - state + res.ok().flatten() + }, + ) { + Ok(new_state) => new_state, + Err(_) => { + return Err("State resolution failed, either an event could not be found or deserialization".into()); + } }; - debug!("starting soft fail auth check"); - // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it - let soft_fail = !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - previous_create, - ¤t_state, - None, - ) - .map_err(|_e| "Auth check failed.".to_owned())?; + state + }; - let mut pdu_id = None; - if !soft_fail { - // Now that the event has passed all auth it is added into the timeline. - // We use the `state_at_event` instead of `state_after` so we accurately - // represent the state for this event. - pdu_id = Some( - append_incoming_pdu( - &db, - &incoming_pdu, - val, - extremities, - &state_at_incoming_event, - &state_lock, - ) - .map_err(|_| "Failed to add pdu to db.".to_owned())?, - ); - debug!("Appended incoming pdu."); - } else { - warn!("Event was soft failed: {:?}", incoming_pdu); - } + debug!("starting soft fail auth check"); + // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it + let soft_fail = !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + previous_create, + ¤t_state, + None, + ) + .map_err(|_e| "Auth check failed.".to_owned())?; - // Set the new room state to the resolved state - if update_state { - db.rooms - .force_state(&room_id, new_room_state, &db) - .map_err(|_| "Failed to set new room state.".to_owned())?; - } - debug!("Updated resolved state"); + let mut pdu_id = None; + if !soft_fail { + // Now that the event has passed all auth it is added into the timeline. + // We use the `state_at_event` instead of `state_after` so we accurately + // represent the state for this event. + pdu_id = Some( + append_incoming_pdu( + &db, + &incoming_pdu, + val, + extremities, + &state_at_incoming_event, + &state_lock, + ) + .map_err(|_| "Failed to add pdu to db.".to_owned())?, + ); + debug!("Appended incoming pdu."); + } else { + warn!("Event was soft failed: {:?}", incoming_pdu); + } - if soft_fail { - // Soft fail, we leave the event as an outlier but don't add it to the timeline - return Err("Event has been soft failed".into()); - } + // Set the new room state to the resolved state + if update_state { + db.rooms + .force_state(&room_id, new_room_state, &db) + .map_err(|_| "Failed to set new room state.".to_owned())?; + } + debug!("Updated resolved state"); - // Event has passed all auth/stateres checks - drop(state_lock); - Ok(pdu_id) + if soft_fail { + // Soft fail, we leave the event as an outlier but don't add it to the timeline + return Err("Event has been soft failed".into()); + } + + // Event has passed all auth/stateres checks + drop(state_lock); + Ok(pdu_id) } /// Find the event and auth it. Once the event is validated (steps 1 - 8) @@ -1535,9 +1570,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>( ) .await { - Ok((pdu, json)) => { - (pdu, Some(json)) - } + Ok((pdu, json)) => (pdu, Some(json)), Err(e) => { warn!("Authentication of event {} failed: {:?}", id, e); back_off(id.clone()); From f9a2edc0dda2b44fd04d20d946ea2b5431cf5d29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 22:50:45 +0200 Subject: [PATCH 15/25] fix: also fetch prev events that are outliers already --- src/database/rooms.rs | 13 +++++++++++++ src/server_server.rs | 24 +++++++++++++----------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 99f0c83e..4a3ab71a 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -844,6 +844,19 @@ impl Rooms { .transpose() } + /// Returns the json of a pdu. + pub fn get_outlier_pdu_json( + &self, + event_id: &EventId, + ) -> Result> { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map(|pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + .transpose() + } + /// Returns the json of a pdu. pub fn get_non_outlier_pdu_json( &self, diff --git a/src/server_server.rs b/src/server_server.rs index e7221261..a293d1b1 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -887,10 +887,10 @@ pub async fn handle_incoming_pdu<'a>( let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); let mut todo_timeline_stack = Vec::new(); while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, Some(json))) = fetch_and_handle_outliers( + if let Some((pdu, json_opt)) = fetch_and_handle_outliers( db, origin, - &[prev_event_id], + &[prev_event_id.clone()], &create_event, &room_id, pub_key_map, @@ -898,15 +898,17 @@ pub async fn handle_incoming_pdu<'a>( .await .pop() { - if incoming_pdu.origin_server_ts - > db.rooms - .first_pdu_in_room(&room_id) - .map_err(|_| "Error loading first room event.".to_owned())? - .expect("Room exists") - .origin_server_ts - { - todo_outlier_stack.extend(pdu.prev_events.iter().cloned()); - todo_timeline_stack.push((pdu, json)); + if let Some(json) = json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) { + if incoming_pdu.origin_server_ts + > db.rooms + .first_pdu_in_room(&room_id) + .map_err(|_| "Error loading first room event.".to_owned())? + .expect("Room exists") + .origin_server_ts + { + todo_outlier_stack.extend(pdu.prev_events.iter().cloned()); + todo_timeline_stack.push((pdu, json)); + } } } } From 5bd5b41c70791e9664d55425066cfe0be8282e35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 14 Aug 2021 23:29:25 +0200 Subject: [PATCH 16/25] fix: fetch event multiple times --- src/database/rooms.rs | 8 ++------ src/server_server.rs | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 4a3ab71a..2832cc2d 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -845,10 +845,7 @@ impl Rooms { } /// Returns the json of a pdu. - pub fn get_outlier_pdu_json( - &self, - event_id: &EventId, - ) -> Result> { + pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? .map(|pdu| { @@ -1134,10 +1131,9 @@ impl Rooms { &serde_json::to_vec(&pdu_json).expect("CanonicalJsonObject is always a valid"), )?; - // This also replaces the eventid of any outliers with the correct - // pduid, removing the place holder. self.eventid_pduid .insert(pdu.event_id.as_bytes(), &pdu_id)?; + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; drop(insert_lock); diff --git a/src/server_server.rs b/src/server_server.rs index a293d1b1..a4bfda58 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -884,9 +884,15 @@ pub async fn handle_incoming_pdu<'a>( } // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events + let mut visited = HashSet::new(); let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); let mut todo_timeline_stack = Vec::new(); while let Some(prev_event_id) = todo_outlier_stack.pop() { + if visited.contains(&prev_event_id) { + continue; + } + visited.insert(prev_event_id.clone()); + if let Some((pdu, json_opt)) = fetch_and_handle_outliers( db, origin, @@ -898,7 +904,9 @@ pub async fn handle_incoming_pdu<'a>( .await .pop() { - if let Some(json) = json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) { + if let Some(json) = + json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) + { if incoming_pdu.origin_server_ts > db.rooms .first_pdu_in_room(&room_id) @@ -949,8 +957,6 @@ fn handle_outlier_pdu<'a>( ) -> AsyncRecursiveType<'a, StdResult<(Arc, BTreeMap), String>> { Box::pin(async move { - let start_time = Instant::now(); - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json // We go through all the signatures we see on the value and fetch the corresponding signing @@ -1109,6 +1115,9 @@ async fn upgrade_outlier_to_timeline_pdu( room_id: &RoomId, pub_key_map: &RwLock>>, ) -> StdResult>, String> { + if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { + return Ok(Some(pduid)); + } // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities // doing all the checks in this list starting at 1. These are not timeline events. From a4310f840ef4c6041b6098c87e4bcce697217209 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 15 Aug 2021 06:46:00 +0200 Subject: [PATCH 17/25] improvement: state info cache --- src/database.rs | 1 + src/database/rooms.rs | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/database.rs b/src/database.rs index 0bf2a440..e66ff04d 100644 --- a/src/database.rs +++ b/src/database.rs @@ -278,6 +278,7 @@ impl Database { pdu_cache: Mutex::new(LruCache::new(100_000)), auth_chain_cache: Mutex::new(LruCache::new(100_000)), shorteventid_cache: Mutex::new(LruCache::new(1_000_000)), + stateinfo_cache: Mutex::new(LruCache::new(1000)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 2832cc2d..5baadf9e 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -92,6 +92,13 @@ pub struct Rooms { pub(super) pdu_cache: Mutex>>, pub(super) auth_chain_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>, + pub(super) stateinfo_cache: Mutex, // full state + HashSet, // added + HashSet, // removed + )>>>, } impl Rooms { @@ -407,6 +414,14 @@ impl Rooms { HashSet, // removed )>, > { + if let Some(r) = self.stateinfo_cache + .lock() + .unwrap() + .get_mut(&shortstatehash) + { + return Ok(r.clone()); + } + let value = self .shortstatehash_statediff .get(&shortstatehash.to_be_bytes())? @@ -443,10 +458,18 @@ impl Rooms { response.push((shortstatehash, state, added, removed)); + self.stateinfo_cache + .lock() + .unwrap() + .insert(shortstatehash, response.clone()); Ok(response) } else { let mut response = Vec::new(); response.push((shortstatehash, added.clone(), added, removed)); + self.stateinfo_cache + .lock() + .unwrap() + .insert(shortstatehash, response.clone()); Ok(response) } } From 2c3bee34a0a9da43374d48527c2208104ea1c95c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 15 Aug 2021 13:17:42 +0200 Subject: [PATCH 18/25] improvement: better sqlite --- src/database/abstraction.rs | 1 + src/database/abstraction/sqlite.rs | 27 ++++++++++++++++++---- src/database/rooms.rs | 37 +++++++++++++++++++----------- 3 files changed, 47 insertions(+), 18 deletions(-) diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index f381ce9f..5b941fbe 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -35,6 +35,7 @@ pub trait Tree: Send + Sync { ) -> Box, Vec)> + 'a>; fn increment(&self, key: &[u8]) -> Result>; + fn increment_batch<'a>(&self, iter: &mut dyn Iterator>) -> Result<()>; fn scan_prefix<'a>( &'a self, diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 3c4ae9c6..1e55418b 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -49,11 +49,11 @@ impl Engine { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { let conn = Connection::open(&path)?; - conn.pragma_update(Some(Main), "page_size", &32768)?; + conn.pragma_update(Some(Main), "page_size", &1024)?; conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; - conn.pragma_update(Some(Main), "wal_autocheckpoint", &0)?; + conn.pragma_update(Some(Main), "wal_autocheckpoint", &8000)?; Ok(conn) } @@ -93,8 +93,9 @@ impl Engine { } pub fn flush_wal(self: &Arc) -> Result<()> { - self.write_lock() - .pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; + // We use autocheckpoints + //self.write_lock() + //.pragma_update(Some(Main), "wal_checkpoint", &"TRUNCATE")?; Ok(()) } } @@ -248,6 +249,24 @@ impl Tree for SqliteTable { Ok(()) } + #[tracing::instrument(skip(self, iter))] + fn increment_batch<'a>(&self, iter: &mut dyn Iterator>) -> Result<()> { + let guard = self.engine.write_lock(); + + guard.execute("BEGIN", [])?; + for key in iter { + let old = self.get_with_guard(&guard, &key)?; + let new = crate::utils::increment(old.as_deref()) + .expect("utils::increment always returns Some"); + self.insert_with_guard(&guard, &key, &new)?; + } + guard.execute("COMMIT", [])?; + + drop(guard); + + Ok(()) + } + #[tracing::instrument(skip(self, key))] fn remove(&self, key: &[u8]) -> Result<()> { let guard = self.engine.write_lock(); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 5baadf9e..d648b7d8 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -92,13 +92,17 @@ pub struct Rooms { pub(super) pdu_cache: Mutex>>, pub(super) auth_chain_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>, - pub(super) stateinfo_cache: Mutex, // full state - HashSet, // added - HashSet, // removed - )>>>, + pub(super) stateinfo_cache: Mutex< + LruCache< + u64, + Vec<( + u64, // sstatehash + HashSet, // full state + HashSet, // added + HashSet, // removed + )>, + >, + >, } impl Rooms { @@ -414,7 +418,8 @@ impl Rooms { HashSet, // removed )>, > { - if let Some(r) = self.stateinfo_cache + if let Some(r) = self + .stateinfo_cache .lock() .unwrap() .get_mut(&shortstatehash) @@ -458,10 +463,6 @@ impl Rooms { response.push((shortstatehash, state, added, removed)); - self.stateinfo_cache - .lock() - .unwrap() - .insert(shortstatehash, response.clone()); Ok(response) } else { let mut response = Vec::new(); @@ -1173,6 +1174,9 @@ impl Rooms { let sync_pdu = pdu.to_sync_room_event(); + let mut notifies = Vec::new(); + let mut highlights = Vec::new(); + for user in db .rooms .room_members(&pdu.room_id) @@ -1218,11 +1222,11 @@ impl Rooms { userroom_id.extend_from_slice(pdu.room_id.as_bytes()); if notify { - self.userroomid_notificationcount.increment(&userroom_id)?; + notifies.push(userroom_id.clone()); } if highlight { - self.userroomid_highlightcount.increment(&userroom_id)?; + highlights.push(userroom_id); } for senderkey in db.pusher.get_pusher_senderkeys(&user) { @@ -1230,6 +1234,11 @@ impl Rooms { } } + self.userroomid_notificationcount + .increment_batch(&mut notifies.into_iter())?; + self.userroomid_highlightcount + .increment_batch(&mut highlights.into_iter())?; + match pdu.kind { EventType::RoomRedaction => { if let Some(redact_id) = &pdu.redacts { From 0823506d05b63efe7ef3a3c6611f066f24ec5bc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Mon, 16 Aug 2021 23:24:52 +0200 Subject: [PATCH 19/25] fix: don't load endless prev events and fix room join bug --- src/database/rooms.rs | 7 ++++++- src/server_server.rs | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index d648b7d8..75251aa2 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -1571,6 +1571,8 @@ impl Rooms { let shortstatekey = self.get_or_create_shortstatekey(&new_pdu.kind, &state_key, globals)?; + let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; + let replaces = states_parents .last() .map(|info| { @@ -1580,11 +1582,14 @@ impl Rooms { }) .unwrap_or_default(); + if Some(&new) == replaces { + return Ok(previous_shortstatehash.expect("must exist")); + } + // TODO: statehash with deterministic inputs let shortstatehash = globals.next_count()?; let mut statediffnew = HashSet::new(); - let new = self.compress_state_event(shortstatekey, &new_pdu.event_id, globals)?; statediffnew.insert(new); let mut statediffremoved = HashSet::new(); diff --git a/src/server_server.rs b/src/server_server.rs index a4bfda58..7a28e5d9 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -907,7 +907,7 @@ pub async fn handle_incoming_pdu<'a>( if let Some(json) = json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) { - if incoming_pdu.origin_server_ts + if pdu.origin_server_ts > db.rooms .first_pdu_in_room(&room_id) .map_err(|_| "Error loading first room event.".to_owned())? From 75ba8bb5657231ebc8ea1b54b49cb667e11ad3a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Tue, 17 Aug 2021 00:22:52 +0200 Subject: [PATCH 20/25] fix: faster room joins --- src/database/abstraction/sqlite.rs | 4 ++-- src/database/rooms.rs | 10 +++++++++- src/server_server.rs | 1 + 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 1e55418b..5b895c78 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -49,11 +49,11 @@ impl Engine { fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { let conn = Connection::open(&path)?; - conn.pragma_update(Some(Main), "page_size", &1024)?; + conn.pragma_update(Some(Main), "page_size", &2048)?; conn.pragma_update(Some(Main), "journal_mode", &"WAL")?; conn.pragma_update(Some(Main), "synchronous", &"NORMAL")?; conn.pragma_update(Some(Main), "cache_size", &(-i64::from(cache_size_kb)))?; - conn.pragma_update(Some(Main), "wal_autocheckpoint", &8000)?; + conn.pragma_update(Some(Main), "wal_autocheckpoint", &2000)?; Ok(conn) } diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 75251aa2..e2415a41 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -392,6 +392,7 @@ impl Rooms { &pdu.sender, None, db, + false, )?; } } @@ -400,6 +401,8 @@ impl Rooms { } } + self.update_joined_count(room_id)?; + self.roomid_shortstatehash .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; @@ -1285,6 +1288,7 @@ impl Rooms { &pdu.sender, invite_state, db, + true, )?; } } @@ -2051,6 +2055,7 @@ impl Rooms { sender: &UserId, last_state: Option>>, db: &Database, + update_joined_count: bool, ) -> Result<()> { // Keep track what remote users exist by adding them as "deactivated" users if user_id.server_name() != db.globals.server_name() { @@ -2232,7 +2237,9 @@ impl Rooms { _ => {} } - self.update_joined_count(room_id)?; + if update_joined_count { + self.update_joined_count(room_id)?; + } Ok(()) } @@ -2269,6 +2276,7 @@ impl Rooms { user_id, last_state, db, + true, )?; } else { let mutex_state = Arc::clone( diff --git a/src/server_server.rs b/src/server_server.rs index 7a28e5d9..de3eef57 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -2471,6 +2471,7 @@ pub async fn create_invite_route( &sender, Some(invite_state), &db, + true, )?; } From bf7e019a686c4263163a3b278431c9fbc184d74b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Tue, 17 Aug 2021 16:06:09 +0200 Subject: [PATCH 21/25] improvement: better prev event fetching, perf improvements --- src/client_server/account.rs | 2 +- src/database.rs | 2 + src/database/rooms.rs | 140 +++++++++++++++++++++++++++++------ src/server_server.rs | 127 +++++++++++++++++++------------ 4 files changed, 202 insertions(+), 69 deletions(-) diff --git a/src/client_server/account.rs b/src/client_server/account.rs index d4f103c7..b00882a6 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -733,7 +733,7 @@ pub async fn deactivate_route( pub async fn third_party_route( body: Ruma, ) -> ConduitResult { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_contacts::Response::new(Vec::new()).into()) } diff --git a/src/database.rs b/src/database.rs index e66ff04d..5ad2add8 100644 --- a/src/database.rs +++ b/src/database.rs @@ -278,6 +278,8 @@ impl Database { pdu_cache: Mutex::new(LruCache::new(100_000)), auth_chain_cache: Mutex::new(LruCache::new(100_000)), shorteventid_cache: Mutex::new(LruCache::new(1_000_000)), + eventidshort_cache: Mutex::new(LruCache::new(1_000_000)), + statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)), stateinfo_cache: Mutex::new(LruCache::new(1000)), }, account_data: account_data::AccountData { diff --git a/src/database/rooms.rs b/src/database/rooms.rs index e2415a41..600566c1 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -92,6 +92,8 @@ pub struct Rooms { pub(super) pdu_cache: Mutex>>, pub(super) auth_chain_cache: Mutex>>, pub(super) shorteventid_cache: Mutex>, + pub(super) eventidshort_cache: Mutex>, + pub(super) statekeyshort_cache: Mutex>, pub(super) stateinfo_cache: Mutex< LruCache< u64, @@ -665,7 +667,11 @@ impl Rooms { event_id: &EventId, globals: &super::globals::Globals, ) -> Result { - Ok(match self.eventid_shorteventid.get(event_id.as_bytes())? { + if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(&event_id) { + return Ok(*short); + } + + let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { Some(shorteventid) => utils::u64_from_bytes(&shorteventid) .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, None => { @@ -676,7 +682,14 @@ impl Rooms { .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; shorteventid } - }) + }; + + self.eventidshort_cache + .lock() + .unwrap() + .insert(event_id.clone(), short); + + Ok(short) } pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { @@ -694,17 +707,36 @@ impl Rooms { event_type: &EventType, state_key: &str, ) -> Result> { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(Some(*short)); + } + let mut statekey = event_type.as_ref().as_bytes().to_vec(); statekey.push(0xff); statekey.extend_from_slice(&state_key.as_bytes()); - self.statekey_shortstatekey + let short = self + .statekey_shortstatekey .get(&statekey)? .map(|shortstatekey| { utils::u64_from_bytes(&shortstatekey) .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) }) - .transpose() + .transpose()?; + + if let Some(s) = short { + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), s); + } + + Ok(short) } pub fn get_or_create_shortroomid( @@ -730,11 +762,20 @@ impl Rooms { state_key: &str, globals: &super::globals::Globals, ) -> Result { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(*short); + } + let mut statekey = event_type.as_ref().as_bytes().to_vec(); statekey.push(0xff); statekey.extend_from_slice(&state_key.as_bytes()); - Ok(match self.statekey_shortstatekey.get(&statekey)? { + let short = match self.statekey_shortstatekey.get(&statekey)? { Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, None => { @@ -743,7 +784,14 @@ impl Rooms { .insert(&statekey, &shortstatekey.to_be_bytes())?; shortstatekey } - }) + }; + + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), short); + + Ok(short) } pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { @@ -2173,8 +2221,10 @@ impl Rooms { } } - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; + if update_joined_count { + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } self.userroomid_joined.insert(&userroom_id, &[])?; self.roomuserid_joined.insert(&roomuser_id, &[])?; self.userroomid_invitestate.remove(&userroom_id)?; @@ -2199,8 +2249,10 @@ impl Rooms { return Ok(()); } - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; + if update_joined_count { + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } self.userroomid_invitestate.insert( &userroom_id, &serde_json::to_vec(&last_state.unwrap_or_default()) @@ -2214,14 +2266,16 @@ impl Rooms { self.roomuserid_leftcount.remove(&roomuser_id)?; } member::MembershipState::Leave | member::MembershipState::Ban => { - if self - .room_members(room_id) - .chain(self.room_members_invited(room_id)) - .filter_map(|r| r.ok()) - .all(|u| u.server_name() != user_id.server_name()) - { - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; + if update_joined_count { + if self + .room_members(room_id) + .chain(self.room_members_invited(room_id)) + .filter_map(|r| r.ok()) + .all(|u| u.server_name() != user_id.server_name()) + { + self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; + } } self.userroomid_leftstate.insert( &userroom_id, @@ -2245,10 +2299,52 @@ impl Rooms { } pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - self.roomid_joinedcount.insert( - room_id.as_bytes(), - &(self.room_members(&room_id).count() as u64).to_be_bytes(), - ) + let mut joinedcount = 0_u64; + let mut joined_servers = HashSet::new(); + + for joined in self.room_members(&room_id).filter_map(|r| r.ok()) { + joined_servers.insert(joined.server_name().to_owned()); + joinedcount += 1; + } + + for invited in self.room_members_invited(&room_id).filter_map(|r| r.ok()) { + joined_servers.insert(invited.server_name().to_owned()); + } + + self.roomid_joinedcount + .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; + + for old_joined_server in self.room_servers(room_id).filter_map(|r| r.ok()) { + if !joined_servers.remove(&old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xff); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xff); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; + } + } + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xff); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xff); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } + + Ok(()) } pub async fn leave_room( diff --git a/src/server_server.rs b/src/server_server.rs index de3eef57..49f225f1 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -272,14 +272,20 @@ where if status == 200 { let response = T::IncomingResponse::try_from_http_response(http_response); response.map_err(|e| { - warn!("Invalid 200 response from {}: {}", &destination, e); + warn!( + "Invalid 200 response from {} on: {} {}", + &destination, url, e + ); Error::BadServerResponse("Server returned bad 200 response.") }) } else { Err(Error::FederationError( destination.to_owned(), RumaError::try_from_http_response(http_response).map_err(|e| { - warn!("Server returned bad error response: {}", e); + warn!( + "Invalid {} response from {} on: {} {}", + status, &destination, url, e + ); Error::BadServerResponse("Server returned bad error response.") })?, )) @@ -884,15 +890,10 @@ pub async fn handle_incoming_pdu<'a>( } // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - let mut visited = HashSet::new(); + let mut graph = HashMap::new(); + let mut eventid_info = HashMap::new(); let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); - let mut todo_timeline_stack = Vec::new(); while let Some(prev_event_id) = todo_outlier_stack.pop() { - if visited.contains(&prev_event_id) { - continue; - } - visited.insert(prev_event_id.clone()); - if let Some((pdu, json_opt)) = fetch_and_handle_outliers( db, origin, @@ -914,24 +915,58 @@ pub async fn handle_incoming_pdu<'a>( .expect("Room exists") .origin_server_ts { - todo_outlier_stack.extend(pdu.prev_events.iter().cloned()); - todo_timeline_stack.push((pdu, json)); + for prev_prev in &pdu.prev_events { + if !graph.contains_key(prev_prev) { + todo_outlier_stack.push(dbg!(prev_prev.clone())); + } + } + + graph.insert( + prev_event_id.clone(), + pdu.prev_events.iter().cloned().collect(), + ); + eventid_info.insert(prev_event_id.clone(), (pdu, json)); + } else { + graph.insert(prev_event_id.clone(), HashSet::new()); + eventid_info.insert(prev_event_id.clone(), (pdu, json)); } + } else { + graph.insert(prev_event_id.clone(), HashSet::new()); } } } - while let Some(prev) = todo_timeline_stack.pop() { - upgrade_outlier_to_timeline_pdu( - prev.0, - prev.1, - &create_event, - origin, - db, - room_id, - pub_key_map, - ) - .await?; + let sorted = + state_res::StateResolution::lexicographical_topological_sort(dbg!(&graph), |event_id| { + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + println!("{}", event_id); + Ok(( + 0, + MilliSecondsSinceUnixEpoch( + eventid_info + .get(event_id) + .map_or_else(|| uint!(0), |info| info.0.origin_server_ts.clone()), + ), + ruma::event_id!("$notimportant"), + )) + }) + .map_err(|_| "Error sorting prev events".to_owned())?; + + for prev_id in dbg!(sorted) { + if let Some((pdu, json)) = eventid_info.remove(&prev_id) { + upgrade_outlier_to_timeline_pdu( + pdu, + json, + &create_event, + origin, + db, + room_id, + pub_key_map, + ) + .await?; + } } upgrade_outlier_to_timeline_pdu( @@ -1872,8 +1907,7 @@ fn get_auth_chain( full_auth_chain.extend(cached.iter().cloned()); } else { drop(cache); - let mut auth_chain = HashSet::new(); - get_auth_chain_recursive(&event_id, &mut auth_chain, db)?; + let auth_chain = get_auth_chain_inner(&event_id, db)?; cache = db.rooms.auth_chain_cache(); cache.insert(sevent_id, auth_chain.clone()); full_auth_chain.extend(auth_chain); @@ -1887,33 +1921,34 @@ fn get_auth_chain( .filter_map(move |sid| db.rooms.get_eventid_from_short(sid).ok())) } -fn get_auth_chain_recursive( - event_id: &EventId, - found: &mut HashSet, - db: &Database, -) -> Result<()> { - let r = db.rooms.get_pdu(&event_id); - match r { - Ok(Some(pdu)) => { - for auth_event in &pdu.auth_events { - let sauthevent = db - .rooms - .get_or_create_shorteventid(auth_event, &db.globals)?; - if !found.contains(&sauthevent) { - found.insert(sauthevent); - get_auth_chain_recursive(&auth_event, found, db)?; +fn get_auth_chain_inner(event_id: &EventId, db: &Database) -> Result> { + let mut todo = vec![event_id.clone()]; + let mut found = HashSet::new(); + + while let Some(event_id) = todo.pop() { + match db.rooms.get_pdu(&event_id) { + Ok(Some(pdu)) => { + for auth_event in &pdu.auth_events { + let sauthevent = db + .rooms + .get_or_create_shorteventid(auth_event, &db.globals)?; + + if !found.contains(&sauthevent) { + found.insert(sauthevent); + todo.push(auth_event.clone()); + } } } - } - Ok(None) => { - warn!("Could not find pdu mentioned in auth events."); - } - Err(e) => { - warn!("Could not load event in auth chain: {}", e); + Ok(None) => { + warn!("Could not find pdu mentioned in auth events: {}", event_id); + } + Err(e) => { + warn!("Could not load event in auth chain: {} {}", event_id, e); + } } } - Ok(()) + Ok(found) } #[cfg_attr( From 46d8a46e1f79420ab14fd7c84c431671216039bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 19 Aug 2021 11:01:18 +0200 Subject: [PATCH 22/25] improvement: faster incoming transaction handling --- Cargo.lock | 114 ++++++++----- Cargo.toml | 8 +- src/client_server/account.rs | 4 + src/client_server/membership.rs | 5 + src/client_server/room.rs | 2 + src/client_server/session.rs | 7 +- src/database.rs | 2 +- src/database/rooms.rs | 39 +++++ src/database/uiaa.rs | 166 ++++++++----------- src/main.rs | 11 +- src/pdu.rs | 4 +- src/server_server.rs | 283 ++++++++++++++++++-------------- 12 files changed, 365 insertions(+), 280 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6ed4ee7b..83e21a3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,7 +248,7 @@ dependencies = [ "jsonwebtoken", "lru-cache", "num_cpus", - "opentelemetry", + "opentelemetry 0.16.0", "opentelemetry-jaeger", "parking_lot", "pretty_env_logger", @@ -1466,16 +1466,46 @@ dependencies = [ ] [[package]] -name = "opentelemetry-jaeger" -version = "0.14.0" +name = "opentelemetry" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09a9fc8192722e7daa0c56e59e2336b797122fb8598383dcb11c8852733b435c" +checksum = "e1cf9b1c4e9a6c4de793c632496fa490bdc0e1eea73f0c91394f7b6990935d22" +dependencies = [ + "async-trait", + "crossbeam-channel", + "futures", + "js-sys", + "lazy_static", + "percent-encoding", + "pin-project", + "rand 0.8.4", + "thiserror", + "tokio", + "tokio-stream", +] + +[[package]] +name = "opentelemetry-jaeger" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db22f492873ea037bc267b35a0e8e4fb846340058cb7c864efe3d0bf23684593" dependencies = [ "async-trait", "lazy_static", - "opentelemetry", + "opentelemetry 0.16.0", + "opentelemetry-semantic-conventions", "thiserror", "thrift", + "tokio", +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffeac823339e8b0f27b961f4385057bf9f97f2863bc745bd015fd6091f2270e9" +dependencies = [ + "opentelemetry 0.16.0", ] [[package]] @@ -2014,8 +2044,8 @@ dependencies = [ [[package]] name = "ruma" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "assign", "js_int", @@ -2035,8 +2065,8 @@ dependencies = [ [[package]] name = "ruma-api" -version = "0.17.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.18.3" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "bytes", "http", @@ -2051,8 +2081,8 @@ dependencies = [ [[package]] name = "ruma-api-macros" -version = "0.17.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.18.3" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2062,8 +2092,8 @@ dependencies = [ [[package]] name = "ruma-appservice-api" -version = "0.3.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.4.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "ruma-api", "ruma-common", @@ -2076,8 +2106,8 @@ dependencies = [ [[package]] name = "ruma-client-api" -version = "0.11.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.12.2" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "assign", "bytes", @@ -2096,8 +2126,8 @@ dependencies = [ [[package]] name = "ruma-common" -version = "0.5.4" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.6.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "indexmap", "js_int", @@ -2111,8 +2141,8 @@ dependencies = [ [[package]] name = "ruma-events" -version = "0.23.2" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.24.4" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "indoc", "js_int", @@ -2127,8 +2157,8 @@ dependencies = [ [[package]] name = "ruma-events-macros" -version = "0.23.2" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.24.4" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2138,8 +2168,8 @@ dependencies = [ [[package]] name = "ruma-federation-api" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "js_int", "ruma-api", @@ -2153,8 +2183,8 @@ dependencies = [ [[package]] name = "ruma-identifiers" -version = "0.19.4" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.20.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "paste", "rand 0.8.4", @@ -2167,8 +2197,8 @@ dependencies = [ [[package]] name = "ruma-identifiers-macros" -version = "0.19.4" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.20.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "quote", "ruma-identifiers-validation", @@ -2177,13 +2207,13 @@ dependencies = [ [[package]] name = "ruma-identifiers-validation" -version = "0.4.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.5.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" [[package]] name = "ruma-identity-service-api" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "js_int", "ruma-api", @@ -2195,8 +2225,8 @@ dependencies = [ [[package]] name = "ruma-push-gateway-api" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "js_int", "ruma-api", @@ -2210,8 +2240,8 @@ dependencies = [ [[package]] name = "ruma-serde" -version = "0.4.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.5.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "bytes", "form_urlencoded", @@ -2224,8 +2254,8 @@ dependencies = [ [[package]] name = "ruma-serde-macros" -version = "0.4.1" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.5.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -2235,8 +2265,8 @@ dependencies = [ [[package]] name = "ruma-signatures" -version = "0.8.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.9.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "base64 0.13.0", "ed25519-dalek", @@ -2252,8 +2282,8 @@ dependencies = [ [[package]] name = "ruma-state-res" -version = "0.2.0" -source = "git+https://github.com/timokoesters/ruma?rev=a2d93500e1dbc87e7032a3c74f3b2479a7f84e93#a2d93500e1dbc87e7032a3c74f3b2479a7f84e93" +version = "0.3.0" +source = "git+https://github.com/ruma/ruma?rev=f5ab038e22421ed338396ece977b6b2844772ced#f5ab038e22421ed338396ece977b6b2844772ced" dependencies = [ "itertools 0.10.1", "js_int", @@ -3022,7 +3052,7 @@ version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c47440f2979c4cd3138922840eec122e3c0ba2148bc290f756bd7fd60fc97fff" dependencies = [ - "opentelemetry", + "opentelemetry 0.15.0", "tracing", "tracing-core", "tracing-log", diff --git a/Cargo.toml b/Cargo.toml index 3d18bfb7..69b54c86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,8 +18,8 @@ edition = "2018" rocket = { version = "0.5.0-rc.1", features = ["tls"] } # Used to handle requests # Used for matrix spec type definitions and helpers -#ruma = { git = "https://github.com/ruma/ruma", rev = "eb19b0e08a901b87d11b3be0890ec788cc760492", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } -ruma = { git = "https://github.com/timokoesters/ruma", rev = "a2d93500e1dbc87e7032a3c74f3b2479a7f84e93", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } +ruma = { git = "https://github.com/ruma/ruma", rev = "f5ab038e22421ed338396ece977b6b2844772ced", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } +#ruma = { git = "https://github.com/timokoesters/ruma", rev = "995ccea20f5f6d4a8fb22041749ed4de22fa1b6a", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } #ruma = { path = "../ruma/crates/ruma", features = ["compat", "rand", "appservice-api-c", "client-api", "federation-api", "push-gateway-api-c", "state-res", "unstable-pre-spec", "unstable-exhaustive-types"] } # Used for long polling and federation sender, should be the same as rocket::tokio @@ -66,11 +66,11 @@ regex = "1.5.4" jsonwebtoken = "7.2.0" # Performance measurements tracing = { version = "0.1.26", features = ["release_max_level_warn"] } -opentelemetry = "0.15.0" tracing-subscriber = "0.2.19" tracing-opentelemetry = "0.14.0" tracing-flame = "0.1.0" -opentelemetry-jaeger = "0.14.0" +opentelemetry = { version = "0.16.0", features = ["rt-tokio"] } +opentelemetry-jaeger = { version = "0.15.0", features = ["rt-tokio"] } pretty_env_logger = "0.4.0" lru-cache = "0.1.2" rusqlite = { version = "0.25.3", optional = true, features = ["bundled"] } diff --git a/src/client_server/account.rs b/src/client_server/account.rs index b00882a6..e68c957f 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -292,6 +292,7 @@ pub async fn register_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -457,6 +458,7 @@ pub async fn register_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -478,6 +480,7 @@ pub async fn register_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -683,6 +686,7 @@ pub async fn deactivate_route( is_direct: None, third_party_invite: None, blurhash: None, + reason: None, }; let mutex_state = Arc::clone( diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index de6fa5a1..222d204f 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -262,6 +262,7 @@ pub async fn ban_user_route( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&body.user_id)?, + reason: None, }), |event| { let mut event = serde_json::from_value::>( @@ -563,6 +564,7 @@ async fn join_room_by_id_helper( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }) .expect("event is valid, we just created it"), ); @@ -695,6 +697,7 @@ async fn join_room_by_id_helper( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }; db.rooms.build_and_append_pdu( @@ -846,6 +849,7 @@ pub async fn invite_helper<'a>( membership: MembershipState::Invite, third_party_invite: None, blurhash: None, + reason: None, }) .expect("member event is valid value"); @@ -1040,6 +1044,7 @@ pub async fn invite_helper<'a>( is_direct: Some(is_direct), third_party_invite: None, blurhash: db.users.blurhash(&user_id)?, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, diff --git a/src/client_server/room.rs b/src/client_server/room.rs index c323be4a..25412788 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -107,6 +107,7 @@ pub async fn create_room_route( is_direct: Some(body.is_direct), third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, @@ -517,6 +518,7 @@ pub async fn upgrade_room_route( is_direct: None, third_party_invite: None, blurhash: db.users.blurhash(&sender_user)?, + reason: None, }) .expect("event is valid, we just created it"), unsigned: None, diff --git a/src/client_server/session.rs b/src/client_server/session.rs index d4d3c033..dada2d50 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -3,7 +3,10 @@ use crate::{database::DatabaseGuard, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, - r0::session::{get_login_types, login, logout, logout_all}, + r0::{ + session::{get_login_types, login, logout, logout_all}, + uiaa::IncomingUserIdentifier, + }, }, UserId, }; @@ -60,7 +63,7 @@ pub async fn login_route( identifier, password, } => { - let username = if let login::IncomingUserIdentifier::MatrixId(matrix_id) = identifier { + let username = if let IncomingUserIdentifier::MatrixId(matrix_id) = identifier { matrix_id } else { return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); diff --git a/src/database.rs b/src/database.rs index 5ad2add8..bfc33f26 100644 --- a/src/database.rs +++ b/src/database.rs @@ -280,7 +280,7 @@ impl Database { shorteventid_cache: Mutex::new(LruCache::new(1_000_000)), eventidshort_cache: Mutex::new(LruCache::new(1_000_000)), statekeyshort_cache: Mutex::new(LruCache::new(1_000_000)), - stateinfo_cache: Mutex::new(LruCache::new(1000)), + stateinfo_cache: Mutex::new(LruCache::new(50)), }, account_data: account_data::AccountData { roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 600566c1..adb748d0 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -110,6 +110,7 @@ pub struct Rooms { impl Rooms { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. + #[tracing::instrument(skip(self))] pub fn state_full_ids(&self, shortstatehash: u64) -> Result> { let full_state = self .load_shortstatehash_info(shortstatehash)? @@ -122,6 +123,7 @@ impl Rooms { .collect() } + #[tracing::instrument(skip(self))] pub fn state_full( &self, shortstatehash: u64, @@ -220,6 +222,7 @@ impl Rooms { } /// This fetches auth events from the current state. + #[tracing::instrument(skip(self))] pub fn get_auth_events( &self, room_id: &RoomId, @@ -261,6 +264,7 @@ impl Rooms { } /// Checks if a room exists. + #[tracing::instrument(skip(self))] pub fn exists(&self, room_id: &RoomId) -> Result { let prefix = match self.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), @@ -277,6 +281,7 @@ impl Rooms { } /// Checks if a room exists. + #[tracing::instrument(skip(self))] pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { let prefix = self .get_shortroomid(room_id)? @@ -300,6 +305,7 @@ impl Rooms { /// Force the creation of a new StateHash and insert it into the db. /// /// Whatever `state` is supplied to `force_state` becomes the new current room state snapshot. + #[tracing::instrument(skip(self, new_state, db))] pub fn force_state( &self, room_id: &RoomId, @@ -412,6 +418,7 @@ impl Rooms { } /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. + #[tracing::instrument(skip(self))] pub fn load_shortstatehash_info( &self, shortstatehash: u64, @@ -480,6 +487,7 @@ impl Rooms { } } + #[tracing::instrument(skip(self, globals))] pub fn compress_state_event( &self, shortstatekey: u64, @@ -495,6 +503,7 @@ impl Rooms { Ok(v.try_into().expect("we checked the size above")) } + #[tracing::instrument(skip(self, compressed_event))] pub fn parse_compressed_state_event( &self, compressed_event: CompressedStateEvent, @@ -518,6 +527,13 @@ impl Rooms { /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer + #[tracing::instrument(skip( + self, + statediffnew, + statediffremoved, + diff_to_sibling, + parent_states + ))] pub fn save_state_from_diff( &self, shortstatehash: u64, @@ -642,6 +658,7 @@ impl Rooms { } /// Returns (shortstatehash, already_existed) + #[tracing::instrument(skip(self, globals))] fn get_or_create_shortstatehash( &self, state_hash: &StateHashId, @@ -662,6 +679,7 @@ impl Rooms { }) } + #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shorteventid( &self, event_id: &EventId, @@ -692,6 +710,7 @@ impl Rooms { Ok(short) } + #[tracing::instrument(skip(self))] pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.roomid_shortroomid .get(&room_id.as_bytes())? @@ -702,6 +721,7 @@ impl Rooms { .transpose() } + #[tracing::instrument(skip(self))] pub fn get_shortstatekey( &self, event_type: &EventType, @@ -739,6 +759,7 @@ impl Rooms { Ok(short) } + #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shortroomid( &self, room_id: &RoomId, @@ -756,6 +777,7 @@ impl Rooms { }) } + #[tracing::instrument(skip(self, globals))] pub fn get_or_create_shortstatekey( &self, event_type: &EventType, @@ -794,6 +816,7 @@ impl Rooms { Ok(short) } + #[tracing::instrument(skip(self))] pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result { if let Some(id) = self .shorteventid_cache @@ -876,12 +899,14 @@ impl Rooms { } /// Returns the `count` of this pdu's id. + #[tracing::instrument(skip(self))] pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? .map_or(Ok(None), |pdu_id| self.pdu_count(&pdu_id).map(Some)) } + #[tracing::instrument(skip(self))] pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { let prefix = self .get_shortroomid(room_id)? @@ -902,6 +927,7 @@ impl Rooms { } /// Returns the json of a pdu. + #[tracing::instrument(skip(self))] pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -920,6 +946,7 @@ impl Rooms { } /// Returns the json of a pdu. + #[tracing::instrument(skip(self))] pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { self.eventid_outlierpdu .get(event_id.as_bytes())? @@ -930,6 +957,7 @@ impl Rooms { } /// Returns the json of a pdu. + #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu_json( &self, event_id: &EventId, @@ -951,6 +979,7 @@ impl Rooms { } /// Returns the pdu's id. + #[tracing::instrument(skip(self))] pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid .get(event_id.as_bytes())? @@ -960,6 +989,7 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + #[tracing::instrument(skip(self))] pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { self.eventid_pduid .get(event_id.as_bytes())? @@ -980,6 +1010,7 @@ impl Rooms { /// Returns the pdu. /// /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + #[tracing::instrument(skip(self))] pub fn get_pdu(&self, event_id: &EventId) -> Result>> { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(&event_id) { return Ok(Some(Arc::clone(p))); @@ -1019,6 +1050,7 @@ impl Rooms { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. + #[tracing::instrument(skip(self))] pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -1029,6 +1061,7 @@ impl Rooms { } /// Returns the pdu as a `BTreeMap`. + #[tracing::instrument(skip(self))] pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { Ok(Some( @@ -1039,6 +1072,7 @@ impl Rooms { } /// Removes a pdu and creates a new one with the same id. + #[tracing::instrument(skip(self))] fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> { if self.pduid_pdu.get(&pdu_id)?.is_some() { self.pduid_pdu.insert( @@ -2298,6 +2332,7 @@ impl Rooms { Ok(()) } + #[tracing::instrument(skip(self))] pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { let mut joinedcount = 0_u64; let mut joined_servers = HashSet::new(); @@ -2347,6 +2382,7 @@ impl Rooms { Ok(()) } + #[tracing::instrument(skip(self, db))] pub async fn leave_room( &self, user_id: &UserId, @@ -2419,6 +2455,7 @@ impl Rooms { Ok(()) } + #[tracing::instrument(skip(self, db))] async fn remote_leave_room( &self, user_id: &UserId, @@ -2650,6 +2687,7 @@ impl Rooms { }) } + #[tracing::instrument(skip(self))] pub fn search_pdus<'a>( &'a self, room_id: &RoomId, @@ -2809,6 +2847,7 @@ impl Rooms { }) } + #[tracing::instrument(skip(self))] pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { Ok(self .roomid_joinedcount diff --git a/src/database/uiaa.rs b/src/database/uiaa.rs index 1372fef5..8a3fe4f0 100644 --- a/src/database/uiaa.rs +++ b/src/database/uiaa.rs @@ -4,11 +4,14 @@ use crate::{client_server::SESSION_ID_LENGTH, utils, Error, Result}; use ruma::{ api::client::{ error::ErrorKind, - r0::uiaa::{IncomingAuthData, UiaaInfo}, + r0::uiaa::{ + IncomingAuthData, IncomingPassword, IncomingUserIdentifier::MatrixId, UiaaInfo, + }, }, signatures::CanonicalJsonValue, DeviceId, UserId, }; +use tracing::error; use super::abstraction::Tree; @@ -49,126 +52,91 @@ impl Uiaa { users: &super::users::Users, globals: &super::globals::Globals, ) -> Result<(bool, UiaaInfo)> { - if let IncomingAuthData::DirectRequest { - kind, - session, - auth_parameters, - } = &auth - { - let mut uiaainfo = session - .as_ref() - .map(|session| self.get_uiaa_session(&user_id, &device_id, session)) - .unwrap_or_else(|| Ok(uiaainfo.clone()))?; + let mut uiaainfo = auth + .session() + .map(|session| self.get_uiaa_session(&user_id, &device_id, session)) + .unwrap_or_else(|| Ok(uiaainfo.clone()))?; - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } + match auth { // Find out what the user completed - match &**kind { - "m.login.password" => { - let identifier = auth_parameters.get("identifier").ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "m.login.password needs identifier.", - ))?; - - let identifier_type = identifier.get("type").ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "Identifier needs a type.", - ))?; - - if identifier_type != "m.id.user" { + IncomingAuthData::Password(IncomingPassword { + identifier, + password, + .. + }) => { + let username = match identifier { + MatrixId(username) => username, + _ => { return Err(Error::BadRequest( ErrorKind::Unrecognized, "Identifier type not recognized.", - )); + )) } + }; - let username = identifier - .get("user") - .ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "Identifier needs user field.", - ))? - .as_str() - .ok_or(Error::BadRequest( - ErrorKind::BadJson, - "User is not a string.", - ))?; - - let user_id = UserId::parse_with_server_name(username, globals.server_name()) + let user_id = + UserId::parse_with_server_name(username.clone(), globals.server_name()) .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") - })?; + Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid.") + })?; - let password = auth_parameters - .get("password") - .ok_or(Error::BadRequest( - ErrorKind::MissingParam, - "Password is missing.", - ))? - .as_str() - .ok_or(Error::BadRequest( - ErrorKind::BadJson, - "Password is not a string.", - ))?; + // Check if password is correct + if let Some(hash) = users.password_hash(&user_id)? { + let hash_matches = + argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); - // Check if password is correct - if let Some(hash) = users.password_hash(&user_id)? { - let hash_matches = - argon2::verify_encoded(&hash, password.as_bytes()).unwrap_or(false); - - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } - - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push("m.login.password".to_owned()); - } - "m.login.dummy" => { - uiaainfo.completed.push("m.login.dummy".to_owned()); - } - k => panic!("type not supported: {}", k), - } - - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; + if !hash_matches { + uiaainfo.auth_error = Some(ruma::api::client::error::ErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password.".to_owned(), + }); + return Ok((false, uiaainfo)); } } - // We didn't break, so this flow succeeded! - completed = true; - } - if !completed { - self.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push("m.login.password".to_owned()); } + IncomingAuthData::Dummy(_) => { + uiaainfo.completed.push("m.login.dummy".to_owned()); + } + k => error!("type not supported: {:?}", k), + } - // UIAA was successful! Remove this session and return true + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } + + if !completed { self.update_uiaa_session( user_id, device_id, uiaainfo.session.as_ref().expect("session is always set"), - None, + Some(&uiaainfo), )?; - Ok((true, uiaainfo)) - } else { - panic!("FallbackAcknowledgement is not supported yet"); + return Ok((false, uiaainfo)); } + + // UIAA was successful! Remove this session and return true + self.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + )?; + Ok((true, uiaainfo)) } fn set_uiaa_request( diff --git a/src/main.rs b/src/main.rs index 5a6f8c76..72f753fe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use database::Config; pub use database::Database; pub use error::{Error, Result}; -use opentelemetry::trace::Tracer; +use opentelemetry::trace::{FutureExt, Tracer}; pub use pdu::PduEvent; pub use rocket::State; use ruma::api::client::error::ErrorKind; @@ -220,14 +220,17 @@ async fn main() { }; if config.allow_jaeger { + opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); let tracer = opentelemetry_jaeger::new_pipeline() - .with_service_name("conduit") - .install_simple() + .install_batch(opentelemetry::runtime::Tokio) .unwrap(); let span = tracer.start("conduit"); - start.await; + start.with_current_context().await; drop(span); + + println!("exporting"); + opentelemetry::global::shutdown_tracer_provider(); } else { std::env::set_var("RUST_LOG", &config.log); diff --git a/src/pdu.rs b/src/pdu.rs index 00eda5b1..1016fe66 100644 --- a/src/pdu.rs +++ b/src/pdu.rs @@ -12,7 +12,7 @@ use ruma::{ use serde::{Deserialize, Serialize}; use serde_json::json; use std::{cmp::Ordering, collections::BTreeMap, convert::TryFrom}; -use tracing::error; +use tracing::warn; #[derive(Clone, Deserialize, Serialize, Debug)] pub struct PduEvent { @@ -322,7 +322,7 @@ pub(crate) fn gen_event_id_canonical_json( pdu: &Raw, ) -> crate::Result<(EventId, CanonicalJsonObject)> { let value = serde_json::from_str(pdu.json().get()).map_err(|e| { - error!("{:?}: {:?}", pdu, e); + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); Error::BadServerResponse("Invalid PDU in server response") })?; diff --git a/src/server_server.rs b/src/server_server.rs index 49f225f1..56b28f27 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -111,7 +111,7 @@ impl FedDest { } } -#[tracing::instrument(skip(globals))] +#[tracing::instrument(skip(globals, request))] pub async fn send_request( globals: &crate::database::globals::Globals, destination: &ServerName, @@ -501,7 +501,7 @@ pub fn get_server_keys_route(db: DatabaseGuard) -> Json { ) .unwrap(); - Json(ruma::serde::to_canonical_json_string(&response).expect("JSON is canonical")) + Json(serde_json::to_string(&response).expect("JSON is canonical")) } #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] @@ -927,12 +927,17 @@ pub async fn handle_incoming_pdu<'a>( ); eventid_info.insert(prev_event_id.clone(), (pdu, json)); } else { + // Time based check failed graph.insert(prev_event_id.clone(), HashSet::new()); eventid_info.insert(prev_event_id.clone(), (pdu, json)); } } else { + // Get json failed graph.insert(prev_event_id.clone(), HashSet::new()); } + } else { + // Fetch and handle failed + graph.insert(prev_event_id.clone(), HashSet::new()); } } @@ -956,7 +961,9 @@ pub async fn handle_incoming_pdu<'a>( for prev_id in dbg!(sorted) { if let Some((pdu, json)) = eventid_info.remove(&prev_id) { - upgrade_outlier_to_timeline_pdu( + let start_time = Instant::now(); + let event_id = pdu.event_id.clone(); + if let Err(e) = upgrade_outlier_to_timeline_pdu( pdu, json, &create_event, @@ -965,7 +972,17 @@ pub async fn handle_incoming_pdu<'a>( room_id, pub_key_map, ) - .await?; + .await + { + warn!("Prev event {} failed: {}", event_id, e); + } + let elapsed = start_time.elapsed(); + warn!( + "Handling prev event {} took {}m{}s", + event_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); } } @@ -981,6 +998,7 @@ pub async fn handle_incoming_pdu<'a>( .await } +#[tracing::instrument(skip(origin, create_event, event_id, room_id, value, db, pub_key_map))] fn handle_outlier_pdu<'a>( origin: &'a ServerName, create_event: &'a PduEvent, @@ -1141,6 +1159,7 @@ fn handle_outlier_pdu<'a>( }) } +#[tracing::instrument(skip(incoming_pdu, val, create_event, origin, db, room_id, pub_key_map))] async fn upgrade_outlier_to_timeline_pdu( incoming_pdu: Arc, val: BTreeMap, @@ -1352,41 +1371,6 @@ async fn upgrade_outlier_to_timeline_pdu( // Only keep those extremities were not referenced yet extremities.retain(|id| !matches!(db.rooms.is_event_referenced(&room_id, id), Ok(true))); - let mut extremity_statehashes = Vec::new(); - - for id in &extremities { - match db - .rooms - .get_pdu(&id) - .map_err(|_| "Failed to ask db for pdu.".to_owned())? - { - Some(leaf_pdu) => { - extremity_statehashes.push(( - db.rooms - .pdu_shortstatehash(&leaf_pdu.event_id) - .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? - .ok_or_else(|| { - error!( - "Found extremity pdu with no statehash in db: {:?}", - leaf_pdu - ); - "Found pdu with no statehash in db.".to_owned() - })?, - Some(leaf_pdu), - )); - } - _ => { - error!("Missing state snapshot for {:?}", id); - return Err("Missing state snapshot.".to_owned()); - } - } - } - - // 12. Ensure that the state is derived from the previous current state (i.e. we calculated - // by doing state res where one of the inputs was a previously trusted set of state, - // don't just trust a set of state we got from a remote). - - // We do this by adding the current state to the list of fork states let current_statehash = db .rooms .current_shortstatehash(&room_id) @@ -1398,91 +1382,138 @@ async fn upgrade_outlier_to_timeline_pdu( .state_full(current_statehash) .map_err(|_| "Failed to load room state.")?; - extremity_statehashes.push((current_statehash.clone(), None)); + if incoming_pdu.state_key.is_some() { + let mut extremity_statehashes = Vec::new(); - let mut fork_states = Vec::new(); - for (statehash, leaf_pdu) in extremity_statehashes { - let mut leaf_state = db - .rooms - .state_full(statehash) - .map_err(|_| "Failed to ask db for room state.".to_owned())?; - - if let Some(leaf_pdu) = leaf_pdu { - if let Some(state_key) = &leaf_pdu.state_key { - // Now it's the state after - let key = (leaf_pdu.kind.clone(), state_key.clone()); - leaf_state.insert(key, leaf_pdu); + for id in &extremities { + match db + .rooms + .get_pdu(&id) + .map_err(|_| "Failed to ask db for pdu.".to_owned())? + { + Some(leaf_pdu) => { + extremity_statehashes.push(( + db.rooms + .pdu_shortstatehash(&leaf_pdu.event_id) + .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? + .ok_or_else(|| { + error!( + "Found extremity pdu with no statehash in db: {:?}", + leaf_pdu + ); + "Found pdu with no statehash in db.".to_owned() + })?, + Some(leaf_pdu), + )); + } + _ => { + error!("Missing state snapshot for {:?}", id); + return Err("Missing state snapshot.".to_owned()); + } } } - fork_states.push(leaf_state); - } + // 12. Ensure that the state is derived from the previous current state (i.e. we calculated + // by doing state res where one of the inputs was a previously trusted set of state, + // don't just trust a set of state we got from a remote). - // We also add state after incoming event to the fork states - extremities.insert(incoming_pdu.event_id.clone()); - let mut state_after = state_at_incoming_event.clone(); - if let Some(state_key) = &incoming_pdu.state_key { - state_after.insert( - (incoming_pdu.kind.clone(), state_key.clone()), - incoming_pdu.clone(), - ); - } - fork_states.push(state_after.clone()); + // We do this by adding the current state to the list of fork states - let mut update_state = false; - // 14. Use state resolution to find new room state - let new_room_state = if fork_states.is_empty() { - return Err("State is empty.".to_owned()); - } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { - // There was only one state, so it has to be the room's current state (because that is - // always included) - fork_states[0] - .iter() - .map(|(k, pdu)| (k.clone(), pdu.event_id.clone())) - .collect() - } else { - // We do need to force an update to this room's state - update_state = true; + extremity_statehashes.push((current_statehash.clone(), None)); - let fork_states = &fork_states - .into_iter() - .map(|map| { - map.into_iter() - .map(|(k, v)| (k, v.event_id.clone())) - .collect::>() - }) - .collect::>(); + let mut fork_states = Vec::new(); + for (statehash, leaf_pdu) in extremity_statehashes { + let mut leaf_state = db + .rooms + .state_full(statehash) + .map_err(|_| "Failed to ask db for room state.".to_owned())?; - let mut auth_chain_sets = Vec::new(); - for state in fork_states { - auth_chain_sets.push( - get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) - .map_err(|_| "Failed to load auth chain.".to_owned())? - .collect(), + if let Some(leaf_pdu) = leaf_pdu { + if let Some(state_key) = &leaf_pdu.state_key { + // Now it's the state after + let key = (leaf_pdu.kind.clone(), state_key.clone()); + leaf_state.insert(key, leaf_pdu); + } + } + + fork_states.push(leaf_state); + } + + // We also add state after incoming event to the fork states + let mut state_after = state_at_incoming_event.clone(); + if let Some(state_key) = &incoming_pdu.state_key { + state_after.insert( + (incoming_pdu.kind.clone(), state_key.clone()), + incoming_pdu.clone(), ); } + fork_states.push(state_after.clone()); - let state = match state_res::StateResolution::resolve( - &room_id, - room_version_id, - fork_states, - auth_chain_sets, - |id| { - let res = db.rooms.get_pdu(id); - if let Err(e) = &res { - error!("LOOK AT ME Failed to fetch event: {}", e); - } - res.ok().flatten() - }, - ) { - Ok(new_state) => new_state, - Err(_) => { - return Err("State resolution failed, either an event could not be found or deserialization".into()); + let mut update_state = false; + // 14. Use state resolution to find new room state + let new_room_state = if fork_states.is_empty() { + return Err("State is empty.".to_owned()); + } else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) { + // There was only one state, so it has to be the room's current state (because that is + // always included) + fork_states[0] + .iter() + .map(|(k, pdu)| (k.clone(), pdu.event_id.clone())) + .collect() + } else { + // We do need to force an update to this room's state + update_state = true; + + let fork_states = &fork_states + .into_iter() + .map(|map| { + map.into_iter() + .map(|(k, v)| (k, v.event_id.clone())) + .collect::>() + }) + .collect::>(); + + let mut auth_chain_sets = Vec::new(); + for state in fork_states { + auth_chain_sets.push( + get_auth_chain(state.iter().map(|(_, id)| id.clone()).collect(), db) + .map_err(|_| "Failed to load auth chain.".to_owned())? + .collect(), + ); } + + let state = match state_res::StateResolution::resolve( + &room_id, + room_version_id, + fork_states, + auth_chain_sets, + |id| { + let res = db.rooms.get_pdu(id); + if let Err(e) = &res { + error!("LOOK AT ME Failed to fetch event: {}", e); + } + res.ok().flatten() + }, + ) { + Ok(new_state) => new_state, + Err(_) => { + return Err("State resolution failed, either an event could not be found or deserialization".into()); + } + }; + + state }; - state - }; + // Set the new room state to the resolved state + if update_state { + db.rooms + .force_state(&room_id, new_room_state, &db) + .map_err(|_| "Failed to set new room state.".to_owned())?; + } + debug!("Updated resolved state"); + } + + extremities.insert(incoming_pdu.event_id.clone()); debug!("starting soft fail auth check"); // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it @@ -1516,14 +1547,6 @@ async fn upgrade_outlier_to_timeline_pdu( warn!("Event was soft failed: {:?}", incoming_pdu); } - // Set the new room state to the resolved state - if update_state { - db.rooms - .force_state(&room_id, new_room_state, &db) - .map_err(|_| "Failed to set new room state.".to_owned())?; - } - debug!("Updated resolved state"); - if soft_fail { // Soft fail, we leave the event as an outlier but don't add it to the timeline return Err("Event has been soft failed".into()); @@ -1543,7 +1566,7 @@ async fn upgrade_outlier_to_timeline_pdu( /// b. Look at outlier pdu tree /// c. Ask origin server over federation /// d. TODO: Ask other servers over federation? -//#[tracing::instrument(skip(db, key_map, auth_cache))] +#[tracing::instrument(skip(db, origin, events, create_event, room_id, pub_key_map))] pub(crate) fn fetch_and_handle_outliers<'a>( db: &'a Database, origin: &'a ServerName, @@ -1562,15 +1585,16 @@ pub(crate) fn fetch_and_handle_outliers<'a>( let mut pdus = vec![]; for id in events { + info!("loading {}", id); if let Some((time, tries)) = db.globals.bad_event_ratelimiter.read().unwrap().get(&id) { // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { min_elapsed_duration = Duration::from_secs(60 * 60 * 24); } if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", id); + info!("Backing off from {}", id); continue; } } @@ -1586,7 +1610,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>( } Ok(None) => { // c. Ask origin server over federation - debug!("Fetching {} over federation.", id); + info!("Fetching {} over federation.", id); match db .sending .send_federation_request( @@ -1597,11 +1621,14 @@ pub(crate) fn fetch_and_handle_outliers<'a>( .await { Ok(res) => { - debug!("Got {} over federation", id); + info!("Got {} over federation", id); let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&res.pdu) { Ok(t) => t, - Err(_) => continue, + Err(_) => { + back_off(id.clone()); + continue; + } }; // This will also fetch the auth chain @@ -1632,7 +1659,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>( } } Err(e) => { - debug!("Error loading {}: {}", id, e); + warn!("Error loading {}: {}", id, e); continue; } }; @@ -1644,7 +1671,7 @@ pub(crate) fn fetch_and_handle_outliers<'a>( /// Search the DB for the signing keys of the given server, if we don't have them /// fetch them from the server and save to our DB. -#[tracing::instrument(skip(db))] +#[tracing::instrument(skip(db, origin, signature_ids))] pub(crate) async fn fetch_signing_keys( db: &Database, origin: &ServerName, @@ -1885,6 +1912,7 @@ fn append_incoming_pdu( Ok(pdu_id) } +#[tracing::instrument(skip(starting_events, db))] fn get_auth_chain( starting_events: Vec, db: &Database, @@ -1921,6 +1949,7 @@ fn get_auth_chain( .filter_map(move |sid| db.rooms.get_eventid_from_short(sid).ok())) } +#[tracing::instrument(skip(event_id, db))] fn get_auth_chain_inner(event_id: &EventId, db: &Database) -> Result> { let mut todo = vec![event_id.clone()]; let mut found = HashSet::new(); @@ -2204,6 +2233,7 @@ pub fn create_join_event_template_route( is_direct: None, membership: MembershipState::Join, third_party_invite: None, + reason: None, }) .expect("member event is valid value"); @@ -2680,6 +2710,7 @@ pub async fn claim_keys_route( .into()) } +#[tracing::instrument(skip(event, pub_key_map, db))] pub async fn fetch_required_signing_keys( event: &BTreeMap, pub_key_map: &RwLock>>, From b09499c2df67c4fc101fe8cdfb0e2baec91d5b2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Thu, 19 Aug 2021 14:05:23 +0200 Subject: [PATCH 23/25] fix: don't save empty tokens --- src/database/rooms.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/database/rooms.rs b/src/database/rooms.rs index adb748d0..d3600f18 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -1378,6 +1378,7 @@ impl Rooms { if let Some(body) = pdu.content.get("body").and_then(|b| b.as_str()) { let mut batch = body .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) .filter(|word| word.len() <= 50) .map(str::to_lowercase) .map(|word| { @@ -2702,6 +2703,7 @@ impl Rooms { let words = search_string .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) .map(str::to_lowercase) .collect::>(); From 4956fb9fbac289df09c92197ff5119913c896788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 21 Aug 2021 14:22:21 +0200 Subject: [PATCH 24/25] improvement: limit prev event fetching --- src/database.rs | 25 ++++++++++++++++++++----- src/server_server.rs | 13 ++++++++++++- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/database.rs b/src/database.rs index bfc33f26..23d3bdf9 100644 --- a/src/database.rs +++ b/src/database.rs @@ -636,7 +636,7 @@ impl Database { if db.globals.database_version()? < 9 { // Update tokenids db layout - let mut batch = db.rooms.tokenids.iter().filter_map(|(key, _)| { + let batch = db.rooms.tokenids.iter().filter_map(|(key, _)| { if !key.starts_with(b"!") { return None; } @@ -659,14 +659,29 @@ impl Database { println!("old {:?}", key); println!("new {:?}", new_key); Some((new_key, Vec::new())) - }); + }).collect::>(); - db.rooms.tokenids.insert_batch(&mut batch)?; + let mut iter = batch.into_iter().peekable(); - for (key, _) in db.rooms.tokenids.iter() { + while iter.peek().is_some() { + db.rooms.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; + println!("smaller batch done"); + } + + println!("Deleting starts"); + + let batch2 = db.rooms.tokenids.iter().filter_map(|(key, _)| { if key.starts_with(b"!") { - db.rooms.tokenids.remove(&key)?; + println!("del {:?}", key); + Some(key) + } else { + None } + }).collect::>(); + + for key in batch2 { + println!("del"); + db.rooms.tokenids.remove(&key)?; } db.globals.bump_database_version(9)?; diff --git a/src/server_server.rs b/src/server_server.rs index 56b28f27..5b09872d 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -254,7 +254,7 @@ where }); // TODO: handle timeout if status != 200 { - info!( + warn!( "{} {}: {}", url, status, @@ -893,6 +893,9 @@ pub async fn handle_incoming_pdu<'a>( let mut graph = HashMap::new(); let mut eventid_info = HashMap::new(); let mut todo_outlier_stack = incoming_pdu.prev_events.clone(); + + let mut amount = 0; + while let Some(prev_event_id) = todo_outlier_stack.pop() { if let Some((pdu, json_opt)) = fetch_and_handle_outliers( db, @@ -905,6 +908,13 @@ pub async fn handle_incoming_pdu<'a>( .await .pop() { + if amount > 100 { + // Max limit reached + warn!("Max prev event limit reached!"); + graph.insert(prev_event_id.clone(), HashSet::new()); + continue + } + if let Some(json) = json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) { @@ -915,6 +925,7 @@ pub async fn handle_incoming_pdu<'a>( .expect("Room exists") .origin_server_ts { + amount += 1; for prev_prev in &pdu.prev_events { if !graph.contains_key(prev_prev) { todo_outlier_stack.push(dbg!(prev_prev.clone())); From 3b78e43a18b17729973d0620c344853eb2c2d610 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sat, 21 Aug 2021 14:24:10 +0200 Subject: [PATCH 25/25] fmt --- src/database.rs | 76 +++++++++++++++++++++++++------------------- src/server_server.rs | 2 +- 2 files changed, 45 insertions(+), 33 deletions(-) diff --git a/src/database.rs b/src/database.rs index 23d3bdf9..3d1324e6 100644 --- a/src/database.rs +++ b/src/database.rs @@ -636,48 +636,60 @@ impl Database { if db.globals.database_version()? < 9 { // Update tokenids db layout - let batch = db.rooms.tokenids.iter().filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); + let batch = db + .rooms + .tokenids + .iter() + .filter_map(|(key, _)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(4, |&b| b == 0xff); + let room_id = parts.next().unwrap(); + let word = parts.next().unwrap(); + let _pdu_id_room = parts.next().unwrap(); + let pdu_id_count = parts.next().unwrap(); - let short_room_id = db - .rooms - .roomid_shortroomid - .get(&room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id; - new_key.extend_from_slice(word); - new_key.push(0xff); - new_key.extend_from_slice(pdu_id_count); - println!("old {:?}", key); - println!("new {:?}", new_key); - Some((new_key, Vec::new())) - }).collect::>(); + let short_room_id = db + .rooms + .roomid_shortroomid + .get(&room_id) + .unwrap() + .expect("shortroomid should exist"); + let mut new_key = short_room_id; + new_key.extend_from_slice(word); + new_key.push(0xff); + new_key.extend_from_slice(pdu_id_count); + println!("old {:?}", key); + println!("new {:?}", new_key); + Some((new_key, Vec::new())) + }) + .collect::>(); let mut iter = batch.into_iter().peekable(); while iter.peek().is_some() { - db.rooms.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; + db.rooms + .tokenids + .insert_batch(&mut iter.by_ref().take(1000))?; println!("smaller batch done"); } println!("Deleting starts"); - let batch2 = db.rooms.tokenids.iter().filter_map(|(key, _)| { - if key.starts_with(b"!") { - println!("del {:?}", key); - Some(key) - } else { - None - } - }).collect::>(); + let batch2 = db + .rooms + .tokenids + .iter() + .filter_map(|(key, _)| { + if key.starts_with(b"!") { + println!("del {:?}", key); + Some(key) + } else { + None + } + }) + .collect::>(); for key in batch2 { println!("del"); diff --git a/src/server_server.rs b/src/server_server.rs index 5b09872d..5299e1f0 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -912,7 +912,7 @@ pub async fn handle_incoming_pdu<'a>( // Max limit reached warn!("Max prev event limit reached!"); graph.insert(prev_event_id.clone(), HashSet::new()); - continue + continue; } if let Some(json) =