add CSRF protection to signup, login, and logout
This commit is contained in:
parent
30f8d9ecf1
commit
8f18d71743
5 changed files with 150 additions and 13 deletions
|
@ -17,7 +17,7 @@ use tracing::{error, info, warn};
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::server::{
|
||||
session::{Auth, User},
|
||||
session::{Auth, CsrfCookie, User},
|
||||
AppState,
|
||||
};
|
||||
|
||||
|
@ -32,19 +32,30 @@ pub fn build() -> Router<AppState> {
|
|||
#[template(path = "signup.html")]
|
||||
struct SignupPage {
|
||||
error: Option<String>,
|
||||
csrf_token: String,
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn signup_page() -> SignupPage {
|
||||
pub async fn signup_page(csrf: CsrfCookie) -> impl IntoResponse {
|
||||
info!("get signup page");
|
||||
|
||||
SignupPage::default()
|
||||
let csrf_token = csrf.token().to_string();
|
||||
|
||||
(
|
||||
csrf,
|
||||
SignupPage {
|
||||
error: None,
|
||||
csrf_token,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct SignupForm {
|
||||
email: String,
|
||||
password: String,
|
||||
csrf_token: String,
|
||||
}
|
||||
|
||||
impl fmt::Debug for SignupForm {
|
||||
|
@ -60,10 +71,15 @@ impl fmt::Debug for SignupForm {
|
|||
pub async fn signup(
|
||||
State(AppState { db, .. }): State<AppState>,
|
||||
mut user: Auth,
|
||||
csrf: CsrfCookie,
|
||||
Form(form): Form<SignupForm>,
|
||||
) -> Result<impl IntoResponse, SignupError> {
|
||||
info!("signup attempt");
|
||||
|
||||
if form.csrf_token != csrf.token() {
|
||||
return Err(SignupError::CsrfValidationFailed);
|
||||
}
|
||||
|
||||
info!("hash password: {}", &form.password);
|
||||
let password_hash: String = tokio::task::spawn_blocking(async move || {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
|
@ -103,6 +119,8 @@ pub async fn signup(
|
|||
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum SignupError {
|
||||
#[error("CSRF Validation Failed")]
|
||||
CsrfValidationFailed,
|
||||
#[error("Unknown Error: {0}")]
|
||||
Unknown(#[from] anyhow::Error),
|
||||
}
|
||||
|
@ -110,6 +128,9 @@ pub enum SignupError {
|
|||
impl IntoResponse for SignupError {
|
||||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
SignupError::CsrfValidationFailed => {
|
||||
(StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again")
|
||||
}
|
||||
SignupError::Unknown(e) => {
|
||||
error!(?e, "returning INTERNAL SERVER ERROR");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error")
|
||||
|
@ -123,19 +144,30 @@ impl IntoResponse for SignupError {
|
|||
#[template(path = "login.html")]
|
||||
struct LoginPage {
|
||||
error: Option<String>,
|
||||
csrf_token: String,
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn login_page() -> LoginPage {
|
||||
pub async fn login_page(csrf: CsrfCookie) -> impl IntoResponse {
|
||||
info!("get login page");
|
||||
|
||||
Default::default()
|
||||
let csrf_token = csrf.token().to_string();
|
||||
|
||||
(
|
||||
csrf,
|
||||
LoginPage {
|
||||
error: None,
|
||||
csrf_token,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct LoginForm {
|
||||
email: String,
|
||||
password: String,
|
||||
csrf_token: String,
|
||||
}
|
||||
|
||||
impl fmt::Debug for LoginForm {
|
||||
|
@ -151,10 +183,15 @@ impl fmt::Debug for LoginForm {
|
|||
pub async fn login(
|
||||
State(AppState { db, .. }): State<AppState>,
|
||||
mut auth: Auth,
|
||||
csrf: CsrfCookie,
|
||||
Form(form): Form<LoginForm>,
|
||||
) -> Result<impl IntoResponse, LoginError> {
|
||||
info!("login attempt");
|
||||
|
||||
if form.csrf_token != csrf.token() {
|
||||
return Err(LoginError::CsrfValidationFailed);
|
||||
}
|
||||
|
||||
let user = sqlx::query!(
|
||||
r#"
|
||||
SELECT id, password FROM users WHERE email = $1 LIMIT 1;
|
||||
|
@ -203,6 +240,8 @@ pub enum LoginError {
|
|||
InvalidPassword,
|
||||
#[error("Unknown User")]
|
||||
UnknownUser,
|
||||
#[error("CSRF Validation Failed")]
|
||||
CsrfValidationFailed,
|
||||
#[error("Unknown Error: {0}")]
|
||||
Unknown(#[from] anyhow::Error),
|
||||
}
|
||||
|
@ -212,6 +251,9 @@ impl IntoResponse for LoginError {
|
|||
let (status, message) = match self {
|
||||
LoginError::InvalidPassword => (StatusCode::UNAUTHORIZED, "Invalid Password"),
|
||||
LoginError::UnknownUser => (StatusCode::UNAUTHORIZED, "Unknown User"),
|
||||
LoginError::CsrfValidationFailed => {
|
||||
(StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again")
|
||||
}
|
||||
LoginError::Unknown(e) => {
|
||||
error!(?e, "returning INTERNAL SERVER ERROR");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error")
|
||||
|
@ -222,6 +264,7 @@ impl IntoResponse for LoginError {
|
|||
status,
|
||||
LoginPage {
|
||||
error: Some(message.to_string()),
|
||||
csrf_token: "".to_string(),
|
||||
},
|
||||
)
|
||||
.into_response()
|
||||
|
@ -230,18 +273,36 @@ impl IntoResponse for LoginError {
|
|||
|
||||
#[derive(Template, WebTemplate)]
|
||||
#[template(path = "logout.html")]
|
||||
struct LogoutPage;
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn logout_page() -> LogoutPage {
|
||||
info!("get logout page");
|
||||
|
||||
LogoutPage
|
||||
struct LogoutPage {
|
||||
csrf_token: String,
|
||||
}
|
||||
|
||||
pub async fn logout(mut user: Auth) -> Result<impl IntoResponse, LogoutError> {
|
||||
#[tracing::instrument]
|
||||
pub async fn logout_page(csrf: CsrfCookie) -> impl IntoResponse {
|
||||
info!("get logout page");
|
||||
|
||||
let csrf_token = csrf.token().to_string();
|
||||
|
||||
(csrf, LogoutPage { csrf_token })
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub struct LogoutForm {
|
||||
csrf_token: String,
|
||||
}
|
||||
|
||||
pub async fn logout(
|
||||
mut user: Auth,
|
||||
csrf: CsrfCookie,
|
||||
Form(form): Form<LogoutForm>,
|
||||
) -> Result<impl IntoResponse, LogoutError> {
|
||||
info!("logout attempt");
|
||||
|
||||
if form.csrf_token != csrf.token() {
|
||||
return Err(LogoutError::CsrfValidationFailed);
|
||||
}
|
||||
|
||||
if user.user.is_none() {
|
||||
return Err(LogoutError::NotLoggedIn);
|
||||
}
|
||||
|
@ -256,6 +317,8 @@ pub async fn logout(mut user: Auth) -> Result<impl IntoResponse, LogoutError> {
|
|||
pub enum LogoutError {
|
||||
#[error("Not Logged In")]
|
||||
NotLoggedIn,
|
||||
#[error("CSRF Validation Failed")]
|
||||
CsrfValidationFailed,
|
||||
#[error("Unknown Error: {0}")]
|
||||
Unknown(#[from] anyhow::Error),
|
||||
}
|
||||
|
@ -264,6 +327,9 @@ impl IntoResponse for LogoutError {
|
|||
fn into_response(self) -> axum::response::Response {
|
||||
match self {
|
||||
LogoutError::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Unknown User"),
|
||||
LogoutError::CsrfValidationFailed => {
|
||||
(StatusCode::BAD_REQUEST, "CSRF Validation Failed, Try Again")
|
||||
}
|
||||
LogoutError::Unknown(e) => {
|
||||
error!(?e, "returning INTERNAL SERVER ERROR");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error")
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
use std::{convert::Infallible, fmt};
|
||||
|
||||
use axum::{
|
||||
extract::FromRequestParts,
|
||||
http::{request::Parts, StatusCode},
|
||||
response::IntoResponseParts,
|
||||
};
|
||||
use axum_extra::extract::{cookie::Cookie, CookieJar};
|
||||
use rand::{distr::Alphanumeric, rng, Rng as _};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_sessions::Session;
|
||||
use uuid::Uuid;
|
||||
|
@ -40,3 +45,66 @@ where
|
|||
Ok(Self { session, user })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CsrfCookie {
|
||||
jar: CookieJar,
|
||||
token: String,
|
||||
}
|
||||
|
||||
impl CsrfCookie {
|
||||
pub fn token(&self) -> &str {
|
||||
&self.token
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> FromRequestParts<S> for CsrfCookie
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = (StatusCode, &'static str);
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
let mut jar = CookieJar::from_request_parts(parts, state).await.unwrap(); // infalible result
|
||||
|
||||
let token = if let Some(token) = jar.get("csrf-token") {
|
||||
token.value().to_string()
|
||||
} else {
|
||||
let token = generate_csrf_token();
|
||||
jar = jar.add(
|
||||
Cookie::build(("csrf-token", token.clone()))
|
||||
.http_only(true)
|
||||
.path("/")
|
||||
.same_site(tower_sessions::cookie::SameSite::Strict),
|
||||
);
|
||||
token
|
||||
};
|
||||
|
||||
Ok(CsrfCookie { jar, token })
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponseParts for CsrfCookie {
|
||||
type Error = Infallible;
|
||||
|
||||
fn into_response_parts(
|
||||
self,
|
||||
res: axum::response::ResponseParts,
|
||||
) -> Result<axum::response::ResponseParts, Self::Error> {
|
||||
self.jar.into_response_parts(res)
|
||||
}
|
||||
}
|
||||
|
||||
const CSRF_TOKEN_LENGTH: usize = 64;
|
||||
fn generate_csrf_token() -> String {
|
||||
rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(CSRF_TOKEN_LENGTH)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl fmt::Debug for CsrfCookie {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("CsrfCookie").field(&self.token).finish()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
Password
|
||||
<input type=password name=password>
|
||||
</label>
|
||||
<input type=hidden name=csrf-token value="{{csrf_token}}" />
|
||||
|
||||
<button type="submit">Login</button>
|
||||
</form>
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
<html lang="en">
|
||||
<title>Logout</title>
|
||||
<form method="post">
|
||||
<input type=hidden name=csrf-token value="{{csrf_token}}" />
|
||||
<button type="submit">Logout</button>
|
||||
</form>
|
||||
</html>
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
Password
|
||||
<input type=password name=password>
|
||||
</label>
|
||||
<input type=hidden name=csrf-token value="{{csrf_token}}" />
|
||||
|
||||
<button type="submit">Signup</button>
|
||||
</form>
|
||||
|
|
Loading…
Add table
Reference in a new issue