Created
January 3, 2025 10:54
-
-
Save fabriziosalmi/4f60cd215e9d7ef11836b790387f0bba to your computer and use it in GitHub Desktop.
4xx, 5xx error loop protection proposal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package caddywaf | |
import ( | |
"bytes" | |
"encoding/json" | |
"fmt" | |
"io" | |
"net" | |
"os" | |
"regexp" | |
"strconv" | |
"strings" | |
"sync" | |
"time" | |
"github.com/caddyserver/caddy/v2" | |
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" | |
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" | |
"github.com/caddyserver/caddy/v2/modules/caddyhttp" | |
"github.com/oschwald/maxminddb-golang" | |
"go.uber.org/zap" | |
"go.uber.org/zap/zapcore" | |
) | |
func init() { | |
caddy.RegisterModule(Middleware{}) | |
httpcaddyfile.RegisterHandlerDirective("waf", parseCaddyfile) | |
} | |
// RateLimit defines the rate limiting configuration. | |
type RateLimit struct { | |
Requests int `json:"requests"` | |
Window time.Duration `json:"window"` | |
} | |
// requestCounter tracks the number of requests from a single IP within a time window. | |
type requestCounter struct { | |
count int | |
window time.Time | |
} | |
// RateLimiter implements rate limiting for IP addresses. | |
type RateLimiter struct { | |
requests sync.Map | |
config RateLimit | |
} | |
// CountryBlocking defines the country blocking configuration. | |
type CountryBlocking struct { | |
Enabled bool `json:"enabled"` | |
BlockList []string `json:"block_list"` | |
GeoIPDBPath string `json:"geoip_db_path"` | |
} | |
// GeoIPCache caches GeoIP lookups. | |
type GeoIPCache struct { | |
cache sync.Map | |
geoIP *maxminddb.Reader | |
} | |
// GeoIPRecord represents the GeoIP data for an IP address. | |
type GeoIPRecord struct { | |
Country struct { | |
ISOCode string `maxminddb:"iso_code"` | |
} `maxminddb:"country"` | |
} | |
// Rule defines a WAF rule. | |
type Rule struct { | |
ID string `json:"id"` | |
Phase int `json:"phase"` | |
Pattern string `json:"pattern"` | |
Targets []string `json:"targets"` | |
Severity string `json:"severity"` | |
Action string `json:"action"` | |
Score int `json:"score"` | |
Mode string `json:"mode"` | |
Description string `json:"description"` | |
regex *regexp.Regexp | |
} | |
// SeverityConfig defines severity levels for WAF rules. | |
type SeverityConfig struct { | |
Critical string `json:"critical,omitempty"` | |
High string `json:"high,omitempty"` | |
Medium string `json:"medium,omitempty"` | |
Low string `json:"low,omitempty"` | |
} | |
// Middleware implements the WAF functionality. | |
type Middleware struct { | |
RuleFiles []string `json:"rule_files"` | |
IPBlacklistFile string `json:"ip_blacklist_file"` | |
DNSBlacklistFile string `json:"dns_blacklist_file"` | |
LogAll bool `json:"log_all"` | |
AnomalyThreshold int `json:"anomaly_threshold"` | |
RateLimit RateLimit `json:"rate_limit"` | |
CountryBlock CountryBlocking `json:"country_block"` | |
Severity SeverityConfig `json:"severity,omitempty"` | |
ErrorThreshold int `json:"error_threshold"` | |
ErrorWindow time.Duration `json:"error_window"` | |
ErrorStatusCodes []int `json:"error_status_codes"` // Customizable status codes for loop protection | |
Rules []Rule `json:"-"` | |
logger *zap.Logger | |
logChan chan *zapcore.Entry | |
ipBlacklistCIDRs []IPNet | |
ipBlacklist map[string]bool | |
dnsBlacklist map[string]bool | |
geoIPCache *GeoIPCache | |
errorTracker *ErrorTracker | |
rateLimiter *RateLimiter | |
} | |
// IPNet represents an IP network. | |
type IPNet struct { | |
IP net.IP | |
Mask net.IPMask | |
} | |
// ErrorTracker tracks HTTP errors from IP addresses. | |
type ErrorTracker struct { | |
errors map[string]int | |
mu sync.Mutex | |
} | |
// NewErrorTracker initializes a new ErrorTracker. | |
func NewErrorTracker() *ErrorTracker { | |
return &ErrorTracker{ | |
errors: make(map[string]int), | |
} | |
} | |
// Increment increments the error count for an IP. | |
func (et *ErrorTracker) Increment(ip string) { | |
et.mu.Lock() | |
defer et.mu.Unlock() | |
et.errors[ip]++ | |
} | |
// Reset resets the error count for an IP. | |
func (et *ErrorTracker) Reset(ip string) { | |
et.mu.Lock() | |
defer et.mu.Unlock() | |
et.errors[ip] = 0 | |
} | |
// Count returns the error count for an IP. | |
func (et *ErrorTracker) Count(ip string) int { | |
et.mu.Lock() | |
defer et.mu.Unlock() | |
return et.errors[ip] | |
} | |
// CaddyModule returns the Caddy module information. | |
func (Middleware) CaddyModule() caddy.ModuleInfo { | |
return caddy.ModuleInfo{ | |
ID: "http.handlers.waf", | |
New: func() caddy.Module { return &Middleware{} }, | |
} | |
} | |
// parseCaddyfile parses the Caddyfile configuration. | |
func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) { | |
var m Middleware | |
err := m.UnmarshalCaddyfile(h.Dispenser) | |
if err != nil { | |
return nil, err | |
} | |
return &m, nil | |
} | |
// UnmarshalCaddyfile unmarshals the Caddyfile configuration. | |
func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { | |
for d.Next() { | |
for d.NextBlock(0) { | |
switch d.Val() { | |
case "error_threshold": | |
if !d.NextArg() { | |
return d.ArgErr() | |
} | |
threshold, err := strconv.Atoi(d.Val()) | |
if err != nil { | |
return d.Errf("invalid error_threshold: %v", err) | |
} | |
m.ErrorThreshold = threshold | |
case "error_window": | |
if !d.NextArg() { | |
return d.ArgErr() | |
} | |
window, err := time.ParseDuration(d.Val()) | |
if err != nil { | |
return d.Errf("invalid error_window: %v", err) | |
} | |
m.ErrorWindow = window | |
case "error_status_codes": | |
for d.NextArg() { | |
code, err := strconv.Atoi(d.Val()) | |
if err != nil { | |
return d.Errf("invalid error_status_code: %v", err) | |
} | |
m.ErrorStatusCodes = append(m.ErrorStatusCodes, code) | |
} | |
case "rate_limit": | |
if !d.NextArg() { | |
return d.ArgErr() | |
} | |
requests, err := strconv.Atoi(d.Val()) | |
if err != nil { | |
return d.Errf("invalid rate_limit requests: %v", err) | |
} | |
if !d.NextArg() { | |
return d.ArgErr() | |
} | |
window, err := time.ParseDuration(d.Val()) | |
if err != nil { | |
return d.Errf("invalid rate_limit window: %v", err) | |
} | |
m.RateLimit = RateLimit{ | |
Requests: requests, | |
Window: window, | |
} | |
case "ip_blacklist_file": | |
if !d.NextArg() { | |
return d.ArgErr() | |
} | |
m.IPBlacklistFile = d.Val() | |
case "dns_blacklist_file": | |
if !d.NextArg() { | |
return d.ArgErr() | |
} | |
m.DNSBlacklistFile = d.Val() | |
case "country_block": | |
m.CountryBlock.Enabled = true | |
for d.NextArg() { | |
m.CountryBlock.BlockList = append(m.CountryBlock.BlockList, d.Val()) | |
} | |
if d.NextArg() { | |
m.CountryBlock.GeoIPDBPath = d.Val() | |
} | |
case "rule_files": | |
for d.NextArg() { | |
m.RuleFiles = append(m.RuleFiles, d.Val()) | |
} | |
} | |
} | |
} | |
return nil | |
} | |
// ServeHTTP handles incoming HTTP requests. | |
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { | |
ip, _, err := net.SplitHostPort(r.RemoteAddr) | |
if err != nil { | |
ip = r.RemoteAddr | |
} | |
// Check if the IP is blacklisted | |
if m.isIPBlacklisted(ip) { | |
m.logRequest(zapcore.WarnLevel, "Blocked blacklisted IP", zap.String("ip", ip)) | |
http.Error(w, "Forbidden", http.StatusForbidden) | |
return nil | |
} | |
// Check if the country is blocked | |
if m.isCountryBlocked(r.RemoteAddr) { | |
m.logRequest(zapcore.WarnLevel, "Blocked country", zap.String("ip", ip)) | |
http.Error(w, "Forbidden", http.StatusForbidden) | |
return nil | |
} | |
// Check rate limiting | |
if m.rateLimiter != nil && m.rateLimiter.isRateLimited(ip) { | |
m.logRequest(zapcore.WarnLevel, "Rate limited IP", zap.String("ip", ip)) | |
http.Error(w, "Too many requests", http.StatusTooManyRequests) | |
return nil | |
} | |
// Apply WAF rules | |
totalScore := m.handlePhase2(w, r) | |
if totalScore >= m.AnomalyThreshold { | |
m.handlePhase3(w, r) | |
return nil | |
} | |
// Call the next handler and capture the response status | |
recorder := &responseRecorder{ResponseWriter: w} | |
err = next.ServeHTTP(recorder, r) | |
// Track HTTP errors based on configured status codes | |
if contains(m.ErrorStatusCodes, recorder.status) { | |
m.errorTracker.Increment(ip) | |
if m.errorTracker.Count(ip) > m.ErrorThreshold { | |
// Block the IP if it exceeds the threshold | |
m.logRequest(zapcore.WarnLevel, "Blocking IP due to too many errors", zap.String("ip", ip)) | |
http.Error(w, "Too many errors", http.StatusTooManyRequests) | |
return nil | |
} | |
} else { | |
// Reset the error count if the request is successful | |
m.errorTracker.Reset(ip) | |
} | |
return err | |
} | |
// contains checks if a slice contains a specific value. | |
func contains(slice []int, value int) bool { | |
for _, v := range slice { | |
if v == value { | |
return true | |
} | |
} | |
return false | |
} | |
// handlePhase2 applies WAF rules and returns the total anomaly score. | |
func (m *Middleware) handlePhase2(w http.ResponseWriter, r *http.Request) int { | |
totalScore := 0 | |
for _, rule := range m.Rules { | |
if rule.Phase == 2 { | |
for _, target := range rule.Targets { | |
switch target { | |
case "url": | |
if rule.regex != nil && rule.regex.MatchString(r.URL.Path) { | |
totalScore += rule.Score | |
} | |
case "header": | |
for _, header := range r.Header { | |
for _, value := range header { | |
if rule.regex != nil && rule.regex.MatchString(value) { | |
totalScore += rule.Score | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
return totalScore | |
} | |
// handlePhase3 handles high anomaly scores (e.g., blocking or logging). | |
func (m *Middleware) handlePhase3(w http.ResponseWriter, r *http.Request) { | |
m.logRequest(zapcore.WarnLevel, "Blocked request due to high anomaly score", zap.String("ip", r.RemoteAddr)) | |
http.Error(w, "Forbidden", http.StatusForbidden) | |
} | |
// loadRulesFromFiles loads WAF rules from the specified files. | |
func (m *Middleware) loadRulesFromFiles() error { | |
for _, file := range m.RuleFiles { | |
if err := m.loadRulesFromFile(file); err != nil { | |
return fmt.Errorf("failed to load rules from %s: %v", file, err) | |
} | |
} | |
return nil | |
} | |
// loadRulesFromFile loads WAF rules from a single file. | |
func (m *Middleware) loadRulesFromFile(path string) error { | |
content, err := os.ReadFile(path) | |
if err != nil { | |
return err | |
} | |
var rules []Rule | |
if err := json.Unmarshal(content, &rules); err != nil { | |
return err | |
} | |
for i, rule := range rules { | |
if rule.Mode == "re" { | |
regex, err := regexp.Compile(rule.Pattern) | |
if err != nil { | |
return fmt.Errorf("invalid pattern in rule %s: %v", rule.ID, err) | |
} | |
rules[i].regex = regex | |
} | |
} | |
m.Rules = append(m.Rules, rules...) | |
return nil | |
} | |
// responseRecorder captures the HTTP response status. | |
type responseRecorder struct { | |
http.ResponseWriter | |
status int | |
} | |
// WriteHeader captures the status code. | |
func (r *responseRecorder) WriteHeader(status int) { | |
r.status = status | |
r.ResponseWriter.WriteHeader(status) | |
} | |
// Provision initializes the middleware. | |
func (m *Middleware) Provision(ctx caddy.Context) error { | |
m.logger = ctx.Logger() | |
m.logChan = make(chan *zapcore.Entry, 1000) | |
go m.logWorker() | |
m.errorTracker = NewErrorTracker() | |
if m.RateLimit.Requests > 0 { | |
m.rateLimiter = &RateLimiter{ | |
config: m.RateLimit, | |
} | |
} | |
if m.CountryBlock.Enabled { | |
geoIP, err := maxminddb.Open(m.CountryBlock.GeoIPDBPath) | |
if err != nil { | |
return fmt.Errorf("failed to load GeoIP database: %v", err) | |
} | |
m.geoIPCache = &GeoIPCache{ | |
geoIP: geoIP, | |
} | |
} | |
if err := m.loadRulesFromFiles(); err != nil { | |
return err | |
} | |
if m.IPBlacklistFile != "" { | |
if err := m.loadIPBlacklistFromFile(m.IPBlacklistFile); err != nil { | |
return fmt.Errorf("failed to load IP blacklist from %s: %v", m.IPBlacklistFile, err) | |
} | |
} | |
if m.DNSBlacklistFile != "" { | |
if err := m.loadDNSBlacklistFromFile(m.DNSBlacklistFile); err != nil { | |
return fmt.Errorf("failed to load DNS blacklist from %s: %v", m.DNSBlacklistFile, err) | |
} | |
} | |
return nil | |
} | |
// loadIPBlacklistFromFile loads IP blacklist from a file. | |
func (m *Middleware) loadIPBlacklistFromFile(path string) error { | |
content, err := os.ReadFile(path) | |
if err != nil { | |
return err | |
} | |
lines := strings.Split(string(content), "\n") | |
ipNets := make([]IPNet, 0, len(lines)) | |
m.ipBlacklist = make(map[string]bool) | |
for _, line := range lines { | |
line = strings.TrimSpace(line) | |
if line == "" { | |
continue | |
} | |
ip, ipNet, err := net.ParseCIDR(line) | |
if err != nil { | |
m.ipBlacklist[ip.String()] = true | |
continue | |
} | |
ipNets = append(ipNets, IPNet{IP: ip, Mask: ipNet.Mask}) | |
} | |
m.ipBlacklistCIDRs = ipNets | |
return nil | |
} | |
// loadDNSBlacklistFromFile loads DNS blacklist from a file. | |
func (m *Middleware) loadDNSBlacklistFromFile(path string) error { | |
content, err := os.ReadFile(path) | |
if err != nil { | |
return err | |
} | |
domains := strings.Split(string(content), "\n") | |
m.dnsBlacklist = make(map[string]bool) | |
for _, domain := range domains { | |
domain = strings.TrimSpace(domain) | |
if domain != "" { | |
m.dnsBlacklist[strings.ToLower(domain)] = true | |
} | |
} | |
return nil | |
} | |
// isIPBlacklisted checks if an IP is blacklisted. | |
func (m *Middleware) isIPBlacklisted(remoteAddr string) bool { | |
ip := net.ParseIP(remoteAddr) | |
if ip == nil { | |
return false | |
} | |
for _, ipNet := range m.ipBlacklistCIDRs { | |
if ipNet.Mask.Contains(ip) { | |
return true | |
} | |
} | |
return m.ipBlacklist[remoteAddr] | |
} | |
// isCountryBlocked checks if a request is from a blocked country. | |
func (m *Middleware) isCountryBlocked(remoteAddr string) bool { | |
if !m.CountryBlock.Enabled || m.geoIPCache == nil { | |
return false | |
} | |
ip, _, err := net.SplitHostPort(remoteAddr) | |
if err != nil { | |
ip = remoteAddr | |
} | |
parsedIP := net.ParseIP(ip) | |
if parsedIP == nil { | |
return false | |
} | |
country, err := m.geoIPCache.getCountry(parsedIP) | |
if err != nil { | |
return false | |
} | |
for _, blockedCountry := range m.CountryBlock.BlockList { | |
if strings.EqualFold(country, blockedCountry) { | |
return true | |
} | |
} | |
return false | |
} | |
// logRequest logs a request with the specified level and message. | |
func (m *Middleware) logRequest(level zapcore.Level, msg string, fields ...zap.Field) { | |
entry := &zapcore.Entry{ | |
Level: level, | |
Message: msg, | |
Time: time.Now(), | |
Fields: fields, | |
} | |
select { | |
case m.logChan <- entry: | |
default: | |
// Optionally handle the case where the channel is full | |
} | |
} | |
// logWorker processes log entries. | |
func (m *Middleware) logWorker() { | |
for entry := range m.logChan { | |
m.logger.Write(entry) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment