分享一个 Rust Axum 中封装自用简单的分页查询函数,使用的是 sqlx 查询数据库。
方法
use std::{collections::HashMap, marker::PhantomData};
use chrono::{DateTime, Utc};
use serde::Serialize;
use sqlx::{
postgres::PgArguments,
query::{QueryAs, QueryScalar},
query_as, query_scalar, FromRow, PgPool, Postgres,
};
use uuid::Uuid;
use crate::utils::log;
use super::env;
#[derive(Serialize)]
pub struct PaginatedResult<T> {
pub records: Vec<T>,
pub limit: i64,
pub page: i64,
pub total: i64,
}
/// 通用分页请求器
pub struct PaginatedQueryBuilder<T> {
table_name: &'static str,
select_fields: &'static str,
_marker: PhantomData<T>,
}
impl<T> PaginatedQueryBuilder<T>
where
T: for<'r> FromRow<'r, sqlx::postgres::PgRow> + Unpin + Send + 'static + std::fmt::Debug,
{
pub fn new(table_name: &'static str, select_fields: &'static str) -> Self {
Self {
table_name,
select_fields,
_marker: PhantomData,
}
}
/// 构建查询条件
fn build_conditions_sql_and_args<'a>(
&'a self,
conditions: Option<&'a HashMap<String, PaginatedQueryValue>>,
) -> (String, Vec<&PaginatedQueryValue>) {
let mut query_str = String::new();
let mut args = Vec::new();
let mut index = 1;
if let Some(conds) = conditions {
if !conds.is_empty() {
query_str.push_str(" WHERE ");
for (key, value) in conds {
query_str.push_str(&format!("{} = ${} AND ", key, index));
args.push(value);
index += 1;
}
query_str.truncate(query_str.len() - 5); // 移除最后一个多余的 " AND "
}
};
(query_str, args)
}
pub async fn count(
&self,
pool: &PgPool,
conditions: Option<&HashMap<String, PaginatedQueryValue>>,
) -> Result<i64, sqlx::Error> {
let mut query_str = format!("SELECT COUNT(*) FROM {}", self.table_name);
let (where_str, args) = self.build_conditions_sql_and_args(conditions);
query_str.push_str(&where_str);
log::info("SQL", &query_str);
let mut query = query_scalar(&query_str);
for value in args {
query = bind_value_to_query_scalar(query, value);
}
let count: i64 = query.fetch_one(&*pool).await?;
Ok(count)
}
pub async fn paginate(
&self,
pool: &PgPool,
page_number: Option<i64>,
items_per_page: Option<i64>,
conditions: Option<&HashMap<String, PaginatedQueryValue>>,
sort_field: Option<&str>,
) -> Result<PaginatedResult<T>, sqlx::Error> {
let limit = items_per_page.unwrap_or(env::APP_PAGE_LIMIT.parse::<i64>().unwrap_or(10));
let offset = (page_number.unwrap_or(1) - 1) * limit;
// 默认排序
let order_by = sort_field.unwrap_or("created_at DESC");
let (where_str, args) = self.build_conditions_sql_and_args(conditions);
// 定义查询字段
let query_str = format!(
"SELECT {} FROM {}{} ORDER BY {} LIMIT ${} OFFSET ${}",
self.select_fields,
self.table_name,
where_str,
order_by,
args.len() + 1,
args.len() + 2
);
log::info("SQL", &query_str);
let mut query = query_as::<_, T>(&query_str);
for value in args {
query = bind_value_to_query(query, value);
}
let records = query.bind(limit).bind(offset).fetch_all(pool).await?;
// 查询总数
let total = self.count(pool, conditions).await?;
Ok(PaginatedResult {
records,
limit,
page: page_number.unwrap_or(1),
total,
})
}
}
/// 筛选值类型
pub enum PaginatedQueryValue {
DateTime(DateTime<Utc>),
Bool(bool),
String(String),
Uuid(Uuid),
I64(i64),
I16(i16),
}
/// 辅助函数来绑定不同类型的值到查询构建器
fn bind_value_to_query<'a, T>(
query_builder: QueryAs<'a, Postgres, T, PgArguments>,
value: &'a PaginatedQueryValue,
) -> QueryAs<'a, Postgres, T, PgArguments> {
match value {
PaginatedQueryValue::DateTime(val) => query_builder.bind(val),
PaginatedQueryValue::Bool(val) => query_builder.bind(val),
PaginatedQueryValue::String(val) => query_builder.bind(val),
PaginatedQueryValue::Uuid(val) => query_builder.bind(val),
PaginatedQueryValue::I64(val) => query_builder.bind(val),
PaginatedQueryValue::I16(val) => query_builder.bind(val),
}
}
/// 辅助函数来绑定不同类型的值到查询构建器
fn bind_value_to_query_scalar<'a, T>(
query_builder: QueryScalar<'a, Postgres, T, PgArguments>,
value: &'a PaginatedQueryValue,
) -> QueryScalar<'a, Postgres, T, PgArguments> {
match value {
PaginatedQueryValue::DateTime(val) => query_builder.bind(val),
PaginatedQueryValue::Bool(val) => query_builder.bind(val),
PaginatedQueryValue::String(val) => query_builder.bind(val),
PaginatedQueryValue::Uuid(val) => query_builder.bind(val),
PaginatedQueryValue::I64(val) => query_builder.bind(val),
PaginatedQueryValue::I16(val) => query_builder.bind(val),
}
}
其中log
env
为本地日志和环境变量的包,自行替换成自己的或者删除,
数据库使用的是postgres
的数据库,如使用其他数据库则自行更换其类型。
使用
use std::collections::HashMap;
use sqlx::{Error, PgPool};
use crate::utils::{
self,
pagination::{PaginatedQueryBuilder, PaginatedQueryValue, PaginatedResult},
};
/// 分页查询用户
pub async fn get_page(pool: &PgPool, payload: PageUser) -> Result<PaginatedResult<User>, Error> {
let query_builder = PaginatedQueryBuilder::<User>::new("users", "*");
let mut conditions: HashMap<String, PaginatedQueryValue> = HashMap::new();
if payload.enabled.is_some() {
conditions.insert(
"name".to_string(),
PaginatedQueryValue::Bool(payload.name.unwrap_or(true)),
);
}
query_builder
.paginate(&pool, payload.page, payload.limit, Some(&conditions), None)
.await
}
其中一些struct
定义需要换成自己的。
完。如有更好的建议和意见欢迎留言评论!