zero-to-axum/src/server/session/mod.rs

110 lines
2.6 KiB
Rust

use std::{convert::Infallible, fmt};
use axum::{
extract::FromRequestParts,
http::{request::Parts, StatusCode},
response::IntoResponseParts,
};
use axum_extra::extract::{cookie::Cookie, CookieJar};
use rand::{distr::Alphanumeric, rng, Rng as _};
use serde::{Deserialize, Serialize};
use tower_sessions::Session;
use uuid::Uuid;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct User {
pub id: Uuid,
// roles: Vec<String>,
}
impl User {
pub fn new(id: Uuid) -> User {
User { id }
}
}
pub struct Auth {
// id: Uuid,
pub session: Session,
pub user: Option<User>,
}
impl<S> FromRequestParts<S> for Auth
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let session = Session::from_request_parts(parts, state).await?;
let user: Option<User> = session.get("user").await.unwrap_or_default();
// session.insert("user", &data).await.unwrap();
Ok(Self { session, user })
}
}
pub struct CsrfCookie {
jar: CookieJar,
token: String,
}
impl CsrfCookie {
pub fn token(&self) -> &str {
&self.token
}
}
impl<S> FromRequestParts<S> for CsrfCookie
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let mut jar = CookieJar::from_request_parts(parts, state).await.unwrap(); // infalible result
let token = if let Some(token) = jar.get("csrf-token") {
token.value().to_string()
} else {
let token = generate_csrf_token();
jar = jar.add(
Cookie::build(("csrf-token", token.clone()))
.http_only(true)
.path("/")
.same_site(tower_sessions::cookie::SameSite::Strict),
);
token
};
Ok(CsrfCookie { jar, token })
}
}
impl IntoResponseParts for CsrfCookie {
type Error = Infallible;
fn into_response_parts(
self,
res: axum::response::ResponseParts,
) -> Result<axum::response::ResponseParts, Self::Error> {
self.jar.into_response_parts(res)
}
}
const CSRF_TOKEN_LENGTH: usize = 64;
fn generate_csrf_token() -> String {
rng()
.sample_iter(&Alphanumeric)
.take(CSRF_TOKEN_LENGTH)
.map(char::from)
.collect()
}
impl fmt::Debug for CsrfCookie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("CsrfCookie").field(&self.token).finish()
}
}