Skip to content

Instantly share code, notes, and snippets.

@jorgecarleitao
Created July 18, 2025 18:01
Show Gist options
  • Save jorgecarleitao/59abcb82f9cfbeaa21305cc65fcacdce to your computer and use it in GitHub Desktop.
Save jorgecarleitao/59abcb82f9cfbeaa21305cc65fcacdce to your computer and use it in GitHub Desktop.
Rust Test Context to manage a postgres instance via docker
#![forbid(unsafe_code)]
/*
This code:
* provides a Test Context that creates a postgres DB via docker the first time it is used, that is shared across multiple tests
* removes the docker container at the end
* ensures that you can run `cargo test` (i.e. tests run in parallel)
In a gist, it does so by ensuring that we (almost always) call Drop on the (lazy globally shared) instance.
This likely fails on aborts, leaving the container running.
*/
use std::sync::{Arc, LazyLock, Mutex};
pub use test_context::{AsyncTestContext, test_context};
/// Returns a tokio_postgres
pub async fn db_client(connection: &str) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
// Connect to the database.
let (client, connection) = tokio_postgres::connect(connection, tokio_postgres::NoTls).await?;
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
Ok(client)
}
fn start_db() -> String {
let args = [
"run",
"-d",
"-e",
"POSTGRES_PASSWORD=password",
"-p",
"5432:5432",
"postgres:latest",
];
let r = std::process::Command::new("docker")
.args(args)
.output()
.unwrap();
assert!(r.status.success());
let id = String::from_utf8(r.stdout).unwrap();
for _ in 0..20 {
let r = std::process::Command::new("docker")
.args(["exec", &id, "pg_isready", "-t", "90"])
.output()
.unwrap();
if r.status.success() {
break;
};
std::thread::sleep(std::time::Duration::from_secs(1));
}
id
}
fn stop_db(id: &str) {
std::process::Command::new("docker")
.args(["rm", "-f", id])
.output()
.unwrap();
}
struct DB {
id: String,
}
impl DB {
fn new() -> Self {
Self { id: start_db() }
}
}
impl Drop for DB {
fn drop(&mut self) {
stop_db(&self.id)
}
}
static DB_INSTANCE: LazyLock<Mutex<Option<Arc<DB>>>> = LazyLock::new(|| Mutex::new(None));
pub struct DBContext {
db: Arc<DB>,
}
impl AsyncTestContext for DBContext {
async fn setup() -> DBContext {
let db = {
let mut lock = DB_INSTANCE.lock().unwrap();
if let Some(db) = &*lock {
db.clone()
} else {
let db = Arc::new(DB::new());
*lock = Some(db.clone());
db
}
};
DBContext { db }
}
async fn teardown(self) {
let mut lock = DB_INSTANCE.lock().unwrap();
let arc = lock.as_ref();
if let Some(static_db) = arc {
if Arc::strong_count(static_db) == 2 {
*lock = None; // Remove from static; since static does not call Drop
}
}
}
}
#[test_context(DBContext)]
#[tokio::test]
async fn test_db_client(ctx: &mut DBContext) -> Result<(), String> {
let client = db_client(ctx.url())
.await
.map_err(|e| e.to_string())?;
let rows = client
.query("SELECT $1::TEXT", &[&"hello world"])
.await
.map_err(|e| e.to_string())?;
let value: &str = rows[0].get(0);
assert_eq!(value, "hello world");
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment