From e0d0fb4703de4fd245ad193bc8b109ccad1939da Mon Sep 17 00:00:00 2001 From: timokoesters Date: Wed, 29 Jul 2020 20:44:06 +0200 Subject: [PATCH] fix: only send device_one_time_keys_count when there are updates --- src/client_server.rs | 10 +++++++--- src/database.rs | 1 + src/database/users.rs | 25 +++++++++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/client_server.rs b/src/client_server.rs index 0536c8ea..4a2f33b3 100644 --- a/src/client_server.rs +++ b/src/client_server.rs @@ -890,7 +890,7 @@ pub fn upload_keys_route( if let Some(one_time_keys) = &body.one_time_keys { for (key_key, key_value) in one_time_keys { db.users - .add_one_time_key(sender_id, device_id, key_key, key_value)?; + .add_one_time_key(sender_id, device_id, key_key, key_value, &db.globals)?; } } @@ -1002,7 +1002,7 @@ pub fn claim_keys_route( for (device_id, key_algorithm) in map { if let Some(one_time_keys) = db.users - .take_one_time_key(user_id, device_id, key_algorithm)? + .take_one_time_key(user_id, device_id, key_algorithm, &db.globals)? { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); @@ -2912,7 +2912,11 @@ pub async fn sync_events_route( changed: device_list_updates.into_iter().collect(), left: Vec::new(), // TODO }, - device_one_time_keys_count: Default::default(), // TODO + device_one_time_keys_count: if db.users.last_one_time_keys_update(sender_id)? > since { + db.users.count_one_time_keys(sender_id, device_id)? + } else { + BTreeMap::new() + }, to_device: sync_events::ToDevice { events: db.users.get_to_device_events(sender_id, device_id)?, }, diff --git a/src/database.rs b/src/database.rs index a8376388..5a1ed0f6 100644 --- a/src/database.rs +++ b/src/database.rs @@ -75,6 +75,7 @@ impl Database { userdeviceid_metadata: db.open_tree("userdeviceid_metadata")?, token_userdeviceid: db.open_tree("token_userdeviceid")?, onetimekeyid_onetimekeys: db.open_tree("onetimekeyid_onetimekeys")?, + userid_lastonetimekeyupdate: db.open_tree("userid_lastonetimekeyupdate")?, keychangeid_userid: db.open_tree("devicekeychangeid_userid")?, keyid_key: db.open_tree("keyid_key")?, userid_masterkeyid: db.open_tree("userid_masterkeyid")?, diff --git a/src/database/users.rs b/src/database/users.rs index 7fb97679..c7927677 100644 --- a/src/database/users.rs +++ b/src/database/users.rs @@ -22,6 +22,7 @@ pub struct Users { pub(super) token_userdeviceid: sled::Tree, pub(super) onetimekeyid_onetimekeys: sled::Tree, // OneTimeKeyId = UserId + AlgorithmAndDeviceId + pub(super) userid_lastonetimekeyupdate: sled::Tree, // LastOneTimeKeyUpdate = Count pub(super) keychangeid_userid: sled::Tree, // KeyChangeId = RoomId + Count pub(super) keyid_key: sled::Tree, // KeyId = UserId + KeyId (depends on key type) pub(super) userid_masterkeyid: sled::Tree, @@ -270,6 +271,7 @@ impl Users { device_id: &DeviceId, one_time_key_key: &AlgorithmAndDeviceId, one_time_key_value: &OneTimeKey, + globals: &super::globals::Globals, ) -> Result<()> { let mut key = user_id.to_string().as_bytes().to_vec(); key.push(0xff); @@ -294,14 +296,32 @@ impl Users { .expect("OneTimeKey::to_string always works"), )?; + self.userid_lastonetimekeyupdate.insert( + &user_id.to_string().as_bytes(), + &globals.next_count()?.to_be_bytes(), + )?; + Ok(()) } + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + self + .userid_lastonetimekeyupdate + .get(&user_id.to_string().as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") + }) + }) + .unwrap_or(Ok(0)) + } + pub fn take_one_time_key( &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &KeyAlgorithm, + globals: &super::globals::Globals, ) -> Result> { let mut prefix = user_id.to_string().as_bytes().to_vec(); prefix.push(0xff); @@ -311,6 +331,11 @@ impl Users { prefix.extend_from_slice(key_algorithm.to_string().as_bytes()); prefix.push(b':'); + self.userid_lastonetimekeyupdate.insert( + &user_id.to_string().as_bytes(), + &globals.next_count()?.to_be_bytes(), + )?; + self.onetimekeyid_onetimekeys .scan_prefix(&prefix) .next()