Last active
May 7, 2025 22:58
-
-
Save Boris-creator/4e6ded4de3eaa211fe133b0cde21fa2c to your computer and use it in GitHub Desktop.
Restore Postgres DB schema from data-only dump
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module pgd | |
go 1.23.8 | |
require github.com/pganalyze/pg_query_go/v6 v6.1.0 | |
require google.golang.org/protobuf v1.31.0 // indirect |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"errors" | |
"fmt" | |
"github.com/pganalyze/pg_query_go/v6" | |
"log" | |
"os" | |
"path/filepath" | |
"strconv" | |
"strings" | |
"time" | |
) | |
func main() { | |
args := os.Args | |
if len(args) < 2 { | |
return | |
} | |
sql, err := readDump(args[1]) | |
if err != nil { | |
log.Fatal(err) | |
} | |
schemaRawSql, err := compile(sql) | |
if err != nil { | |
log.Fatal(err) | |
} | |
_, _ = os.Stdout.Write([]byte(schemaRawSql)) | |
} | |
// file validation & reading | |
func readDump(path string) (string, error) { | |
ext := filepath.Ext(path) | |
if ext != ".sql" { | |
return "", errors.New("sql file expected") | |
} | |
data, err := os.ReadFile(path) | |
if err != nil { | |
return "", fmt.Errorf("read file: %w", err) | |
} | |
return string(data), nil | |
} | |
// SQL parsing & analyze | |
type dataType string | |
const ( | |
intType dataType = "BIGINT" | |
floatType dataType = "FLOAT" | |
charType dataType = "VARCHAR" | |
boolType dataType = "BOOLEAN" | |
timeType dataType = "TIME" | |
timestampWithTZType dataType = "TIME_WITH_TZ" | |
jsonType dataType = "JSON" | |
) | |
type columnDef struct { | |
name string | |
dataType *dataType | |
canBeInt bool | |
canBeTime bool | |
canBeTimestampWithTZ bool | |
} | |
type tableDef []columnDef | |
func compile(input string) (string, error) { | |
tree, err := pg_query.Parse(input) | |
if err != nil { | |
return "", fmt.Errorf("parse sql: %w", err) | |
} | |
tablesSchema := make(map[string]tableDef) | |
var result []string | |
for _, stmt := range tree.Stmts { | |
insertStmt := stmt.GetStmt().GetInsertStmt() | |
if insertStmt == nil { | |
continue | |
} | |
tableName := insertStmt.GetRelation().Relname | |
var withColumns = true | |
if _, ok := tablesSchema[tableName]; !ok { | |
tableCols, hasCols := getColumns(insertStmt) | |
withColumns = hasCols | |
tablesSchema[tableName] = tableCols | |
} | |
tablesSchema[tableName] = getColsDefFromInsertStmt(insertStmt, tablesSchema[tableName], withColumns) | |
} | |
for name, def := range tablesSchema { | |
res, _ := compileCreateSt(name, def) | |
result = append(result, res) | |
} | |
return strings.Join(result, ""), nil | |
} | |
func getColumns(is *pg_query.InsertStmt) ([]columnDef, bool) { | |
cols := is.GetCols() | |
tableCols := make([]columnDef, 0, len(cols)) | |
for _, col := range cols { | |
tableCols = append(tableCols, columnDef{name: col.GetResTarget().Name}) | |
} | |
return tableCols, len(cols) != 0 | |
} | |
func getDefaultColumnName(pos int) string { | |
return fmt.Sprintf("col_%d", pos+1) | |
} | |
func getDataType(val *pg_query.A_Const) *dataType { | |
var dt dataType | |
switch val.GetVal().(type) { | |
case *pg_query.A_Const_Ival: | |
dt = intType | |
case *pg_query.A_Const_Sval: | |
dt = charType | |
case *pg_query.A_Const_Fval: | |
if canBeInt(val) { | |
dt = intType | |
} else { | |
dt = floatType | |
} | |
case *pg_query.A_Const_Boolval: | |
dt = boolType | |
default: | |
return nil | |
} | |
return &dt | |
} | |
func getColsDefFromInsertStmt(insertStmt *pg_query.InsertStmt, tableCols tableDef, withColumns bool) tableDef { | |
stmtListItems := insertStmt.GetSelectStmt().GetSelectStmt().GetValuesLists()[0].GetList().GetItems() | |
for i, valueItem := range stmtListItems { | |
if !withColumns { | |
tableCols = append(tableCols, columnDef{name: getDefaultColumnName(i)}) | |
} | |
columnVal := valueItem.GetAConst() | |
if tableCols[i].dataType == nil { | |
dt := getDataType(columnVal) | |
tableCols[i].dataType = dt | |
tableCols[i] = initDataTypeSuggestions(tableCols[i], columnVal) | |
} else { | |
tableCols[i] = updateDataTypeSuggestionsIfNeed(tableCols[i], columnVal) | |
} | |
} | |
return tableCols | |
} | |
func initDataTypeSuggestions(col columnDef, val *pg_query.A_Const) columnDef { | |
isTimeValue := canBeTime(val) | |
isTimestampValue := canBeTimestampWithTZ(val) | |
if col.dataType == nil { | |
col.canBeInt = true | |
col.canBeTime = isTimeValue | |
col.canBeTimestampWithTZ = isTimestampValue | |
return col | |
} | |
if *col.dataType == intType { | |
col.canBeInt = true | |
} | |
if *col.dataType == charType && isTimeValue { | |
col.canBeTime = true | |
} | |
if *col.dataType == charType && isTimestampValue { | |
col.canBeTimestampWithTZ = true | |
} | |
return col | |
} | |
func updateDataTypeSuggestionsIfNeed(col columnDef, val *pg_query.A_Const) columnDef { | |
if col.canBeInt && !canBeInt(val) { | |
col.canBeInt = false | |
} | |
if col.canBeTime && !canBeTime(val) { | |
col.canBeTime = false | |
} | |
if col.canBeTimestampWithTZ && !canBeTimestampWithTZ(val) { | |
col.canBeTimestampWithTZ = false | |
} | |
return col | |
} | |
func canBeTime(val *pg_query.A_Const) bool { | |
const layout = "15:04:05.999999" | |
value := val.GetSval() | |
if value == nil { | |
return true | |
} | |
_, err := time.Parse(layout, value.Sval) | |
return err == nil | |
} | |
func canBeTimestampWithTZ(val *pg_query.A_Const) bool { | |
const layout = "2006-01-02 15:04:05.999999+03" | |
value := val.GetSval() | |
if value == nil { | |
return true | |
} | |
_, err := time.Parse(layout, value.Sval) | |
return err == nil | |
} | |
func canBeInt(val *pg_query.A_Const) bool { | |
if val.GetIval() != nil { | |
return true | |
} | |
// for some reason pg_query parses big integers as floats | |
if val.GetFval() != nil { | |
_, err := strconv.ParseInt(val.GetFval().Fval, 10, 64) | |
return err == nil | |
} | |
return false | |
} | |
// SQL compilation | |
func compileCreateSt(tableName string, cols []columnDef) (string, error) { | |
st, _ := pg_query.Parse("CREATE TABLE t();") | |
ct := st.GetStmts()[0].GetStmt().GetCreateStmt() | |
ct.GetRelation().Relname = tableName | |
tableCols := make([]*pg_query.Node, 0, len(cols)) | |
for _, col := range cols { | |
tableCols = append(tableCols, compileColDef(col)) | |
} | |
ct.TableElts = tableCols | |
stmt, err := pg_query.Deparse(st) | |
if err != nil { | |
return "", err | |
} | |
return fmt.Sprintf("%s;\n", stmt), nil | |
} | |
func compileColDef(col columnDef) *pg_query.Node { | |
types := map[dataType]int{ | |
intType: 0, | |
floatType: 1, | |
charType: 2, | |
timeType: 3, | |
timestampWithTZType: 4, | |
boolType: 5, | |
jsonType: 6, | |
} | |
st, _ := pg_query.Parse("CREATE TABLE t (col1 BIGINT, col2 REAL, col3 VARCHAR, col4 TIME, col5 TIMESTAMP WITH TIME ZONE, col6 BOOLEAN, col7 JSON)") | |
createStmt := st.GetStmts()[0].GetStmt().GetCreateStmt() | |
var ct int | |
if col.dataType == nil { | |
jt := jsonType | |
col.dataType = &jt | |
} | |
if *col.dataType == intType && !col.canBeInt { | |
ft := floatType | |
col.dataType = &ft | |
} | |
if *col.dataType == charType && col.canBeTime { | |
tt := timeType | |
col.dataType = &tt | |
} | |
if *col.dataType == charType && col.canBeTimestampWithTZ { | |
tt := timestampWithTZType | |
col.dataType = &tt | |
} | |
ct, _ = types[*col.dataType] | |
names := pg_query.TypeName{Names: createStmt.TableElts[ct].GetColumnDef().GetTypeName().Names} | |
constraints := createStmt.TableElts[ct].GetColumnDef().Constraints | |
return pg_query.MakeSimpleColumnDefNode(col.name, &names, constraints, 0) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment