Skip to content

Instantly share code, notes, and snippets.

@patrickod
Last active April 30, 2025 16:53
Show Gist options
  • Save patrickod/d5c152e9d76e8b2cca39837056137a2c to your computer and use it in GitHub Desktop.
Save patrickod/d5c152e9d76e8b2cca39837056137a2c to your computer and use it in GitHub Desktop.
Go AST analysis tool to identify instances of unsafe r.URL.Scheme value comparison in HTTP request handlers
module github.com/patrickod/schemevet
go 1.23.4
require golang.org/x/tools v0.28.0
require (
golang.org/x/mod v0.22.0 // indirect
golang.org/x/sync v0.10.0 // indirect
)
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8=
golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw=
// schemevet is go AST analysis tool that identifies instances of unsafe
// comparison of unpopulated URL.Scheme field in HTTP request handlers.
package schemevet
import (
"go/ast"
"go/token"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
)
var Analyzer = &analysis.Analyzer{
Name: "schemevet",
Doc: "reports unsafe comparison of unpopulated URL.Scheme field in HTTP request handlers",
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: run,
}
func run(pass *analysis.Pass) (any, error) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
inspect.Preorder([]ast.Node{
(*ast.File)(nil),
}, func(n ast.Node) {
analyzeFile(pass, n.(*ast.File))
})
return nil, nil
}
func analyzeFile(pass *analysis.Pass, f *ast.File) {
ast.Inspect(f, func(n ast.Node) bool {
var fn *ast.FuncType
var body *ast.BlockStmt
switch x := n.(type) {
case *ast.FuncDecl:
fn = x.Type
body = x.Body
case *ast.FuncLit:
fn = x.Type
body = x.Body
default:
return true
}
// Check if the function is an HTTP handler
// as determined by having both http.ResponseWriter and *http.Request parameters
// capture the name of the *http.Request parameter
isHandler, requestVarName := isHTTPHandlerFunc(fn.Params.List)
if !isHandler {
return true
}
ast.Inspect(body, func(n ast.Node) bool {
// Check if the node is a binary expression
binExpr, ok := n.(*ast.BinaryExpr)
if !ok {
return true
}
// Check if the binary expression is an equality check
if binExpr.Op != token.EQL {
return true
}
// Check if the left side of the equality is r.URL.Scheme
selector, ok := binExpr.X.(*ast.SelectorExpr)
if !ok {
return true
}
if sel, ok := selector.X.(*ast.SelectorExpr); ok {
if ident, ok := sel.X.(*ast.Ident); ok && ident.Name == requestVarName {
if sel.Sel.Name == "URL" && selector.Sel.Name == "Scheme" {
// Check if the right side of the equality is the string literal "https"
lit, ok := binExpr.Y.(*ast.BasicLit)
if !ok {
return true
}
if lit.Kind != token.STRING {
return true
}
if lit.Value == `"https"` {
pass.Reportf(selector.Pos(), "unsafe comparison of unpopulated request URL.Scheme field in HTTP handler")
return true
}
}
}
}
return true
})
return false
})
}
// isHTTPHandlerFunc checks if the function is an HTTP handler function as
// determined by having both http.ResponseWriter and *http.Request parameters.
// The second return value is the name of the *http.Request parameter.
func isHTTPHandlerFunc(params []*ast.Field) (bool, string) {
if len(params) < 2 {
return false, ""
}
var hasResponseWriter, hasRequest bool
var requestVarName string
for _, p := range params {
if len(p.Names) != 1 {
return false, ""
}
if isType(p.Type, "http.ResponseWriter") {
hasResponseWriter = true
}
if isType(p.Type, "*http.Request") {
hasRequest = true
requestVarName = p.Names[0].Name
}
}
if hasResponseWriter && hasRequest {
return true, requestVarName
}
return false, ""
}
func isType(expr ast.Expr, typeName string) bool {
switch t := expr.(type) {
case *ast.Ident:
return t.Name == typeName
case *ast.SelectorExpr:
if x, ok := t.X.(*ast.Ident); ok {
n := x.Name + "." + t.Sel.Name
return n == typeName
}
case *ast.StarExpr:
if sel, ok := t.X.(*ast.SelectorExpr); ok {
if x, ok := sel.X.(*ast.Ident); ok {
n := "*" + x.Name + "." + sel.Sel.Name
return n == typeName
}
}
}
return false
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment