Skip to content

Instantly share code, notes, and snippets.

@the-mikedavis
Last active May 5, 2026 22:06
Show Gist options
  • Select an option

  • Save the-mikedavis/45f1a00790bcc99ee9bb2c321c08920b to your computer and use it in GitHub Desktop.

Select an option

Save the-mikedavis/45f1a00790bcc99ee9bb2c321c08920b to your computer and use it in GitHub Desktop.
AMQP 1.0 vs. 0-9-1 connection memory impact script

amqp-memory-compare

A small Go tool that opens a gradually growing number of idle connections against two single-instance RabbitMQ brokers — AMQP 0-9-1 to one, AMQP 1.0 to the other — and samples node memory via the RabbitMQ HTTP API into a CSV for side-by-side comparison.

It is intentionally minimal: no publishing, no consuming, one channel (0-9-1) or one session (1.0) per connection, all idle.

Requirements

  • Go 1.22 or newer
  • Two RabbitMQ 4.x brokers with the rabbitmq_management plugin enabled
  • A management user with at least the monitoring tag on each broker
  • The client process must be able to open enough sockets; see "File descriptor limits" below

Build

cd amqp-memory-compare
go mod tidy
go build -o amqp-memory-compare .

Usage

./amqp-memory-compare \
  --broker-091 amqps://user:pass@broker-a.example:5671/ \
  --broker-10  amqps://user:pass@broker-b.example:5671/ \
  --start 50 --max 2000 --step 50 \
  --step-interval 60s --sample-interval 5s \
  --insecure \
  --output results.csv

Which broker gets which protocol is determined by the flag name: the URL in --broker-091 only ever sees AMQP 0-9-1 connections, and the URL in --broker-10 only ever sees AMQP 1.0 connections.

Management URLs are derived from the AMQP URLs by default (567115671, 567215672), using the same credentials. Override with --mgmt-091 and --mgmt-10 if the management endpoint lives on a different host, port or uses different credentials.

On SIGINT/SIGTERM the tool tears down all open connections, flushes the CSV and exits.

Flags

Flag Default Description
--broker-091 required AMQP(S) URL for the AMQP 0-9-1 target
--broker-10 required AMQP(S) URL for the AMQP 1.0 target
--mgmt-091 derived HTTP(S) management URL for --broker-091
--mgmt-10 derived HTTP(S) management URL for --broker-10
--start 50 Initial connections per broker
--max 2000 Final connections per broker
--step 50 Connections added per step
--step-interval 60s How long each step is held before the next ramp
--sample-interval 5s HTTP API sampling cadence within a step
--settle-interval 15s Wait after teardown before the final sample
--insecure false Skip TLS verification for AMQPS and HTTPS
--open-channel true Open one AMQP 0-9-1 channel per connection
--open-session true Open one AMQP 1.0 session per connection
--ramp-concurrency 32 Max concurrent dials during ramp-up
--dial-timeout 30s Per-connection dial timeout
--heartbeat 60s AMQP 0-9-1 heartbeat and AMQP 1.0 idle-timeout
--output amqp-memory-compare.csv CSV output path

Scaling model

The tool walks the target count from --start to --max in --step increments. At each step it opens the missing connections (ramp-up), holds the connection count for --step-interval, and samples memory every --sample-interval throughout the hold.

For example, --start 50 --max 300 --step 50 --step-interval 60s produces 6 steps (50, 100, 150, 200, 250, 300), each held for 60 seconds, for a total runtime of ~6 minutes plus the final settle window.

Scale-down is not implemented. Once opened, connections stay open until the tool finishes or is interrupted.

Output

The CSV has one row per broker per RabbitMQ node per sample, with these columns:

  • timestamp_rfc3339, elapsed_seconds: sample timing
  • phase: one of baseline, hold, post-teardown
  • step_index, target_connections: the current step and target
  • broker: broker host as extracted from the AMQP URL
  • protocol: amqp-0-9-1 or amqp-1.0
  • client_alive, client_opened_total, client_failed_total, client_dropped_total: client-side counters
  • mgmt_connections_total: from GET /api/overview
  • node: RabbitMQ node name
  • mem_used: from GET /api/nodes?columns=... (same value the Prometheus endpoint reports as rabbitmq_process_resident_memory_bytes)
  • mem_total_rss, mem_total_allocated, mem_total_erlang: from GET /api/nodes/{node}/memory (memory.total.*)
  • mem_connection_readers, mem_connection_writers, mem_connection_channels, mem_connection_other: connection-attributed Erlang process memory
  • mem_binary, mem_other_proc, mem_other_system, mem_code, mem_atom: remaining notable rabbit_vm:memory/0 fields

baseline and post-teardown samples bookend the experiment so that the CSV shows broker memory both before and after the ramp.

Notes and caveats

  • The two brokers are compared independently. The tool does not itself check that the two brokers are comparably configured — it is up to you to match their hardware, Erlang VM flags, enabled plugins and memory high watermark settings before drawing conclusions from the output.
  • mem_connection_channels reflects AMQP 0-9-1 channels and AMQP 1.0 sessions, because both are tracked under the rabbit_channel_sup_sup family in rabbit_vm:memory/0.
  • The tool sets connection_name and product to amqp-memory-compare in the AMQP client properties so the load can be filtered out in the management UI.
  • A vhost in the AMQP URL path (e.g. .../vhost) is honoured on both protocols. AMQP 1.0 carries it as hostname = vhost:<name> per RabbitMQ's convention (see rabbit_amqp_reader:vhost/1).
  • The tool does not reconnect dropped connections. If drops happen during a step, the client_dropped_total counter records them but the target count is not refilled until the next step boundary.

File descriptor limits

Each open connection consumes one file descriptor on the client. Running with --max 5000 against two brokers means 10000+ sockets from a single process. On macOS and Linux the default per-process limit is often 256 or 1024; raise it before running:

ulimit -n 65536

RabbitMQ also enforces its own connection limit per vhost and per user; /api/overview will show connections_total capped if you hit it.

TLS

--insecure is convenient for dev clusters with self-signed certificates but disables verification for both AMQPS and HTTPS management calls. For real testing, populate a system trust store (or provide a proper CA) and omit the flag.

License

MIT or Apache-2.0, pick whichever your project prefers — this is a throwaway test harness.

module github.com/the-mikedavis/amqp-memory-compare
go 1.22
require (
github.com/Azure/go-amqp v1.3.0
github.com/rabbitmq/amqp091-go v1.10.0
)
// amqp-memory-compare is a load tool that opens a gradually growing set of
// amqp-memory-compare is a load tool that opens a gradually growing set of
// idle connections against two single-instance RabbitMQ brokers, speaking
// AMQP 0-9-1 to one and AMQP 1.0 to the other, and samples node memory via
// the RabbitMQ HTTP API into a CSV for side-by-side comparison.
//
// The tool is deliberately minimal: it opens connections, optionally opens
// one channel (0-9-1) or one session (1.0) per connection, keeps them idle,
// and records memory samples. It does not publish or consume.
package main
import (
"context"
"crypto/tls"
"encoding/csv"
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
amqp10 "github.com/Azure/go-amqp"
amqp091 "github.com/rabbitmq/amqp091-go"
)
const (
protoAMQP091 = "amqp-0-9-1"
protoAMQP10 = "amqp-1.0"
defaultAMQPSPort = "5671"
defaultAMQPPort = "5672"
defaultMgmtTLS = "15671"
defaultMgmtPort = "15672"
)
type flags struct {
broker091 string
broker10 string
mgmt091 string
mgmt10 string
start int
max int
step int
stepInterval time.Duration
sampleInterval time.Duration
settleInterval time.Duration
insecure bool
openChannel bool
openSession bool
rampConcurrency int
dialTimeout time.Duration
heartbeat time.Duration
output string
}
func main() {
var f flags
flag.StringVar(&f.broker091, "broker-091", "",
"AMQP(S) URL for the broker that receives AMQP 0-9-1 connections, e.g. amqps://user:pass@host:5671/vhost")
flag.StringVar(&f.broker10, "broker-10", "",
"AMQP(S) URL for the broker that receives AMQP 1.0 connections, e.g. amqps://user:pass@host:5671/vhost")
flag.StringVar(&f.mgmt091, "mgmt-091", "",
"HTTP(S) management URL for broker-091 (optional; derived from --broker-091 if omitted)")
flag.StringVar(&f.mgmt10, "mgmt-10", "",
"HTTP(S) management URL for broker-10 (optional; derived from --broker-10 if omitted)")
flag.IntVar(&f.start, "start", 50, "Initial number of connections per broker")
flag.IntVar(&f.max, "max", 2000, "Maximum number of connections per broker")
flag.IntVar(&f.step, "step", 50, "Number of connections added per step")
flag.DurationVar(&f.stepInterval, "step-interval", 60*time.Second,
"Duration to hold each step before scaling up, sampling memory throughout")
flag.DurationVar(&f.sampleInterval, "sample-interval", 5*time.Second,
"How often to sample the HTTP API during a step")
flag.DurationVar(&f.settleInterval, "settle-interval", 15*time.Second,
"Time to wait after teardown before taking a final sample")
flag.BoolVar(&f.insecure, "insecure", false,
"Skip TLS certificate verification for both AMQPS and HTTPS mgmt")
flag.BoolVar(&f.openChannel, "open-channel", true,
"Open one AMQP 0-9-1 channel per connection")
flag.BoolVar(&f.openSession, "open-session", true,
"Open one AMQP 1.0 session per connection")
flag.IntVar(&f.rampConcurrency, "ramp-concurrency", 32,
"Maximum number of concurrent dials during ramp-up")
flag.DurationVar(&f.dialTimeout, "dial-timeout", 30*time.Second,
"Per-connection dial timeout")
flag.DurationVar(&f.heartbeat, "heartbeat", 60*time.Second,
"AMQP 0-9-1 heartbeat and AMQP 1.0 idle-timeout")
flag.StringVar(&f.output, "output", "amqp-memory-compare.csv", "CSV output path")
flag.Parse()
if err := validateFlags(&f); err != nil {
fmt.Fprintln(os.Stderr, "error:", err)
flag.Usage()
os.Exit(2)
}
if err := run(f); err != nil {
log.Fatal(err)
}
}
func validateFlags(f *flags) error {
if f.broker091 == "" || f.broker10 == "" {
return fmt.Errorf("--broker-091 and --broker-10 are required")
}
if f.start < 1 || f.max < f.start || f.step < 1 {
return fmt.Errorf("require 1 <= --start <= --max and --step >= 1")
}
if f.sampleInterval <= 0 || f.stepInterval <= 0 {
return fmt.Errorf("--sample-interval and --step-interval must be positive")
}
if f.mgmt091 == "" {
f.mgmt091 = deriveMgmtURL(f.broker091)
}
if f.mgmt10 == "" {
f.mgmt10 = deriveMgmtURL(f.broker10)
}
return nil
}
// deriveMgmtURL derives a reasonable default management URL from an AMQP URL
// using the conventional port mapping: 5671 -> 15671 and 5672 -> 15672.
// Any custom port is mapped by adding 10000 to it.
func deriveMgmtURL(amqpURL string) string {
u, err := url.Parse(amqpURL)
if err != nil {
return ""
}
host, port := splitHostPort(u.Host, u.Scheme)
var scheme, mgmtPort string
switch u.Scheme {
case "amqps":
scheme = "https"
mgmtPort = defaultMgmtTLS
if port != "" && port != defaultAMQPSPort {
mgmtPort = shiftPort(port)
}
default:
scheme = "http"
mgmtPort = defaultMgmtPort
if port != "" && port != defaultAMQPPort {
mgmtPort = shiftPort(port)
}
}
mu := &url.URL{
Scheme: scheme,
Host: host + ":" + mgmtPort,
User: u.User,
}
return mu.String()
}
func splitHostPort(hostport, scheme string) (host, port string) {
if i := strings.LastIndex(hostport, ":"); i > 0 && !strings.Contains(hostport[i:], "]") {
return hostport[:i], hostport[i+1:]
}
return hostport, ""
}
func shiftPort(port string) string {
if n, err := strconv.Atoi(port); err == nil {
return strconv.Itoa(n + 10000)
}
return port
}
func run(f flags) error {
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer cancel()
tlsConf := &tls.Config{InsecureSkipVerify: f.insecure}
httpClient := &http.Client{
Timeout: 15 * time.Second,
Transport: &http.Transport{
TLSClientConfig: tlsConf,
MaxIdleConnsPerHost: 4,
},
}
b091, err := newBroker(protoAMQP091, f.broker091, f.mgmt091, tlsConf, httpClient, f)
if err != nil {
return fmt.Errorf("broker-091: %w", err)
}
b10, err := newBroker(protoAMQP10, f.broker10, f.mgmt10, tlsConf, httpClient, f)
if err != nil {
return fmt.Errorf("broker-10: %w", err)
}
// Validate management endpoints before starting the test proper.
if err := b091.probe(ctx); err != nil {
return fmt.Errorf("broker-091 mgmt probe failed: %w", err)
}
if err := b10.probe(ctx); err != nil {
return fmt.Errorf("broker-10 mgmt probe failed: %w", err)
}
log.Printf("broker-091 amqp=%s mgmt=%s", b091.amqpHost, b091.mgmtBase)
log.Printf("broker-10 amqp=%s mgmt=%s", b10.amqpHost, b10.mgmtBase)
out, err := os.Create(f.output)
if err != nil {
return err
}
defer out.Close()
w := csv.NewWriter(out)
defer w.Flush()
if err := w.Write(csvHeader()); err != nil {
return err
}
w.Flush()
// Teardown should always run, even on early return or signal.
defer func() {
log.Println("Tearing down connections")
b091.closeAll()
b10.closeAll()
}()
start := time.Now()
log.Println("Recording baseline sample (no test connections)")
if err := writeSamples(ctx, w, start, "baseline", 0, 0, b091, b10); err != nil {
log.Printf("baseline sample error: %v", err)
}
w.Flush()
steps := buildSteps(f.start, f.max, f.step)
log.Printf("Running %d steps: %v", len(steps), steps)
for i, target := range steps {
if ctx.Err() != nil {
break
}
stepIdx := i + 1
log.Printf("=== Step %d/%d: target=%d connections per broker ===", stepIdx, len(steps), target)
// Ramp both brokers in parallel so memory samples taken during the
// subsequent hold reflect roughly the same observation window on each.
var wg sync.WaitGroup
wg.Add(2)
go func() { defer wg.Done(); b091.scale(ctx, target, f.rampConcurrency) }()
go func() { defer wg.Done(); b10.scale(ctx, target, f.rampConcurrency) }()
wg.Wait()
log.Printf("Step %d ramp-up done: 091 alive=%d, 10 alive=%d; holding for %s",
stepIdx, b091.alive(), b10.alive(), f.stepInterval)
if err := holdAndSample(ctx, w, start, "hold", stepIdx, target, f.stepInterval, f.sampleInterval, b091, b10); err != nil {
// ctx cancellation is the expected exit path; other errors are already logged.
break
}
}
if ctx.Err() == nil && f.settleInterval > 0 {
// Close everything, wait briefly, then take one more sample so the
// CSV shows the broker settling back toward baseline.
b091.closeAll()
b10.closeAll()
select {
case <-ctx.Done():
case <-time.After(f.settleInterval):
}
if ctx.Err() == nil {
_ = writeSamples(ctx, w, start, "post-teardown", len(steps)+1, 0, b091, b10)
}
}
w.Flush()
log.Printf("Summary broker-091 (%s): opened=%d failed=%d dropped=%d",
b091.name(), b091.opened.Load(), b091.failed.Load(), b091.dropped.Load())
log.Printf("Summary broker-10 (%s): opened=%d failed=%d dropped=%d",
b10.name(), b10.opened.Load(), b10.failed.Load(), b10.dropped.Load())
log.Printf("Results written to %s", f.output)
return nil
}
// buildSteps returns a monotonically increasing list of per-broker connection
// targets, ending at max even if max is not a multiple of step.
func buildSteps(start, max, step int) []int {
out := []int{}
for v := start; v <= max; v += step {
out = append(out, v)
}
if len(out) == 0 || out[len(out)-1] != max {
out = append(out, max)
}
return out
}
func csvHeader() []string {
return []string{
"timestamp_rfc3339",
"elapsed_seconds",
"phase",
"step_index",
"target_connections",
"broker",
"protocol",
"client_alive",
"client_opened_total",
"client_failed_total",
"client_dropped_total",
"mgmt_connections_total",
"node",
"mem_used",
"mem_total_rss",
"mem_total_allocated",
"mem_total_erlang",
"mem_connection_readers",
"mem_connection_writers",
"mem_connection_channels",
"mem_connection_other",
"mem_binary",
"mem_other_proc",
"mem_other_system",
"mem_code",
"mem_atom",
}
}
func holdAndSample(ctx context.Context, w *csv.Writer, start time.Time, phase string, stepIdx, target int,
total, interval time.Duration, brokers ...*broker) error {
deadline := time.Now().Add(total)
for {
if err := writeSamples(ctx, w, start, phase, stepIdx, target, brokers...); err != nil {
log.Printf("sample error: %v", err)
}
w.Flush()
remaining := time.Until(deadline)
if remaining <= 0 {
return nil
}
wait := interval
if wait > remaining {
wait = remaining
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(wait):
}
}
}
func writeSamples(ctx context.Context, w *csv.Writer, start time.Time, phase string, stepIdx, target int, brokers ...*broker) error {
ts := time.Now()
elapsed := ts.Sub(start).Seconds()
var firstErr error
for _, br := range brokers {
samples, err := br.sample(ctx)
if err != nil {
if firstErr == nil {
firstErr = err
}
log.Printf("%s (%s): sample error: %v", br.name(), br.protocol, err)
// Emit one row with empty mgmt fields so the CSV stays aligned.
_ = w.Write(emptyMgmtRow(ts, elapsed, phase, stepIdx, target, br))
continue
}
for _, s := range samples {
_ = w.Write(sampleRow(ts, elapsed, phase, stepIdx, target, br, s))
}
if len(samples) > 0 {
s := samples[0]
log.Printf("%s (%s): alive=%d mgmt_conn=%d mem_used=%s rss=%s",
br.name(), br.protocol, br.alive(), s.MgmtConnections,
humanBytes(s.MemUsed), humanBytes(s.TotalRss))
}
}
return firstErr
}
func emptyMgmtRow(ts time.Time, elapsed float64, phase string, stepIdx, target int, br *broker) []string {
return []string{
ts.UTC().Format(time.RFC3339Nano),
strconv.FormatFloat(elapsed, 'f', 3, 64),
phase,
strconv.Itoa(stepIdx),
strconv.Itoa(target),
br.name(),
br.protocol,
strconv.Itoa(br.alive()),
strconv.FormatInt(br.opened.Load(), 10),
strconv.FormatInt(br.failed.Load(), 10),
strconv.FormatInt(br.dropped.Load(), 10),
"", "", "", "", "", "", "", "", "", "", "", "", "", "", "",
}
}
func sampleRow(ts time.Time, elapsed float64, phase string, stepIdx, target int, br *broker, s memSample) []string {
return []string{
ts.UTC().Format(time.RFC3339Nano),
strconv.FormatFloat(elapsed, 'f', 3, 64),
phase,
strconv.Itoa(stepIdx),
strconv.Itoa(target),
br.name(),
br.protocol,
strconv.Itoa(br.alive()),
strconv.FormatInt(br.opened.Load(), 10),
strconv.FormatInt(br.failed.Load(), 10),
strconv.FormatInt(br.dropped.Load(), 10),
strconv.Itoa(s.MgmtConnections),
s.Node,
strconv.FormatInt(s.MemUsed, 10),
strconv.FormatInt(s.TotalRss, 10),
strconv.FormatInt(s.TotalAllocated, 10),
strconv.FormatInt(s.TotalErlang, 10),
strconv.FormatInt(s.ConnectionReaders, 10),
strconv.FormatInt(s.ConnectionWriters, 10),
strconv.FormatInt(s.ConnectionChannels, 10),
strconv.FormatInt(s.ConnectionOther, 10),
strconv.FormatInt(s.Binary, 10),
strconv.FormatInt(s.OtherProc, 10),
strconv.FormatInt(s.OtherSystem, 10),
strconv.FormatInt(s.Code, 10),
strconv.FormatInt(s.Atom, 10),
}
}
func humanBytes(n int64) string {
const unit = 1024
if n < unit {
return fmt.Sprintf("%d B", n)
}
div, exp := int64(unit), 0
for v := n / unit; v >= unit; v /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp])
}
// broker tracks state for a single RabbitMQ node: its AMQP endpoint, the
// matching HTTP management endpoint, the open client-side connections and
// a few counters that summarise client-observed outcomes.
type broker struct {
protocol string
amqpURL string
amqpHost string
dialFlags flags
mgmtClient *http.Client
mgmtBase string
mgmtUser *url.Userinfo
// tlsConf is a per-broker clone of the shared TLS config with ServerName
// preset to the broker host. It is never mutated after newBroker returns
// so it can be shared across concurrent dials without locking.
// amqp091-go would otherwise mutate ServerName on the first AMQPS dial,
// which would race with other dials that use the same config.
tlsConf *tls.Config
mu sync.Mutex
conns []*trackedConn
shuttingDown atomic.Bool
opened atomic.Int64
failed atomic.Int64
dropped atomic.Int64
}
// trackedConn wraps a protocol-specific connection with a close marker used
// to distinguish intentional teardowns from unexpected drops.
type trackedConn struct {
closer io.Closer
closed atomic.Bool
}
func (t *trackedConn) Close() error {
t.closed.Store(true)
return t.closer.Close()
}
func newBroker(protocol, amqpURL, mgmtURL string, tlsConf *tls.Config, client *http.Client, f flags) (*broker, error) {
au, err := url.Parse(amqpURL)
if err != nil {
return nil, fmt.Errorf("invalid amqp URL: %w", err)
}
if au.Scheme != "amqp" && au.Scheme != "amqps" {
return nil, fmt.Errorf("amqp URL must use amqp or amqps scheme, got %q", au.Scheme)
}
mu, err := url.Parse(mgmtURL)
if err != nil {
return nil, fmt.Errorf("invalid mgmt URL: %w", err)
}
if mu.Scheme != "http" && mu.Scheme != "https" {
return nil, fmt.Errorf("mgmt URL must use http or https scheme, got %q", mu.Scheme)
}
auth := mu.User
mu.User = nil
mu.Path = strings.TrimSuffix(mu.Path, "/")
// Preset ServerName on a per-broker clone so the shared tls.Config is not
// mutated by amqp091-go's first AMQPS dial.
host, _ := splitHostPort(au.Host, au.Scheme)
perBrokerTLS := tlsConf.Clone()
if perBrokerTLS.ServerName == "" {
perBrokerTLS.ServerName = host
}
return &broker{
protocol: protocol,
amqpURL: amqpURL,
amqpHost: au.Host,
dialFlags: f,
mgmtClient: client,
mgmtBase: mu.String(),
mgmtUser: auth,
tlsConf: perBrokerTLS,
}, nil
}
func (b *broker) name() string { return b.amqpHost }
func (b *broker) alive() int {
b.mu.Lock()
defer b.mu.Unlock()
return len(b.conns)
}
// scale brings the number of alive connections up to target by dialing the
// missing count concurrently. Scale-down is not implemented because the tool
// only ramps up; closeAll handles teardown.
func (b *broker) scale(ctx context.Context, target, concurrency int) {
current := b.alive()
if current >= target {
return
}
needed := target - current
if concurrency < 1 {
concurrency = 1
}
sem := make(chan struct{}, concurrency)
var wg sync.WaitGroup
var errMu sync.Mutex
var errSamples []string
var errCount int
for i := 0; i < needed; i++ {
if ctx.Err() != nil {
break
}
wg.Add(1)
sem <- struct{}{}
go func() {
defer wg.Done()
defer func() { <-sem }()
if err := b.dial(ctx); err != nil {
b.failed.Add(1)
errMu.Lock()
errCount++
if len(errSamples) < 3 {
errSamples = append(errSamples, err.Error())
}
errMu.Unlock()
}
}()
}
wg.Wait()
if errCount > 0 {
log.Printf("%s (%s): %d dial errors; first samples: %v", b.name(), b.protocol, errCount, errSamples)
}
}
// dial establishes one connection using the broker's protocol. On success,
// it tracks the connection for subsequent teardown and drop detection, and
// returns nil. The caller never needs a reference to the connection.
func (b *broker) dial(ctx context.Context) error {
switch b.protocol {
case protoAMQP091:
return b.dial091(ctx)
case protoAMQP10:
return b.dial10(ctx)
}
return fmt.Errorf("unknown protocol %q", b.protocol)
}
// track registers tc in the alive list and starts a goroutine that watches
// for unexpected drops via watch. track MUST be called only after watch has
// been wired up (e.g. via NotifyClose) so that a close signal occurring
// between wiring and tracking is delivered via the channel buffer.
// If the broker is already shutting down, tc is closed immediately so a
// dial that completes mid-teardown does not leak past closeAll.
func (b *broker) track(tc *trackedConn, watch <-chan struct{}) {
b.mu.Lock()
if b.shuttingDown.Load() {
b.mu.Unlock()
_ = tc.Close()
return
}
b.conns = append(b.conns, tc)
b.mu.Unlock()
b.opened.Add(1)
go func() {
<-watch
if !tc.closed.Load() && !b.shuttingDown.Load() {
b.dropped.Add(1)
b.removeConn(tc)
}
}()
}
func (b *broker) dial091(ctx context.Context) error {
// Use a context-aware TCP dialer so SIGINT can interrupt dials in flight.
// Once the TCP connection is up, amqp091-go manages its own deadlines
// for the TLS and AMQP handshakes.
netDialer := &net.Dialer{Timeout: b.dialFlags.dialTimeout}
cfg := amqp091.Config{
TLSClientConfig: b.tlsConf,
Heartbeat: b.dialFlags.heartbeat,
Dial: func(network, addr string) (net.Conn, error) {
c, err := netDialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
// Match amqp091-go's DefaultDial: set a handshake deadline that is
// cleared once the AMQP open completes.
if err := c.SetDeadline(time.Now().Add(b.dialFlags.dialTimeout)); err != nil {
_ = c.Close()
return nil, err
}
return c, nil
},
Properties: amqp091.Table{
"connection_name": "amqp-memory-compare",
"product": "amqp-memory-compare",
},
}
if err := ctx.Err(); err != nil {
return err
}
conn, err := amqp091.DialConfig(b.amqpURL, cfg)
if err != nil {
return err
}
wrap := &conn091{conn: conn}
if b.dialFlags.openChannel {
ch, err := conn.Channel()
if err != nil {
_ = conn.Close()
return fmt.Errorf("open channel: %w", err)
}
wrap.ch = ch
}
// Wire the close listener before tracking so a close that happens between
// NotifyClose and track is still delivered via the buffered channel.
closeCh := make(chan *amqp091.Error, 1)
conn.NotifyClose(closeCh)
watch := make(chan struct{})
go func() {
<-closeCh
close(watch)
}()
b.track(&trackedConn{closer: wrap}, watch)
return nil
}
func (b *broker) dial10(ctx context.Context) error {
u, err := url.Parse(b.amqpURL)
if err != nil {
return err
}
host, port := splitHostPort(u.Host, u.Scheme)
if port == "" {
if u.Scheme == "amqps" {
port = defaultAMQPSPort
} else {
port = defaultAMQPPort
}
}
addr := u.Scheme + "://" + host + ":" + port
vhost := strings.TrimPrefix(u.Path, "/")
opts := &amqp10.ConnOptions{
TLSConfig: b.tlsConf,
IdleTimeout: b.dialFlags.heartbeat,
Properties: map[string]any{
"connection_name": "amqp-memory-compare",
"product": "amqp-memory-compare",
},
}
if u.User != nil {
pass, _ := u.User.Password()
opts.SASLType = amqp10.SASLTypePlain(u.User.Username(), pass)
}
if vhost != "" {
// RabbitMQ maps the AMQP 1.0 hostname "vhost:<name>" to a specific vhost.
// See rabbit_amqp_reader:vhost/1.
opts.HostName = "vhost:" + vhost
}
dialCtx, cancel := context.WithTimeout(ctx, b.dialFlags.dialTimeout)
defer cancel()
conn, err := amqp10.Dial(dialCtx, addr, opts)
if err != nil {
return err
}
wrap := &conn10{conn: conn}
if b.dialFlags.openSession {
sess, err := conn.NewSession(ctx, nil)
if err != nil {
_ = conn.Close()
return fmt.Errorf("open session: %w", err)
}
wrap.sess = sess
}
b.track(&trackedConn{closer: wrap}, conn.Done())
return nil
}
func (b *broker) removeConn(tc *trackedConn) {
b.mu.Lock()
defer b.mu.Unlock()
for i, existing := range b.conns {
if existing == tc {
b.conns = append(b.conns[:i], b.conns[i+1:]...)
return
}
}
}
func (b *broker) closeAll() {
b.shuttingDown.Store(true)
b.mu.Lock()
conns := b.conns
b.conns = nil
b.mu.Unlock()
var wg sync.WaitGroup
for _, c := range conns {
wg.Add(1)
go func(c *trackedConn) {
defer wg.Done()
_ = c.Close()
}(c)
}
wg.Wait()
}
// conn091 groups an AMQP 0-9-1 connection with its single optional channel
// so they are torn down together.
type conn091 struct {
conn *amqp091.Connection
ch *amqp091.Channel
}
func (c *conn091) Close() error {
if c.ch != nil {
_ = c.ch.Close()
}
return c.conn.Close()
}
// conn10 groups an AMQP 1.0 connection with its single optional session.
type conn10 struct {
conn *amqp10.Conn
sess *amqp10.Session
}
func (c *conn10) Close() error {
if c.sess != nil {
// Session close takes a context. Use a short deadline so teardown does
// not stall if the peer has already gone away.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = c.sess.Close(ctx)
}
return c.conn.Close()
}
// memSample holds one node's memory snapshot. The tool emits one row per
// broker per node; single-instance brokers emit exactly one row per sample.
type memSample struct {
Node string
MgmtConnections int
MemUsed int64
TotalRss int64
TotalAllocated int64
TotalErlang int64
ConnectionReaders int64
ConnectionWriters int64
ConnectionChannels int64
ConnectionOther int64
Binary int64
OtherProc int64
OtherSystem int64
Code int64
Atom int64
}
type apiNode struct {
Name string `json:"name"`
MemUsed int64 `json:"mem_used"`
}
type apiOverview struct {
ObjectTotals struct {
Connections int `json:"connections"`
} `json:"object_totals"`
}
// apiNodeMemory mirrors the shape of GET /api/nodes/{name}/memory.
// See rabbit_vm:memory/0 in the RabbitMQ source.
type apiNodeMemory struct {
Memory struct {
ConnectionReaders int64 `json:"connection_readers"`
ConnectionWriters int64 `json:"connection_writers"`
ConnectionChannels int64 `json:"connection_channels"`
ConnectionOther int64 `json:"connection_other"`
Binary int64 `json:"binary"`
OtherProc int64 `json:"other_proc"`
OtherSystem int64 `json:"other_system"`
Code int64 `json:"code"`
Atom int64 `json:"atom"`
Total struct {
Rss int64 `json:"rss"`
Allocated int64 `json:"allocated"`
Erlang int64 `json:"erlang"`
} `json:"total"`
} `json:"memory"`
}
func (b *broker) probe(ctx context.Context) error {
// Exercise every endpoint sample() uses so that authorisation or URL
// problems surface up-front instead of as repeated errors mid-run.
nodes, err := b.fetchNodes(ctx)
if err != nil {
return fmt.Errorf("GET /api/nodes: %w", err)
}
if len(nodes) == 0 {
return fmt.Errorf("GET /api/nodes returned no nodes")
}
if _, err := b.fetchOverview(ctx); err != nil {
return fmt.Errorf("GET /api/overview: %w", err)
}
if _, err := b.fetchNodeMemory(ctx, nodes[0].Name); err != nil {
return fmt.Errorf("GET /api/nodes/%s/memory: %w", nodes[0].Name, err)
}
return nil
}
func (b *broker) sample(ctx context.Context) ([]memSample, error) {
nodes, err := b.fetchNodes(ctx)
if err != nil {
return nil, fmt.Errorf("fetch nodes: %w", err)
}
ov, err := b.fetchOverview(ctx)
if err != nil {
return nil, fmt.Errorf("fetch overview: %w", err)
}
out := make([]memSample, 0, len(nodes))
for _, n := range nodes {
m, err := b.fetchNodeMemory(ctx, n.Name)
if err != nil {
return nil, fmt.Errorf("fetch node memory %q: %w", n.Name, err)
}
out = append(out, memSample{
Node: n.Name,
MgmtConnections: ov.ObjectTotals.Connections,
MemUsed: n.MemUsed,
TotalRss: m.Memory.Total.Rss,
TotalAllocated: m.Memory.Total.Allocated,
TotalErlang: m.Memory.Total.Erlang,
ConnectionReaders: m.Memory.ConnectionReaders,
ConnectionWriters: m.Memory.ConnectionWriters,
ConnectionChannels: m.Memory.ConnectionChannels,
ConnectionOther: m.Memory.ConnectionOther,
Binary: m.Memory.Binary,
OtherProc: m.Memory.OtherProc,
OtherSystem: m.Memory.OtherSystem,
Code: m.Memory.Code,
Atom: m.Memory.Atom,
})
}
return out, nil
}
func (b *broker) fetchNodes(ctx context.Context) ([]apiNode, error) {
var nodes []apiNode
err := b.getJSON(ctx, "/api/nodes?columns=name,mem_used", &nodes)
return nodes, err
}
func (b *broker) fetchOverview(ctx context.Context) (apiOverview, error) {
var ov apiOverview
err := b.getJSON(ctx, "/api/overview?columns=object_totals.connections", &ov)
return ov, err
}
func (b *broker) fetchNodeMemory(ctx context.Context, node string) (apiNodeMemory, error) {
var mem apiNodeMemory
err := b.getJSON(ctx, "/api/nodes/"+url.PathEscape(node)+"/memory", &mem)
return mem, err
}
func (b *broker) getJSON(ctx context.Context, path string, v any) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, b.mgmtBase+path, nil)
if err != nil {
return err
}
if b.mgmtUser != nil {
pass, _ := b.mgmtUser.Password()
req.SetBasicAuth(b.mgmtUser.Username(), pass)
}
req.Header.Set("Accept", "application/json")
resp, err := b.mgmtClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode/100 != 2 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return fmt.Errorf("GET %s: %s: %s", path, resp.Status, strings.TrimSpace(string(body)))
}
return json.NewDecoder(resp.Body).Decode(v)
}
package main
import (
"reflect"
"testing"
"time"
)
func TestBuildSteps(t *testing.T) {
cases := []struct {
name string
start, max, step int
want []int
}{
{"exact divisor", 50, 200, 50, []int{50, 100, 150, 200}},
{"remainder gets a final step at max", 50, 175, 50, []int{50, 100, 150, 175}},
{"start equals max", 100, 100, 10, []int{100}},
{"single step", 10, 10, 1, []int{10}},
{"step larger than range", 10, 15, 100, []int{10, 15}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
got := buildSteps(tc.start, tc.max, tc.step)
if !reflect.DeepEqual(got, tc.want) {
t.Fatalf("buildSteps(%d,%d,%d) = %v, want %v", tc.start, tc.max, tc.step, got, tc.want)
}
})
}
}
func TestSplitHostPort(t *testing.T) {
cases := []struct {
in string
wantHost, port string
}{
{"host:5671", "host", "5671"},
{"host", "host", ""},
{"[::1]:5671", "[::1]", "5671"},
{"[::1]", "[::1]", ""},
}
for _, tc := range cases {
t.Run(tc.in, func(t *testing.T) {
host, port := splitHostPort(tc.in, "amqps")
if host != tc.wantHost || port != tc.port {
t.Fatalf("splitHostPort(%q) = (%q, %q), want (%q, %q)", tc.in, host, port, tc.wantHost, tc.port)
}
})
}
}
func TestDeriveMgmtURL(t *testing.T) {
cases := []struct {
in, want string
}{
{"amqps://user:pass@host:5671/vh", "https://user:pass@host:15671"},
{"amqp://user:pass@host:5672/", "http://user:pass@host:15672"},
{"amqps://host:5671", "https://host:15671"},
{"amqps://host:6671", "https://host:16671"},
{"amqp://host:6672", "http://host:16672"},
}
for _, tc := range cases {
t.Run(tc.in, func(t *testing.T) {
got := deriveMgmtURL(tc.in)
if got != tc.want {
t.Fatalf("deriveMgmtURL(%q) = %q, want %q", tc.in, got, tc.want)
}
})
}
}
func TestValidateFlagsErrors(t *testing.T) {
cases := []struct {
name string
f flags
}{
{"missing broker-091", flags{broker10: "amqps://b", start: 1, max: 1, step: 1, sampleInterval: 1, stepInterval: 1}},
{"missing broker-10", flags{broker091: "amqps://a", start: 1, max: 1, step: 1, sampleInterval: 1, stepInterval: 1}},
{"start > max", flags{broker091: "amqps://a", broker10: "amqps://b", start: 10, max: 5, step: 1, sampleInterval: 1, stepInterval: 1}},
{"step zero", flags{broker091: "amqps://a", broker10: "amqps://b", start: 1, max: 10, step: 0, sampleInterval: 1, stepInterval: 1}},
{"zero sample-interval", flags{broker091: "amqps://a", broker10: "amqps://b", start: 1, max: 1, step: 1, sampleInterval: 0, stepInterval: 1}},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
f := tc.f
if err := validateFlags(&f); err == nil {
t.Fatalf("expected error, got nil")
}
})
}
}
// TestCSVRowShapes guards against header and row functions drifting out of
// sync, which would silently misalign columns in the output.
func TestCSVRowShapes(t *testing.T) {
want := len(csvHeader())
br := &broker{protocol: protoAMQP091, amqpHost: "h"}
ts := time.Now()
if got := len(emptyMgmtRow(ts, 0, "p", 0, 0, br)); got != want {
t.Fatalf("emptyMgmtRow len = %d, want %d", got, want)
}
if got := len(sampleRow(ts, 0, "p", 0, 0, br, memSample{})); got != want {
t.Fatalf("sampleRow len = %d, want %d", got, want)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment