|
// amqp-memory-compare is a load tool that opens a gradually growing set of |
|
// amqp-memory-compare is a load tool that opens a gradually growing set of |
|
// idle connections against two single-instance RabbitMQ brokers, speaking |
|
// AMQP 0-9-1 to one and AMQP 1.0 to the other, and samples node memory via |
|
// the RabbitMQ HTTP API into a CSV for side-by-side comparison. |
|
// |
|
// The tool is deliberately minimal: it opens connections, optionally opens |
|
// one channel (0-9-1) or one session (1.0) per connection, keeps them idle, |
|
// and records memory samples. It does not publish or consume. |
|
package main |
|
|
|
import ( |
|
"context" |
|
"crypto/tls" |
|
"encoding/csv" |
|
"encoding/json" |
|
"flag" |
|
"fmt" |
|
"io" |
|
"log" |
|
"net" |
|
"net/http" |
|
"net/url" |
|
"os" |
|
"os/signal" |
|
"strconv" |
|
"strings" |
|
"sync" |
|
"sync/atomic" |
|
"syscall" |
|
"time" |
|
|
|
amqp10 "github.com/Azure/go-amqp" |
|
amqp091 "github.com/rabbitmq/amqp091-go" |
|
) |
|
|
|
const ( |
|
protoAMQP091 = "amqp-0-9-1" |
|
protoAMQP10 = "amqp-1.0" |
|
|
|
defaultAMQPSPort = "5671" |
|
defaultAMQPPort = "5672" |
|
defaultMgmtTLS = "15671" |
|
defaultMgmtPort = "15672" |
|
) |
|
|
|
type flags struct { |
|
broker091 string |
|
broker10 string |
|
mgmt091 string |
|
mgmt10 string |
|
|
|
start int |
|
max int |
|
step int |
|
stepInterval time.Duration |
|
|
|
sampleInterval time.Duration |
|
settleInterval time.Duration |
|
|
|
insecure bool |
|
openChannel bool |
|
openSession bool |
|
rampConcurrency int |
|
dialTimeout time.Duration |
|
heartbeat time.Duration |
|
|
|
output string |
|
} |
|
|
|
func main() { |
|
var f flags |
|
flag.StringVar(&f.broker091, "broker-091", "", |
|
"AMQP(S) URL for the broker that receives AMQP 0-9-1 connections, e.g. amqps://user:pass@host:5671/vhost") |
|
flag.StringVar(&f.broker10, "broker-10", "", |
|
"AMQP(S) URL for the broker that receives AMQP 1.0 connections, e.g. amqps://user:pass@host:5671/vhost") |
|
flag.StringVar(&f.mgmt091, "mgmt-091", "", |
|
"HTTP(S) management URL for broker-091 (optional; derived from --broker-091 if omitted)") |
|
flag.StringVar(&f.mgmt10, "mgmt-10", "", |
|
"HTTP(S) management URL for broker-10 (optional; derived from --broker-10 if omitted)") |
|
|
|
flag.IntVar(&f.start, "start", 50, "Initial number of connections per broker") |
|
flag.IntVar(&f.max, "max", 2000, "Maximum number of connections per broker") |
|
flag.IntVar(&f.step, "step", 50, "Number of connections added per step") |
|
flag.DurationVar(&f.stepInterval, "step-interval", 60*time.Second, |
|
"Duration to hold each step before scaling up, sampling memory throughout") |
|
|
|
flag.DurationVar(&f.sampleInterval, "sample-interval", 5*time.Second, |
|
"How often to sample the HTTP API during a step") |
|
flag.DurationVar(&f.settleInterval, "settle-interval", 15*time.Second, |
|
"Time to wait after teardown before taking a final sample") |
|
|
|
flag.BoolVar(&f.insecure, "insecure", false, |
|
"Skip TLS certificate verification for both AMQPS and HTTPS mgmt") |
|
flag.BoolVar(&f.openChannel, "open-channel", true, |
|
"Open one AMQP 0-9-1 channel per connection") |
|
flag.BoolVar(&f.openSession, "open-session", true, |
|
"Open one AMQP 1.0 session per connection") |
|
flag.IntVar(&f.rampConcurrency, "ramp-concurrency", 32, |
|
"Maximum number of concurrent dials during ramp-up") |
|
flag.DurationVar(&f.dialTimeout, "dial-timeout", 30*time.Second, |
|
"Per-connection dial timeout") |
|
flag.DurationVar(&f.heartbeat, "heartbeat", 60*time.Second, |
|
"AMQP 0-9-1 heartbeat and AMQP 1.0 idle-timeout") |
|
|
|
flag.StringVar(&f.output, "output", "amqp-memory-compare.csv", "CSV output path") |
|
flag.Parse() |
|
|
|
if err := validateFlags(&f); err != nil { |
|
fmt.Fprintln(os.Stderr, "error:", err) |
|
flag.Usage() |
|
os.Exit(2) |
|
} |
|
|
|
if err := run(f); err != nil { |
|
log.Fatal(err) |
|
} |
|
} |
|
|
|
func validateFlags(f *flags) error { |
|
if f.broker091 == "" || f.broker10 == "" { |
|
return fmt.Errorf("--broker-091 and --broker-10 are required") |
|
} |
|
if f.start < 1 || f.max < f.start || f.step < 1 { |
|
return fmt.Errorf("require 1 <= --start <= --max and --step >= 1") |
|
} |
|
if f.sampleInterval <= 0 || f.stepInterval <= 0 { |
|
return fmt.Errorf("--sample-interval and --step-interval must be positive") |
|
} |
|
if f.mgmt091 == "" { |
|
f.mgmt091 = deriveMgmtURL(f.broker091) |
|
} |
|
if f.mgmt10 == "" { |
|
f.mgmt10 = deriveMgmtURL(f.broker10) |
|
} |
|
return nil |
|
} |
|
|
|
// deriveMgmtURL derives a reasonable default management URL from an AMQP URL |
|
// using the conventional port mapping: 5671 -> 15671 and 5672 -> 15672. |
|
// Any custom port is mapped by adding 10000 to it. |
|
func deriveMgmtURL(amqpURL string) string { |
|
u, err := url.Parse(amqpURL) |
|
if err != nil { |
|
return "" |
|
} |
|
host, port := splitHostPort(u.Host, u.Scheme) |
|
var scheme, mgmtPort string |
|
switch u.Scheme { |
|
case "amqps": |
|
scheme = "https" |
|
mgmtPort = defaultMgmtTLS |
|
if port != "" && port != defaultAMQPSPort { |
|
mgmtPort = shiftPort(port) |
|
} |
|
default: |
|
scheme = "http" |
|
mgmtPort = defaultMgmtPort |
|
if port != "" && port != defaultAMQPPort { |
|
mgmtPort = shiftPort(port) |
|
} |
|
} |
|
mu := &url.URL{ |
|
Scheme: scheme, |
|
Host: host + ":" + mgmtPort, |
|
User: u.User, |
|
} |
|
return mu.String() |
|
} |
|
|
|
func splitHostPort(hostport, scheme string) (host, port string) { |
|
if i := strings.LastIndex(hostport, ":"); i > 0 && !strings.Contains(hostport[i:], "]") { |
|
return hostport[:i], hostport[i+1:] |
|
} |
|
return hostport, "" |
|
} |
|
|
|
func shiftPort(port string) string { |
|
if n, err := strconv.Atoi(port); err == nil { |
|
return strconv.Itoa(n + 10000) |
|
} |
|
return port |
|
} |
|
|
|
func run(f flags) error { |
|
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) |
|
defer cancel() |
|
|
|
tlsConf := &tls.Config{InsecureSkipVerify: f.insecure} |
|
httpClient := &http.Client{ |
|
Timeout: 15 * time.Second, |
|
Transport: &http.Transport{ |
|
TLSClientConfig: tlsConf, |
|
MaxIdleConnsPerHost: 4, |
|
}, |
|
} |
|
|
|
b091, err := newBroker(protoAMQP091, f.broker091, f.mgmt091, tlsConf, httpClient, f) |
|
if err != nil { |
|
return fmt.Errorf("broker-091: %w", err) |
|
} |
|
b10, err := newBroker(protoAMQP10, f.broker10, f.mgmt10, tlsConf, httpClient, f) |
|
if err != nil { |
|
return fmt.Errorf("broker-10: %w", err) |
|
} |
|
|
|
// Validate management endpoints before starting the test proper. |
|
if err := b091.probe(ctx); err != nil { |
|
return fmt.Errorf("broker-091 mgmt probe failed: %w", err) |
|
} |
|
if err := b10.probe(ctx); err != nil { |
|
return fmt.Errorf("broker-10 mgmt probe failed: %w", err) |
|
} |
|
|
|
log.Printf("broker-091 amqp=%s mgmt=%s", b091.amqpHost, b091.mgmtBase) |
|
log.Printf("broker-10 amqp=%s mgmt=%s", b10.amqpHost, b10.mgmtBase) |
|
|
|
out, err := os.Create(f.output) |
|
if err != nil { |
|
return err |
|
} |
|
defer out.Close() |
|
w := csv.NewWriter(out) |
|
defer w.Flush() |
|
|
|
if err := w.Write(csvHeader()); err != nil { |
|
return err |
|
} |
|
w.Flush() |
|
|
|
// Teardown should always run, even on early return or signal. |
|
defer func() { |
|
log.Println("Tearing down connections") |
|
b091.closeAll() |
|
b10.closeAll() |
|
}() |
|
|
|
start := time.Now() |
|
log.Println("Recording baseline sample (no test connections)") |
|
if err := writeSamples(ctx, w, start, "baseline", 0, 0, b091, b10); err != nil { |
|
log.Printf("baseline sample error: %v", err) |
|
} |
|
w.Flush() |
|
|
|
steps := buildSteps(f.start, f.max, f.step) |
|
log.Printf("Running %d steps: %v", len(steps), steps) |
|
|
|
for i, target := range steps { |
|
if ctx.Err() != nil { |
|
break |
|
} |
|
stepIdx := i + 1 |
|
log.Printf("=== Step %d/%d: target=%d connections per broker ===", stepIdx, len(steps), target) |
|
|
|
// Ramp both brokers in parallel so memory samples taken during the |
|
// subsequent hold reflect roughly the same observation window on each. |
|
var wg sync.WaitGroup |
|
wg.Add(2) |
|
go func() { defer wg.Done(); b091.scale(ctx, target, f.rampConcurrency) }() |
|
go func() { defer wg.Done(); b10.scale(ctx, target, f.rampConcurrency) }() |
|
wg.Wait() |
|
|
|
log.Printf("Step %d ramp-up done: 091 alive=%d, 10 alive=%d; holding for %s", |
|
stepIdx, b091.alive(), b10.alive(), f.stepInterval) |
|
|
|
if err := holdAndSample(ctx, w, start, "hold", stepIdx, target, f.stepInterval, f.sampleInterval, b091, b10); err != nil { |
|
// ctx cancellation is the expected exit path; other errors are already logged. |
|
break |
|
} |
|
} |
|
|
|
if ctx.Err() == nil && f.settleInterval > 0 { |
|
// Close everything, wait briefly, then take one more sample so the |
|
// CSV shows the broker settling back toward baseline. |
|
b091.closeAll() |
|
b10.closeAll() |
|
select { |
|
case <-ctx.Done(): |
|
case <-time.After(f.settleInterval): |
|
} |
|
if ctx.Err() == nil { |
|
_ = writeSamples(ctx, w, start, "post-teardown", len(steps)+1, 0, b091, b10) |
|
} |
|
} |
|
w.Flush() |
|
|
|
log.Printf("Summary broker-091 (%s): opened=%d failed=%d dropped=%d", |
|
b091.name(), b091.opened.Load(), b091.failed.Load(), b091.dropped.Load()) |
|
log.Printf("Summary broker-10 (%s): opened=%d failed=%d dropped=%d", |
|
b10.name(), b10.opened.Load(), b10.failed.Load(), b10.dropped.Load()) |
|
log.Printf("Results written to %s", f.output) |
|
return nil |
|
} |
|
|
|
// buildSteps returns a monotonically increasing list of per-broker connection |
|
// targets, ending at max even if max is not a multiple of step. |
|
func buildSteps(start, max, step int) []int { |
|
out := []int{} |
|
for v := start; v <= max; v += step { |
|
out = append(out, v) |
|
} |
|
if len(out) == 0 || out[len(out)-1] != max { |
|
out = append(out, max) |
|
} |
|
return out |
|
} |
|
|
|
func csvHeader() []string { |
|
return []string{ |
|
"timestamp_rfc3339", |
|
"elapsed_seconds", |
|
"phase", |
|
"step_index", |
|
"target_connections", |
|
"broker", |
|
"protocol", |
|
"client_alive", |
|
"client_opened_total", |
|
"client_failed_total", |
|
"client_dropped_total", |
|
"mgmt_connections_total", |
|
"node", |
|
"mem_used", |
|
"mem_total_rss", |
|
"mem_total_allocated", |
|
"mem_total_erlang", |
|
"mem_connection_readers", |
|
"mem_connection_writers", |
|
"mem_connection_channels", |
|
"mem_connection_other", |
|
"mem_binary", |
|
"mem_other_proc", |
|
"mem_other_system", |
|
"mem_code", |
|
"mem_atom", |
|
} |
|
} |
|
|
|
func holdAndSample(ctx context.Context, w *csv.Writer, start time.Time, phase string, stepIdx, target int, |
|
total, interval time.Duration, brokers ...*broker) error { |
|
deadline := time.Now().Add(total) |
|
for { |
|
if err := writeSamples(ctx, w, start, phase, stepIdx, target, brokers...); err != nil { |
|
log.Printf("sample error: %v", err) |
|
} |
|
w.Flush() |
|
remaining := time.Until(deadline) |
|
if remaining <= 0 { |
|
return nil |
|
} |
|
wait := interval |
|
if wait > remaining { |
|
wait = remaining |
|
} |
|
select { |
|
case <-ctx.Done(): |
|
return ctx.Err() |
|
case <-time.After(wait): |
|
} |
|
} |
|
} |
|
|
|
func writeSamples(ctx context.Context, w *csv.Writer, start time.Time, phase string, stepIdx, target int, brokers ...*broker) error { |
|
ts := time.Now() |
|
elapsed := ts.Sub(start).Seconds() |
|
var firstErr error |
|
for _, br := range brokers { |
|
samples, err := br.sample(ctx) |
|
if err != nil { |
|
if firstErr == nil { |
|
firstErr = err |
|
} |
|
log.Printf("%s (%s): sample error: %v", br.name(), br.protocol, err) |
|
// Emit one row with empty mgmt fields so the CSV stays aligned. |
|
_ = w.Write(emptyMgmtRow(ts, elapsed, phase, stepIdx, target, br)) |
|
continue |
|
} |
|
for _, s := range samples { |
|
_ = w.Write(sampleRow(ts, elapsed, phase, stepIdx, target, br, s)) |
|
} |
|
if len(samples) > 0 { |
|
s := samples[0] |
|
log.Printf("%s (%s): alive=%d mgmt_conn=%d mem_used=%s rss=%s", |
|
br.name(), br.protocol, br.alive(), s.MgmtConnections, |
|
humanBytes(s.MemUsed), humanBytes(s.TotalRss)) |
|
} |
|
} |
|
return firstErr |
|
} |
|
|
|
func emptyMgmtRow(ts time.Time, elapsed float64, phase string, stepIdx, target int, br *broker) []string { |
|
return []string{ |
|
ts.UTC().Format(time.RFC3339Nano), |
|
strconv.FormatFloat(elapsed, 'f', 3, 64), |
|
phase, |
|
strconv.Itoa(stepIdx), |
|
strconv.Itoa(target), |
|
br.name(), |
|
br.protocol, |
|
strconv.Itoa(br.alive()), |
|
strconv.FormatInt(br.opened.Load(), 10), |
|
strconv.FormatInt(br.failed.Load(), 10), |
|
strconv.FormatInt(br.dropped.Load(), 10), |
|
"", "", "", "", "", "", "", "", "", "", "", "", "", "", "", |
|
} |
|
} |
|
|
|
func sampleRow(ts time.Time, elapsed float64, phase string, stepIdx, target int, br *broker, s memSample) []string { |
|
return []string{ |
|
ts.UTC().Format(time.RFC3339Nano), |
|
strconv.FormatFloat(elapsed, 'f', 3, 64), |
|
phase, |
|
strconv.Itoa(stepIdx), |
|
strconv.Itoa(target), |
|
br.name(), |
|
br.protocol, |
|
strconv.Itoa(br.alive()), |
|
strconv.FormatInt(br.opened.Load(), 10), |
|
strconv.FormatInt(br.failed.Load(), 10), |
|
strconv.FormatInt(br.dropped.Load(), 10), |
|
strconv.Itoa(s.MgmtConnections), |
|
s.Node, |
|
strconv.FormatInt(s.MemUsed, 10), |
|
strconv.FormatInt(s.TotalRss, 10), |
|
strconv.FormatInt(s.TotalAllocated, 10), |
|
strconv.FormatInt(s.TotalErlang, 10), |
|
strconv.FormatInt(s.ConnectionReaders, 10), |
|
strconv.FormatInt(s.ConnectionWriters, 10), |
|
strconv.FormatInt(s.ConnectionChannels, 10), |
|
strconv.FormatInt(s.ConnectionOther, 10), |
|
strconv.FormatInt(s.Binary, 10), |
|
strconv.FormatInt(s.OtherProc, 10), |
|
strconv.FormatInt(s.OtherSystem, 10), |
|
strconv.FormatInt(s.Code, 10), |
|
strconv.FormatInt(s.Atom, 10), |
|
} |
|
} |
|
|
|
func humanBytes(n int64) string { |
|
const unit = 1024 |
|
if n < unit { |
|
return fmt.Sprintf("%d B", n) |
|
} |
|
div, exp := int64(unit), 0 |
|
for v := n / unit; v >= unit; v /= unit { |
|
div *= unit |
|
exp++ |
|
} |
|
return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp]) |
|
} |
|
|
|
// broker tracks state for a single RabbitMQ node: its AMQP endpoint, the |
|
// matching HTTP management endpoint, the open client-side connections and |
|
// a few counters that summarise client-observed outcomes. |
|
type broker struct { |
|
protocol string |
|
|
|
amqpURL string |
|
amqpHost string |
|
dialFlags flags |
|
|
|
mgmtClient *http.Client |
|
mgmtBase string |
|
mgmtUser *url.Userinfo |
|
|
|
// tlsConf is a per-broker clone of the shared TLS config with ServerName |
|
// preset to the broker host. It is never mutated after newBroker returns |
|
// so it can be shared across concurrent dials without locking. |
|
// amqp091-go would otherwise mutate ServerName on the first AMQPS dial, |
|
// which would race with other dials that use the same config. |
|
tlsConf *tls.Config |
|
|
|
mu sync.Mutex |
|
conns []*trackedConn |
|
shuttingDown atomic.Bool |
|
|
|
opened atomic.Int64 |
|
failed atomic.Int64 |
|
dropped atomic.Int64 |
|
} |
|
|
|
// trackedConn wraps a protocol-specific connection with a close marker used |
|
// to distinguish intentional teardowns from unexpected drops. |
|
type trackedConn struct { |
|
closer io.Closer |
|
closed atomic.Bool |
|
} |
|
|
|
func (t *trackedConn) Close() error { |
|
t.closed.Store(true) |
|
return t.closer.Close() |
|
} |
|
|
|
func newBroker(protocol, amqpURL, mgmtURL string, tlsConf *tls.Config, client *http.Client, f flags) (*broker, error) { |
|
au, err := url.Parse(amqpURL) |
|
if err != nil { |
|
return nil, fmt.Errorf("invalid amqp URL: %w", err) |
|
} |
|
if au.Scheme != "amqp" && au.Scheme != "amqps" { |
|
return nil, fmt.Errorf("amqp URL must use amqp or amqps scheme, got %q", au.Scheme) |
|
} |
|
mu, err := url.Parse(mgmtURL) |
|
if err != nil { |
|
return nil, fmt.Errorf("invalid mgmt URL: %w", err) |
|
} |
|
if mu.Scheme != "http" && mu.Scheme != "https" { |
|
return nil, fmt.Errorf("mgmt URL must use http or https scheme, got %q", mu.Scheme) |
|
} |
|
auth := mu.User |
|
mu.User = nil |
|
mu.Path = strings.TrimSuffix(mu.Path, "/") |
|
|
|
// Preset ServerName on a per-broker clone so the shared tls.Config is not |
|
// mutated by amqp091-go's first AMQPS dial. |
|
host, _ := splitHostPort(au.Host, au.Scheme) |
|
perBrokerTLS := tlsConf.Clone() |
|
if perBrokerTLS.ServerName == "" { |
|
perBrokerTLS.ServerName = host |
|
} |
|
|
|
return &broker{ |
|
protocol: protocol, |
|
amqpURL: amqpURL, |
|
amqpHost: au.Host, |
|
dialFlags: f, |
|
mgmtClient: client, |
|
mgmtBase: mu.String(), |
|
mgmtUser: auth, |
|
tlsConf: perBrokerTLS, |
|
}, nil |
|
} |
|
|
|
func (b *broker) name() string { return b.amqpHost } |
|
|
|
func (b *broker) alive() int { |
|
b.mu.Lock() |
|
defer b.mu.Unlock() |
|
return len(b.conns) |
|
} |
|
|
|
// scale brings the number of alive connections up to target by dialing the |
|
// missing count concurrently. Scale-down is not implemented because the tool |
|
// only ramps up; closeAll handles teardown. |
|
func (b *broker) scale(ctx context.Context, target, concurrency int) { |
|
current := b.alive() |
|
if current >= target { |
|
return |
|
} |
|
needed := target - current |
|
if concurrency < 1 { |
|
concurrency = 1 |
|
} |
|
sem := make(chan struct{}, concurrency) |
|
var wg sync.WaitGroup |
|
var errMu sync.Mutex |
|
var errSamples []string |
|
var errCount int |
|
|
|
for i := 0; i < needed; i++ { |
|
if ctx.Err() != nil { |
|
break |
|
} |
|
wg.Add(1) |
|
sem <- struct{}{} |
|
go func() { |
|
defer wg.Done() |
|
defer func() { <-sem }() |
|
if err := b.dial(ctx); err != nil { |
|
b.failed.Add(1) |
|
errMu.Lock() |
|
errCount++ |
|
if len(errSamples) < 3 { |
|
errSamples = append(errSamples, err.Error()) |
|
} |
|
errMu.Unlock() |
|
} |
|
}() |
|
} |
|
wg.Wait() |
|
if errCount > 0 { |
|
log.Printf("%s (%s): %d dial errors; first samples: %v", b.name(), b.protocol, errCount, errSamples) |
|
} |
|
} |
|
|
|
// dial establishes one connection using the broker's protocol. On success, |
|
// it tracks the connection for subsequent teardown and drop detection, and |
|
// returns nil. The caller never needs a reference to the connection. |
|
func (b *broker) dial(ctx context.Context) error { |
|
switch b.protocol { |
|
case protoAMQP091: |
|
return b.dial091(ctx) |
|
case protoAMQP10: |
|
return b.dial10(ctx) |
|
} |
|
return fmt.Errorf("unknown protocol %q", b.protocol) |
|
} |
|
|
|
// track registers tc in the alive list and starts a goroutine that watches |
|
// for unexpected drops via watch. track MUST be called only after watch has |
|
// been wired up (e.g. via NotifyClose) so that a close signal occurring |
|
// between wiring and tracking is delivered via the channel buffer. |
|
// If the broker is already shutting down, tc is closed immediately so a |
|
// dial that completes mid-teardown does not leak past closeAll. |
|
func (b *broker) track(tc *trackedConn, watch <-chan struct{}) { |
|
b.mu.Lock() |
|
if b.shuttingDown.Load() { |
|
b.mu.Unlock() |
|
_ = tc.Close() |
|
return |
|
} |
|
b.conns = append(b.conns, tc) |
|
b.mu.Unlock() |
|
b.opened.Add(1) |
|
go func() { |
|
<-watch |
|
if !tc.closed.Load() && !b.shuttingDown.Load() { |
|
b.dropped.Add(1) |
|
b.removeConn(tc) |
|
} |
|
}() |
|
} |
|
|
|
func (b *broker) dial091(ctx context.Context) error { |
|
// Use a context-aware TCP dialer so SIGINT can interrupt dials in flight. |
|
// Once the TCP connection is up, amqp091-go manages its own deadlines |
|
// for the TLS and AMQP handshakes. |
|
netDialer := &net.Dialer{Timeout: b.dialFlags.dialTimeout} |
|
cfg := amqp091.Config{ |
|
TLSClientConfig: b.tlsConf, |
|
Heartbeat: b.dialFlags.heartbeat, |
|
Dial: func(network, addr string) (net.Conn, error) { |
|
c, err := netDialer.DialContext(ctx, network, addr) |
|
if err != nil { |
|
return nil, err |
|
} |
|
// Match amqp091-go's DefaultDial: set a handshake deadline that is |
|
// cleared once the AMQP open completes. |
|
if err := c.SetDeadline(time.Now().Add(b.dialFlags.dialTimeout)); err != nil { |
|
_ = c.Close() |
|
return nil, err |
|
} |
|
return c, nil |
|
}, |
|
Properties: amqp091.Table{ |
|
"connection_name": "amqp-memory-compare", |
|
"product": "amqp-memory-compare", |
|
}, |
|
} |
|
if err := ctx.Err(); err != nil { |
|
return err |
|
} |
|
conn, err := amqp091.DialConfig(b.amqpURL, cfg) |
|
if err != nil { |
|
return err |
|
} |
|
wrap := &conn091{conn: conn} |
|
if b.dialFlags.openChannel { |
|
ch, err := conn.Channel() |
|
if err != nil { |
|
_ = conn.Close() |
|
return fmt.Errorf("open channel: %w", err) |
|
} |
|
wrap.ch = ch |
|
} |
|
// Wire the close listener before tracking so a close that happens between |
|
// NotifyClose and track is still delivered via the buffered channel. |
|
closeCh := make(chan *amqp091.Error, 1) |
|
conn.NotifyClose(closeCh) |
|
watch := make(chan struct{}) |
|
go func() { |
|
<-closeCh |
|
close(watch) |
|
}() |
|
b.track(&trackedConn{closer: wrap}, watch) |
|
return nil |
|
} |
|
|
|
func (b *broker) dial10(ctx context.Context) error { |
|
u, err := url.Parse(b.amqpURL) |
|
if err != nil { |
|
return err |
|
} |
|
host, port := splitHostPort(u.Host, u.Scheme) |
|
if port == "" { |
|
if u.Scheme == "amqps" { |
|
port = defaultAMQPSPort |
|
} else { |
|
port = defaultAMQPPort |
|
} |
|
} |
|
addr := u.Scheme + "://" + host + ":" + port |
|
vhost := strings.TrimPrefix(u.Path, "/") |
|
|
|
opts := &amqp10.ConnOptions{ |
|
TLSConfig: b.tlsConf, |
|
IdleTimeout: b.dialFlags.heartbeat, |
|
Properties: map[string]any{ |
|
"connection_name": "amqp-memory-compare", |
|
"product": "amqp-memory-compare", |
|
}, |
|
} |
|
if u.User != nil { |
|
pass, _ := u.User.Password() |
|
opts.SASLType = amqp10.SASLTypePlain(u.User.Username(), pass) |
|
} |
|
if vhost != "" { |
|
// RabbitMQ maps the AMQP 1.0 hostname "vhost:<name>" to a specific vhost. |
|
// See rabbit_amqp_reader:vhost/1. |
|
opts.HostName = "vhost:" + vhost |
|
} |
|
|
|
dialCtx, cancel := context.WithTimeout(ctx, b.dialFlags.dialTimeout) |
|
defer cancel() |
|
conn, err := amqp10.Dial(dialCtx, addr, opts) |
|
if err != nil { |
|
return err |
|
} |
|
wrap := &conn10{conn: conn} |
|
if b.dialFlags.openSession { |
|
sess, err := conn.NewSession(ctx, nil) |
|
if err != nil { |
|
_ = conn.Close() |
|
return fmt.Errorf("open session: %w", err) |
|
} |
|
wrap.sess = sess |
|
} |
|
b.track(&trackedConn{closer: wrap}, conn.Done()) |
|
return nil |
|
} |
|
|
|
func (b *broker) removeConn(tc *trackedConn) { |
|
b.mu.Lock() |
|
defer b.mu.Unlock() |
|
for i, existing := range b.conns { |
|
if existing == tc { |
|
b.conns = append(b.conns[:i], b.conns[i+1:]...) |
|
return |
|
} |
|
} |
|
} |
|
|
|
func (b *broker) closeAll() { |
|
b.shuttingDown.Store(true) |
|
b.mu.Lock() |
|
conns := b.conns |
|
b.conns = nil |
|
b.mu.Unlock() |
|
var wg sync.WaitGroup |
|
for _, c := range conns { |
|
wg.Add(1) |
|
go func(c *trackedConn) { |
|
defer wg.Done() |
|
_ = c.Close() |
|
}(c) |
|
} |
|
wg.Wait() |
|
} |
|
|
|
// conn091 groups an AMQP 0-9-1 connection with its single optional channel |
|
// so they are torn down together. |
|
type conn091 struct { |
|
conn *amqp091.Connection |
|
ch *amqp091.Channel |
|
} |
|
|
|
func (c *conn091) Close() error { |
|
if c.ch != nil { |
|
_ = c.ch.Close() |
|
} |
|
return c.conn.Close() |
|
} |
|
|
|
// conn10 groups an AMQP 1.0 connection with its single optional session. |
|
type conn10 struct { |
|
conn *amqp10.Conn |
|
sess *amqp10.Session |
|
} |
|
|
|
func (c *conn10) Close() error { |
|
if c.sess != nil { |
|
// Session close takes a context. Use a short deadline so teardown does |
|
// not stall if the peer has already gone away. |
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) |
|
defer cancel() |
|
_ = c.sess.Close(ctx) |
|
} |
|
return c.conn.Close() |
|
} |
|
|
|
// memSample holds one node's memory snapshot. The tool emits one row per |
|
// broker per node; single-instance brokers emit exactly one row per sample. |
|
type memSample struct { |
|
Node string |
|
MgmtConnections int |
|
MemUsed int64 |
|
TotalRss int64 |
|
TotalAllocated int64 |
|
TotalErlang int64 |
|
ConnectionReaders int64 |
|
ConnectionWriters int64 |
|
ConnectionChannels int64 |
|
ConnectionOther int64 |
|
Binary int64 |
|
OtherProc int64 |
|
OtherSystem int64 |
|
Code int64 |
|
Atom int64 |
|
} |
|
|
|
type apiNode struct { |
|
Name string `json:"name"` |
|
MemUsed int64 `json:"mem_used"` |
|
} |
|
|
|
type apiOverview struct { |
|
ObjectTotals struct { |
|
Connections int `json:"connections"` |
|
} `json:"object_totals"` |
|
} |
|
|
|
// apiNodeMemory mirrors the shape of GET /api/nodes/{name}/memory. |
|
// See rabbit_vm:memory/0 in the RabbitMQ source. |
|
type apiNodeMemory struct { |
|
Memory struct { |
|
ConnectionReaders int64 `json:"connection_readers"` |
|
ConnectionWriters int64 `json:"connection_writers"` |
|
ConnectionChannels int64 `json:"connection_channels"` |
|
ConnectionOther int64 `json:"connection_other"` |
|
Binary int64 `json:"binary"` |
|
OtherProc int64 `json:"other_proc"` |
|
OtherSystem int64 `json:"other_system"` |
|
Code int64 `json:"code"` |
|
Atom int64 `json:"atom"` |
|
Total struct { |
|
Rss int64 `json:"rss"` |
|
Allocated int64 `json:"allocated"` |
|
Erlang int64 `json:"erlang"` |
|
} `json:"total"` |
|
} `json:"memory"` |
|
} |
|
|
|
func (b *broker) probe(ctx context.Context) error { |
|
// Exercise every endpoint sample() uses so that authorisation or URL |
|
// problems surface up-front instead of as repeated errors mid-run. |
|
nodes, err := b.fetchNodes(ctx) |
|
if err != nil { |
|
return fmt.Errorf("GET /api/nodes: %w", err) |
|
} |
|
if len(nodes) == 0 { |
|
return fmt.Errorf("GET /api/nodes returned no nodes") |
|
} |
|
if _, err := b.fetchOverview(ctx); err != nil { |
|
return fmt.Errorf("GET /api/overview: %w", err) |
|
} |
|
if _, err := b.fetchNodeMemory(ctx, nodes[0].Name); err != nil { |
|
return fmt.Errorf("GET /api/nodes/%s/memory: %w", nodes[0].Name, err) |
|
} |
|
return nil |
|
} |
|
|
|
func (b *broker) sample(ctx context.Context) ([]memSample, error) { |
|
nodes, err := b.fetchNodes(ctx) |
|
if err != nil { |
|
return nil, fmt.Errorf("fetch nodes: %w", err) |
|
} |
|
ov, err := b.fetchOverview(ctx) |
|
if err != nil { |
|
return nil, fmt.Errorf("fetch overview: %w", err) |
|
} |
|
out := make([]memSample, 0, len(nodes)) |
|
for _, n := range nodes { |
|
m, err := b.fetchNodeMemory(ctx, n.Name) |
|
if err != nil { |
|
return nil, fmt.Errorf("fetch node memory %q: %w", n.Name, err) |
|
} |
|
out = append(out, memSample{ |
|
Node: n.Name, |
|
MgmtConnections: ov.ObjectTotals.Connections, |
|
MemUsed: n.MemUsed, |
|
TotalRss: m.Memory.Total.Rss, |
|
TotalAllocated: m.Memory.Total.Allocated, |
|
TotalErlang: m.Memory.Total.Erlang, |
|
ConnectionReaders: m.Memory.ConnectionReaders, |
|
ConnectionWriters: m.Memory.ConnectionWriters, |
|
ConnectionChannels: m.Memory.ConnectionChannels, |
|
ConnectionOther: m.Memory.ConnectionOther, |
|
Binary: m.Memory.Binary, |
|
OtherProc: m.Memory.OtherProc, |
|
OtherSystem: m.Memory.OtherSystem, |
|
Code: m.Memory.Code, |
|
Atom: m.Memory.Atom, |
|
}) |
|
} |
|
return out, nil |
|
} |
|
|
|
func (b *broker) fetchNodes(ctx context.Context) ([]apiNode, error) { |
|
var nodes []apiNode |
|
err := b.getJSON(ctx, "/api/nodes?columns=name,mem_used", &nodes) |
|
return nodes, err |
|
} |
|
|
|
func (b *broker) fetchOverview(ctx context.Context) (apiOverview, error) { |
|
var ov apiOverview |
|
err := b.getJSON(ctx, "/api/overview?columns=object_totals.connections", &ov) |
|
return ov, err |
|
} |
|
|
|
func (b *broker) fetchNodeMemory(ctx context.Context, node string) (apiNodeMemory, error) { |
|
var mem apiNodeMemory |
|
err := b.getJSON(ctx, "/api/nodes/"+url.PathEscape(node)+"/memory", &mem) |
|
return mem, err |
|
} |
|
|
|
func (b *broker) getJSON(ctx context.Context, path string, v any) error { |
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, b.mgmtBase+path, nil) |
|
if err != nil { |
|
return err |
|
} |
|
if b.mgmtUser != nil { |
|
pass, _ := b.mgmtUser.Password() |
|
req.SetBasicAuth(b.mgmtUser.Username(), pass) |
|
} |
|
req.Header.Set("Accept", "application/json") |
|
resp, err := b.mgmtClient.Do(req) |
|
if err != nil { |
|
return err |
|
} |
|
defer resp.Body.Close() |
|
if resp.StatusCode/100 != 2 { |
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) |
|
return fmt.Errorf("GET %s: %s: %s", path, resp.Status, strings.TrimSpace(string(body))) |
|
} |
|
return json.NewDecoder(resp.Body).Decode(v) |
|
} |