add CSRF protection to signup, login, and logout

This commit is contained in:
azdle 2025-07-22 15:45:43 -05:00
parent 30f8d9ecf1
commit 8f18d71743
5 changed files with 150 additions and 13 deletions

View file

@ -17,7 +17,7 @@ use tracing::{error, info, warn};
use uuid::Uuid; use uuid::Uuid;
use crate::server::{ use crate::server::{
session::{Auth, User}, session::{Auth, CsrfCookie, User},
AppState, AppState,
}; };
@ -32,19 +32,30 @@ pub fn build() -> Router<AppState> {
#[template(path = "signup.html")] #[template(path = "signup.html")]
struct SignupPage { struct SignupPage {
error: Option<String>, error: Option<String>,
csrf_token: String,
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn signup_page() -> SignupPage { pub async fn signup_page(csrf: CsrfCookie) -> impl IntoResponse {
info!("get signup page"); info!("get signup page");
SignupPage::default() let csrf_token = csrf.token().to_string();
(
csrf,
SignupPage {
error: None,
csrf_token,
},
)
} }
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct SignupForm { pub struct SignupForm {
email: String, email: String,
password: String, password: String,
csrf_token: String,
} }
impl fmt::Debug for SignupForm { impl fmt::Debug for SignupForm {
@ -60,10 +71,15 @@ impl fmt::Debug for SignupForm {
pub async fn signup( pub async fn signup(
State(AppState { db, .. }): State<AppState>, State(AppState { db, .. }): State<AppState>,
mut user: Auth, mut user: Auth,
csrf: CsrfCookie,
Form(form): Form<SignupForm>, Form(form): Form<SignupForm>,
) -> Result<impl IntoResponse, SignupError> { ) -> Result<impl IntoResponse, SignupError> {
info!("signup attempt"); info!("signup attempt");
if form.csrf_token != csrf.token() {
return Err(SignupError::CsrfValidationFailed);
}
info!("hash password: {}", &form.password); info!("hash password: {}", &form.password);
let password_hash: String = tokio::task::spawn_blocking(async move || { let password_hash: String = tokio::task::spawn_blocking(async move || {
let salt = SaltString::generate(&mut OsRng); let salt = SaltString::generate(&mut OsRng);
@ -103,6 +119,8 @@ pub async fn signup(
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
pub enum SignupError { pub enum SignupError {
#[error("CSRF Validation Failed")]
CsrfValidationFailed,
#[error("Unknown Error: {0}")] #[error("Unknown Error: {0}")]
Unknown(#[from] anyhow::Error), Unknown(#[from] anyhow::Error),
} }
@ -110,6 +128,9 @@ pub enum SignupError {
impl IntoResponse for SignupError { impl IntoResponse for SignupError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
SignupError::CsrfValidationFailed => {
(StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again")
}
SignupError::Unknown(e) => { SignupError::Unknown(e) => {
error!(?e, "returning INTERNAL SERVER ERROR"); error!(?e, "returning INTERNAL SERVER ERROR");
(StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error")
@ -123,19 +144,30 @@ impl IntoResponse for SignupError {
#[template(path = "login.html")] #[template(path = "login.html")]
struct LoginPage { struct LoginPage {
error: Option<String>, error: Option<String>,
csrf_token: String,
} }
#[tracing::instrument] #[tracing::instrument]
pub async fn login_page() -> LoginPage { pub async fn login_page(csrf: CsrfCookie) -> impl IntoResponse {
info!("get login page"); info!("get login page");
Default::default() let csrf_token = csrf.token().to_string();
(
csrf,
LoginPage {
error: None,
csrf_token,
},
)
} }
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct LoginForm { pub struct LoginForm {
email: String, email: String,
password: String, password: String,
csrf_token: String,
} }
impl fmt::Debug for LoginForm { impl fmt::Debug for LoginForm {
@ -151,10 +183,15 @@ impl fmt::Debug for LoginForm {
pub async fn login( pub async fn login(
State(AppState { db, .. }): State<AppState>, State(AppState { db, .. }): State<AppState>,
mut auth: Auth, mut auth: Auth,
csrf: CsrfCookie,
Form(form): Form<LoginForm>, Form(form): Form<LoginForm>,
) -> Result<impl IntoResponse, LoginError> { ) -> Result<impl IntoResponse, LoginError> {
info!("login attempt"); info!("login attempt");
if form.csrf_token != csrf.token() {
return Err(LoginError::CsrfValidationFailed);
}
let user = sqlx::query!( let user = sqlx::query!(
r#" r#"
SELECT id, password FROM users WHERE email = $1 LIMIT 1; SELECT id, password FROM users WHERE email = $1 LIMIT 1;
@ -203,6 +240,8 @@ pub enum LoginError {
InvalidPassword, InvalidPassword,
#[error("Unknown User")] #[error("Unknown User")]
UnknownUser, UnknownUser,
#[error("CSRF Validation Failed")]
CsrfValidationFailed,
#[error("Unknown Error: {0}")] #[error("Unknown Error: {0}")]
Unknown(#[from] anyhow::Error), Unknown(#[from] anyhow::Error),
} }
@ -212,6 +251,9 @@ impl IntoResponse for LoginError {
let (status, message) = match self { let (status, message) = match self {
LoginError::InvalidPassword => (StatusCode::UNAUTHORIZED, "Invalid Password"), LoginError::InvalidPassword => (StatusCode::UNAUTHORIZED, "Invalid Password"),
LoginError::UnknownUser => (StatusCode::UNAUTHORIZED, "Unknown User"), LoginError::UnknownUser => (StatusCode::UNAUTHORIZED, "Unknown User"),
LoginError::CsrfValidationFailed => {
(StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again")
}
LoginError::Unknown(e) => { LoginError::Unknown(e) => {
error!(?e, "returning INTERNAL SERVER ERROR"); error!(?e, "returning INTERNAL SERVER ERROR");
(StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error")
@ -222,6 +264,7 @@ impl IntoResponse for LoginError {
status, status,
LoginPage { LoginPage {
error: Some(message.to_string()), error: Some(message.to_string()),
csrf_token: "".to_string(),
}, },
) )
.into_response() .into_response()
@ -230,18 +273,36 @@ impl IntoResponse for LoginError {
#[derive(Template, WebTemplate)] #[derive(Template, WebTemplate)]
#[template(path = "logout.html")] #[template(path = "logout.html")]
struct LogoutPage; struct LogoutPage {
csrf_token: String,
#[tracing::instrument]
pub async fn logout_page() -> LogoutPage {
info!("get logout page");
LogoutPage
} }
pub async fn logout(mut user: Auth) -> Result<impl IntoResponse, LogoutError> { #[tracing::instrument]
pub async fn logout_page(csrf: CsrfCookie) -> impl IntoResponse {
info!("get logout page");
let csrf_token = csrf.token().to_string();
(csrf, LogoutPage { csrf_token })
}
#[derive(Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct LogoutForm {
csrf_token: String,
}
pub async fn logout(
mut user: Auth,
csrf: CsrfCookie,
Form(form): Form<LogoutForm>,
) -> Result<impl IntoResponse, LogoutError> {
info!("logout attempt"); info!("logout attempt");
if form.csrf_token != csrf.token() {
return Err(LogoutError::CsrfValidationFailed);
}
if user.user.is_none() { if user.user.is_none() {
return Err(LogoutError::NotLoggedIn); return Err(LogoutError::NotLoggedIn);
} }
@ -256,6 +317,8 @@ pub async fn logout(mut user: Auth) -> Result<impl IntoResponse, LogoutError> {
pub enum LogoutError { pub enum LogoutError {
#[error("Not Logged In")] #[error("Not Logged In")]
NotLoggedIn, NotLoggedIn,
#[error("CSRF Validation Failed")]
CsrfValidationFailed,
#[error("Unknown Error: {0}")] #[error("Unknown Error: {0}")]
Unknown(#[from] anyhow::Error), Unknown(#[from] anyhow::Error),
} }
@ -264,6 +327,9 @@ impl IntoResponse for LogoutError {
fn into_response(self) -> axum::response::Response { fn into_response(self) -> axum::response::Response {
match self { match self {
LogoutError::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Unknown User"), LogoutError::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Unknown User"),
LogoutError::CsrfValidationFailed => {
(StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again")
}
LogoutError::Unknown(e) => { LogoutError::Unknown(e) => {
error!(?e, "returning INTERNAL SERVER ERROR"); error!(?e, "returning INTERNAL SERVER ERROR");
(StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error")

View file

@ -1,7 +1,12 @@
use std::{convert::Infallible, fmt};
use axum::{ use axum::{
extract::FromRequestParts, extract::FromRequestParts,
http::{request::Parts, StatusCode}, http::{request::Parts, StatusCode},
response::IntoResponseParts,
}; };
use axum_extra::extract::{cookie::Cookie, CookieJar};
use rand::{distr::Alphanumeric, rng, Rng as _};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tower_sessions::Session; use tower_sessions::Session;
use uuid::Uuid; use uuid::Uuid;
@ -40,3 +45,66 @@ where
Ok(Self { session, user }) Ok(Self { session, user })
} }
} }
pub struct CsrfCookie {
jar: CookieJar,
token: String,
}
impl CsrfCookie {
pub fn token(&self) -> &str {
&self.token
}
}
impl<S> FromRequestParts<S> for CsrfCookie
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let mut jar = CookieJar::from_request_parts(parts, state).await.unwrap(); // infalible result
let token = if let Some(token) = jar.get("csrf-token") {
token.value().to_string()
} else {
let token = generate_csrf_token();
jar = jar.add(
Cookie::build(("csrf-token", token.clone()))
.http_only(true)
.path("/")
.same_site(tower_sessions::cookie::SameSite::Strict),
);
token
};
Ok(CsrfCookie { jar, token })
}
}
impl IntoResponseParts for CsrfCookie {
type Error = Infallible;
fn into_response_parts(
self,
res: axum::response::ResponseParts,
) -> Result<axum::response::ResponseParts, Self::Error> {
self.jar.into_response_parts(res)
}
}
const CSRF_TOKEN_LENGTH: usize = 64;
fn generate_csrf_token() -> String {
rng()
.sample_iter(&Alphanumeric)
.take(CSRF_TOKEN_LENGTH)
.map(char::from)
.collect()
}
impl fmt::Debug for CsrfCookie {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("CsrfCookie").field(&self.token).finish()
}
}

View file

@ -13,6 +13,7 @@
Password Password
<input type=password name=password> <input type=password name=password>
</label> </label>
<input type=hidden name=csrf-token value="{{csrf_token}}" />
<button type="submit">Login</button> <button type="submit">Login</button>
</form> </form>

View file

@ -2,6 +2,7 @@
<html lang="en"> <html lang="en">
<title>Logout</title> <title>Logout</title>
<form method="post"> <form method="post">
<input type=hidden name=csrf-token value="{{csrf_token}}" />
<button type="submit">Logout</button> <button type="submit">Logout</button>
</form> </form>
</html> </html>

View file

@ -13,6 +13,7 @@
Password Password
<input type=password name=password> <input type=password name=password>
</label> </label>
<input type=hidden name=csrf-token value="{{csrf_token}}" />
<button type="submit">Signup</button> <button type="submit">Signup</button>
</form> </form>