diff --git a/crates/apub/src/http/inbox.rs b/crates/apub/src/http/inbox.rs index b8ef837df..71215c032 100644 --- a/crates/apub/src/http/inbox.rs +++ b/crates/apub/src/http/inbox.rs @@ -16,7 +16,7 @@ use once_cell::sync::Lazy; use std::{ cmp::Ordering, collections::BinaryHeap, - sync::{Arc, Mutex}, + sync::{Arc, RwLock}, thread::available_parallelism, time::Duration, }; @@ -38,26 +38,26 @@ pub async fn shared_inbox( request.method().clone(), request.uri().clone(), ); - ACTIVITY_QUEUE.lock().unwrap().push(InboxActivity { + ACTIVITY_QUEUE.write().unwrap().push(InboxActivity { request_parts, bytes, published, }); + Ok(HttpResponse::Ok().finish()) } None => { // no timestamp included, process immediately receive_activity::( request, bytes, &data, ) - .await?; + .await } - }; - Ok(HttpResponse::Ok().finish()) + } } /// Queue of incoming activities, ordered by oldest published first -static ACTIVITY_QUEUE: Lazy>>> = - Lazy::new(|| Arc::new(Mutex::new(BinaryHeap::new()))); +static ACTIVITY_QUEUE: Lazy>>> = + Lazy::new(|| Arc::new(RwLock::new(BinaryHeap::new()))); /// Minimum age of an activity before it gets processed. This ensures that an activity which was /// delayed still gets processed in correct order. @@ -100,11 +100,11 @@ pub fn handle_received_activities( } fn peek_queue_timestamp() -> Option> { - ACTIVITY_QUEUE.lock().unwrap().peek().map(|i| i.published) + ACTIVITY_QUEUE.read().unwrap().peek().map(|i| i.published) } fn pop_queue<'a>() -> Option { - ACTIVITY_QUEUE.lock().unwrap().pop() + ACTIVITY_QUEUE.write().unwrap().pop() } #[derive(Clone, Debug)] @@ -156,7 +156,7 @@ mod tests { bytes: Default::default(), published: Local::now().into(), }; - let mut lock = ACTIVITY_QUEUE.lock().unwrap(); + let mut lock = ACTIVITY_QUEUE.write().unwrap(); // insert in wrong order lock.push(activity3.clone()); diff --git a/crates/db_schema/src/source/activity.rs b/crates/db_schema/src/source/activity.rs index 6eb17f606..8438e0124 100644 --- a/crates/db_schema/src/source/activity.rs +++ b/crates/db_schema/src/source/activity.rs @@ -51,7 +51,7 @@ impl ActivitySendTargets { } } -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] #[cfg_attr(feature = "full", derive(Queryable, Selectable, Identifiable))] #[cfg_attr(feature = "full", diesel(check_for_backend(diesel::pg::Pg)))] #[cfg_attr(feature = "full", diesel(table_name = sent_activity))] diff --git a/crates/federate/src/worker.rs b/crates/federate/src/worker.rs index ff2a68e3c..4c2702343 100644 --- a/crates/federate/src/worker.rs +++ b/crates/federate/src/worker.rs @@ -30,9 +30,10 @@ use reqwest::Url; use std::{ collections::{HashMap, HashSet}, ops::{Add, Deref}, + sync::{Arc, RwLock}, time::Duration, }; -use tokio::{sync::mpsc::UnboundedSender, time::sleep}; +use tokio::{spawn, sync::mpsc::UnboundedSender, time::sleep}; use tokio_util::sync::CancellationToken; /// Check whether to save state to db every n sends if there's no failures (during failures state is saved after every attempt) @@ -57,8 +58,11 @@ static FOLLOW_ADDITIONS_RECHECK_DELAY: Lazy = Lazy::new(|| { /// This is expected to happen pretty rarely and updating it in a timely manner is not too important. static FOLLOW_REMOVALS_RECHECK_DELAY: Lazy = Lazy::new(|| chrono::TimeDelta::try_hours(1).expect("TimeDelta out of bounds")); + +const MAX_INFLIGHT_REQUESTS: u8 = 5; + pub(crate) struct InstanceWorker { - instance: Instance, + instance: Arc>, // load site lazily because if an instance is first seen due to being on allowlist, // the corresponding row in `site` may not exist yet since that is only added once // `fetch_instance_actor_for_object` is called. @@ -68,14 +72,26 @@ pub(crate) struct InstanceWorker { followed_communities: HashMap>, stop: CancellationToken, context: Data, - stats_sender: UnboundedSender<(String, FederationQueueState)>, last_full_communities_fetch: DateTime, last_incremental_communities_fetch: DateTime, + stats: Arc>, +} + +#[derive(Clone)] +struct InstanceStats { + stats_sender: UnboundedSender<(String, FederationQueueState)>, state: FederationQueueState, last_state_insert: DateTime, + inflight_requests: u8, } impl InstanceWorker { + fn stats(&self) -> InstanceStats { + self.stats.read().unwrap().clone() + } + fn instance(&self) -> Instance { + self.instance.read().unwrap().clone() + } pub(crate) async fn init_and_loop( instance: Instance, context: Data, @@ -85,17 +101,20 @@ impl InstanceWorker { ) -> Result<(), anyhow::Error> { let state = FederationQueueState::load(pool, instance.id).await?; let mut worker = InstanceWorker { - instance, + instance: Arc::new(RwLock::new(instance)), site_loaded: false, site: None, followed_communities: HashMap::new(), stop, context, - stats_sender, last_full_communities_fetch: Utc.timestamp_nanos(0), last_incremental_communities_fetch: Utc.timestamp_nanos(0), - state, - last_state_insert: Utc.timestamp_nanos(0), + stats: Arc::new(RwLock::new(InstanceStats { + stats_sender, + state, + last_state_insert: Utc.timestamp_nanos(0), + inflight_requests: 0, + })), }; worker.loop_until_stopped(pool).await } @@ -114,25 +133,26 @@ impl InstanceWorker { if self.stop.is_cancelled() { break; } - if (Utc::now() - self.last_state_insert) > save_state_every { - self.save_and_send_state(pool).await?; + if (Utc::now() - self.stats().last_state_insert) > save_state_every { + save_and_send_state(self.stats.clone(), &self.instance(), pool).await?; } self.update_communities(pool).await?; } // final update of state in db - self.save_and_send_state(pool).await?; + save_and_send_state(self.stats.clone(), &self.instance(), pool).await?; Ok(()) } async fn initial_fail_sleep(&mut self) -> Result<()> { // before starting queue, sleep remaining duration if last request failed - if self.state.fail_count > 0 { - let last_retry = self + let stats = self.stats(); + if stats.state.fail_count > 0 { + let last_retry = stats .state .last_retry .context("impossible: if fail count set last retry also set")?; let elapsed = (Utc::now() - last_retry).to_std()?; - let required = federate_retry_sleep_duration(self.state.fail_count); + let required = federate_retry_sleep_duration(stats.state.fail_count); if elapsed >= required { return Ok(()); } @@ -147,14 +167,16 @@ impl InstanceWorker { /// send out a batch of CHECK_SAVE_STATE_EVERY_IT activities async fn loop_batch(&mut self, pool: &mut DbPool<'_>) -> Result<()> { let latest_id = get_latest_activity_id(pool).await?; - let mut id = if let Some(id) = self.state.last_successful_id { + let mut id = if let Some(id) = self.stats().state.last_successful_id { id } else { // this is the initial creation (instance first seen) of the federation queue for this instance // skip all past activities: - self.state.last_successful_id = Some(latest_id); + { + self.stats.write().unwrap().state.last_successful_id = Some(latest_id); + } // save here to ensure it's not read as 0 again later if no activities have happened - self.save_and_send_state(pool).await?; + save_and_send_state(self.stats.clone(), &self.instance(), pool).await?; latest_id }; if id >= latest_id { @@ -170,30 +192,20 @@ impl InstanceWorker { && processed_activities < CHECK_SAVE_STATE_EVERY_IT && !self.stop.is_cancelled() { + while self.stats().inflight_requests >= MAX_INFLIGHT_REQUESTS { + sleep(Duration::from_millis(100)).await; + } id = ActivityId(id.0 + 1); processed_activities += 1; let Some(ele) = get_activity_cached(pool, id) .await .context("failed reading activity from db")? else { - tracing::debug!("{}: {:?} does not exist", self.instance.domain, id); - self.state.last_successful_id = Some(id); + tracing::debug!("{}: {:?} does not exist", self.instance().domain, id); + self.stats.write().unwrap().state.last_successful_id = Some(id); continue; }; - if let Err(e) = self.send_retry_loop(pool, &ele.0, &ele.1).await { - tracing::warn!( - "sending {} errored internally, skipping activity: {:?}", - ele.0.ap_id, - e - ); - } - if self.stop.is_cancelled() { - return Ok(()); - } - // send success! - self.state.last_successful_id = Some(id); - self.state.last_successful_published_time = Some(ele.0.published); - self.state.fail_count = 0; + self.send_retry_loop(pool, ele.0.clone(), &ele.1).await? } Ok(()) } @@ -203,17 +215,34 @@ impl InstanceWorker { async fn send_retry_loop( &mut self, pool: &mut DbPool<'_>, - activity: &SentActivity, + activity: SentActivity, object: &SharedInboxActivities, ) -> Result<()> { + let stats = self.stats(); + let retry_delay: Duration = federate_retry_sleep_duration(stats.state.fail_count); + tracing::info!( + "{}: retrying {:?} attempt {} with delay {retry_delay:.2?}", + self.instance().domain, + activity.id, + stats.state.fail_count + ); + tokio::select! { + () = sleep(retry_delay) => {}, + () = self.stop.cancelled() => { + // save state to db and exit + return Ok(()); + } + }; + let inbox_urls = self - .get_inbox_urls(pool, activity) + .get_inbox_urls(pool, &activity) .await .context("failed figuring out inbox urls")?; if inbox_urls.is_empty() { - tracing::debug!("{}: {:?} no inboxes", self.instance.domain, activity.id); - self.state.last_successful_id = Some(activity.id); - self.state.last_successful_published_time = Some(activity.published); + tracing::debug!("{}: {:?} no inboxes", self.instance().domain, activity.id); + let mut stats = self.stats.write().unwrap(); + stats.state.last_successful_id = Some(activity.id); + stats.state.last_successful_published_time = Some(activity.published); return Ok(()); } let Some(actor_apub_id) = &activity.actor_apub_id else { @@ -225,43 +254,63 @@ impl InstanceWorker { let object = WithContext::new(object.clone(), FEDERATION_CONTEXT.deref().clone()); let inbox_urls = inbox_urls.into_iter().collect(); - let requests = - SendActivityTask::prepare(&object, actor.as_ref(), inbox_urls, &self.context).await?; - for task in requests { + let context = self.context.reset_request_count(); + let stats = self.stats.clone(); + let instance = self.instance(); + let write_instance = self.instance.clone(); + spawn(async move { + { + stats.write().unwrap().inflight_requests += 1; + } + let requests = SendActivityTask::prepare(&object, actor.as_ref(), inbox_urls, &context) + .await + .unwrap(); // usually only one due to shared inbox - tracing::debug!("sending out {}", task); - while let Err(e) = task.sign_and_send(&self.context).await { - self.state.fail_count += 1; - self.state.last_retry = Some(Utc::now()); - let retry_delay: Duration = federate_retry_sleep_duration(self.state.fail_count); - tracing::info!( - "{}: retrying {:?} attempt {} with delay {retry_delay:.2?}. ({e})", - self.instance.domain, - activity.id, - self.state.fail_count - ); - self.save_and_send_state(pool).await?; - tokio::select! { - () = sleep(retry_delay) => {}, - () = self.stop.cancelled() => { - // save state to db and exit - return Ok(()); + for task in requests { + tracing::debug!("sending out {}", task); + let res = task.sign_and_send(&context).await; + match res { + Ok(_) => { + // send success! + { + let mut stats_ = stats.write().unwrap(); + stats_.inflight_requests -= 1; + stats_.state.last_successful_id = Some(activity.id); + stats_.state.last_successful_published_time = Some(activity.published); + stats_.state.fail_count = 0; + } + + // mark instance as alive if it hasn't been updated in a while. + let updated = instance.updated.unwrap_or(instance.published); + if updated.add(Days::new(1)) < Utc::now() { + { + write_instance.write().unwrap().updated = Some(Utc::now()); + } + + let form = InstanceForm::builder() + .domain(instance.domain.clone()) + .updated(Some(naive_now())) + .build(); + Instance::update(&mut context.pool(), instance.id, form) + .await + .unwrap(); + } + } + Err(e) => { + tracing::info!("{} send failed: {e}", instance.domain); + { + let mut stats_ = stats.write().unwrap(); + stats_.inflight_requests -= 1; + stats_.state.fail_count += 1; + stats_.state.last_retry = Some(Utc::now()); + } + save_and_send_state(stats.clone(), &instance, &mut context.pool()) + .await + .unwrap(); } } } - - // Activity send successful, mark instance as alive if it hasn't been updated in a while. - let updated = self.instance.updated.unwrap_or(self.instance.published); - if updated.add(Days::new(1)) < Utc::now() { - self.instance.updated = Some(Utc::now()); - - let form = InstanceForm::builder() - .domain(self.instance.domain.clone()) - .updated(Some(naive_now())) - .build(); - Instance::update(pool, self.instance.id, form).await?; - } - } + }); Ok(()) } @@ -278,7 +327,7 @@ impl InstanceWorker { if activity.send_all_instances { if !self.site_loaded { - self.site = Site::read_from_instance_id(pool, self.instance.id).await?; + self.site = Site::read_from_instance_id(pool, self.instance().id).await?; self.site_loaded = true; } if let Some(site) = &self.site { @@ -296,7 +345,7 @@ impl InstanceWorker { .send_inboxes .iter() .filter_map(std::option::Option::as_ref) - .filter(|&u| (u.domain() == Some(&self.instance.domain))) + .filter(|&u| (u.domain() == Some(&self.instance().domain))) .map(|u| u.inner().clone()), ); Ok(inbox_urls) @@ -306,7 +355,7 @@ impl InstanceWorker { if (Utc::now() - self.last_full_communities_fetch) > *FOLLOW_REMOVALS_RECHECK_DELAY { // process removals every hour (self.followed_communities, self.last_full_communities_fetch) = self - .get_communities(pool, self.instance.id, Utc.timestamp_nanos(0)) + .get_communities(pool, self.instance().id, Utc.timestamp_nanos(0)) .await?; self.last_incremental_communities_fetch = self.last_full_communities_fetch; } @@ -315,7 +364,7 @@ impl InstanceWorker { let (news, time) = self .get_communities( pool, - self.instance.id, + self.instance().id, self.last_incremental_communities_fetch, ) .await?; @@ -345,12 +394,21 @@ impl InstanceWorker { new_last_fetch, )) } - async fn save_and_send_state(&mut self, pool: &mut DbPool<'_>) -> Result<()> { - self.last_state_insert = Utc::now(); - FederationQueueState::upsert(pool, &self.state).await?; - self - .stats_sender - .send((self.instance.domain.clone(), self.state.clone()))?; - Ok(()) - } +} + +async fn save_and_send_state( + stats: Arc>, + instance: &Instance, + pool: &mut DbPool<'_>, +) -> Result<()> { + let stats = { + let mut lock = stats.write().unwrap(); + lock.last_state_insert = Utc::now(); + lock.clone() + }; + FederationQueueState::upsert(pool, &stats.state).await?; + stats + .stats_sender + .send((instance.domain.clone(), stats.state))?; + Ok(()) }