|
#!/usr/bin/env python3 |
|
# -*- coding: utf-8 -*- |
|
""" |
|
A script to check that a load balancer is functioning correctly by: |
|
1) Looking up DNS records (A/AAAA for HTTP, SVCB/HTTPS for TLS; fallback to A/AAAA). |
|
2) Checking both HTTP and HTTPS for each discovered IP. |
|
3) Using a raw socket for HTTP and an SSL socket with SNI for HTTPS, |
|
ensuring the correct certificate is presented/validated. |
|
4) Noting redirects (3xx + Location header). |
|
5) Optionally doing reverse DNS on each IP (skippable via -n). |
|
6) Providing a DNS Source icon (🌐 or 🔐) indicating how the IP was discovered. |
|
7) Collecting connect-time and response-time metrics if --timing is requested. |
|
8) Reporting overall "success"/"fail" at the end or in the JSON. |
|
9) Printing the table with dynamically calculated column widths (minus ANSI codes), |
|
so everything is neatly aligned. |
|
|
|
Dependencies: |
|
pip install dnspython |
|
|
|
Example usage: |
|
./check-lb.py https://cp.mdt.zone |
|
./check-lb.py https://cp.mdt.zone --no-reverse |
|
./check-lb.py https://cp.mdt.zone --json |
|
./check-lb.py https://cp.mdt.zone --timing |
|
""" |
|
|
|
import argparse |
|
import json |
|
import re |
|
import socket |
|
import ssl |
|
import sys |
|
import time |
|
import urllib.parse |
|
from typing import Dict, List, Optional, Union |
|
|
|
import dns.exception |
|
import dns.resolver |
|
|
|
# Create our own DNS resolver rather than using dns.resolver.default_resolver |
|
resolver = dns.resolver.Resolver(configure=True) |
|
resolver.lifetime = 5.0 |
|
|
|
ANSI_ESCAPE_PATTERN = re.compile(r'\x1b\[[0-9;]*[A-Za-z]') |
|
|
|
|
|
def strip_ansi(s: str) -> str: |
|
"""Remove ANSI escape sequences from a string, so we can measure its real length.""" |
|
return ANSI_ESCAPE_PATTERN.sub('', s) |
|
|
|
|
|
def reverse_dns(ip_addr: str) -> Optional[str]: |
|
"""Perform a reverse DNS lookup on ip_addr. |
|
|
|
Args: |
|
ip_addr: The IP address to look up. |
|
|
|
Returns: |
|
Hostname if reverse DNS is found, otherwise None. |
|
""" |
|
try: |
|
host_info = socket.gethostbyaddr(ip_addr) |
|
return host_info[0] |
|
except (socket.herror, socket.gaierror): |
|
return None |
|
|
|
|
|
def resolve_a_aaaa(domain: str) -> (List[str], List[str]): |
|
"""Resolve A and AAAA records for the domain. |
|
|
|
Args: |
|
domain: The domain to query. |
|
|
|
Returns: |
|
A tuple of (ipv4_list, ipv6_list). |
|
""" |
|
ipv4_addresses: List[str] = [] |
|
ipv6_addresses: List[str] = [] |
|
|
|
# Resolve A |
|
try: |
|
answers_a = resolver.resolve(domain, "A") |
|
for rdata in answers_a: |
|
addr = rdata.address |
|
if addr not in ipv4_addresses: |
|
ipv4_addresses.append(addr) |
|
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.exception.Timeout): |
|
pass |
|
|
|
# Resolve AAAA |
|
try: |
|
answers_aaaa = resolver.resolve(domain, "AAAA") |
|
for rdata in answers_aaaa: |
|
addr = rdata.address |
|
if addr not in ipv6_addresses: |
|
ipv6_addresses.append(addr) |
|
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.exception.Timeout): |
|
pass |
|
|
|
return ipv4_addresses, ipv6_addresses |
|
|
|
|
|
def resolve_https_records(domain: str) -> List[str]: |
|
"""Resolve SVCB/HTTPS (type=65) records for the domain (RFC 8499). |
|
|
|
If the records have 'ipv4hint' or 'ipv6hint', we gather them. |
|
|
|
Args: |
|
domain: The domain to query. |
|
|
|
Returns: |
|
A list of IP addresses (v4 or v6). Possibly empty if none found. |
|
""" |
|
ip_list: List[str] = [] |
|
try: |
|
answers = resolver.resolve(domain, "HTTPS") |
|
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.exception.Timeout): |
|
return ip_list |
|
|
|
for ans in answers: |
|
if hasattr(ans, "params"): |
|
svcparams = getattr(ans, "params", {}) |
|
if "ipv4hint" in svcparams: |
|
for v4 in svcparams["ipv4hint"]: |
|
ip_list.append(str(v4)) |
|
if "ipv6hint" in svcparams: |
|
for v6 in svcparams["ipv6hint"]: |
|
ip_list.append(str(v6)) |
|
|
|
return list(set(ip_list)) |
|
|
|
|
|
def get_ip_family(ip_addr: str) -> str: |
|
"""Return 'IPv4' or 'IPv6' for the given IP.""" |
|
if ":" in ip_addr: |
|
return "IPv6" |
|
return "IPv4" |
|
|
|
|
|
def build_hostport(ip_addr: str, scheme: str) -> (str, int): |
|
"""Return (host, port) for the IP address + scheme. |
|
|
|
HTTP => port 80 |
|
HTTPS => port 443 |
|
""" |
|
port = 80 if scheme == "http" else 443 |
|
return ip_addr, port |
|
|
|
|
|
def http_request( |
|
ip_addr: str, |
|
domain: str, |
|
scheme: str, |
|
path: str, |
|
do_reverse: bool, |
|
dns_source: str, |
|
measure_timing: bool |
|
) -> Dict[str, Union[str, int, float, None]]: |
|
"""Perform a minimal HTTP/1.1 request to ip_addr, optionally over TLS with SNI. |
|
|
|
We do NOT follow redirects, just read the response headers to get status code + location. |
|
|
|
Certificate validation is done if scheme is HTTPS, with SNI=domain. |
|
|
|
Args: |
|
ip_addr: The IP address we're connecting to. |
|
domain: The domain (for Host header and TLS SNI). |
|
scheme: "http" or "https". |
|
path: The path/query/fragment. |
|
do_reverse: Whether to attempt reverse DNS. |
|
dns_source: "A/AAAA" or "SVCB" (the source of the IP). |
|
measure_timing: If True, record connect time and response time in ms. |
|
|
|
Returns: |
|
A dict with keys: |
|
- "dns_source": "A/AAAA" or "SVCB" |
|
- "ip": raw IP address |
|
- "hostname": reverse-DNS if found, else None |
|
- "family": "IPv4" or "IPv6" |
|
- "scheme": "http" or "https" |
|
- "status": int or "TLS"/"ERR" |
|
- "location_header": str or None |
|
- "error_message": str or None |
|
- "connect_time_ms": float or None |
|
- "response_time_ms": float or None |
|
""" |
|
result: Dict[str, Union[str, int, float, None]] = { |
|
"dns_source": dns_source, |
|
"ip": ip_addr, |
|
"hostname": None, |
|
"family": get_ip_family(ip_addr), |
|
"scheme": scheme, |
|
"status": None, |
|
"location_header": None, |
|
"error_message": None, |
|
"connect_time_ms": None, |
|
"response_time_ms": None, |
|
} |
|
|
|
if do_reverse: |
|
rdns = reverse_dns(ip_addr) |
|
if rdns and rdns != ip_addr: |
|
result["hostname"] = rdns |
|
|
|
host, port = build_hostport(ip_addr, scheme) |
|
|
|
sock: Optional[socket.socket] = None |
|
ssl_sock: Optional[ssl.SSLSocket] = None |
|
|
|
connect_start = time.time() |
|
try: |
|
sock = socket.create_connection((host, port), timeout=10) |
|
if scheme == "https": |
|
context = ssl.create_default_context() |
|
context.check_hostname = True |
|
context.verify_mode = ssl.CERT_REQUIRED |
|
ssl_sock = context.wrap_socket(sock, server_hostname=domain) |
|
sock_to_send = ssl_sock |
|
else: |
|
sock_to_send = sock |
|
|
|
connect_end = time.time() |
|
|
|
# If measuring, record connect time (ms) |
|
if measure_timing: |
|
connect_time = (connect_end - connect_start) * 1000.0 |
|
result["connect_time_ms"] = round(connect_time, 3) |
|
|
|
sock_to_send.settimeout(10) |
|
|
|
# Now send the request |
|
request_headers = [ |
|
f"GET {path} HTTP/1.1", |
|
f"Host: {domain}", |
|
"User-Agent: check-lb/1.0", |
|
"Accept: */*", |
|
"Connection: close", |
|
"", |
|
"" |
|
] |
|
request_data = "\r\n".join(request_headers).encode("utf-8") |
|
|
|
read_start = time.time() |
|
sock_to_send.sendall(request_data) |
|
|
|
raw_response = b"" |
|
while True: |
|
chunk = sock_to_send.recv(1024) |
|
if not chunk: |
|
break |
|
raw_response += chunk |
|
# Stop at end of headers |
|
if b"\r\n\r\n" in raw_response: |
|
break |
|
|
|
read_end = time.time() |
|
# If measuring, record response time (ms) |
|
if measure_timing: |
|
response_time = (read_end - read_start) * 1000.0 |
|
result["response_time_ms"] = round(response_time, 3) |
|
|
|
response_str = raw_response.decode("iso-8859-1", errors="replace") |
|
head_part, _, _ = response_str.partition("\r\n\r\n") |
|
lines = head_part.split("\r\n") |
|
if not lines: |
|
result["status"] = "ERR" |
|
result["error_message"] = "No response data" |
|
return result |
|
|
|
first_line = lines[0] |
|
match = re.match(r"^HTTP/\d\.\d\s+(\d+)", first_line) |
|
if not match: |
|
result["status"] = "ERR" |
|
result["error_message"] = f"Malformed status line: {first_line}" |
|
return result |
|
|
|
code_str = match.group(1) |
|
status_code = int(code_str) |
|
result["status"] = status_code |
|
|
|
if 300 <= status_code < 400: |
|
# Look for a Location header |
|
for line in lines[1:]: |
|
if line.lower().startswith("location:"): |
|
_, _, val = line.partition(":") |
|
location_val = val.strip() |
|
result["location_header"] = location_val |
|
break |
|
|
|
except ssl.SSLError as ssl_ex: |
|
result["status"] = "TLS" |
|
result["error_message"] = str(ssl_ex) |
|
except socket.timeout as t_ex: |
|
result["status"] = "ERR" |
|
result["error_message"] = f"Timeout: {t_ex}" |
|
except OSError as os_ex: |
|
result["status"] = "ERR" |
|
result["error_message"] = str(os_ex) |
|
except Exception as ex: |
|
result["status"] = "ERR" |
|
result["error_message"] = str(ex) |
|
finally: |
|
if ssl_sock is not None: |
|
ssl_sock.close() |
|
elif sock is not None: |
|
sock.close() |
|
|
|
return result |
|
|
|
|
|
def compute_overall_status(results: List[Dict[str, Union[str, int, float, None]]]) -> str: |
|
"""Compute overall success/fail based on the results. |
|
|
|
'success' if every row is a numeric status < 400. |
|
Any row with 'TLS', 'ERR', or >= 400 => fail. |
|
""" |
|
for r in results: |
|
st = r["status"] |
|
if isinstance(st, int): |
|
if st >= 400: |
|
return "fail" |
|
else: |
|
# 'TLS' or 'ERR' |
|
return "fail" |
|
return "success" |
|
|
|
|
|
def print_table( |
|
domain: str, |
|
url: str, |
|
results: List[Dict[str, Union[str, int, float, None]]], |
|
measure_timing: bool |
|
) -> None: |
|
"""Print a table of the results using dynamically measured column widths, |
|
plus overall status at the end. |
|
|
|
We gather each row's final string for each column, strip ANSI codes for |
|
measuring width, then compute the max needed width. For "Redirect / Error", |
|
we cap at 60 to avoid super-wide lines. |
|
|
|
Columns: |
|
1) DNS (icon) |
|
2) Hostname |
|
3) IP |
|
4) Proto |
|
5) Status |
|
6) (optional) ConnT |
|
7) (optional) RespT |
|
8) Redirect / Error |
|
""" |
|
GREEN = "\033[32m" |
|
RED = "\033[31m" |
|
YELLOW = "\033[33m" |
|
RESET = "\033[0m" |
|
|
|
# We'll build a list of rows (each row is a dict of column -> string). |
|
# Then we measure the final strings (minus ANSI) to determine widths. |
|
|
|
table_rows: List[Dict[str, str]] = [] |
|
# We'll define column keys in the order we want them displayed: |
|
col_keys = ["dns_icon", "hostname", "ip", "scheme", "status_str"] |
|
if measure_timing: |
|
col_keys.extend(["conn_s", "resp_s"]) |
|
col_keys.append("redirect_str") |
|
|
|
# We'll store each column's heading for reference |
|
headings = { |
|
"dns_icon": "DNS", |
|
"hostname": "Hostname", |
|
"ip": "IP", |
|
"scheme": "Proto", |
|
"status_str": "St", |
|
"conn_s": "ConnT", |
|
"resp_s": "RespT", |
|
"redirect_str": "Redirect / Error", |
|
} |
|
|
|
# We'll build the row data as strings |
|
for r in results: |
|
dns_src = r["dns_source"] if isinstance(r["dns_source"], str) else "?" |
|
dns_icon = "🔐" if dns_src == "SVCB" else "🌐" |
|
|
|
hostname = r["hostname"] if isinstance(r["hostname"], str) else "" |
|
ip_str = r["ip"] if isinstance(r["ip"], str) else "" |
|
# If reverse DNS is same as IP, just blank out the hostname to reduce clutter |
|
if hostname == ip_str: |
|
hostname = "" |
|
|
|
scheme = r["scheme"] if isinstance(r["scheme"], str) else "?" |
|
status = r["status"] |
|
loc = r["location_header"] if isinstance(r["location_header"], str) else None |
|
err = r["error_message"] if isinstance(r["error_message"], str) else None |
|
ct = r["connect_time_ms"] |
|
rt = r["response_time_ms"] |
|
|
|
# Build the status string (with ANSI for color) |
|
if isinstance(status, int): |
|
code = status |
|
if 200 <= code < 300: |
|
color = GREEN |
|
emoji = "✅" |
|
elif 300 <= code < 400: |
|
color = YELLOW |
|
emoji = "↪️" |
|
else: |
|
color = RED |
|
emoji = "❌" |
|
status_str = f"{color}{code}{RESET} {emoji}" |
|
redirect_str = loc if loc else "" |
|
else: |
|
# 'TLS' or 'ERR' |
|
if status == "TLS": |
|
color = RED |
|
emoji = "🔒" |
|
status_str = f"{color}TLS{RESET} {emoji}" |
|
redirect_str = err if err else "(cert error)" |
|
else: |
|
color = RED |
|
emoji = "❌" |
|
status_str = f"{color}{status}{RESET} {emoji}" |
|
redirect_str = err if err else "" |
|
|
|
# Cap redirect/error at 60 chars |
|
if len(redirect_str) > 60: |
|
redirect_str = redirect_str[:57] + "..." |
|
|
|
# Build timing strings |
|
conn_s = "-" |
|
if measure_timing and isinstance(ct, float): |
|
conn_s = f"{ct:.1f}" |
|
|
|
resp_s = "-" |
|
if measure_timing and isinstance(rt, float): |
|
resp_s = f"{rt:.1f}" |
|
|
|
row_dict = { |
|
"dns_icon": dns_icon, |
|
"hostname": hostname, |
|
"ip": ip_str, |
|
"scheme": scheme, |
|
"status_str": status_str, |
|
"conn_s": conn_s, |
|
"resp_s": resp_s, |
|
"redirect_str": redirect_str, |
|
} |
|
table_rows.append(row_dict) |
|
|
|
# Now measure each column's max width |
|
# We'll ensure it's at least the heading's length (minus ANSI) and at least 2 wide. |
|
col_widths = {} |
|
for col in col_keys: |
|
heading_text = headings[col] |
|
max_len = len(strip_ansi(heading_text)) |
|
for row in table_rows: |
|
text_no_ansi = strip_ansi(row[col]) |
|
text_len = len(text_no_ansi) |
|
if text_len > max_len: |
|
max_len = text_len |
|
# We'll store that as the final width for col |
|
col_widths[col] = max(max_len, 2) |
|
|
|
# Build the heading line |
|
# We'll join columns with a " | " separator |
|
heading_line_parts = [] |
|
for col in col_keys: |
|
hd = headings[col] |
|
w = col_widths[col] |
|
heading_line_parts.append(f"{hd:<{w}}") |
|
heading_line = " | ".join(heading_line_parts) |
|
|
|
# Separator line |
|
sep_line = "-" * len(strip_ansi(heading_line)) |
|
|
|
print(f"Load Balancer Check for domain: {domain}") |
|
print(f"Original URL: {url}") |
|
print(heading_line) |
|
print(sep_line) |
|
|
|
for row in table_rows: |
|
line_parts = [] |
|
for col in col_keys: |
|
text = row[col] |
|
w = col_widths[col] |
|
# left align |
|
line_parts.append(f"{text:<{w}}") |
|
print(" | ".join(line_parts)) |
|
|
|
# Compute overall status and print |
|
overall = compute_overall_status(results) |
|
if overall == "success": |
|
print(f"\nOverall LB check: {GREEN}{overall.upper()}{RESET}") |
|
else: |
|
print(f"\nOverall LB check: {RED}{overall.upper()}{RESET}") |
|
|
|
|
|
def print_json_output( |
|
domain: str, |
|
url: str, |
|
results: List[Dict[str, Union[str, int, float, None]]] |
|
) -> None: |
|
"""Print results as JSON, including an overall_status field.""" |
|
overall_status = compute_overall_status(results) |
|
|
|
output_dict = { |
|
"domain": domain, |
|
"original_url": url, |
|
"overall_status": overall_status, |
|
"results": [] |
|
} |
|
|
|
for r in results: |
|
entry = { |
|
"dns_source": r["dns_source"], |
|
"ip": r["ip"], |
|
"hostname": r["hostname"] if r["hostname"] else None, |
|
"family": r["family"], |
|
"scheme": r["scheme"], |
|
"status": r["status"], |
|
"location_header": r["location_header"], |
|
"error_message": r["error_message"], |
|
"connect_time_ms": r["connect_time_ms"], |
|
"response_time_ms": r["response_time_ms"], |
|
} |
|
output_dict["results"].append(entry) |
|
|
|
print(json.dumps(output_dict, indent=2)) |
|
|
|
|
|
def main() -> None: |
|
"""Main entry point of the script.""" |
|
parser = argparse.ArgumentParser( |
|
description="Check a LB cluster by verifying both HTTP & HTTPS over IPv4 & IPv6, with SNI/cert checks." |
|
) |
|
parser.add_argument( |
|
"url", |
|
type=str, |
|
help="URL to check (e.g. https://example.com/health). " |
|
"We will do both HTTP and HTTPS checks for discovered IPs." |
|
) |
|
parser.add_argument( |
|
"-n", |
|
"--no-reverse", |
|
action="store_true", |
|
default=False, |
|
help="Skip reverse DNS lookups." |
|
) |
|
parser.add_argument( |
|
"--json", |
|
action="store_true", |
|
default=False, |
|
help="Output results as JSON instead of a colorful table." |
|
) |
|
parser.add_argument( |
|
"--timing", |
|
action="store_true", |
|
default=False, |
|
help="Measure connect time and response time in milliseconds, show in table/JSON." |
|
) |
|
args = parser.parse_args() |
|
|
|
input_url: str = args.url |
|
do_reverse: bool = not args.no_reverse |
|
as_json: bool = args.json |
|
measure_timing: bool = args.timing |
|
|
|
parsed = urllib.parse.urlparse(input_url) |
|
if not parsed.netloc: |
|
print(f"Invalid URL: {input_url}", file=sys.stderr) |
|
sys.exit(1) |
|
|
|
domain = parsed.netloc |
|
path = parsed.path or "/" |
|
if parsed.query: |
|
path += f"?{parsed.query}" |
|
if parsed.fragment: |
|
path += f"#{parsed.fragment}" |
|
|
|
# 1) Resolve IPs for HTTP (A/AAAA). |
|
http_ipv4, http_ipv6 = resolve_a_aaaa(domain) |
|
http_ips = http_ipv4 + http_ipv6 |
|
|
|
# 2) Resolve IPs for HTTPS from SVCB/HTTPS; fallback to A/AAAA if none |
|
https_ips_svcb = resolve_https_records(domain) # purely from SVCB |
|
https_ips_fallback = [] |
|
if not https_ips_svcb: |
|
https_v4, https_v6 = resolve_a_aaaa(domain) |
|
https_ips_fallback = https_v4 + https_v6 |
|
|
|
# If we literally have no addresses for either protocol, bail out |
|
if not http_ips and not https_ips_svcb and not https_ips_fallback: |
|
print(f"No valid DNS records found for domain {domain}. Exiting.", file=sys.stderr) |
|
sys.exit(1) |
|
|
|
results: List[Dict[str, Union[str, int, float, None]]] = [] |
|
|
|
# 3) Check all IPs for HTTP (these come from A/AAAA) |
|
for ip in http_ips: |
|
outcome = http_request( |
|
ip_addr=ip, |
|
domain=domain, |
|
scheme="http", |
|
path=path, |
|
do_reverse=do_reverse, |
|
dns_source="A/AAAA", |
|
measure_timing=measure_timing |
|
) |
|
results.append(outcome) |
|
|
|
# 4) Check SVCB/HTTPS IPs |
|
for ip in https_ips_svcb: |
|
outcome = http_request( |
|
ip_addr=ip, |
|
domain=domain, |
|
scheme="https", |
|
path=path, |
|
do_reverse=do_reverse, |
|
dns_source="SVCB", |
|
measure_timing=measure_timing |
|
) |
|
results.append(outcome) |
|
|
|
# 5) Check fallback HTTPS IPs |
|
for ip in https_ips_fallback: |
|
outcome = http_request( |
|
ip_addr=ip, |
|
domain=domain, |
|
scheme="https", |
|
path=path, |
|
do_reverse=do_reverse, |
|
dns_source="A/AAAA", |
|
measure_timing=measure_timing |
|
) |
|
results.append(outcome) |
|
|
|
# Output |
|
if as_json: |
|
print_json_output(domain, input_url, results) |
|
else: |
|
print_table(domain, input_url, results, measure_timing) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |