Last active
February 14, 2022 19:19
-
-
Save tadasv/525ab661bd7916d79cee83e6e1635a56 to your computer and use it in GitHub Desktop.
Simple ORM for Golang
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
This is an implementation of simple ORM that supports CRUD operations on single object, column mapping from struct tags. | |
example: | |
type Account struct { | |
orm.Model | |
UUID string | |
Name string | |
Email string | |
ActiveOrganizationUUID *string | |
MemberSince time.Time | |
CreatedAt time.Time | |
ModifiedAt time.Time | |
} | |
func (a *Account) BindTo(connector orm.Connector) *Account { | |
a.Model.Connector = connector | |
a.Model.Object = a | |
return a | |
} | |
acc := (&Account{UUID: uuid}).BindTo(db) | |
err := acc.Load() | |
fmt.Printf("%v %v\n", err, acc) | |
Unlike other ORM frameworks I've seen, here we bind DB connection to your object. Which is very useful if you want | |
to do rapid web development and pass objects to template renderer, where you might pull out more related data from DB | |
during rendering. | |
*/ | |
package orm | |
import ( | |
"database/sql" | |
sq "github.com/Masterminds/squirrel" | |
"reflect" | |
"regexp" | |
"strings" | |
) | |
var ( | |
matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") | |
matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") | |
defaultPrimaryKeyColumnNames = []string{ | |
"id", | |
"uuid", | |
} | |
) | |
type Connector interface { | |
Exec(query string, args ...interface{}) (sql.Result, error) | |
Query(query string, args ...interface{}) (*sql.Rows, error) | |
QueryRow(query string, args ...interface{}) *sql.Row | |
} | |
type Model struct { | |
PrimaryKeyColumnName string | |
TableName string | |
Connector Connector | |
Object interface{} | |
} | |
func (o Model) getTableName() string { | |
if o.TableName != "" { | |
return o.TableName | |
} | |
var name string | |
if t := reflect.TypeOf(o.Object); t.Kind() == reflect.Ptr { | |
name = t.Elem().Name() | |
} else { | |
name = t.Name() | |
} | |
return strings.ToLower(name) | |
} | |
func (o Model) getColumnsAndNames() ([]interface{}, []string) { | |
names := []string{} | |
columns := []interface{}{} | |
val := reflect.Indirect(reflect.ValueOf(o.Object)) | |
valType := val.Type() | |
for i := 0; i < valType.NumField(); i++ { | |
field := valType.Field(i) | |
if field.Type.AssignableTo(reflect.TypeOf(o)) { | |
// Ignore embedded Model internals when deriving column names | |
continue | |
} | |
if colName, ok := field.Tag.Lookup("db_col"); ok { | |
if colName == "-" { | |
continue | |
} | |
names = append(names, colName) | |
} else { | |
names = append(names, toSnakeCase(field.Name)) | |
} | |
columns = append(columns, val.Field(i).Addr().Interface()) | |
} | |
return columns, names | |
} | |
func (o Model) getPrimaryKey(names []string, columns []interface{}) (string, interface{}) { | |
for i, name := range names { | |
for _, pkName := range defaultPrimaryKeyColumnNames { | |
if name == pkName { | |
return name, columns[i] | |
} | |
} | |
} | |
return "", nil | |
} | |
func (o Model) Load() error { | |
columns, colNames := o.getColumnsAndNames() | |
primaryKeyName, primaryKeyValue := o.getPrimaryKey(colNames, columns) | |
q := sq.Select(colNames...).From(o.getTableName()) | |
q = q.Where(sq.Eq{ | |
primaryKeyName: primaryKeyValue, | |
}).Limit(1) | |
sql, args, err := q.ToSql() | |
if err != nil { | |
return err | |
} | |
row := o.Connector.QueryRow(sql, args...) | |
return row.Scan(columns...) | |
} | |
func (o Model) getUpdateQuery() sq.UpdateBuilder { | |
columns, colNames := o.getColumnsAndNames() | |
primaryKeyName, primaryKeyValue := o.getPrimaryKey(colNames, columns) | |
q := sq.Update(o.getTableName()) | |
for i, name := range colNames { | |
value := columns[i] | |
q = q.Set(name, value) | |
} | |
q = q.Where(sq.Eq{ | |
primaryKeyName: primaryKeyValue, | |
}).Limit(1) | |
return q | |
} | |
func (o Model) Save() error { | |
q := o.getUpdateQuery() | |
sql, args, err := q.ToSql() | |
if err != nil { | |
return err | |
} | |
_, err = o.Connector.Exec(sql, args...) | |
if err != nil { | |
return err | |
} | |
return nil | |
} | |
func (o Model) Create() error { | |
columns, colNames := o.getColumnsAndNames() | |
q := sq.Insert(o.getTableName()).Columns(colNames...).Values(columns...) | |
sql, args, err := q.ToSql() | |
if err != nil { | |
return err | |
} | |
_, err = o.Connector.Exec(sql, args...) | |
return err | |
} | |
func (o Model) Delete() error { | |
columns, colNames := o.getColumnsAndNames() | |
primaryKeyName, primaryKeyValue := o.getPrimaryKey(colNames, columns) | |
q := sq.Delete(o.getTableName()).Where(sq.Eq{ | |
primaryKeyName: primaryKeyValue, | |
}).Limit(1) | |
sql, args, err := q.ToSql() | |
if err != nil { | |
return err | |
} | |
_, err = o.Connector.Exec(sql, args...) | |
if err != nil { | |
return err | |
} | |
return nil | |
} | |
func toSnakeCase(str string) string { | |
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}") | |
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}") | |
return strings.ToLower(snake) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment