diff --git a/Cargo.lock b/Cargo.lock index 426661bf..e99928e0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,17 +80,6 @@ version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f093eed78becd229346bf859eec0aa4dd7ddde0757287b2b4107a1f09c80002" -[[package]] -name = "async-recursion" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd55a5ba1179988837d24ab4c7cc8ed6efdeff578ede0416b4225a5fca35bd0" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.48", -] - [[package]] name = "async-trait" version = "0.1.77" @@ -385,7 +374,6 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" name = "conduit" version = "0.7.0-alpha" dependencies = [ - "async-recursion", "async-trait", "axum", "axum-server", diff --git a/Cargo.toml b/Cargo.toml index 93ff2f3b..e8c1c8bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,7 +115,6 @@ lazy_static = "1.4.0" async-trait = "0.1.68" sd-notify = { version = "0.4.1", optional = true } -async-recursion = "1.0.5" [target.'cfg(unix)'.dependencies] nix = { version = "0.26.2", features = ["resource"] } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 7cc662ee..1547d406 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -8,7 +8,6 @@ use std::{ time::{Duration, Instant, SystemTime}, }; -use async_recursion::async_recursion; use futures_util::{stream::FuturesUnordered, Future, StreamExt}; use ruma::{ api::{ @@ -1044,8 +1043,7 @@ impl Service { /// d. TODO: Ask other servers over federation? #[allow(clippy::type_complexity)] #[tracing::instrument(skip_all)] - #[async_recursion] - pub(crate) async fn fetch_and_handle_outliers<'a>( + pub(crate) fn fetch_and_handle_outliers<'a>( &'a self, origin: &'a ServerName, events: &'a [Arc], @@ -1053,175 +1051,180 @@ impl Service { room_id: &'a RoomId, room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, - ) -> Vec<(Arc, Option>)> { - let back_off = |id| async move { - match services() - .globals - .bad_event_ratelimiter - .write() - .await - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - } - }; - - let mut pdus = vec![]; - for id in events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - continue; - } - - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::new(); - let mut events_all = HashSet::new(); - let mut i = 0; - while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .await - .get(&*next_id) - { - // Exponential backoff - let mut min_elapsed_duration = - Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", next_id); - continue; - } - } - - if events_all.contains(&next_id) { - continue; - } - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - - if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { - trace!("Found {} in db", next_id); - continue; - } - - info!("Fetching {} over federation.", next_id); + ) -> AsyncRecursiveType<'a, Vec<(Arc, Option>)>> + { + Box::pin(async move { + let back_off = |id| async move { match services() - .sending - .send_federation_request( - origin, - get_event::v1::Request { - event_id: (*next_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - info!("Got {} over federation", next_id); - let (calculated_event_id, value) = - match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) { - Ok(t) => t, - Err(_) => { - back_off((*next_id).to_owned()).await; - continue; - } - }; - - if calculated_event_id != *next_id { - warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu); - } - - if let Some(auth_events) = - value.get("auth_events").and_then(|c| c.as_array()) - { - for auth_event in auth_events { - if let Ok(auth_event) = - serde_json::from_value(auth_event.clone().into()) - { - let a: Arc = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } - } - } else { - warn!("Auth event list invalid"); - } - - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - } - Err(_) => { - warn!("Failed to fetch event: {}", next_id); - back_off((*next_id).to_owned()).await; - } - } - } - - for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = services() .globals .bad_event_ratelimiter - .read() + .write() .await - .get(&**next_id) + .entry(id) { - // Exponential backoff - let mut min_elapsed_duration = - Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + } + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1) + } + } + }; + + let mut pdus = vec![]; + for id in events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { + trace!("Found {} in db", id); + pdus.push((local_pdu, None)); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::new(); + let mut events_all = HashSet::new(); + let mut i = 0; + while let Some(next_id) = todo_auth_events.pop() { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&*next_id) + { + // Exponential backoff + let mut min_elapsed_duration = + Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", next_id); + continue; + } } - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", next_id); + if events_all.contains(&next_id) { continue; } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + + if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { + trace!("Found {} in db", next_id); + continue; + } + + info!("Fetching {} over federation.", next_id); + match services() + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + info!("Got {} over federation", next_id); + let (calculated_event_id, value) = + match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) { + Ok(t) => t, + Err(_) => { + back_off((*next_id).to_owned()).await; + continue; + } + }; + + if calculated_event_id != *next_id { + warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", + next_id, calculated_event_id, &res.pdu); + } + + if let Some(auth_events) = + value.get("auth_events").and_then(|c| c.as_array()) + { + for auth_event in auth_events { + if let Ok(auth_event) = + serde_json::from_value(auth_event.clone().into()) + { + let a: Arc = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); + } + } + } else { + warn!("Auth event list invalid"); + } + + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + } + Err(_) => { + warn!("Failed to fetch event: {}", next_id); + back_off((*next_id).to_owned()).await; + } + } } - match self - .handle_outlier_pdu( - origin, - create_event, - next_id, - room_id, - value.clone(), - true, - pub_key_map, - ) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); + for (next_id, value) in events_in_reverse_order.iter().rev() { + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(&**next_id) + { + // Exponential backoff + let mut min_elapsed_duration = + Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", next_id); + continue; } } - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()).await; + + match self + .handle_outlier_pdu( + origin, + create_event, + next_id, + room_id, + value.clone(), + true, + pub_key_map, + ) + .await + { + Ok((pdu, json)) => { + if next_id == id { + pdus.push((pdu, Some(json))); + } + } + Err(e) => { + warn!("Authentication of event {} failed: {:?}", next_id, e); + back_off((**next_id).to_owned()).await; + } } } } - } - pdus + pdus + }) } async fn fetch_unknown_prev_events(