Last active
January 25, 2025 21:28
-
-
Save briceburg/632ba62773d5ee430a40c026647cac7d to your computer and use it in GitHub Desktop.
Quick compare of postgres databases (using random sampling)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# @code-style: black | |
import argparse, os, sys | |
import atexit | |
import psycopg2 # export LIBRARY_PATH=$LIBRARY_PATH:/opt/homebrew/opt/openssl/lib when doing pip install | |
import difflib | |
import random | |
def die(errMsg): | |
print(f"ERROR: {errMsg}", file=sys.stderr) | |
exit(1) | |
def help(errMsg=None): | |
print( | |
""" | |
USAGE | |
$ db:compare <DATABASE_A_URL> <DATABASE_B_URL> [-h|--help|help] [--table-count <COUNT>] [--row-count <COUNT>] | |
db:compare postgres://user:password@host:port/database_a postgres://user:password@host:port/database_b | |
SUMMARY | |
Given two connection strings, attempt to compare databases for parity. | |
OPTIONS | |
--table-count <count> Number of tables to sample. Defaults to 20. If 0, all tables are sampled. | |
--row-count <count> Number of rows to sample. Defaults to 12. | |
""" | |
) | |
exit(0) if not errMsg else die(errMsg) | |
class Sdiffer: | |
def __init__(self, width=130): | |
self.width = width | |
self.half_width = (width - 3) // 2 | |
def format_line(self, left, right): | |
left = str(left).ljust(self.half_width) | |
right = str(right).ljust(self.half_width) | |
return f"{left} | {right}" | |
def print_sdiff(self, a_lines, b_lines): | |
a_lines = [str(x) for x in a_lines] | |
b_lines = [str(x) for x in b_lines] | |
ret = False | |
matcher = difflib.SequenceMatcher(None, a_lines, b_lines) | |
for tag, a_start, a_end, b_start, b_end in matcher.get_opcodes(): | |
if tag == "equal": | |
continue | |
elif tag == "replace": | |
ret = True | |
for i, j in zip(range(a_start, a_end), range(b_start, b_end)): | |
print(self.format_line(f" [!] {a_lines[i]}", f"{b_lines[j]}")) | |
elif tag == "delete": | |
ret = True | |
for i in range(a_start, a_end): | |
print(self.format_line(f" [-] {a_lines[i]}", "")) | |
elif tag == "insert": | |
ret = True | |
for i in range(b_start, b_end): | |
print(self.format_line("", f" [+] {b_lines[i]}")) | |
return ret | |
def diff_query( | |
query: str, | |
conn_a: psycopg2.extensions.connection, | |
conn_b: psycopg2.extensions.connection, | |
**kwargs, | |
): | |
with conn_a.cursor() as cursor_a, conn_b.cursor() as cursor_b: | |
cursor_a.execute(query, kwargs.get("query_vars", ())) | |
cursor_b.execute(query, kwargs.get("query_vars", ())) | |
db_a_results = cursor_a.fetchall() | |
db_b_results = cursor_b.fetchall() | |
diff = Sdiffer().print_sdiff(db_a_results, db_b_results) | |
if diff: | |
if kwargs.get("outliers", False): | |
print(("\n calculating outliers...")) | |
keys = [] | |
drift = [] | |
max_key_len = 0 | |
for index, item in enumerate(db_a_results): | |
ka, va = item | |
kb, vb = db_b_results[index] | |
if ka == kb and va != vb: | |
keys.append(ka) | |
drift.append(abs(va - vb)) | |
max_key_len = max(max_key_len, len(ka)) | |
for v, k in sorted(zip(drift, keys), key=lambda x: x[0], reverse=True)[ | |
:12 | |
]: | |
print(f"- {k:<{max_key_len + 1}}: {v}") | |
print(" [err] data is different") | |
else: | |
print(" [ok] data is identical") | |
return diff, db_a_results, db_b_results | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("db-compare", add_help=False) | |
parser.add_argument("-h", "--help", action="store_true") | |
parser.add_argument("--table-count", type=int, default=20) | |
parser.add_argument("--row-count", type=int, default=12) | |
args, unknown = parser.parse_known_args() | |
if args.help or "help" in unknown: | |
help() | |
elif not unknown: | |
help("please provide database connection strings") | |
elif len(unknown) != 2: | |
help("please provide two database connection strings") | |
else: | |
connstr_a, connstr_b = unknown | |
def cleanup(): | |
try: | |
db_a.close() | |
db_b.close() | |
except Exception as e: | |
pass | |
atexit.register(cleanup) | |
db_a = psycopg2.connect(connstr_a) | |
db_b = psycopg2.connect(connstr_b) | |
print("\nCOMPARING VERSIONS\n") | |
diff_query("SHOW server_version;", db_a, db_b) | |
print("\nCOMPARING TABLE NAMES\n") | |
diff, db_a_tables, db_b_tables = diff_query( | |
"SELECT table_name FROM information_schema.tables WHERE table_schema='public' AND table_type='BASE TABLE' ORDER BY table_name;", | |
db_a, | |
db_b, | |
) | |
print("\nCOMPARING SEQUENCE NAMES\n") | |
diff, db_a_sequences, db_b_sequences = diff_query( | |
"SELECT sequence_name FROM information_schema.sequences ORDER BY sequence_name;", | |
db_a, | |
db_b, | |
) | |
if diff: | |
die("please sync sequences before continuing") | |
print("\nCOMPARING SEQUENCE VALUES\n") | |
diff_query( | |
"SELECT sequencename, last_value FROM pg_sequences ORDER BY sequencename;", | |
db_a, | |
db_b, | |
outliers=True, | |
) | |
print("\nCOMPARING PG_STAT ROW COUNTS ACROSS PUBLIC TABLES\n") | |
diff_query( | |
"SELECT relname, n_live_tup FROM pg_stat_user_tables WHERE schemaname = 'public' ORDER BY relname;", | |
db_a, | |
db_b, | |
outliers=True, | |
) | |
sample_table_count = args.table_count | |
sample_row_count = args.row_count | |
sample_tables = ( | |
random.sample([table[0] for table in db_a_tables], len(db_a_tables)) | |
if sample_table_count > 0 | |
else [table[0] for table in db_a_tables] | |
) | |
print(f"\nSAMPLING DATA FROM {sample_table_count} RANDOM TABLES\n") | |
try: | |
with db_a.cursor() as cursor_a: | |
query = "CREATE EXTENSION IF NOT EXISTS tsm_system_rows;" | |
cursor_a.execute(query) | |
except Exception as e: | |
die(str(e)) | |
count = 0 | |
for table in sample_tables: | |
if count > sample_table_count: | |
break | |
with db_a.cursor() as cursor_a, db_b.cursor() as cursor_b: | |
try: | |
query = f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table}' AND column_name = 'id';" | |
cursor_a.execute(query) | |
if cursor_a.fetchone()[0] != "id": | |
print(f" skipping {table} - no id column") | |
continue | |
query = f"SELECT id FROM {table} TABLESAMPLE system_rows({sample_row_count});" | |
cursor_a.execute(query) | |
ids = [record[0] for record in cursor_a] | |
if len(ids) < 1: | |
print(f" skipping {table} - no rows") | |
continue | |
count += 1 | |
print(f"{table} - sampling ids {', '.join([str(id) for id in ids])}") | |
diff_query( | |
f"SELECT * FROM {table} WHERE id=ANY(%s) ORDER BY id LIMIT {sample_row_count};", | |
db_a, | |
db_b, | |
query_vars=(ids,), | |
) | |
except Exception as e: | |
print(f" skipping {table} - {e}") | |
continue | |
else: | |
help("db-compare cannot be imported at this time") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment