Created
September 16, 2023 22:41
-
-
Save ShigeoTejima/78e2efc5bf900c85b9d3127aec0fd7c5 to your computer and use it in GitHub Desktop.
axumでQueryのvalidationをしてみた。bar.rsの方はもっと簡潔にできないか...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* @see https://github.com/tokio-rs/axum/tree/main/examples/validator | |
* @see https://github.com/gengteng/axum-valid/blob/main/src/lib.rs | |
*/ | |
use axum::{extract::{Query, FromRequestParts}, http::StatusCode, response::{IntoResponse, Response}, async_trait, http::request::Parts}; | |
use serde::Deserialize; | |
use validator::{Validate, ValidationErrors}; | |
pub async fn handler(Valid(Query(params)): Valid<Query<Params>>) -> String { | |
format!("{:?}", params) | |
} | |
#[derive(Debug, Validate, Deserialize)] | |
pub struct Params { | |
#[validate(length(min = 1, message = "message must not be empty"))] | |
#[validate(required)] | |
message: Option<String>, | |
} | |
pub trait HasValidate { | |
/// Inner type that can be validated for correctness | |
type Validate: Validate; | |
/// Get the inner value | |
fn get_validate(&self) -> &Self::Validate; | |
} | |
impl<T: Validate> HasValidate for Query<T> { | |
type Validate = T; | |
fn get_validate(&self) -> &T { | |
&self.0 | |
} | |
} | |
#[derive(Debug, Clone, Copy, Default)] | |
pub struct Valid<E>(pub E); | |
#[derive(Debug)] | |
pub enum ValidRejection<E> { | |
Valid(ValidationErrors), | |
Inner(E), | |
} | |
impl<E> From<ValidationErrors> for ValidRejection<E> { | |
fn from(value: ValidationErrors) -> Self { | |
Self::Valid(value) | |
} | |
} | |
pub const VALIDATION_ERROR_STATUS: StatusCode = StatusCode::BAD_REQUEST; | |
impl<E: IntoResponse> IntoResponse for ValidRejection<E> { | |
fn into_response(self) -> Response { | |
match self { | |
ValidRejection::Valid(validate_error) => { | |
{ | |
(VALIDATION_ERROR_STATUS, validate_error.to_string()).into_response() | |
} | |
} | |
ValidRejection::Inner(json_error) => json_error.into_response(), | |
} | |
} | |
} | |
#[async_trait] | |
impl<S, E> FromRequestParts<S> for Valid<E> | |
where | |
S: Send + Sync + 'static, | |
E: HasValidate + FromRequestParts<S>, | |
E::Validate: Validate, | |
{ | |
type Rejection = ValidRejection<<E as FromRequestParts<S>>::Rejection>; | |
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { | |
let inner = E::from_request_parts(parts, state) | |
.await | |
.map_err(ValidRejection::Inner)?; | |
inner.get_validate().validate()?; | |
Ok(Valid(inner)) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[package] | |
name = "demo" | |
version = "0.1.0" | |
edition = "2021" | |
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | |
[dependencies] | |
axum = "0.6.20" | |
hyper = "0.14" | |
serde = { version = "1.0", features = ["derive"] } | |
tokio = { version = "1.32", features = ["full"] } | |
tower = { version = "0.4", features = ["util"] } | |
validator = { version = "0.16", features = ["derive"] } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use core::fmt; | |
use std::str::FromStr; | |
use axum::extract::Query; | |
use serde::{Deserialize, Deserializer, de}; | |
pub async fn handler(params: Query<Params>) -> String { | |
format!("{:?}", params) | |
} | |
#[derive(Debug, Deserialize)] | |
pub struct Params { | |
#[serde(deserialize_with = "name_required")] | |
name: String, | |
year: i32 | |
} | |
fn name_required<'de, D, T>(de: D) -> Result<T, D::Error> | |
where | |
D: Deserializer<'de>, | |
T: FromStr, | |
T::Err: fmt::Display, | |
{ | |
let opt = String::deserialize(de)?; | |
println!("opt: {:?}", opt); | |
if opt.is_empty() { | |
Err(de::Error::custom("name required")) | |
} else { | |
FromStr::from_str(opt.as_str()).map_err(de::Error::custom) | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub mod foo; | |
pub mod bar; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::net::SocketAddr; | |
use axum::{Router, routing::get}; | |
use demo::{foo, bar}; | |
#[tokio::main] | |
async fn main() { | |
println!("Hello, world!"); | |
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); | |
axum::Server::bind(&addr) | |
.serve(app().into_make_service()) | |
.await | |
.unwrap(); | |
} | |
fn app() -> Router { | |
Router::new() | |
.route("/foo", get(foo::handler)) | |
.route("/bar", get(bar::handler)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment