diff --git a/src/server/routes/auth/mod.rs b/src/server/routes/auth/mod.rs index e50ebab..692106c 100644 --- a/src/server/routes/auth/mod.rs +++ b/src/server/routes/auth/mod.rs @@ -17,7 +17,7 @@ use tracing::{error, info, warn}; use uuid::Uuid; use crate::server::{ - session::{Auth, User}, + session::{Auth, CsrfCookie, User}, AppState, }; @@ -32,19 +32,30 @@ pub fn build() -> Router { #[template(path = "signup.html")] struct SignupPage { error: Option, + csrf_token: String, } #[tracing::instrument] -pub async fn signup_page() -> SignupPage { +pub async fn signup_page(csrf: CsrfCookie) -> impl IntoResponse { info!("get signup page"); - SignupPage::default() + let csrf_token = csrf.token().to_string(); + + ( + csrf, + SignupPage { + error: None, + csrf_token, + }, + ) } #[derive(Deserialize)] +#[serde(rename_all = "kebab-case")] pub struct SignupForm { email: String, password: String, + csrf_token: String, } impl fmt::Debug for SignupForm { @@ -60,10 +71,15 @@ impl fmt::Debug for SignupForm { pub async fn signup( State(AppState { db, .. }): State, mut user: Auth, + csrf: CsrfCookie, Form(form): Form, ) -> Result { info!("signup attempt"); + if form.csrf_token != csrf.token() { + return Err(SignupError::CsrfValidationFailed); + } + info!("hash password: {}", &form.password); let password_hash: String = tokio::task::spawn_blocking(async move || { let salt = SaltString::generate(&mut OsRng); @@ -103,6 +119,8 @@ pub async fn signup( #[derive(thiserror::Error, Debug)] pub enum SignupError { + #[error("CSRF Validation Failed")] + CsrfValidationFailed, #[error("Unknown Error: {0}")] Unknown(#[from] anyhow::Error), } @@ -110,6 +128,9 @@ pub enum SignupError { impl IntoResponse for SignupError { fn into_response(self) -> axum::response::Response { match self { + SignupError::CsrfValidationFailed => { + (StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again") + } SignupError::Unknown(e) => { error!(?e, "returning INTERNAL SERVER ERROR"); (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") @@ -123,19 +144,30 @@ impl IntoResponse for SignupError { #[template(path = "login.html")] struct LoginPage { error: Option, + csrf_token: String, } #[tracing::instrument] -pub async fn login_page() -> LoginPage { +pub async fn login_page(csrf: CsrfCookie) -> impl IntoResponse { info!("get login page"); - Default::default() + let csrf_token = csrf.token().to_string(); + + ( + csrf, + LoginPage { + error: None, + csrf_token, + }, + ) } #[derive(Deserialize)] +#[serde(rename_all = "kebab-case")] pub struct LoginForm { email: String, password: String, + csrf_token: String, } impl fmt::Debug for LoginForm { @@ -151,10 +183,15 @@ impl fmt::Debug for LoginForm { pub async fn login( State(AppState { db, .. }): State, mut auth: Auth, + csrf: CsrfCookie, Form(form): Form, ) -> Result { info!("login attempt"); + if form.csrf_token != csrf.token() { + return Err(LoginError::CsrfValidationFailed); + } + let user = sqlx::query!( r#" SELECT id, password FROM users WHERE email = $1 LIMIT 1; @@ -203,6 +240,8 @@ pub enum LoginError { InvalidPassword, #[error("Unknown User")] UnknownUser, + #[error("CSRF Validation Failed")] + CsrfValidationFailed, #[error("Unknown Error: {0}")] Unknown(#[from] anyhow::Error), } @@ -212,6 +251,9 @@ impl IntoResponse for LoginError { let (status, message) = match self { LoginError::InvalidPassword => (StatusCode::UNAUTHORIZED, "Invalid Password"), LoginError::UnknownUser => (StatusCode::UNAUTHORIZED, "Unknown User"), + LoginError::CsrfValidationFailed => { + (StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again") + } LoginError::Unknown(e) => { error!(?e, "returning INTERNAL SERVER ERROR"); (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") @@ -222,6 +264,7 @@ impl IntoResponse for LoginError { status, LoginPage { error: Some(message.to_string()), + csrf_token: "".to_string(), }, ) .into_response() @@ -230,18 +273,36 @@ impl IntoResponse for LoginError { #[derive(Template, WebTemplate)] #[template(path = "logout.html")] -struct LogoutPage; - -#[tracing::instrument] -pub async fn logout_page() -> LogoutPage { - info!("get logout page"); - - LogoutPage +struct LogoutPage { + csrf_token: String, } -pub async fn logout(mut user: Auth) -> Result { +#[tracing::instrument] +pub async fn logout_page(csrf: CsrfCookie) -> impl IntoResponse { + info!("get logout page"); + + let csrf_token = csrf.token().to_string(); + + (csrf, LogoutPage { csrf_token }) +} + +#[derive(Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct LogoutForm { + csrf_token: String, +} + +pub async fn logout( + mut user: Auth, + csrf: CsrfCookie, + Form(form): Form, +) -> Result { info!("logout attempt"); + if form.csrf_token != csrf.token() { + return Err(LogoutError::CsrfValidationFailed); + } + if user.user.is_none() { return Err(LogoutError::NotLoggedIn); } @@ -256,6 +317,8 @@ pub async fn logout(mut user: Auth) -> Result { pub enum LogoutError { #[error("Not Logged In")] NotLoggedIn, + #[error("CSRF Validation Failed")] + CsrfValidationFailed, #[error("Unknown Error: {0}")] Unknown(#[from] anyhow::Error), } @@ -264,6 +327,9 @@ impl IntoResponse for LogoutError { fn into_response(self) -> axum::response::Response { match self { LogoutError::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Unknown User"), + LogoutError::CsrfValidationFailed => { + (StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again") + } LogoutError::Unknown(e) => { error!(?e, "returning INTERNAL SERVER ERROR"); (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") diff --git a/src/server/session/mod.rs b/src/server/session/mod.rs index fb67389..0f836a2 100644 --- a/src/server/session/mod.rs +++ b/src/server/session/mod.rs @@ -1,7 +1,12 @@ +use std::{convert::Infallible, fmt}; + use axum::{ extract::FromRequestParts, http::{request::Parts, StatusCode}, + response::IntoResponseParts, }; +use axum_extra::extract::{cookie::Cookie, CookieJar}; +use rand::{distr::Alphanumeric, rng, Rng as _}; use serde::{Deserialize, Serialize}; use tower_sessions::Session; use uuid::Uuid; @@ -40,3 +45,66 @@ where Ok(Self { session, user }) } } + +pub struct CsrfCookie { + jar: CookieJar, + token: String, +} + +impl CsrfCookie { + pub fn token(&self) -> &str { + &self.token + } +} + +impl FromRequestParts for CsrfCookie +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let mut jar = CookieJar::from_request_parts(parts, state).await.unwrap(); // infalible result + + let token = if let Some(token) = jar.get("csrf-token") { + token.value().to_string() + } else { + let token = generate_csrf_token(); + jar = jar.add( + Cookie::build(("csrf-token", token.clone())) + .http_only(true) + .path("/") + .same_site(tower_sessions::cookie::SameSite::Strict), + ); + token + }; + + Ok(CsrfCookie { jar, token }) + } +} + +impl IntoResponseParts for CsrfCookie { + type Error = Infallible; + + fn into_response_parts( + self, + res: axum::response::ResponseParts, + ) -> Result { + self.jar.into_response_parts(res) + } +} + +const CSRF_TOKEN_LENGTH: usize = 64; +fn generate_csrf_token() -> String { + rng() + .sample_iter(&Alphanumeric) + .take(CSRF_TOKEN_LENGTH) + .map(char::from) + .collect() +} + +impl fmt::Debug for CsrfCookie { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("CsrfCookie").field(&self.token).finish() + } +} diff --git a/templates/login.html b/templates/login.html index 1feeb9e..a57980c 100644 --- a/templates/login.html +++ b/templates/login.html @@ -13,6 +13,7 @@ Password + diff --git a/templates/logout.html b/templates/logout.html index 2465ff0..c5b9bd3 100644 --- a/templates/logout.html +++ b/templates/logout.html @@ -2,6 +2,7 @@ Logout
+
diff --git a/templates/signup.html b/templates/signup.html index 6d998f1..6da9faa 100644 --- a/templates/signup.html +++ b/templates/signup.html @@ -13,6 +13,7 @@ Password +