|
package main |
|
|
|
import ( |
|
"context" |
|
"errors" |
|
"flag" |
|
"fmt" |
|
"io" |
|
"log" |
|
"net" |
|
"net/url" |
|
"os" |
|
"os/signal" |
|
"runtime" |
|
"strconv" |
|
"strings" |
|
"sync" |
|
"syscall" |
|
"time" |
|
|
|
"github.com/ishidawataru/sctp" |
|
"github.com/vishvananda/netns" |
|
) |
|
|
|
type ForwardRule struct { |
|
Protocol string |
|
Listen string |
|
Target string |
|
NetnsIn string |
|
NetnsOut string |
|
|
|
Raw string |
|
Timeout time.Duration |
|
} |
|
|
|
func parseRule(arg string) (*ForwardRule, error) { |
|
u, err := url.Parse(arg) |
|
if err != nil { |
|
return nil, err |
|
} |
|
target := strings.TrimPrefix(u.Path, "/") |
|
if target == "" { |
|
return nil, errors.New("missing target address in URL path") |
|
} |
|
query := u.Query() |
|
netnsIn := query.Get("netns.in") |
|
netnsOut := query.Get("netns.out") |
|
if netnsIn == "" || netnsOut == "" { |
|
return nil, errors.New("netns.in and netns.out are required") |
|
} |
|
|
|
timeout := 10 * time.Second |
|
if t := query.Get("timeout"); t != "" { |
|
if sec, err := strconv.Atoi(t); err == nil && sec > 0 { |
|
timeout = time.Duration(sec) * time.Second |
|
} else { |
|
return nil, fmt.Errorf("invalid timeout value: %q", t) |
|
} |
|
} |
|
|
|
return &ForwardRule{ |
|
Protocol: u.Scheme, |
|
Listen: u.Host, |
|
Target: target, |
|
NetnsIn: netnsIn, |
|
NetnsOut: netnsOut, |
|
Raw: arg, |
|
Timeout: timeout, |
|
}, nil |
|
} |
|
|
|
var ( |
|
netnsHandlesMu sync.Mutex |
|
netnsHandles = map[string]netns.NsHandle{} |
|
) |
|
|
|
func getNetnsHandle(path string) (netns.NsHandle, error) { |
|
netnsHandlesMu.Lock() |
|
defer netnsHandlesMu.Unlock() |
|
|
|
if ns, ok := netnsHandles[path]; ok { |
|
return ns, nil |
|
} |
|
|
|
var ns netns.NsHandle |
|
var err error |
|
|
|
if path == "/proc/self/ns/net" { |
|
f, err := os.Open(path) |
|
if err != nil { |
|
return netns.NsHandle(-1), err |
|
} |
|
fd := int(f.Fd()) |
|
ns = netns.NsHandle(fd) |
|
netnsHandles[path] = ns |
|
return ns, nil |
|
} |
|
|
|
if strings.HasPrefix(path, "/") { |
|
ns, err = netns.GetFromPath(path) |
|
} else { |
|
ns, err = netns.GetFromName(path) |
|
} |
|
if err != nil { |
|
return netns.NsHandle(-1), err |
|
} |
|
netnsHandles[path] = ns |
|
return ns, nil |
|
} |
|
|
|
func enterNetns(path string) (func(), error) { |
|
ns, err := getNetnsHandle(path) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
runtime.LockOSThread() |
|
origNS, err := netns.Get() |
|
if err != nil { |
|
runtime.UnlockOSThread() |
|
return nil, err |
|
} |
|
|
|
if err := netns.Set(ns); err != nil { |
|
origNS.Close() |
|
runtime.UnlockOSThread() |
|
return nil, err |
|
} |
|
|
|
return func() { |
|
defer func() { |
|
if r := recover(); r != nil { |
|
log.Printf("Recovered from panic in netns context: %v", r) |
|
} |
|
origNS.Close() |
|
runtime.UnlockOSThread() |
|
}() |
|
netns.Set(origNS) |
|
}, nil |
|
} |
|
|
|
func waitForTarget(proto, target string, timeout time.Duration) error { |
|
ctx, cancel := context.WithTimeout(context.Background(), timeout) |
|
defer cancel() |
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return fmt.Errorf("target %s not available after %v", target, timeout) |
|
default: |
|
conn, err := net.DialTimeout(proto, target, 1*time.Second) |
|
if err == nil { |
|
conn.Close() |
|
return nil |
|
} |
|
time.Sleep(1000 * time.Millisecond) |
|
} |
|
} |
|
} |
|
|
|
func handleTCP(rule *ForwardRule) error { |
|
exitIn, err := enterNetns(rule.NetnsIn) |
|
if err != nil { |
|
return fmt.Errorf("enter netns.in: %w", err) |
|
} |
|
defer exitIn() |
|
|
|
ln, err := net.Listen("tcp", rule.Listen) |
|
if err != nil { |
|
return fmt.Errorf("listen tcp: %w", err) |
|
} |
|
defer ln.Close() |
|
|
|
for { |
|
connIn, err := ln.Accept() |
|
if err != nil { |
|
log.Printf("Accept error: %v", err) |
|
continue |
|
} |
|
|
|
go func(connIn net.Conn) { |
|
defer connIn.Close() |
|
|
|
exitOut, err := enterNetns(rule.NetnsOut) |
|
if err != nil { |
|
log.Printf("enterNetns out error: %v", err) |
|
return |
|
} |
|
defer exitOut() |
|
|
|
connOut, err := net.DialTimeout("tcp", rule.Target, rule.Timeout) |
|
if err != nil { |
|
log.Printf("Dial target error: %v", err) |
|
return |
|
} |
|
defer connOut.Close() |
|
|
|
var wg sync.WaitGroup |
|
copyBufSize := 64 * 1024 |
|
wg.Add(2) |
|
|
|
go func() { |
|
defer wg.Done() |
|
_, err := io.CopyBuffer(connOut, connIn, make([]byte, copyBufSize)) |
|
if err != nil && !errors.Is(err, io.EOF) { |
|
log.Printf("Copy client→target error: %v", err) |
|
} |
|
if tcpConn, ok := connOut.(*net.TCPConn); ok { |
|
tcpConn.CloseWrite() |
|
} |
|
}() |
|
|
|
go func() { |
|
defer wg.Done() |
|
_, err := io.CopyBuffer(connIn, connOut, make([]byte, copyBufSize)) |
|
if err != nil && !errors.Is(err, io.EOF) { |
|
log.Printf("Copy target→client error: %v", err) |
|
} |
|
if tcpConn, ok := connIn.(*net.TCPConn); ok { |
|
tcpConn.CloseWrite() |
|
} |
|
}() |
|
|
|
wg.Wait() |
|
}(connIn) |
|
} |
|
} |
|
|
|
type udpClient struct { |
|
addr *net.UDPAddr |
|
lastSeen time.Time |
|
} |
|
|
|
func handleUDP(rule *ForwardRule) error { |
|
exitIn, err := enterNetns(rule.NetnsIn) |
|
if err != nil { |
|
return fmt.Errorf("enter netns in failed: %w", err) |
|
} |
|
defer exitIn() |
|
|
|
laddr, err := net.ResolveUDPAddr("udp", rule.Listen) |
|
if err != nil { |
|
return fmt.Errorf("resolve listen addr failed: %w", err) |
|
} |
|
inConn, err := net.ListenUDP("udp", laddr) |
|
if err != nil { |
|
return fmt.Errorf("listenUDP failed: %w", err) |
|
} |
|
|
|
type udpClient struct { |
|
srcAddr *net.UDPAddr |
|
outConn *net.UDPConn |
|
lastSeen time.Time |
|
cancel context.CancelFunc |
|
} |
|
|
|
clients := make(map[string]*udpClient) |
|
mu := sync.Mutex{} |
|
|
|
go func() { |
|
buf := make([]byte, 2048) |
|
for { |
|
n, srcAddr, err := inConn.ReadFromUDP(buf) |
|
if err != nil { |
|
log.Println("UDP read error:", err) |
|
continue |
|
} |
|
data := make([]byte, n) |
|
copy(data, buf[:n]) |
|
|
|
clientKey := srcAddr.String() |
|
|
|
mu.Lock() |
|
client, exists := clients[clientKey] |
|
if !exists { |
|
exitOut, err := enterNetns(rule.NetnsOut) |
|
if err != nil { |
|
log.Printf("netns out error: %v", err) |
|
mu.Unlock() |
|
continue |
|
} |
|
raddr, err := net.ResolveUDPAddr("udp", rule.Target) |
|
if err != nil { |
|
log.Printf("resolve target: %v", err) |
|
exitOut() |
|
mu.Unlock() |
|
continue |
|
} |
|
outConn, err := net.DialUDP("udp", nil, raddr) |
|
if err != nil { |
|
log.Printf("dial UDP: %v", err) |
|
exitOut() |
|
mu.Unlock() |
|
continue |
|
} |
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
client = &udpClient{ |
|
srcAddr: srcAddr, |
|
outConn: outConn, |
|
lastSeen: time.Now(), |
|
cancel: cancel, |
|
} |
|
clients[clientKey] = client |
|
|
|
go func(ctx context.Context, conn *net.UDPConn, src *net.UDPAddr, key string) { |
|
buf := make([]byte, 2048) |
|
defer conn.Close() |
|
for { |
|
select { |
|
case <-ctx.Done(): |
|
return |
|
default: |
|
n, err := conn.Read(buf) |
|
if err != nil { |
|
log.Printf("read from target failed: %v", err) |
|
return |
|
} |
|
mu.Lock() |
|
client := clients[key] |
|
mu.Unlock() |
|
if client != nil { |
|
_, err := inConn.WriteToUDP(buf[:n], client.srcAddr) |
|
if err != nil { |
|
log.Printf("write back to client failed: %v", err) |
|
} |
|
} |
|
} |
|
} |
|
}(ctx, outConn, srcAddr, clientKey) |
|
} |
|
client.lastSeen = time.Now() |
|
mu.Unlock() |
|
|
|
_, err = client.outConn.Write(data) |
|
if err != nil { |
|
log.Printf("write to target failed: %v", err) |
|
} |
|
} |
|
}() |
|
|
|
return nil |
|
} |
|
|
|
|
|
func parseSCTPAddr(addr string) (*sctp.SCTPAddr, error) { |
|
host, portStr, err := net.SplitHostPort(addr) |
|
if err != nil { |
|
return nil, err |
|
} |
|
port, err := strconv.Atoi(portStr) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
ips, err := net.LookupIP(host) |
|
if err != nil || len(ips) == 0 { |
|
return nil, fmt.Errorf("cannot resolve host %s: %v", host, err) |
|
} |
|
var ipAddrs []net.IPAddr |
|
for _, ip := range ips { |
|
ipAddrs = append(ipAddrs, net.IPAddr{IP: ip}) |
|
} |
|
return &sctp.SCTPAddr{ |
|
IPAddrs: ipAddrs, |
|
Port: port, |
|
}, nil |
|
} |
|
|
|
func resolveSCTPAddrMulti(addr string) (*sctp.SCTPAddr, error) { |
|
hostPort := addr |
|
var portStr string |
|
lastColon := strings.LastIndex(addr, ":") |
|
if lastColon == -1 { |
|
return nil, fmt.Errorf("missing port in SCTP addr %q", addr) |
|
} |
|
hostPort = addr[:lastColon] |
|
portStr = addr[lastColon+1:] |
|
port, err := strconv.Atoi(portStr) |
|
if err != nil { |
|
return nil, fmt.Errorf("invalid port %q: %w", portStr, err) |
|
} |
|
|
|
ipsRaw := strings.Split(hostPort, ",") |
|
ipAddrs := make([]net.IPAddr, 0, len(ipsRaw)) |
|
|
|
for _, ipRaw := range ipsRaw { |
|
ipRaw = strings.Trim(ipRaw, "[]") |
|
|
|
ip, zone := parseIPZone(ipRaw) |
|
if ip == nil { |
|
return nil, fmt.Errorf("invalid IP %q", ipRaw) |
|
} |
|
ipAddrs = append(ipAddrs, net.IPAddr{IP: ip, Zone: zone}) |
|
} |
|
|
|
return &sctp.SCTPAddr{ |
|
IPAddrs: ipAddrs, |
|
Port: port, |
|
}, nil |
|
} |
|
|
|
func parseIPZone(ip string) (net.IP, string) { |
|
if i := strings.Index(ip, "%"); i != -1 { |
|
return net.ParseIP(ip[:i]), ip[i+1:] |
|
} |
|
return net.ParseIP(ip), "" |
|
} |
|
|
|
func dialSCTPWithRetry(ctx context.Context, raddr *sctp.SCTPAddr, timeout time.Duration, maxRetries int) (*sctp.SCTPConn, error) { |
|
var lastErr error |
|
for i := 0; i < maxRetries; i++ { |
|
dialCtx, cancel := context.WithTimeout(ctx, timeout) |
|
done := make(chan struct{}) |
|
var conn *sctp.SCTPConn |
|
go func() { |
|
var err error |
|
conn, err = sctp.DialSCTP("sctp", nil, raddr) |
|
if err != nil { |
|
lastErr = err |
|
} |
|
close(done) |
|
}() |
|
select { |
|
case <-dialCtx.Done(): |
|
lastErr = dialCtx.Err() |
|
case <-done: |
|
cancel() |
|
if lastErr == nil { |
|
return conn, nil |
|
} |
|
} |
|
cancel() |
|
backoff := time.Duration((i+1)*(i+1)) * 500 * time.Millisecond |
|
time.Sleep(backoff) |
|
} |
|
return nil, fmt.Errorf("dial SCTP failed after %d retries: %w", maxRetries, lastErr) |
|
} |
|
|
|
func tuneSCTPSocket(ln *sctp.SCTPListener) error { |
|
return nil |
|
} |
|
|
|
func handleSCTP(rule *ForwardRule) error { |
|
timeout := 10 * time.Second |
|
u, err := url.Parse(rule.Raw) |
|
if err == nil { |
|
if tStr := u.Query().Get("timeout"); tStr != "" { |
|
if t, err := strconv.Atoi(tStr); err == nil && t > 0 { |
|
timeout = time.Duration(t) * time.Second |
|
} |
|
} |
|
} else { |
|
log.Printf("failed to parse rule URL for timeout, using default 10s: %v", err) |
|
} |
|
|
|
exitIn, err := enterNetns(rule.NetnsIn) |
|
if err != nil { |
|
return err |
|
} |
|
defer exitIn() |
|
|
|
laddr, err := resolveSCTPAddrMulti(rule.Listen) |
|
if err != nil { |
|
return fmt.Errorf("resolve SCTP listen addr: %w", err) |
|
} |
|
|
|
inLn, err := sctp.ListenSCTP("sctp", laddr) |
|
if err != nil { |
|
return fmt.Errorf("listen SCTP: %w", err) |
|
} |
|
defer inLn.Close() |
|
|
|
if err := tuneSCTPSocket(inLn); err != nil { |
|
log.Printf("warning: SCTP socket tuning failed: %v", err) |
|
} |
|
|
|
for { |
|
inConn, err := inLn.AcceptSCTP() |
|
if err != nil { |
|
log.Println("SCTP accept error:", err) |
|
continue |
|
} |
|
|
|
go func(inConn *sctp.SCTPConn) { |
|
defer inConn.Close() |
|
|
|
exitOut, err := enterNetns(rule.NetnsOut) |
|
if err != nil { |
|
log.Println("enter netns.out failed:", err) |
|
return |
|
} |
|
defer exitOut() |
|
|
|
raddr, err := resolveSCTPAddrMulti(rule.Target) |
|
if err != nil { |
|
log.Println("resolve SCTP target failed:", err) |
|
return |
|
} |
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout) |
|
defer cancel() |
|
|
|
outConn, err := dialSCTPWithRetry(ctx, raddr, timeout, 3) |
|
if err != nil { |
|
log.Println("dial SCTP target failed:", err) |
|
return |
|
} |
|
defer outConn.Close() |
|
|
|
buf1 := make([]byte, 32*1024) |
|
buf2 := make([]byte, 32*1024) |
|
|
|
var wg sync.WaitGroup |
|
wg.Add(2) |
|
|
|
go func() { |
|
defer wg.Done() |
|
if _, err := io.CopyBuffer(outConn, inConn, buf1); err != nil && !errors.Is(err, io.EOF) { |
|
log.Println("SCTP copy client→target error:", err) |
|
} |
|
}() |
|
|
|
go func() { |
|
defer wg.Done() |
|
if _, err := io.CopyBuffer(inConn, outConn, buf2); err != nil && !errors.Is(err, io.EOF) { |
|
log.Println("SCTP copy target→client error:", err) |
|
} |
|
}() |
|
|
|
wg.Wait() |
|
}(inConn) |
|
} |
|
} |
|
|
|
type multiFlag []string |
|
|
|
func (m *multiFlag) String() string { |
|
return strings.Join(*m, ",") |
|
} |
|
|
|
func (m *multiFlag) Set(value string) error { |
|
*m = append(*m, value) |
|
return nil |
|
} |
|
|
|
func main() { |
|
var listenRules multiFlag |
|
flag.Var(&listenRules, "L", "forward rule: protocol://listen/target?netns.in=NAME|PATH&netns.out=NAME|PATH (can be repeated)") |
|
flag.Parse() |
|
|
|
if len(listenRules) == 0 { |
|
log.Fatalf("no forwarding rules specified") |
|
} |
|
|
|
for _, r := range listenRules { |
|
if strings.Contains(r, "/proc/self/ns/net") { |
|
_, err := getNetnsHandle("/proc/self/ns/net") |
|
if err != nil { |
|
log.Fatalf("failed to open /proc/self/ns/net: %v", err) |
|
} |
|
break |
|
} |
|
} |
|
|
|
var rules []*ForwardRule |
|
for _, r := range listenRules { |
|
rule, err := parseRule(r) |
|
if err != nil { |
|
log.Fatalf("invalid rule %q: %v", r, err) |
|
} |
|
rules = append(rules, rule) |
|
} |
|
|
|
for _, rule := range rules { |
|
switch rule.Protocol { |
|
case "tcp": |
|
go func(r *ForwardRule) { |
|
if err := handleTCP(r); err != nil { |
|
log.Fatalf("tcp forward failed: %v", err) |
|
} |
|
}(rule) |
|
case "udp": |
|
go func(r *ForwardRule) { |
|
if err := handleUDP(r); err != nil { |
|
log.Fatalf("udp forward failed: %v", err) |
|
} |
|
}(rule) |
|
case "sctp": |
|
go func(r *ForwardRule) { |
|
if err := handleSCTP(r); err != nil { |
|
log.Fatalf("sctp forward failed: %v", err) |
|
} |
|
}(rule) |
|
default: |
|
log.Fatalf("unsupported protocol %q", rule.Protocol) |
|
} |
|
} |
|
|
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) |
|
defer stop() |
|
<-ctx.Done() |
|
log.Println("Shutting down...") |
|
time.Sleep(1 * time.Second) |
|
} |