Created
December 19, 2021 02:04
-
-
Save jxskiss/3d8a0c1362e961a2edb178473b91a154 to your computer and use it in GitHub Desktop.
Opinionted and simple Golang assertion helpers.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package assert | |
import ( | |
"bytes" | |
"fmt" | |
"reflect" | |
"regexp" | |
"strings" | |
"testing" | |
) | |
func Contains(t testing.TB, container interface{}, elem interface{}, msgAndArgs ...interface{}) bool { | |
if checkContains(container, elem) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("%#v does not contain %#v", container, elem) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func NotContains(t testing.TB, container interface{}, elem interface{}, msgAndArgs ...interface{}) bool { | |
if !checkContains(container, elem) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("%#v does contain %#v", container, elem) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func Equal(t testing.TB, left, right interface{}, msgAndArgs ...interface{}) bool { | |
if isEqual(left, right) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("%#v does not equal %#v", left, right) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func NotEqual(t testing.TB, left, right interface{}, msgAndArgs ...interface{}) bool { | |
if !isEqual(left, right) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("%#v does equal %#v", left, right) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func True(t testing.TB, value bool, msgAndArgs ...interface{}) bool { | |
if value { | |
return true | |
} | |
t.Helper() | |
return fail(t, "should be true, got false", msgAndArgs) | |
} | |
func False(t testing.TB, value bool, msgAndArgs ...interface{}) bool { | |
if !value { | |
return true | |
} | |
t.Helper() | |
return fail(t, "should be false, got true", msgAndArgs) | |
} | |
func Nil(t testing.TB, object interface{}, msgAndArgs ...interface{}) bool { | |
if isNil(object) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("shoule be nil, got %#v", object) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func NotNil(t testing.TB, object interface{}, msgAndArgs ...interface{}) bool { | |
if !isNil(object) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("should not be nil, got %#v", object) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func Empty(t testing.TB, object interface{}, msgAndArgs ...interface{}) bool { | |
if isEmpty(object) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("should be empty, got %#v", object) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func NotEmpty(t testing.TB, objeect interface{}, msgAndArgs ...interface{}) bool { | |
if !isEmpty(objeect) { | |
return true | |
} | |
t.Helper() | |
errmsg := fmt.Sprintf("should not be empty, got %#v", objeect) | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func Regexp(t testing.TB, re interface{}, str string, msgAndArgs ...interface{}) bool { | |
var match bool | |
var errmsg string | |
switch re := re.(type) { | |
case *regexp.Regexp: | |
match = re.MatchString(str) | |
case string: | |
match = regexp.MustCompile(re).MatchString(str) | |
default: | |
errmsg = fmt.Sprintf("%#v is not a valid re param", re) | |
} | |
if match { | |
return true | |
} | |
t.Helper() | |
if errmsg == "" { | |
errmsg = fmt.Sprintf("%q does not match re %q", str, re) | |
} | |
return fail(t, errmsg, msgAndArgs) | |
} | |
func Panics(t testing.TB, f func(), msgAndArgs ...interface{}) bool { | |
var didPanic bool | |
var exc interface{} | |
func() { | |
defer func() { | |
exc = recover() | |
didPanic = exc != nil | |
}() | |
f() | |
}() | |
if didPanic { | |
return true | |
} | |
t.Helper() | |
return fail(t, "the code does not panic", msgAndArgs) | |
} | |
// -------------------------------- // | |
var typeOf, valueOf = reflect.TypeOf, reflect.ValueOf | |
func fail(t testing.TB, errmsg string, userMsgAndArgs []interface{}) bool { | |
t.Helper() | |
switch { | |
case len(userMsgAndArgs) == 1: | |
errmsg += "\n" + fmt.Sprint(userMsgAndArgs[0]) | |
case len(userMsgAndArgs) > 1: | |
if tmpl, ok := userMsgAndArgs[0].(string); ok && strings.Contains(tmpl, "%") { | |
errmsg += "\n" + fmt.Sprintf(tmpl, userMsgAndArgs[1:]...) | |
} else { | |
errmsg += "\n" + fmt.Sprint(userMsgAndArgs...) | |
} | |
} | |
t.Error(errmsg) | |
return false | |
} | |
func checkContains(container, elem interface{}) bool { | |
cVal, eVal := valueOf(container), valueOf(elem) | |
cTyp, eTyp := cVal.Type(), eVal.Type() | |
cKind, eKind := cTyp.Kind(), eTyp.Kind() | |
if cKind == eKind { | |
if cKind == reflect.String { | |
return strings.Contains(cVal.String(), eVal.String()) | |
} | |
if cKind == reflect.Slice { | |
if cTyp.Elem().Kind() == reflect.Uint { | |
return bytes.Contains(cVal.Bytes(), eVal.Bytes()) | |
} | |
cLen, eLen := cVal.Len(), eVal.Len() | |
for i := 0; i+eLen < cLen; i++ { | |
if reflect.DeepEqual(cVal.Slice(i, i+eLen).Interface(), elem) { | |
return true | |
} | |
} | |
return false | |
} | |
} | |
if cKind == reflect.Slice && cTyp.Elem() == eTyp { | |
for i := 0; i < cVal.Len(); i++ { | |
if cVal.Index(i).Interface() == elem { | |
return true | |
} | |
} | |
return false | |
} | |
if cKind == reflect.Map && cTyp.Key() == eTyp { | |
if x := cVal.MapIndex(eVal); x.IsValid() { | |
return true | |
} | |
} | |
return false | |
} | |
func isEqual(left, right interface{}) bool { | |
aTyp, bTyp := typeOf(left), typeOf(right) | |
if aTyp != nil && bTyp != nil && aTyp.Comparable() && bTyp.Comparable() && | |
(left == right || literalConvert(left) == literalConvert(right)) { | |
return true | |
} | |
return reflect.DeepEqual(left, right) | |
} | |
func isNil(object interface{}) bool { | |
if object == nil { | |
return true | |
} | |
kind := typeOf(object).Kind() | |
return (kind == reflect.Chan || kind == reflect.Func || kind == reflect.Interface || | |
kind == reflect.Map || kind == reflect.Ptr || kind == reflect.Slice || | |
kind == reflect.UnsafePointer) && valueOf(object).IsNil() | |
} | |
func isEmpty(object interface{}) bool { | |
if object == nil { | |
return true | |
} | |
kind, val := typeOf(object).Kind(), valueOf(object) | |
return ((kind == reflect.Array || kind == reflect.Chan || kind == reflect.Map || kind == reflect.Slice) && val.Len() == 0) || | |
(kind == reflect.Ptr && (val.IsNil() || isEmpty(val.Elem()))) || | |
reflect.DeepEqual(object, reflect.Zero(val.Type()).Interface()) | |
} | |
func literalConvert(val interface{}) interface{} { | |
switch val := valueOf(val); val.Kind() { | |
case reflect.Bool: | |
return val.Bool() | |
case reflect.String: | |
return val.Convert(typeOf("")).Interface() | |
case reflect.Float32, reflect.Float64: | |
return val.Float() | |
case reflect.Complex64, reflect.Complex128: | |
return val.Complex() | |
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |
if asInt := val.Int(); asInt < 0 { | |
return asInt | |
} | |
return val.Convert(typeOf(uint64(0))).Uint() | |
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |
return val.Uint() | |
default: | |
return val | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment