Last active
December 20, 2015 01:09
-
-
Save grahamking/6047449 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"encoding/gob" | |
"fmt" | |
"net" | |
"os" | |
"os/exec" | |
"os/signal" | |
"strings" | |
"syscall" | |
"time" | |
) | |
const ( | |
VERSION = 1 | |
) | |
var ( | |
pos = 0 | |
isRunning = false | |
) | |
func main() { | |
log("Start") | |
var err error | |
parentPID := os.Getppid() | |
log(fmt.Sprintf("Parent: %d", parentPID)) | |
var c1, c2 net.Conn | |
var l net.Listener | |
var f *os.File | |
ppid := os.Getenv("PARENT_PID") | |
if ppid != "" { | |
// We're the child | |
// Connect to the domain socket pipe | |
f = os.NewFile(3, "domain socket") | |
_, err = f.Stat() | |
if err != nil { | |
log("Domain pipe err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
defer f.Close() | |
var ok bool | |
var uc *net.UnixConn | |
netc, _ := net.FileConn(f) | |
uc, ok = netc.(*net.UnixConn) | |
if !ok { | |
log("Domain pipe is not UnixConn") | |
os.Exit(1) | |
} | |
defer uc.Close() | |
// Signal the parent to send us the fds | |
var parent *os.Process | |
parent, err = os.FindProcess(parentPID) | |
if err != nil { | |
log("FindProcess Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
parent.Signal(syscall.SIGTERM) | |
// Receive the fd | |
buf := make([]byte, 1024) // connections metadata | |
oob := make([]byte, 32) // expect 24 bytes | |
_, oobn, _, _, err := uc.ReadMsgUnix(buf, oob) | |
scms, err := syscall.ParseSocketControlMessage(oob[:oobn]) | |
if err != nil { | |
fmt.Printf("ParseSocketControlMessage: %v\n", err) | |
os.Exit(1) | |
} | |
if len(scms) != 1 { | |
fmt.Printf("expected 1 SocketControlMessage; got scms = %#v\n", scms) | |
} | |
scm := scms[0] | |
gotFds, err := syscall.ParseUnixRights(&scm) | |
if err != nil { | |
fmt.Printf("syscall.ParseUnixRights: %v\n", err) | |
} | |
if len(gotFds) != 2 { | |
fmt.Printf("wanted 2 fd; got %#v\n", gotFds) | |
} | |
contentMsg := decodeContentMessage(buf) | |
fmt.Printf("Message received: %s (%d)\n", contentMsg, len(contentMsg)) | |
// Rebuild the net.Conn(s) | |
fConn1 := os.NewFile(uintptr(gotFds[0]), "fd-from-parent-1") | |
fConn2 := os.NewFile(uintptr(gotFds[1]), "fd-from-parent-2") | |
defer fConn1.Close() | |
defer fConn2.Close() | |
c1, err = net.FileConn(fConn1) | |
c2, err = net.FileConn(fConn2) | |
if err != nil { | |
log("FileConn Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
fmt.Println("c1.RemoteAddr():", c1.RemoteAddr()) | |
defer c1.Close() | |
defer c2.Close() | |
} else { | |
// We're the parent, open the connection | |
l, err = net.Listen("tcp", ":1122") | |
if err != nil { | |
log("Listen Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
defer l.Close() | |
c1, err = l.Accept() | |
if err != nil { | |
log("Accept Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
defer c1.Close() | |
fmt.Println("Got first connection") | |
c2, err = l.Accept() | |
if err != nil { | |
log("Accept Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
defer c2.Close() | |
fmt.Println("Got second connection") | |
} | |
go writeLoop(c1) | |
go writeLoop(c2) | |
// Build the domain socket pair to communicate with child | |
var fds [2]int | |
fds, err = syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0) | |
if err != nil { | |
fmt.Printf("Socketpair: %v\n", err) | |
} | |
defer syscall.Close(fds[0]) | |
defer syscall.Close(fds[1]) | |
writeFile := os.NewFile(uintptr(fds[0]), "write-end") | |
readFile := os.NewFile(uintptr(fds[1]), "read-end") | |
defer writeFile.Close() | |
defer readFile.Close() | |
// Turn writeFile into a UnixConn. We use it on SIGTERM receipt | |
var writeConnI net.Conn | |
writeConnI, err = net.FileConn(writeFile) | |
if err != nil { | |
fmt.Printf("FileConn: %v\n", err) | |
os.Exit(1) | |
} | |
defer writeConnI.Close() | |
writeConn, ok := writeConnI.(*net.UnixConn) | |
if !ok { | |
fmt.Printf("unexpected FileConn type; expected UnixConn, got %T\n", writeConnI) | |
os.Exit(1) | |
} | |
defer writeConn.Close() | |
incoming := make(chan os.Signal) | |
signal.Notify(incoming, | |
syscall.SIGINT, | |
syscall.SIGTERM, | |
syscall.SIGUSR2, | |
os.Interrupt) | |
isRunning = true | |
for isRunning { | |
sig := <-incoming | |
fmt.Println(sig) | |
switch sig { | |
case syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM: | |
isRunning = false | |
var connFile1, connFile2 *os.File | |
connFile1, err = c1.(*net.TCPConn).File() | |
if err != nil { | |
log("TCPConn File Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
defer connFile1.Close() | |
connFile2, err = c2.(*net.TCPConn).File() | |
if err != nil { | |
log("TCPConn File Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
defer connFile2.Close() | |
// Send socket fd(s) down the pipe | |
rights := syscall.UnixRights( | |
int(connFile1.Fd()), int(connFile2.Fd())) | |
contentMessage := encodeContentMessage() | |
n, oobn, err := writeConn.WriteMsgUnix(contentMessage, rights, nil) | |
if err != nil { | |
fmt.Printf("WriteMsgUnix: %v\n", err) | |
return | |
} | |
if oobn != len(rights) { | |
fmt.Printf("WriteMsgUnix = %d, %d\n", n, oobn) | |
return | |
} | |
// Parent done - quit | |
log("Bye") | |
case syscall.SIGUSR2: | |
startChild(readFile) | |
} | |
} | |
log("End") | |
} | |
func encodeContentMessage() []byte { | |
var buf bytes.Buffer | |
encoder := gob.NewEncoder(&buf) | |
encoder.Encode("LISTEN") | |
encoder.Encode(pos) | |
return buf.Bytes() | |
} | |
func decodeContentMessage(msg []byte) string { | |
decoder := gob.NewDecoder(bytes.NewBuffer(msg)) | |
var contentType string | |
decoder.Decode(&contentType) | |
decoder.Decode(&pos) | |
return contentType | |
} | |
func log(msg string) { | |
fmt.Printf("[%d] v%d %s\n", os.Getpid(), VERSION, msg) | |
} | |
func writeLoop(c net.Conn) { | |
var err error | |
for err == nil && isRunning { | |
msg := fmt.Sprintf("[v%d] Message %d\n", VERSION, pos) | |
pos++ | |
_, err = c.Write([]byte(msg)) | |
if err != nil { | |
fmt.Println(err) | |
} | |
time.Sleep(500 * time.Millisecond) | |
} | |
fmt.Println("Exit writeLoop") | |
} | |
func startChild(pipe *os.File) { | |
var err error | |
var path string | |
path, err = exec.LookPath(os.Args[0]) | |
if strings.HasPrefix(path, "./") { | |
var pwd string | |
pwd, err = os.Getwd() | |
if err != nil { | |
log("Getwd Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
path = pwd + string(os.PathSeparator) + path[2:] | |
} | |
cmd := exec.Command(path) | |
cmd.Stdin = os.Stdin | |
cmd.Stdout = os.Stdout | |
cmd.Stderr = os.Stderr | |
cmd.ExtraFiles = []*os.File{pipe} | |
cmd.Env = append(os.Environ(), fmt.Sprintf("PARENT_PID=%d", os.Getpid())) | |
err = cmd.Start() | |
if err != nil { | |
log("Start Err") | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment