use anyhow::{Context as _, Result}; use axum::http; use bollard::query_parameters::{CreateImageOptionsBuilder, ListContainersOptionsBuilder}; use bollard::secret::{ ContainerCreateBody, ContainerInspectResponse, ContainerState, CreateImageInfo, Health, HealthConfig, HealthStatusEnum, HostConfig, PortBinding, }; use bollard::Docker; use futures_util::{FutureExt, StreamExt as _}; use rand::distr::slice::Choose; use rand::{rng, Rng}; use reqwest::header::{HeaderName, HeaderValue}; use reqwest::Body; use select::document::Document; use select::predicate::Attr; use serde::Serialize; use sqlx::migrate::MigrateDatabase; use sqlx::{Connection, PgConnection}; use std::net::SocketAddr; use std::path::Path; use std::time::Duration; use tokio::task::JoinHandle; use tokio::time::sleep; use tracing::{debug, trace}; use zero_to_axum::{conf, Conf, ZeroToAxum}; pub struct TestServer { server_task_handle: JoinHandle<()>, addr: SocketAddr, _db: TestDb, pub mock_smtp_server: maik::MockServer, } impl TestServer { pub async fn spawn() -> TestServer { debug!("start test server"); let db = TestDb::spawn().await; let url = dbg!(db.get_url()); let mock_smtp_server = maik::MockServer::builder() .add_mailbox("bot@example.com", "1234") .assert_after_n_emails(1) .build(); mock_smtp_server.start(); let server = ZeroToAxum::serve(Conf { app: conf::App { listen: "[::]:0".parse().unwrap(), // TODO: how do I both configure this and use a random port? public_url: "http://localhost/".to_string(), }, database: conf::Database { url }, debug: true, email: Some(conf::Email { server: mock_smtp_server.host().to_string(), port: Some(mock_smtp_server.port()), username: "bot@example.com".to_owned(), password: "1234".to_owned(), sender: "bot@example.com".to_owned(), cert: Some(String::from_utf8_lossy(mock_smtp_server.cert_pem()).to_string()), }), }) .await .unwrap(); let addr = server.local_addr(); let server_task_handle = tokio::spawn(server.map(|res| res.unwrap())); debug!(?addr, "test server spawned"); TestServer { server_task_handle, addr, _db: db, mock_smtp_server, } } /// format a URL for the given path pub fn url(&self, path: &str) -> String { format!("http://{}{path}", self.addr) } /// Construct a browser-like client. pub fn browser_client(&self) -> BrowserClient { BrowserClient::new(self.addr) } /// Request a graceful shutdown and return imideately. pub async fn start_shutdown(&self) -> Result<()> { self.server_task_handle.abort(); Ok(()) } /// Request a graceful shutdown and then wait for shutdown to complete pub async fn shutdown(self) -> Result<()> { self.server_task_handle.abort(); let _ = self.server_task_handle.await; Ok(()) } } pub struct BrowserClient { inner: reqwest::Client, server: SocketAddr, csrf_token: Option, } impl BrowserClient { fn new(server: SocketAddr) -> BrowserClient { let inner = reqwest::Client::builder() .cookie_store(true) .build() .expect("build reqwest client"); BrowserClient { inner, server, csrf_token: None, } } /// format a URL for the given path fn url(&self, path: &str) -> String { println!("url: http://{}{path}", self.server); format!("http://{}{path}", self.server) } pub fn post>(&self, url: U) -> RequestBuilder { RequestBuilder::new( self.inner.post(self.url(url.as_ref())), self.csrf_token.clone(), ) } pub async fn get_csrf_token(&mut self) -> Result { // Any page with a CSRF-cookie will do, tokens are valid for a session. let resp = self.inner.get(&self.url("/auth/signup")).send().await?; assert_eq!(resp.status(), 200, "get CSRF page"); let signup_page = resp.text().await.context("recv signup page body")?; let document = Document::from(signup_page.as_str()); let csrf_node = document .find(Attr("name", "csrf-token")) .next() .context("find csrf node")?; let csrf_token = csrf_node .attr("value") .context("get csrf token from node")? .to_string(); self.csrf_token = Some(csrf_token.clone()); Ok(csrf_token) } } pub struct RequestBuilder { inner: reqwest::RequestBuilder, csrf_token: Option, } impl RequestBuilder { fn new(request: reqwest::RequestBuilder, csrf_token: Option) -> RequestBuilder { RequestBuilder { inner: request, csrf_token, } } pub fn header(mut self, key: K, value: V) -> Self where HeaderName: TryFrom, >::Error: Into, HeaderValue: TryFrom, >::Error: Into, { self.inner = self.inner.header(key, value); self } pub fn form(mut self, form: &T) -> RequestBuilder { self.inner = self.inner.form(form); self } pub fn csrf_form(mut self, form: &T) -> RequestBuilder { let body_str = serde_urlencoded::to_string(form).unwrap(); let full_body_str = format!( "{body_str}&csrf-token={}", self.csrf_token.as_ref().unwrap() ); self.inner = self .inner .body(full_body_str) .header("Content-Type", "application/x-www-form-urlencoded"); self } pub fn body>(mut self, body: T) -> Self { self.inner = self.inner.body(body); self } pub async fn send(self) -> Result { self.inner.send().await } } const TEST_DB_IMAGE_NAME: &str = "postgres"; const TEST_DB_SUPERUSER: &str = "postgres"; const TEST_DB_SUPERUSER_PASS: &str = "password"; pub struct TestDb { container: bollard::secret::ContainerInspectResponse, user: String, pass: String, name: String, } impl TestDb { pub async fn spawn() -> Self { let docker = Docker::connect_with_local_defaults().expect("connect to docker daemon"); let docker = docker.negotiate_version().await.unwrap(); let version = docker.version().await.unwrap(); trace!("version: {version:?}"); // check for existing container let mut found_containers = docker .list_containers(Some( ListContainersOptionsBuilder::new() .filters( &([( "label".to_string(), vec!["zero-to-axum_test-db".to_string()], )] .into()), ) .build(), )) .await .unwrap(); let container_id; if let Some(container) = found_containers.pop() { container_id = container.id.unwrap(); } else { // build container let mut image_id = None; // check for image if let Ok(image) = docker.inspect_image(TEST_DB_IMAGE_NAME).await { image_id = Some(image.id.unwrap()); } // build docker image from docker file // let mut image_id = None; // { // let filename = "Dockerfile.db"; // let image_options = bollard::query_parameters::BuildImageOptionsBuilder::default() // .dockerfile(filename) // .rm(true) // .build(); // let archive_bytes = { // let mut archive = tar::Builder::new(Vec::new()); // archive.append_path(filename).unwrap(); // archive.into_inner().unwrap() // }; // let mut image_build_stream = docker.build_image( // image_options, // None, // Some(http_body_util::Either::Left(http_body_util::Full::new( // archive_bytes.into(), // ))), // ); // while let Some(msg) = image_build_stream.next().await { // info!("Message: {msg:?}"); // if let Ok(BuildInfo { // aux: Some(ImageId { id: Some(id) }), // .. // }) = msg // { // trace!("Image ID: {id}"); // image_id = Some(id); // } // } // } // let image_id = image_id.expect("get image id for built docker image"); // pull image if image_id.is_none() { let image_opts = CreateImageOptionsBuilder::new() .from_image(TEST_DB_IMAGE_NAME) .build(); trace!(?image_opts, "pull image"); let mut image_create_stream = docker.create_image(Some(image_opts), None, None); while let Some(msg) = image_create_stream.next().await { trace!("Message: {msg:?}"); if let Ok(CreateImageInfo { id: Some(id), .. }) = msg { trace!("Image ID: {id}"); image_id = Some(id); } } } let image_id = image_id.expect("get image id for built docker image"); // create and start docker container { let container_config = ContainerCreateBody { image: Some(image_id.clone()), exposed_ports: Some([("5432/tcp".to_string(), [].into())].into()), host_config: Some(HostConfig { port_bindings: Some( [( "5432/tcp".to_string(), Some(vec![PortBinding { host_ip: Some("127.0.0.1".to_string()), host_port: None, // auto-assign }]), )] .into(), ), ..Default::default() }), env: Some(vec![ format!("POSTGRES_USER={TEST_DB_SUPERUSER}"), format!("POSTGRES_PASSWORD={TEST_DB_SUPERUSER_PASS}"), ]), healthcheck: Some(HealthConfig { test: Some(vec!["pg_isready -U postgres || exit 1".to_string()]), // nano seconds interval: Some(1 * 1000 * 1000 * 1000), timeout: Some(5 * 1000 * 1000 * 1000), retries: Some(5 * 1000 * 1000 * 1000), ..Default::default() }), labels: Some([("zero-to-axum_test-db".to_string(), String::new())].into()), ..Default::default() }; trace!("create container"); bollard::secret::ContainerCreateResponse { id: container_id, .. } = docker .create_container( None::, container_config, ) .await .unwrap(); trace!("start container"); docker .start_container( &container_id, None::, ) .await .unwrap(); } } // wait for container to be started let container = loop { trace!("inspect container"); let container = docker .inspect_container( &container_id, None::, ) .await .unwrap(); if let ContainerInspectResponse { state: Some(ContainerState { health: Some(Health { status: Some(status), .. }), .. }), .. } = &container { trace!("status: {status:?}"); if *status == HealthStatusEnum::HEALTHY { break container; } } sleep(Duration::from_secs(2)).await; }; let lowercase_alpha = Choose::new(b"abcdefghijklmnopqrstuvwxyz").unwrap(); let db = TestDb { container, user: rng() .sample_iter(lowercase_alpha) .take(16) .map(|i| char::from(*i)) .collect(), pass: rng() .sample_iter(lowercase_alpha) .take(32) .map(|i| char::from(*i)) .collect(), name: rng() .sample_iter(lowercase_alpha) .take(8) .map(|i| char::from(*i)) .collect(), }; // setup app db { let mut conn = PgConnection::connect(&dbg!(db.get_superuser_url())) .await .unwrap(); // create application user // Note: In general, string formtting a query is bad practice, but it's required here. dbg!( sqlx::query(&dbg!(format!( "CREATE USER {} WITH PASSWORD '{}';", db.user, db.pass ))) .execute(&mut conn) .await ) .unwrap(); // grant privs to app user // Note: In general, string formtting a query is bad practice, but it's required here. sqlx::query(&format!("ALTER USER {} CREATEDB;", db.user)) .execute(&mut conn) .await .unwrap(); } // create test db sqlx::Postgres::create_database(&dbg!(db.get_url())) .await .unwrap(); let mut conn = PgConnection::connect(&db.get_url()).await.unwrap(); // run migrations on test db let m = sqlx::migrate::Migrator::new(Path::new("./migrations")) .await .unwrap(); m.run(&mut conn).await.unwrap(); db } /// Get the authenticated URL for accessing the test DB from the host. pub fn get_url(&self) -> String { let binding = self .container .network_settings .as_ref() .unwrap() .ports .as_ref() .unwrap() .get("5432/tcp") .as_ref() .unwrap() .as_ref() .unwrap() .first() .unwrap(); let host_ip = binding.host_ip.as_ref().unwrap().clone(); let host_port = binding.host_port.as_ref().unwrap().clone(); format!( "postgres://{}:{}@{host_ip}:{host_port}/{}", self.user, self.pass, self.name ) } /// Get the superuser-authenticated URL for accessing the `postgres` db from the host. fn get_superuser_url(&self) -> String { let binding = self .container .network_settings .as_ref() .unwrap() .ports .as_ref() .unwrap() .get("5432/tcp") .as_ref() .unwrap() .as_ref() .unwrap() .first() .unwrap(); let host_ip = binding.host_ip.as_ref().unwrap().clone(); let host_port = binding.host_port.as_ref().unwrap().clone(); format!("postgres://{TEST_DB_SUPERUSER}:{TEST_DB_SUPERUSER_PASS}@{host_ip}:{host_port}/postgres") } }