Last active
August 18, 2018 17:37
-
-
Save zombiezen/83892a3ef7c5226f6af5bfca3f54cb84 to your computer and use it in GitHub Desktop.
schema2go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Copyright 2018 Google LLC | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// https://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
// schema2go converts a sequence of SQL files into a Go source file of constants. | |
package main | |
import ( | |
"bytes" | |
"errors" | |
"flag" | |
"fmt" | |
"go/format" | |
"io/ioutil" | |
"os" | |
"path/filepath" | |
"sort" | |
"strconv" | |
"strings" | |
) | |
func main() { | |
pkg := flag.String("pkg", "main", "package name") | |
prefix := flag.String("prefix", "", "name prefix") | |
sortFiles := flag.Bool("sort", true, "sort files") | |
flag.Parse() | |
args := append([]string(nil), flag.Args()...) | |
if *sortFiles { | |
sort.Slice(args, func(i, j int) bool { | |
return filepath.Base(args[i]) < filepath.Base(args[j]) | |
}) | |
} | |
src, err := run(*pkg, *prefix, args) | |
if err != nil { | |
fmt.Fprintln(os.Stderr, "schema2go:", err) | |
os.Exit(1) | |
} | |
_, err = os.Stdout.Write(src) | |
if err != nil { | |
fmt.Fprintln(os.Stderr, "schema2go:", err) | |
os.Exit(1) | |
} | |
} | |
func run(pkg, prefix string, files []string) ([]byte, error) { | |
if len(files) == 0 { | |
return nil, errors.New("no schemas") | |
} | |
var schemas [][]byte | |
for _, path := range files { | |
sql, err := ioutil.ReadFile(path) | |
if err != nil { | |
return nil, err | |
} | |
schemas = append(schemas, minifySQL(sql)) | |
} | |
buf := new(bytes.Buffer) | |
fmt.Fprint(buf, "// Code generated by schema2go. DO NOT EDIT.\n\n") | |
fmt.Fprintf(buf, "package %s\n\n", pkg) | |
fmt.Fprintf(buf, "// %sFull is the in-order concatenation of all the schemas.\n", prefix) | |
fmt.Fprintf(buf, "const %sFull = ", prefix) | |
for first, s := true, joinSchemas(schemas); len(s) > 0; { | |
if first { | |
first = false | |
} else { | |
buf.WriteString(" +\n\t") | |
} | |
end := strings.IndexByte(s, '\n') | |
if end == -1 { | |
end = len(s) | |
} | |
buf.WriteString(strconv.Quote(s[:end+1])) | |
s = s[end+1:] | |
} | |
fmt.Fprint(buf, "\n\n") | |
fmt.Fprintf(buf, "// %sMigrations has the set of SQL to apply at each version to get to the latest schema.\n", prefix) | |
fmt.Fprintf(buf, "var %sMigrations = [...]string{\n", prefix) | |
for i, idx := 0, 0; i < len(schemas); i++ { | |
fmt.Fprintf(buf, "\t// %s", filepath.Base(files[i])) | |
if i < len(schemas)-1 { | |
fmt.Fprint(buf, " + ...") | |
} | |
fmt.Fprintf(buf, "\n\t%d: %sFull[%d:],\n", i, prefix, idx) | |
idx += len(schemas[i]) | |
} | |
fmt.Fprintln(buf, "\t// Last entry empty: all migrated.") | |
fmt.Fprintf(buf, "\t%d: \"\",\n", len(schemas)) | |
fmt.Fprint(buf, "}\n\n") | |
fmt.Fprintf(buf, "// %sLatest is the highest known schema version.\n", prefix) | |
fmt.Fprintf(buf, "const %sLatest = %d\n", prefix, len(schemas)) | |
return format.Source(buf.Bytes()) | |
} | |
func minifySQL(sql []byte) []byte { | |
min := sql[:0] | |
for len(sql) > 0 { | |
var line []byte | |
if i := bytes.IndexByte(sql, '\n'); i != -1 { | |
line, sql = sql[:i+1], sql[i+1:] | |
} else { | |
line, sql = sql, nil | |
} | |
if isCommentLine(line) { | |
continue | |
} | |
min = append(min, bytes.TrimLeft(line, " \t")...) | |
} | |
return min | |
} | |
func isCommentLine(line []byte) bool { | |
for ; len(line) > 0; line = line[1:] { | |
if line[0] == '-' { | |
return len(line) >= 2 && line[1] == '-' | |
} | |
if !isSpace(line[0]) { | |
return false | |
} | |
} | |
return true | |
} | |
func isSpace(b byte) bool { | |
return b == ' ' || b == '\t' || b == '\r' || b == '\n' | |
} | |
func joinSchemas(schemas [][]byte) string { | |
var b strings.Builder | |
for _, sql := range schemas { | |
b.Write(sql) | |
} | |
return b.String() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment