use std::{convert::Infallible, fmt}; use axum::{ extract::FromRequestParts, http::{request::Parts, StatusCode}, response::IntoResponseParts, }; use axum_extra::extract::{cookie::Cookie, CookieJar}; use rand::{distr::Alphanumeric, rng, Rng as _}; use serde::{Deserialize, Serialize}; use tower_sessions::Session; use uuid::Uuid; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct User { pub id: Uuid, // roles: Vec, } impl User { pub fn new(id: Uuid) -> User { User { id } } } pub struct Auth { // id: Uuid, pub session: Session, pub user: Option, } impl FromRequestParts for Auth where S: Send + Sync, { type Rejection = (StatusCode, &'static str); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let session = Session::from_request_parts(parts, state).await?; let user: Option = session.get("user").await.unwrap_or_default(); // session.insert("user", &data).await.unwrap(); Ok(Self { session, user }) } } pub struct CsrfCookie { jar: CookieJar, token: String, } impl CsrfCookie { pub fn token(&self) -> &str { &self.token } } impl FromRequestParts for CsrfCookie where S: Send + Sync, { type Rejection = (StatusCode, &'static str); async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { let mut jar = CookieJar::from_request_parts(parts, state).await.unwrap(); // infalible result let token = if let Some(token) = jar.get("csrf-token") { token.value().to_string() } else { let token = generate_csrf_token(); jar = jar.add( Cookie::build(("csrf-token", token.clone())) .http_only(true) .path("/") .same_site(tower_sessions::cookie::SameSite::Strict), ); token }; Ok(CsrfCookie { jar, token }) } } impl IntoResponseParts for CsrfCookie { type Error = Infallible; fn into_response_parts( self, res: axum::response::ResponseParts, ) -> Result { self.jar.into_response_parts(res) } } const CSRF_TOKEN_LENGTH: usize = 64; fn generate_csrf_token() -> String { rng() .sample_iter(&Alphanumeric) .take(CSRF_TOKEN_LENGTH) .map(char::from) .collect() } impl fmt::Debug for CsrfCookie { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("CsrfCookie").field(&self.token).finish() } }