diff --git a/src/crawl.rs b/src/crawl.rs index 7d2ad39..a178c1a 100644 --- a/src/crawl.rs +++ b/src/crawl.rs @@ -3,52 +3,75 @@ use crate::node_info::NodeInfo; use crate::REQUEST_TIMEOUT; use anyhow::anyhow; use anyhow::Error; -use futures::try_join; +use futures::executor::block_on_stream; +use futures::future::join_all; +use futures::stream::FuturesUnordered; +use futures::{future, stream, try_join, StreamExt, TryStreamExt}; use reqwest::Client; use serde::Serialize; +use std::cmp::max; use std::collections::VecDeque; +use std::future::Future; +use std::sync::{Arc, Mutex}; pub async fn crawl( start_instances: Vec, exclude: Vec, max_depth: i32, ) -> Result<(Vec, i32), Error> { - let mut pending_instances: VecDeque = start_instances + let exclude = Arc::new(exclude); + let mut pending_instances: VecDeque = start_instances .iter() - .map(|s| CrawlInstance::new(s.to_string(), 0)) + .map(|s| CrawlInstanceTask::new(s.to_string(), 0, exclude.clone())) .collect(); - let mut crawled_instances = vec![]; - let mut instance_details = vec![]; - let mut failed_instances = 0; - while let Some(current_instance) = pending_instances.pop_back() { - crawled_instances.push(current_instance.domain.clone()); - if current_instance.depth > max_depth || exclude.contains(¤t_instance.domain) { - continue; - } - match fetch_instance_details(¤t_instance.domain).await { - Ok(details) => { - instance_details.push(details.to_owned()); - for i in details.linked_instances { - let is_in_crawled = crawled_instances.contains(&i); - let is_in_pending = pending_instances.iter().any(|p| p.domain == i); - if !is_in_crawled && !is_in_pending { - let ci = CrawlInstance::new(i, current_instance.depth + 1); - pending_instances.push_back(ci); - } - } - } - Err(e) => { - failed_instances += 1; - eprintln!("Failed to crawl {}: {}", current_instance.domain, e) - } - } - } + let mut crawled_instances = Mutex::new(vec![]); + //let mut instance_details = vec![]; + //let mut failed_instances = 0; + + let stream = Box::pin( + stream::iter(pending_instances) + .then(|task: CrawlInstanceTask| async { + crawled_instances.lock().unwrap().push(task.domain.clone()); + crawl_instance(task, max_depth).await.unwrap() + }) + .flat_map(|(instance_details, task)| { + let futures = instance_details.linked_instances.iter().map(|i| { + crawled_instances.lock().unwrap().push(i.clone()); + crawl_instance( + CrawlInstanceTask::new(i.clone(), task.depth + 1, task.exclude.clone()), + max_depth, + ) + }); + + stream::iter(futures) + }), + ); + + let crawl_result: Vec> = stream + .buffer_unordered(10) + .map_ok(|(details, _)| details) + .collect() + .await; + + todo!() + /* // Sort by active monthly users descending - instance_details.sort_by_key(|i| i.users_active_month); - instance_details.reverse(); + crawl_result.sort_by_key(|i| i.users_active_month); + crawl_result.reverse(); - Ok((instance_details, failed_instances)) + Ok((crawl_result, failed_instances)) + */ +} + +async fn crawl_instance( + task: CrawlInstanceTask, + max_depth: i32, +) -> Result<(InstanceDetails, CrawlInstanceTask), Error> { + if task.depth > max_depth || task.exclude.contains(&task.domain) { + return Err(anyhow!("max depth reached")); + } + Ok((fetch_instance_details(&task.domain).await?, task)) } #[derive(Serialize, Clone)] @@ -70,14 +93,19 @@ pub struct InstanceDetails { pub linked_instances: Vec, } -struct CrawlInstance { +struct CrawlInstanceTask { domain: String, depth: i32, + exclude: Arc>, } -impl CrawlInstance { - pub fn new(domain: String, depth: i32) -> CrawlInstance { - CrawlInstance { domain, depth } +impl CrawlInstanceTask { + pub fn new(domain: String, depth: i32, exclude: Arc>) -> CrawlInstanceTask { + CrawlInstanceTask { + domain, + depth, + exclude, + } } }