Skip to content

Instantly share code, notes, and snippets.

@fabriziosalmi
Created January 3, 2025 10:54
Show Gist options
  • Save fabriziosalmi/4f60cd215e9d7ef11836b790387f0bba to your computer and use it in GitHub Desktop.
Save fabriziosalmi/4f60cd215e9d7ef11836b790387f0bba to your computer and use it in GitHub Desktop.
4xx, 5xx error loop protection proposal
package caddywaf
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"os"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/oschwald/maxminddb-golang"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func init() {
caddy.RegisterModule(Middleware{})
httpcaddyfile.RegisterHandlerDirective("waf", parseCaddyfile)
}
// RateLimit defines the rate limiting configuration.
type RateLimit struct {
Requests int `json:"requests"`
Window time.Duration `json:"window"`
}
// requestCounter tracks the number of requests from a single IP within a time window.
type requestCounter struct {
count int
window time.Time
}
// RateLimiter implements rate limiting for IP addresses.
type RateLimiter struct {
requests sync.Map
config RateLimit
}
// CountryBlocking defines the country blocking configuration.
type CountryBlocking struct {
Enabled bool `json:"enabled"`
BlockList []string `json:"block_list"`
GeoIPDBPath string `json:"geoip_db_path"`
}
// GeoIPCache caches GeoIP lookups.
type GeoIPCache struct {
cache sync.Map
geoIP *maxminddb.Reader
}
// GeoIPRecord represents the GeoIP data for an IP address.
type GeoIPRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Rule defines a WAF rule.
type Rule struct {
ID string `json:"id"`
Phase int `json:"phase"`
Pattern string `json:"pattern"`
Targets []string `json:"targets"`
Severity string `json:"severity"`
Action string `json:"action"`
Score int `json:"score"`
Mode string `json:"mode"`
Description string `json:"description"`
regex *regexp.Regexp
}
// SeverityConfig defines severity levels for WAF rules.
type SeverityConfig struct {
Critical string `json:"critical,omitempty"`
High string `json:"high,omitempty"`
Medium string `json:"medium,omitempty"`
Low string `json:"low,omitempty"`
}
// Middleware implements the WAF functionality.
type Middleware struct {
RuleFiles []string `json:"rule_files"`
IPBlacklistFile string `json:"ip_blacklist_file"`
DNSBlacklistFile string `json:"dns_blacklist_file"`
LogAll bool `json:"log_all"`
AnomalyThreshold int `json:"anomaly_threshold"`
RateLimit RateLimit `json:"rate_limit"`
CountryBlock CountryBlocking `json:"country_block"`
Severity SeverityConfig `json:"severity,omitempty"`
ErrorThreshold int `json:"error_threshold"`
ErrorWindow time.Duration `json:"error_window"`
ErrorStatusCodes []int `json:"error_status_codes"` // Customizable status codes for loop protection
Rules []Rule `json:"-"`
logger *zap.Logger
logChan chan *zapcore.Entry
ipBlacklistCIDRs []IPNet
ipBlacklist map[string]bool
dnsBlacklist map[string]bool
geoIPCache *GeoIPCache
errorTracker *ErrorTracker
rateLimiter *RateLimiter
}
// IPNet represents an IP network.
type IPNet struct {
IP net.IP
Mask net.IPMask
}
// ErrorTracker tracks HTTP errors from IP addresses.
type ErrorTracker struct {
errors map[string]int
mu sync.Mutex
}
// NewErrorTracker initializes a new ErrorTracker.
func NewErrorTracker() *ErrorTracker {
return &ErrorTracker{
errors: make(map[string]int),
}
}
// Increment increments the error count for an IP.
func (et *ErrorTracker) Increment(ip string) {
et.mu.Lock()
defer et.mu.Unlock()
et.errors[ip]++
}
// Reset resets the error count for an IP.
func (et *ErrorTracker) Reset(ip string) {
et.mu.Lock()
defer et.mu.Unlock()
et.errors[ip] = 0
}
// Count returns the error count for an IP.
func (et *ErrorTracker) Count(ip string) int {
et.mu.Lock()
defer et.mu.Unlock()
return et.errors[ip]
}
// CaddyModule returns the Caddy module information.
func (Middleware) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "http.handlers.waf",
New: func() caddy.Module { return &Middleware{} },
}
}
// parseCaddyfile parses the Caddyfile configuration.
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) {
var m Middleware
err := m.UnmarshalCaddyfile(h.Dispenser)
if err != nil {
return nil, err
}
return &m, nil
}
// UnmarshalCaddyfile unmarshals the Caddyfile configuration.
func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
for d.NextBlock(0) {
switch d.Val() {
case "error_threshold":
if !d.NextArg() {
return d.ArgErr()
}
threshold, err := strconv.Atoi(d.Val())
if err != nil {
return d.Errf("invalid error_threshold: %v", err)
}
m.ErrorThreshold = threshold
case "error_window":
if !d.NextArg() {
return d.ArgErr()
}
window, err := time.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid error_window: %v", err)
}
m.ErrorWindow = window
case "error_status_codes":
for d.NextArg() {
code, err := strconv.Atoi(d.Val())
if err != nil {
return d.Errf("invalid error_status_code: %v", err)
}
m.ErrorStatusCodes = append(m.ErrorStatusCodes, code)
}
case "rate_limit":
if !d.NextArg() {
return d.ArgErr()
}
requests, err := strconv.Atoi(d.Val())
if err != nil {
return d.Errf("invalid rate_limit requests: %v", err)
}
if !d.NextArg() {
return d.ArgErr()
}
window, err := time.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid rate_limit window: %v", err)
}
m.RateLimit = RateLimit{
Requests: requests,
Window: window,
}
case "ip_blacklist_file":
if !d.NextArg() {
return d.ArgErr()
}
m.IPBlacklistFile = d.Val()
case "dns_blacklist_file":
if !d.NextArg() {
return d.ArgErr()
}
m.DNSBlacklistFile = d.Val()
case "country_block":
m.CountryBlock.Enabled = true
for d.NextArg() {
m.CountryBlock.BlockList = append(m.CountryBlock.BlockList, d.Val())
}
if d.NextArg() {
m.CountryBlock.GeoIPDBPath = d.Val()
}
case "rule_files":
for d.NextArg() {
m.RuleFiles = append(m.RuleFiles, d.Val())
}
}
}
}
return nil
}
// ServeHTTP handles incoming HTTP requests.
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
// Check if the IP is blacklisted
if m.isIPBlacklisted(ip) {
m.logRequest(zapcore.WarnLevel, "Blocked blacklisted IP", zap.String("ip", ip))
http.Error(w, "Forbidden", http.StatusForbidden)
return nil
}
// Check if the country is blocked
if m.isCountryBlocked(r.RemoteAddr) {
m.logRequest(zapcore.WarnLevel, "Blocked country", zap.String("ip", ip))
http.Error(w, "Forbidden", http.StatusForbidden)
return nil
}
// Check rate limiting
if m.rateLimiter != nil && m.rateLimiter.isRateLimited(ip) {
m.logRequest(zapcore.WarnLevel, "Rate limited IP", zap.String("ip", ip))
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return nil
}
// Apply WAF rules
totalScore := m.handlePhase2(w, r)
if totalScore >= m.AnomalyThreshold {
m.handlePhase3(w, r)
return nil
}
// Call the next handler and capture the response status
recorder := &responseRecorder{ResponseWriter: w}
err = next.ServeHTTP(recorder, r)
// Track HTTP errors based on configured status codes
if contains(m.ErrorStatusCodes, recorder.status) {
m.errorTracker.Increment(ip)
if m.errorTracker.Count(ip) > m.ErrorThreshold {
// Block the IP if it exceeds the threshold
m.logRequest(zapcore.WarnLevel, "Blocking IP due to too many errors", zap.String("ip", ip))
http.Error(w, "Too many errors", http.StatusTooManyRequests)
return nil
}
} else {
// Reset the error count if the request is successful
m.errorTracker.Reset(ip)
}
return err
}
// contains checks if a slice contains a specific value.
func contains(slice []int, value int) bool {
for _, v := range slice {
if v == value {
return true
}
}
return false
}
// handlePhase2 applies WAF rules and returns the total anomaly score.
func (m *Middleware) handlePhase2(w http.ResponseWriter, r *http.Request) int {
totalScore := 0
for _, rule := range m.Rules {
if rule.Phase == 2 {
for _, target := range rule.Targets {
switch target {
case "url":
if rule.regex != nil && rule.regex.MatchString(r.URL.Path) {
totalScore += rule.Score
}
case "header":
for _, header := range r.Header {
for _, value := range header {
if rule.regex != nil && rule.regex.MatchString(value) {
totalScore += rule.Score
}
}
}
}
}
}
}
return totalScore
}
// handlePhase3 handles high anomaly scores (e.g., blocking or logging).
func (m *Middleware) handlePhase3(w http.ResponseWriter, r *http.Request) {
m.logRequest(zapcore.WarnLevel, "Blocked request due to high anomaly score", zap.String("ip", r.RemoteAddr))
http.Error(w, "Forbidden", http.StatusForbidden)
}
// loadRulesFromFiles loads WAF rules from the specified files.
func (m *Middleware) loadRulesFromFiles() error {
for _, file := range m.RuleFiles {
if err := m.loadRulesFromFile(file); err != nil {
return fmt.Errorf("failed to load rules from %s: %v", file, err)
}
}
return nil
}
// loadRulesFromFile loads WAF rules from a single file.
func (m *Middleware) loadRulesFromFile(path string) error {
content, err := os.ReadFile(path)
if err != nil {
return err
}
var rules []Rule
if err := json.Unmarshal(content, &rules); err != nil {
return err
}
for i, rule := range rules {
if rule.Mode == "re" {
regex, err := regexp.Compile(rule.Pattern)
if err != nil {
return fmt.Errorf("invalid pattern in rule %s: %v", rule.ID, err)
}
rules[i].regex = regex
}
}
m.Rules = append(m.Rules, rules...)
return nil
}
// responseRecorder captures the HTTP response status.
type responseRecorder struct {
http.ResponseWriter
status int
}
// WriteHeader captures the status code.
func (r *responseRecorder) WriteHeader(status int) {
r.status = status
r.ResponseWriter.WriteHeader(status)
}
// Provision initializes the middleware.
func (m *Middleware) Provision(ctx caddy.Context) error {
m.logger = ctx.Logger()
m.logChan = make(chan *zapcore.Entry, 1000)
go m.logWorker()
m.errorTracker = NewErrorTracker()
if m.RateLimit.Requests > 0 {
m.rateLimiter = &RateLimiter{
config: m.RateLimit,
}
}
if m.CountryBlock.Enabled {
geoIP, err := maxminddb.Open(m.CountryBlock.GeoIPDBPath)
if err != nil {
return fmt.Errorf("failed to load GeoIP database: %v", err)
}
m.geoIPCache = &GeoIPCache{
geoIP: geoIP,
}
}
if err := m.loadRulesFromFiles(); err != nil {
return err
}
if m.IPBlacklistFile != "" {
if err := m.loadIPBlacklistFromFile(m.IPBlacklistFile); err != nil {
return fmt.Errorf("failed to load IP blacklist from %s: %v", m.IPBlacklistFile, err)
}
}
if m.DNSBlacklistFile != "" {
if err := m.loadDNSBlacklistFromFile(m.DNSBlacklistFile); err != nil {
return fmt.Errorf("failed to load DNS blacklist from %s: %v", m.DNSBlacklistFile, err)
}
}
return nil
}
// loadIPBlacklistFromFile loads IP blacklist from a file.
func (m *Middleware) loadIPBlacklistFromFile(path string) error {
content, err := os.ReadFile(path)
if err != nil {
return err
}
lines := strings.Split(string(content), "\n")
ipNets := make([]IPNet, 0, len(lines))
m.ipBlacklist = make(map[string]bool)
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
ip, ipNet, err := net.ParseCIDR(line)
if err != nil {
m.ipBlacklist[ip.String()] = true
continue
}
ipNets = append(ipNets, IPNet{IP: ip, Mask: ipNet.Mask})
}
m.ipBlacklistCIDRs = ipNets
return nil
}
// loadDNSBlacklistFromFile loads DNS blacklist from a file.
func (m *Middleware) loadDNSBlacklistFromFile(path string) error {
content, err := os.ReadFile(path)
if err != nil {
return err
}
domains := strings.Split(string(content), "\n")
m.dnsBlacklist = make(map[string]bool)
for _, domain := range domains {
domain = strings.TrimSpace(domain)
if domain != "" {
m.dnsBlacklist[strings.ToLower(domain)] = true
}
}
return nil
}
// isIPBlacklisted checks if an IP is blacklisted.
func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {
ip := net.ParseIP(remoteAddr)
if ip == nil {
return false
}
for _, ipNet := range m.ipBlacklistCIDRs {
if ipNet.Mask.Contains(ip) {
return true
}
}
return m.ipBlacklist[remoteAddr]
}
// isCountryBlocked checks if a request is from a blocked country.
func (m *Middleware) isCountryBlocked(remoteAddr string) bool {
if !m.CountryBlock.Enabled || m.geoIPCache == nil {
return false
}
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ip = remoteAddr
}
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
country, err := m.geoIPCache.getCountry(parsedIP)
if err != nil {
return false
}
for _, blockedCountry := range m.CountryBlock.BlockList {
if strings.EqualFold(country, blockedCountry) {
return true
}
}
return false
}
// logRequest logs a request with the specified level and message.
func (m *Middleware) logRequest(level zapcore.Level, msg string, fields ...zap.Field) {
entry := &zapcore.Entry{
Level: level,
Message: msg,
Time: time.Now(),
Fields: fields,
}
select {
case m.logChan <- entry:
default:
// Optionally handle the case where the channel is full
}
}
// logWorker processes log entries.
func (m *Middleware) logWorker() {
for entry := range m.logChan {
m.logger.Write(entry)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment