diff --git a/Cargo.lock b/Cargo.lock index 8c5a5c4..06f2ed7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -121,6 +121,72 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" +[[package]] +name = "askama" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f75363874b771be265f4ffe307ca705ef6f3baa19011c149da8674a87f1b75c4" +dependencies = [ + "askama_derive", + "itoa", + "percent-encoding", + "serde", + "serde_json", +] + +[[package]] +name = "askama_derive" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "129397200fe83088e8a68407a8e2b1f826cf0086b21ccdb866a722c8bcd3a94f" +dependencies = [ + "askama_parser", + "basic-toml", + "memchr", + "proc-macro2", + "quote", + "rustc-hash", + "serde", + "serde_derive", + "syn", +] + +[[package]] +name = "askama_parser" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6ab5630b3d5eaf232620167977f95eb51f3432fc76852328774afbd242d4358" +dependencies = [ + "memchr", + "serde", + "serde_derive", + "winnow", +] + +[[package]] +name = "askama_web" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83731f1a2286209c2b679445e8faaa53270646a90c509bf92729e966d198cb6b" +dependencies = [ + "askama", + "askama_web_derive", + "axum-core", + "bytes", + "http", +] + +[[package]] +name = "askama_web_derive" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34921de3d57974069bad483fdfe0ec65d88c4ff892edd1ab4d8b03be0dda1b9b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -276,6 +342,15 @@ version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" +[[package]] +name = "basic-toml" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba62675e8242a4c4e806d12f11d136e626e6c8361d6b829310732241652a178a" +dependencies = [ + "serde", +] + [[package]] name = "bitflags" version = "2.9.0" @@ -798,6 +873,20 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1503,6 +1592,7 @@ checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" dependencies = [ "autocfg", "scopeguard", + "serde", ] [[package]] @@ -2267,6 +2357,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustix" version = "1.0.7" @@ -3114,6 +3210,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-cookies" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "151b5a3e3c45df17466454bb74e9ecedecc955269bdedbf4d150dfa393b55a36" +dependencies = [ + "axum-core", + "cookie", + "futures-util", + "http", + "parking_lot", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-http" version = "0.6.6" @@ -3142,6 +3254,57 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" +[[package]] +name = "tower-sessions" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a05911f23e8fae446005fe9b7b97e66d95b6db589dc1c4d59f6a2d4d4927d3" +dependencies = [ + "async-trait", + "http", + "time", + "tokio", + "tower-cookies", + "tower-layer", + "tower-service", + "tower-sessions-core", + "tower-sessions-memory-store", + "tracing", +] + +[[package]] +name = "tower-sessions-core" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce8cce604865576b7751b7a6bc3058f754569a60d689328bb74c52b1d87e355b" +dependencies = [ + "async-trait", + "axum-core", + "base64 0.22.1", + "futures", + "http", + "parking_lot", + "rand 0.8.5", + "serde", + "serde_json", + "thiserror", + "time", + "tokio", + "tracing", +] + +[[package]] +name = "tower-sessions-memory-store" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb05909f2e1420135a831dd5df9f5596d69196d0a64c3499ca474c4bd3d33242" +dependencies = [ + "async-trait", + "time", + "tokio", + "tower-sessions-core", +] + [[package]] name = "tracing" version = "0.1.41" @@ -3311,6 +3474,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" dependencies = [ "getrandom 0.3.3", + "serde", ] [[package]] @@ -3871,6 +4035,8 @@ version = "0.1.0-dev" dependencies = [ "anyhow", "argon2", + "askama", + "askama_web", "axum", "axum-extra", "bollard", @@ -3896,6 +4062,7 @@ dependencies = [ "tokio-util", "tower", "tower-http", + "tower-sessions", "tracing", "tracing-subscriber", "uuid", diff --git a/Cargo.toml b/Cargo.toml index 1e3f909..66bf9a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,8 @@ license = "MIT OR Apache-2.0" [dependencies] anyhow = { version = "1.0.71", features = ["backtrace"] } argon2 = "0.5.3" +askama = "0.14.0" +askama_web = { version = "0.14.4", features = ["axum-0.8"] } axum = { version = "0.8", features = ["tokio", "http1", "http2", "macros"] } axum-extra = { version = "0.10", features = ["cookie-private", "typed-header"] } config = { version = "0.15", features = ["toml"] } @@ -29,9 +31,10 @@ tokio-stream = "0.1" tokio-util = "0.7.15" tower = "0.5.2" tower-http = { version = "0.6.6", features = ["trace"] } +tower-sessions = "0.14.0" tracing = "0.1.37" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } -uuid = { version = "1.16.0", features = ["v4"] } +uuid = { version = "1.16.0", features = ["serde", "v4"] } [dev-dependencies] bollard = { git = "https://github.com/fussybeaver/bollard.git", rev = "50a25a0" } diff --git a/src/server/mod.rs b/src/server/mod.rs index e82e987..259278f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,4 +1,5 @@ pub mod routes; +pub mod session; use anyhow::{Context as _, Result}; use axum::extract::FromRef; @@ -13,6 +14,8 @@ use std::pin::Pin; use std::sync::Arc; use tokio::signal; use tower_http::trace::TraceLayer; +use tower_sessions::cookie::time::Duration; +use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; use tracing::info; use crate::email_client::EmailClient; @@ -58,9 +61,16 @@ impl ZeroToAxum { email_client, }; + // Just store locally for now. Supports database connections. + let session_store = MemoryStore::default(); + let session_layer = SessionManagerLayer::new(session_store) + .with_secure(false) + .with_expiry(Expiry::OnInactivity(Duration::weeks(1))); + let app = routes::build() .with_state(app_state) - .layer(TraceLayer::new_for_http()); + .layer(TraceLayer::new_for_http()) + .layer(session_layer); let listener = tokio::net::TcpListener::bind(&conf.app.listen) .await diff --git a/src/server/routes/auth/mod.rs b/src/server/routes/auth/mod.rs index cd0c374..8fffe2c 100644 --- a/src/server/routes/auth/mod.rs +++ b/src/server/routes/auth/mod.rs @@ -2,20 +2,30 @@ use std::fmt; use anyhow::Context as _; use argon2::Argon2; -use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::post, Form, Router}; -use axum_extra::extract::cookie::{Cookie, PrivateCookieJar}; +use askama::Template; +use askama_web::WebTemplate; +use axum::{ + extract::State, + http::StatusCode, + response::{IntoResponse, Redirect}, + routing::{get, post}, + Form, Router, +}; use password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString}; use serde::Deserialize; use tracing::{error, info, warn}; use uuid::Uuid; -use crate::server::AppState; +use crate::server::{ + session::{Auth, User}, + AppState, +}; pub fn build() -> Router { Router::new() .route("/signup", post(signup)) - .route("/login", post(login)) - .route("/logout", post(logout)) + .route("/login", get(login_page).post(login)) + .route("/logout", get(logout_page).post(logout)) } #[derive(Deserialize)] @@ -33,12 +43,12 @@ impl fmt::Debug for SignupForm { } } -#[tracing::instrument(skip(db))] +#[tracing::instrument(skip(db, user))] pub async fn signup( State(AppState { db, .. }): State, - jar: PrivateCookieJar, + mut user: Auth, Form(form): Form, -) -> Result { +) -> Result { info!("signup attempt"); info!("hash password: {}", &form.password); @@ -72,9 +82,10 @@ pub async fn signup( .await .context("insert new user into database")?; - let authed_jar = jar.add(Cookie::new("username", form.email)); + user.user = Some(User::new(user_id)); + user.session.insert("user", &user.user).await.unwrap(); - Ok(authed_jar) + Ok(Redirect::to("/")) } #[derive(thiserror::Error, Debug)] @@ -95,6 +106,19 @@ impl IntoResponse for SignupError { } } +#[derive(Template, WebTemplate, Default)] +#[template(path = "login.html")] +struct LoginPage { + error: Option, +} + +#[tracing::instrument] +pub async fn login_page() -> LoginPage { + info!("get login page"); + + Default::default() +} + #[derive(Deserialize)] pub struct LoginForm { email: String, @@ -110,32 +134,32 @@ impl fmt::Debug for LoginForm { } } -#[tracing::instrument(skip(db))] +#[tracing::instrument(skip(db, auth))] pub async fn login( State(AppState { db, .. }): State, - jar: PrivateCookieJar, + mut auth: Auth, Form(form): Form, -) -> Result { +) -> Result { info!("login attempt"); - let password_hash = sqlx::query!( + let user = sqlx::query!( r#" - SELECT password FROM users WHERE email = $1 LIMIT 1; + SELECT id, password FROM users WHERE email = $1 LIMIT 1; "#, form.email ) .fetch_one(&db) .await; - let password_hash = match password_hash { - Ok(ph) => ph, - Err(sqlx::Error::RowNotFound) => return Err(LoginError::UnknownUser), - Err(e) => Err(e).context("get user info from db")?, - }; + if matches!(user, Err(sqlx::Error::RowNotFound)) { + return Err(LoginError::UnknownUser); + } + + let user = user.context("get user info from db")?; tokio::task::spawn_blocking(async move || { let parsed_hash = - PasswordHash::new(&password_hash.password).context("parse stored password hash")?; + PasswordHash::new(&user.password).context("parse stored password hash")?; match Argon2::default().verify_password(form.password.as_bytes(), &parsed_hash) { Ok(()) => Ok(()), @@ -147,9 +171,17 @@ pub async fn login( .context("spawn password verifier task")? .await?; - let authed_jar = jar.add(Cookie::new("username", "admin")); + auth.user = Some(User::new(user.id)); + auth.session + .cycle_id() + .await + .context("refresh session id")?; + auth.session + .insert("user", &auth.user) + .await + .context("set user data in session")?; - Ok(authed_jar) + Ok(Redirect::to("/")) } #[derive(thiserror::Error, Debug)] @@ -164,37 +196,66 @@ pub enum LoginError { impl IntoResponse for LoginError { fn into_response(self) -> axum::response::Response { - match self { + let (status, message) = match self { LoginError::InvalidPassword => (StatusCode::UNAUTHORIZED, "Invalid Password"), LoginError::UnknownUser => (StatusCode::UNAUTHORIZED, "Unknown User"), LoginError::Unknown(e) => { error!(?e, "returning INTERNAL SERVER ERROR"); (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") } + }; + + ( + status, + LoginPage { + error: Some(message.to_string()), + }, + ) + .into_response() + } +} + +#[derive(Template, WebTemplate)] +#[template(path = "logout.html")] +struct LogoutPage; + +#[tracing::instrument] +pub async fn logout_page() -> LogoutPage { + info!("get logout page"); + + LogoutPage +} + +pub async fn logout(mut user: Auth) -> Result { + info!("logout attempt"); + + if user.user.is_none() { + return Err(LogoutError::NotLoggedIn); + } + + user.user = None; + user.session.flush().await.context("flush user session")?; + + Ok(Redirect::to("/")) +} + +#[derive(thiserror::Error, Debug)] +pub enum LogoutError { + #[error("Not Logged In")] + NotLoggedIn, + #[error("Unknown Error: {0}")] + Unknown(#[from] anyhow::Error), +} + +impl IntoResponse for LogoutError { + fn into_response(self) -> axum::response::Response { + match self { + LogoutError::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Unknown User"), + LogoutError::Unknown(e) => { + error!(?e, "returning INTERNAL SERVER ERROR"); + (StatusCode::INTERNAL_SERVER_ERROR, "Unknown Error") + } } .into_response() } } - -pub async fn logout(jar: PrivateCookieJar) -> Result { - info!("logout attempt"); - - if jar.get("username").is_none() { - return Err(LogoutError::NotLoggedIn); - } - - Ok(jar.remove("username")) -} - -pub enum LogoutError { - NotLoggedIn, -} - -impl IntoResponse for LogoutError { - fn into_response(self) -> axum::response::Response { - match self { - LogoutError::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Unknown User"), - } - .into_response() - } -} diff --git a/src/server/routes/mod.rs b/src/server/routes/mod.rs index e132a61..fd22d7a 100644 --- a/src/server/routes/mod.rs +++ b/src/server/routes/mod.rs @@ -1,16 +1,31 @@ mod auth; mod subscriptions; +use askama::Template; +use askama_web::WebTemplate; use axum::{routing::get, Router}; -use super::AppState; +use super::{session::Auth, AppState}; pub fn build() -> Router { Router::new() + .route("/", get(homepage)) .route("/health", get(health_check)) .nest("/auth", auth::build()) .nest("/subscriptions", subscriptions::build()) } // just always returns a 200 OK for now, the server has no state, if it's up, it's working -pub async fn health_check() {} +async fn health_check() {} + +#[derive(Template, WebTemplate)] +#[template(path = "homepage.html")] +struct Homepage { + is_logged_in: bool, +} + +async fn homepage(user: Auth) -> Homepage { + Homepage { + is_logged_in: user.user.is_some(), + } +} diff --git a/src/server/session/mod.rs b/src/server/session/mod.rs new file mode 100644 index 0000000..fb67389 --- /dev/null +++ b/src/server/session/mod.rs @@ -0,0 +1,42 @@ +use axum::{ + extract::FromRequestParts, + http::{request::Parts, StatusCode}, +}; +use serde::{Deserialize, Serialize}; +use tower_sessions::Session; +use uuid::Uuid; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct User { + pub id: Uuid, + // roles: Vec, +} + +impl User { + pub fn new(id: Uuid) -> User { + User { id } + } +} + +pub struct Auth { + // id: Uuid, + pub session: Session, + pub user: Option, +} + +impl FromRequestParts for Auth +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let session = Session::from_request_parts(parts, state).await?; + + let user: Option = session.get("user").await.unwrap_or_default(); + + // session.insert("user", &data).await.unwrap(); + + Ok(Self { session, user }) + } +} diff --git a/templates/homepage.html b/templates/homepage.html new file mode 100644 index 0000000..d50a761 --- /dev/null +++ b/templates/homepage.html @@ -0,0 +1,16 @@ + + + Login +

Welcome!

+

Pages

+ + {% if is_logged_in %} +

Super Secret Pages

+ + {% endif %} + diff --git a/templates/login.html b/templates/login.html new file mode 100644 index 0000000..1feeb9e --- /dev/null +++ b/templates/login.html @@ -0,0 +1,19 @@ + + + Login + {%if let Some(msg) = error %} +

{{msg}}

+ {% endif %} +
+ + + + +
+ diff --git a/templates/logout.html b/templates/logout.html new file mode 100644 index 0000000..2465ff0 --- /dev/null +++ b/templates/logout.html @@ -0,0 +1,7 @@ + + + Logout +
+ +
+ diff --git a/tests/auth.rs b/tests/auth.rs index 8f56487..040df8d 100644 --- a/tests/auth.rs +++ b/tests/auth.rs @@ -28,23 +28,11 @@ async fn login_succeeds_with_valid_credentials() -> Result<()> { .await?; assert_eq!(resp.status(), 200, "login succeeds"); - assert!( - resp.headers().get("Set-Cookie").is_some(), - "cookie set on successful login" - ); // Logout let resp = client.post(server.url("/auth/logout")).send().await?; assert_eq!(resp.status(), 200, "logout succeeds"); - let set_cookie = resp - .headers() - .get("Set-Cookie") - .expect("logout has set-cookie header"); - assert!( - set_cookie.to_str().unwrap().starts_with("username=;"), - "cookie unset on sucessful logout" - ); server.shutdown().await }