diff --git a/crates/routes/src/feeds.rs b/crates/routes/src/feeds.rs index 5d1f285c..74821354 100644 --- a/crates/routes/src/feeds.rs +++ b/crates/routes/src/feeds.rs @@ -31,13 +31,30 @@ use rss::{ }; use serde::Deserialize; use std::{collections::BTreeMap, str::FromStr}; -use strum::ParseError; const RSS_FETCH_LIMIT: i64 = 20; #[derive(Deserialize)] struct Params { sort: Option, + limit: Option, + page: Option, +} + +impl Params { + fn sort_type(&self) -> Result { + let sort_query = self + .sort + .clone() + .unwrap_or_else(|| SortType::Hot.to_string()); + SortType::from_str(&sort_query).map_err(ErrorBadRequest) + } + fn get_limit(&self) -> i64 { + self.limit.unwrap_or(RSS_FETCH_LIMIT) + } + fn get_page(&self) -> i64 { + self.page.unwrap_or(1) + } } enum RequestType { @@ -68,8 +85,16 @@ async fn get_all_feed( info: web::Query, context: web::Data, ) -> Result { - let sort_type = get_sort_type(info).map_err(ErrorBadRequest)?; - Ok(get_feed_data(&context, ListingType::All, sort_type).await?) + Ok( + get_feed_data( + &context, + ListingType::All, + info.sort_type()?, + info.get_limit(), + info.get_page(), + ) + .await?, + ) } #[tracing::instrument(skip_all)] @@ -77,8 +102,16 @@ async fn get_local_feed( info: web::Query, context: web::Data, ) -> Result { - let sort_type = get_sort_type(info).map_err(ErrorBadRequest)?; - Ok(get_feed_data(&context, ListingType::Local, sort_type).await?) + Ok( + get_feed_data( + &context, + ListingType::Local, + info.sort_type()?, + info.get_limit(), + info.get_page(), + ) + .await?, + ) } #[tracing::instrument(skip_all)] @@ -86,6 +119,8 @@ async fn get_feed_data( context: &LemmyContext, listing_type: ListingType, sort_type: SortType, + limit: i64, + page: i64, ) -> Result { let site_view = SiteView::read_local(context.pool()).await?; @@ -93,7 +128,8 @@ async fn get_feed_data( .pool(context.pool()) .listing_type(Some(listing_type)) .sort(Some(sort_type)) - .limit(Some(RSS_FETCH_LIMIT)) + .limit(Some(limit)) + .page(Some(page)) .build() .list() .await?; @@ -125,8 +161,6 @@ async fn get_feed( info: web::Query, context: web::Data, ) -> Result { - let sort_type = get_sort_type(info).map_err(ErrorBadRequest)?; - let req_type: String = req.match_info().get("type").unwrap_or("none").parse()?; let param: String = req.match_info().get("name").unwrap_or("none").parse()?; @@ -143,16 +177,34 @@ async fn get_feed( let builder = match request_type { RequestType::User => { - get_feed_user(context.pool(), &sort_type, ¶m, &protocol_and_hostname).await + get_feed_user( + context.pool(), + &info.sort_type()?, + &info.get_limit(), + &info.get_page(), + ¶m, + &protocol_and_hostname, + ) + .await } RequestType::Community => { - get_feed_community(context.pool(), &sort_type, ¶m, &protocol_and_hostname).await + get_feed_community( + context.pool(), + &info.sort_type()?, + &info.get_limit(), + &info.get_page(), + ¶m, + &protocol_and_hostname, + ) + .await } RequestType::Front => { get_feed_front( context.pool(), &jwt_secret, - &sort_type, + &info.sort_type()?, + &info.get_limit(), + &info.get_page(), ¶m, &protocol_and_hostname, ) @@ -173,18 +225,12 @@ async fn get_feed( ) } -fn get_sort_type(info: web::Query) -> Result { - let sort_query = info - .sort - .clone() - .unwrap_or_else(|| SortType::Hot.to_string()); - SortType::from_str(&sort_query) -} - #[tracing::instrument(skip_all)] async fn get_feed_user( pool: &DbPool, sort_type: &SortType, + limit: &i64, + page: &i64, user_name: &str, protocol_and_hostname: &str, ) -> Result { @@ -196,7 +242,8 @@ async fn get_feed_user( .listing_type(Some(ListingType::All)) .sort(Some(*sort_type)) .creator_id(Some(person.id)) - .limit(Some(RSS_FETCH_LIMIT)) + .limit(Some(*limit)) + .page(Some(*page)) .build() .list() .await?; @@ -217,6 +264,8 @@ async fn get_feed_user( async fn get_feed_community( pool: &DbPool, sort_type: &SortType, + limit: &i64, + page: &i64, community_name: &str, protocol_and_hostname: &str, ) -> Result { @@ -227,7 +276,8 @@ async fn get_feed_community( .pool(pool) .sort(Some(*sort_type)) .community_id(Some(community.id)) - .limit(Some(RSS_FETCH_LIMIT)) + .limit(Some(*limit)) + .page(Some(*page)) .build() .list() .await?; @@ -253,6 +303,8 @@ async fn get_feed_front( pool: &DbPool, jwt_secret: &str, sort_type: &SortType, + limit: &i64, + page: &i64, jwt: &str, protocol_and_hostname: &str, ) -> Result { @@ -265,7 +317,8 @@ async fn get_feed_front( .listing_type(Some(ListingType::Subscribed)) .local_user(Some(&local_user)) .sort(Some(*sort_type)) - .limit(Some(RSS_FETCH_LIMIT)) + .limit(Some(*limit)) + .page(Some(*page)) .build() .list() .await?;