Skip to content

Instantly share code, notes, and snippets.

@alpominth
Last active June 2, 2025 18:26
Show Gist options
  • Save alpominth/e33c8ff8e63acf66ec5d2e7da4decf6a to your computer and use it in GitHub Desktop.
Save alpominth/e33c8ff8e63acf66ec5d2e7da4decf6a to your computer and use it in GitHub Desktop.
Port forwarder for network namespaces: forward TCP/UDP/SCTP ports from one netns to another easily.

netns-pf

This is a simple TCP/UDP/SCTP port forwarder for network namespaces, so it will require root privileges.

Compiling:

Download all files from this gist in a folder and run:

$ go build ./netns-pf.go

Using:

sudo ./netns-pf -L "<protocol>://<listen_address>:<listen_port>/<outgoing_address>:<outgoing_port>?netns.in=<listen_netns>&netns.out=<outgoing_netns>&timeout=<seconds>"

Like this:

$ sudo ./netns-pf -L "tcp://127.0.0.1:1234/127.0.0.1:4321?netns.in=somens&netns.out=/proc/self/ns/net"

$ sudo ./netns-pf -L "udp://[::1]:53/[::1]:5353?netns.in=somens&netns.out=/proc/self/ns/net&timeout=5"

$ sudo ./netns-pf -L "sctp://127.0.0.1:1111,[::1]:1111/127.0.0.1,[::1]:2222?netns.in=somens&netns.out=/proc/1/ns/net"

You can use as many "-L" flags as you want. Default timeout is 10 seconds but you can specify the timeout with "timeout" command.

SCTP mode in this port forwarder has multihoming support, you can specify multiple IPs in the command line.

module netns-pf
go 1.24.2
require (
github.com/ishidawataru/sctp v0.0.0-20250530054746-c4c76e25d7e3
github.com/vishvananda/netns v0.0.5
golang.org/x/sys v0.33.0
)
github.com/ishidawataru/sctp v0.0.0-20250530054746-c4c76e25d7e3 h1:GeDZ2OFRJ50qCgx9yKGoIlS4I7UzekzINk372zC287w=
github.com/ishidawataru/sctp v0.0.0-20250530054746-c4c76e25d7e3/go.mod h1:co9pwDoBCm1kGxawmb4sPq0cSIOOWNPT4KnHotMP1Zg=
github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY=
github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
package main
import (
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/url"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/ishidawataru/sctp"
"github.com/vishvananda/netns"
)
type ForwardRule struct {
Protocol string
Listen string
Target string
NetnsIn string
NetnsOut string
Raw string
Timeout time.Duration
}
func parseRule(arg string) (*ForwardRule, error) {
u, err := url.Parse(arg)
if err != nil {
return nil, err
}
target := strings.TrimPrefix(u.Path, "/")
if target == "" {
return nil, errors.New("missing target address in URL path")
}
query := u.Query()
netnsIn := query.Get("netns.in")
netnsOut := query.Get("netns.out")
if netnsIn == "" || netnsOut == "" {
return nil, errors.New("netns.in and netns.out are required")
}
timeout := 10 * time.Second
if t := query.Get("timeout"); t != "" {
if sec, err := strconv.Atoi(t); err == nil && sec > 0 {
timeout = time.Duration(sec) * time.Second
} else {
return nil, fmt.Errorf("invalid timeout value: %q", t)
}
}
return &ForwardRule{
Protocol: u.Scheme,
Listen: u.Host,
Target: target,
NetnsIn: netnsIn,
NetnsOut: netnsOut,
Raw: arg,
Timeout: timeout,
}, nil
}
var (
netnsHandlesMu sync.Mutex
netnsHandles = map[string]netns.NsHandle{}
)
func getNetnsHandle(path string) (netns.NsHandle, error) {
netnsHandlesMu.Lock()
defer netnsHandlesMu.Unlock()
if ns, ok := netnsHandles[path]; ok {
return ns, nil
}
var ns netns.NsHandle
var err error
if path == "/proc/self/ns/net" {
f, err := os.Open(path)
if err != nil {
return netns.NsHandle(-1), err
}
fd := int(f.Fd())
ns = netns.NsHandle(fd)
netnsHandles[path] = ns
return ns, nil
}
if strings.HasPrefix(path, "/") {
ns, err = netns.GetFromPath(path)
} else {
ns, err = netns.GetFromName(path)
}
if err != nil {
return netns.NsHandle(-1), err
}
netnsHandles[path] = ns
return ns, nil
}
func enterNetns(path string) (func(), error) {
ns, err := getNetnsHandle(path)
if err != nil {
return nil, err
}
runtime.LockOSThread()
origNS, err := netns.Get()
if err != nil {
runtime.UnlockOSThread()
return nil, err
}
if err := netns.Set(ns); err != nil {
origNS.Close()
runtime.UnlockOSThread()
return nil, err
}
return func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered from panic in netns context: %v", r)
}
origNS.Close()
runtime.UnlockOSThread()
}()
netns.Set(origNS)
}, nil
}
func waitForTarget(proto, target string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
for {
select {
case <-ctx.Done():
return fmt.Errorf("target %s not available after %v", target, timeout)
default:
conn, err := net.DialTimeout(proto, target, 1*time.Second)
if err == nil {
conn.Close()
return nil
}
time.Sleep(1000 * time.Millisecond)
}
}
}
func handleTCP(rule *ForwardRule) error {
exitIn, err := enterNetns(rule.NetnsIn)
if err != nil {
return fmt.Errorf("enter netns.in: %w", err)
}
defer exitIn()
ln, err := net.Listen("tcp", rule.Listen)
if err != nil {
return fmt.Errorf("listen tcp: %w", err)
}
defer ln.Close()
for {
connIn, err := ln.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go func(connIn net.Conn) {
defer connIn.Close()
exitOut, err := enterNetns(rule.NetnsOut)
if err != nil {
log.Printf("enterNetns out error: %v", err)
return
}
defer exitOut()
connOut, err := net.DialTimeout("tcp", rule.Target, rule.Timeout)
if err != nil {
log.Printf("Dial target error: %v", err)
return
}
defer connOut.Close()
var wg sync.WaitGroup
copyBufSize := 64 * 1024
wg.Add(2)
go func() {
defer wg.Done()
_, err := io.CopyBuffer(connOut, connIn, make([]byte, copyBufSize))
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Copy client→target error: %v", err)
}
if tcpConn, ok := connOut.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
go func() {
defer wg.Done()
_, err := io.CopyBuffer(connIn, connOut, make([]byte, copyBufSize))
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Copy target→client error: %v", err)
}
if tcpConn, ok := connIn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
wg.Wait()
}(connIn)
}
}
type udpClient struct {
addr *net.UDPAddr
lastSeen time.Time
}
func handleUDP(rule *ForwardRule) error {
exitIn, err := enterNetns(rule.NetnsIn)
if err != nil {
return fmt.Errorf("enter netns in failed: %w", err)
}
defer exitIn()
laddr, err := net.ResolveUDPAddr("udp", rule.Listen)
if err != nil {
return fmt.Errorf("resolve listen addr failed: %w", err)
}
inConn, err := net.ListenUDP("udp", laddr)
if err != nil {
return fmt.Errorf("listenUDP failed: %w", err)
}
type udpClient struct {
srcAddr *net.UDPAddr
outConn *net.UDPConn
lastSeen time.Time
cancel context.CancelFunc
}
clients := make(map[string]*udpClient)
mu := sync.Mutex{}
go func() {
buf := make([]byte, 2048)
for {
n, srcAddr, err := inConn.ReadFromUDP(buf)
if err != nil {
log.Println("UDP read error:", err)
continue
}
data := make([]byte, n)
copy(data, buf[:n])
clientKey := srcAddr.String()
mu.Lock()
client, exists := clients[clientKey]
if !exists {
exitOut, err := enterNetns(rule.NetnsOut)
if err != nil {
log.Printf("netns out error: %v", err)
mu.Unlock()
continue
}
raddr, err := net.ResolveUDPAddr("udp", rule.Target)
if err != nil {
log.Printf("resolve target: %v", err)
exitOut()
mu.Unlock()
continue
}
outConn, err := net.DialUDP("udp", nil, raddr)
if err != nil {
log.Printf("dial UDP: %v", err)
exitOut()
mu.Unlock()
continue
}
ctx, cancel := context.WithCancel(context.Background())
client = &udpClient{
srcAddr: srcAddr,
outConn: outConn,
lastSeen: time.Now(),
cancel: cancel,
}
clients[clientKey] = client
go func(ctx context.Context, conn *net.UDPConn, src *net.UDPAddr, key string) {
buf := make([]byte, 2048)
defer conn.Close()
for {
select {
case <-ctx.Done():
return
default:
n, err := conn.Read(buf)
if err != nil {
log.Printf("read from target failed: %v", err)
return
}
mu.Lock()
client := clients[key]
mu.Unlock()
if client != nil {
_, err := inConn.WriteToUDP(buf[:n], client.srcAddr)
if err != nil {
log.Printf("write back to client failed: %v", err)
}
}
}
}
}(ctx, outConn, srcAddr, clientKey)
}
client.lastSeen = time.Now()
mu.Unlock()
_, err = client.outConn.Write(data)
if err != nil {
log.Printf("write to target failed: %v", err)
}
}
}()
return nil
}
func parseSCTPAddr(addr string) (*sctp.SCTPAddr, error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, err
}
ips, err := net.LookupIP(host)
if err != nil || len(ips) == 0 {
return nil, fmt.Errorf("cannot resolve host %s: %v", host, err)
}
var ipAddrs []net.IPAddr
for _, ip := range ips {
ipAddrs = append(ipAddrs, net.IPAddr{IP: ip})
}
return &sctp.SCTPAddr{
IPAddrs: ipAddrs,
Port: port,
}, nil
}
func resolveSCTPAddrMulti(addr string) (*sctp.SCTPAddr, error) {
hostPort := addr
var portStr string
lastColon := strings.LastIndex(addr, ":")
if lastColon == -1 {
return nil, fmt.Errorf("missing port in SCTP addr %q", addr)
}
hostPort = addr[:lastColon]
portStr = addr[lastColon+1:]
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("invalid port %q: %w", portStr, err)
}
ipsRaw := strings.Split(hostPort, ",")
ipAddrs := make([]net.IPAddr, 0, len(ipsRaw))
for _, ipRaw := range ipsRaw {
ipRaw = strings.Trim(ipRaw, "[]")
ip, zone := parseIPZone(ipRaw)
if ip == nil {
return nil, fmt.Errorf("invalid IP %q", ipRaw)
}
ipAddrs = append(ipAddrs, net.IPAddr{IP: ip, Zone: zone})
}
return &sctp.SCTPAddr{
IPAddrs: ipAddrs,
Port: port,
}, nil
}
func parseIPZone(ip string) (net.IP, string) {
if i := strings.Index(ip, "%"); i != -1 {
return net.ParseIP(ip[:i]), ip[i+1:]
}
return net.ParseIP(ip), ""
}
func dialSCTPWithRetry(ctx context.Context, raddr *sctp.SCTPAddr, timeout time.Duration, maxRetries int) (*sctp.SCTPConn, error) {
var lastErr error
for i := 0; i < maxRetries; i++ {
dialCtx, cancel := context.WithTimeout(ctx, timeout)
done := make(chan struct{})
var conn *sctp.SCTPConn
go func() {
var err error
conn, err = sctp.DialSCTP("sctp", nil, raddr)
if err != nil {
lastErr = err
}
close(done)
}()
select {
case <-dialCtx.Done():
lastErr = dialCtx.Err()
case <-done:
cancel()
if lastErr == nil {
return conn, nil
}
}
cancel()
backoff := time.Duration((i+1)*(i+1)) * 500 * time.Millisecond
time.Sleep(backoff)
}
return nil, fmt.Errorf("dial SCTP failed after %d retries: %w", maxRetries, lastErr)
}
func tuneSCTPSocket(ln *sctp.SCTPListener) error {
return nil
}
func handleSCTP(rule *ForwardRule) error {
timeout := 10 * time.Second
u, err := url.Parse(rule.Raw)
if err == nil {
if tStr := u.Query().Get("timeout"); tStr != "" {
if t, err := strconv.Atoi(tStr); err == nil && t > 0 {
timeout = time.Duration(t) * time.Second
}
}
} else {
log.Printf("failed to parse rule URL for timeout, using default 10s: %v", err)
}
exitIn, err := enterNetns(rule.NetnsIn)
if err != nil {
return err
}
defer exitIn()
laddr, err := resolveSCTPAddrMulti(rule.Listen)
if err != nil {
return fmt.Errorf("resolve SCTP listen addr: %w", err)
}
inLn, err := sctp.ListenSCTP("sctp", laddr)
if err != nil {
return fmt.Errorf("listen SCTP: %w", err)
}
defer inLn.Close()
if err := tuneSCTPSocket(inLn); err != nil {
log.Printf("warning: SCTP socket tuning failed: %v", err)
}
for {
inConn, err := inLn.AcceptSCTP()
if err != nil {
log.Println("SCTP accept error:", err)
continue
}
go func(inConn *sctp.SCTPConn) {
defer inConn.Close()
exitOut, err := enterNetns(rule.NetnsOut)
if err != nil {
log.Println("enter netns.out failed:", err)
return
}
defer exitOut()
raddr, err := resolveSCTPAddrMulti(rule.Target)
if err != nil {
log.Println("resolve SCTP target failed:", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
outConn, err := dialSCTPWithRetry(ctx, raddr, timeout, 3)
if err != nil {
log.Println("dial SCTP target failed:", err)
return
}
defer outConn.Close()
buf1 := make([]byte, 32*1024)
buf2 := make([]byte, 32*1024)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
if _, err := io.CopyBuffer(outConn, inConn, buf1); err != nil && !errors.Is(err, io.EOF) {
log.Println("SCTP copy client→target error:", err)
}
}()
go func() {
defer wg.Done()
if _, err := io.CopyBuffer(inConn, outConn, buf2); err != nil && !errors.Is(err, io.EOF) {
log.Println("SCTP copy target→client error:", err)
}
}()
wg.Wait()
}(inConn)
}
}
type multiFlag []string
func (m *multiFlag) String() string {
return strings.Join(*m, ",")
}
func (m *multiFlag) Set(value string) error {
*m = append(*m, value)
return nil
}
func main() {
var listenRules multiFlag
flag.Var(&listenRules, "L", "forward rule: protocol://listen/target?netns.in=NAME|PATH&netns.out=NAME|PATH (can be repeated)")
flag.Parse()
if len(listenRules) == 0 {
log.Fatalf("no forwarding rules specified")
}
for _, r := range listenRules {
if strings.Contains(r, "/proc/self/ns/net") {
_, err := getNetnsHandle("/proc/self/ns/net")
if err != nil {
log.Fatalf("failed to open /proc/self/ns/net: %v", err)
}
break
}
}
var rules []*ForwardRule
for _, r := range listenRules {
rule, err := parseRule(r)
if err != nil {
log.Fatalf("invalid rule %q: %v", r, err)
}
rules = append(rules, rule)
}
for _, rule := range rules {
switch rule.Protocol {
case "tcp":
go func(r *ForwardRule) {
if err := handleTCP(r); err != nil {
log.Fatalf("tcp forward failed: %v", err)
}
}(rule)
case "udp":
go func(r *ForwardRule) {
if err := handleUDP(r); err != nil {
log.Fatalf("udp forward failed: %v", err)
}
}(rule)
case "sctp":
go func(r *ForwardRule) {
if err := handleSCTP(r); err != nil {
log.Fatalf("sctp forward failed: %v", err)
}
}(rule)
default:
log.Fatalf("unsupported protocol %q", rule.Protocol)
}
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
<-ctx.Done()
log.Println("Shutting down...")
time.Sleep(1 * time.Second)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment