Created
September 29, 2022 07:49
-
-
Save Merovius/ca9e9199f8f46fea63f99744b61f7ea7 to your computer and use it in GitHub Desktop.
package callsrc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Package callsrc helps enforcing that a function is only called in certain contexts. | |
package callsrc | |
import ( | |
"io" | |
"os" | |
"runtime" | |
"strings" | |
) | |
const ( | |
// Init allows a call from init() | |
Init = (1 << iota) | |
// PkgScope allows a call from package-scope variable initializers | |
PkgScope | |
// MainPkg allows a call originating from package main | |
MainPkg | |
// MainFn allows a call from main.main. Implies MainPkg | |
MainFn | |
// TestMain allows a call form TestMain | |
TestMain | |
// TestFunc allows a call from a Test*, Bench* or Fuzz* function. Implies | |
// TestMain | |
TestFunc | |
) | |
func has(flags, f int) bool { | |
return flags&f != 0 | |
} | |
// Allow asserts that a call happens in a particular context. skip is the | |
// number of frames to skip, with 0 identifying the caller of Allow. If the | |
// call is not allowed in the actual context, a message is printed and the | |
// program exits. | |
func Allow(skip, flags int) { | |
if skip < 0 { | |
io.WriteString(os.Stderr, "skip must not be negative\n") | |
os.Exit(1) | |
} | |
f := functions(skip, 3) | |
// f[0] is the restricted function | |
// f[1] is the caller of the restricted function | |
// f[2] is its caller, i.e. the testing or runtime package in common contexts | |
if f[1].full == "main.main" && has(flags, MainFn) { | |
return | |
} | |
if f[1].pkg == "main" && has(flags, MainPkg) { | |
return | |
} | |
if f[1].name == "init" && has(flags, PkgScope) { | |
return | |
} | |
if strings.HasPrefix(f[1].name, "init.") && has(flags, Init) { | |
return | |
} | |
if len(f) < 3 { | |
io.WriteString(os.Stderr, f[0].full+" must not be called from "+f[1].full+"\n") | |
os.Exit(1) | |
} | |
if f[1].name == "TestMain" && f[2].pkg == "testing" && has(flags, TestMain) { | |
return | |
} | |
if f[2].pkg == "testing" && has(flags, TestFunc) { | |
return | |
} | |
if f[1].name == "TestMain" && f[2].pkg == "testing" && has(flags, TestMain) { | |
// TODO: Technically, this isn't enough. e.g. testing.AllocsPerRun | |
// takes a callback and this callback would be matched here but isn't a | |
// Test*, Bench* or Fuzz* function. | |
return | |
} | |
io.WriteString(os.Stderr, f[0].full+" must not be called from "+f[1].full+"\n") | |
os.Exit(1) | |
} | |
type function struct { | |
full string | |
pkg string | |
name string | |
} | |
func functions(skip, n int) []function { | |
pc := make([]uintptr, 1024) | |
pc = pc[:runtime.Callers(3+skip, pc)] | |
frames := runtime.CallersFrames(pc) | |
var out []function | |
for n > 0 { | |
f, ok := frames.Next() | |
if !ok { | |
break | |
} | |
p, n, ok := strings.Cut(f.Function, ".") | |
if ok { | |
out = append(out, function{f.Function, p, n}) | |
} else { | |
out = append(out, function{f.Function, "", f.Function}) | |
} | |
} | |
return out | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment