Skip to content

Instantly share code, notes, and snippets.

@Boris-creator
Last active May 7, 2025 22:58
Show Gist options
  • Save Boris-creator/4e6ded4de3eaa211fe133b0cde21fa2c to your computer and use it in GitHub Desktop.
Save Boris-creator/4e6ded4de3eaa211fe133b0cde21fa2c to your computer and use it in GitHub Desktop.
Restore Postgres DB schema from data-only dump
module pgd
go 1.23.8
require github.com/pganalyze/pg_query_go/v6 v6.1.0
require google.golang.org/protobuf v1.31.0 // indirect
package main
import (
"errors"
"fmt"
"github.com/pganalyze/pg_query_go/v6"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"time"
)
func main() {
args := os.Args
if len(args) < 2 {
return
}
sql, err := readDump(args[1])
if err != nil {
log.Fatal(err)
}
schemaRawSql, err := compile(sql)
if err != nil {
log.Fatal(err)
}
_, _ = os.Stdout.Write([]byte(schemaRawSql))
}
// file validation & reading
func readDump(path string) (string, error) {
ext := filepath.Ext(path)
if ext != ".sql" {
return "", errors.New("sql file expected")
}
data, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("read file: %w", err)
}
return string(data), nil
}
// SQL parsing & analyze
type dataType string
const (
intType dataType = "BIGINT"
floatType dataType = "FLOAT"
charType dataType = "VARCHAR"
boolType dataType = "BOOLEAN"
timeType dataType = "TIME"
timestampWithTZType dataType = "TIME_WITH_TZ"
jsonType dataType = "JSON"
)
type columnDef struct {
name string
dataType *dataType
canBeInt bool
canBeTime bool
canBeTimestampWithTZ bool
}
type tableDef []columnDef
func compile(input string) (string, error) {
tree, err := pg_query.Parse(input)
if err != nil {
return "", fmt.Errorf("parse sql: %w", err)
}
tablesSchema := make(map[string]tableDef)
var result []string
for _, stmt := range tree.Stmts {
insertStmt := stmt.GetStmt().GetInsertStmt()
if insertStmt == nil {
continue
}
tableName := insertStmt.GetRelation().Relname
var withColumns = true
if _, ok := tablesSchema[tableName]; !ok {
tableCols, hasCols := getColumns(insertStmt)
withColumns = hasCols
tablesSchema[tableName] = tableCols
}
tablesSchema[tableName] = getColsDefFromInsertStmt(insertStmt, tablesSchema[tableName], withColumns)
}
for name, def := range tablesSchema {
res, _ := compileCreateSt(name, def)
result = append(result, res)
}
return strings.Join(result, ""), nil
}
func getColumns(is *pg_query.InsertStmt) ([]columnDef, bool) {
cols := is.GetCols()
tableCols := make([]columnDef, 0, len(cols))
for _, col := range cols {
tableCols = append(tableCols, columnDef{name: col.GetResTarget().Name})
}
return tableCols, len(cols) != 0
}
func getDefaultColumnName(pos int) string {
return fmt.Sprintf("col_%d", pos+1)
}
func getDataType(val *pg_query.A_Const) *dataType {
var dt dataType
switch val.GetVal().(type) {
case *pg_query.A_Const_Ival:
dt = intType
case *pg_query.A_Const_Sval:
dt = charType
case *pg_query.A_Const_Fval:
if canBeInt(val) {
dt = intType
} else {
dt = floatType
}
case *pg_query.A_Const_Boolval:
dt = boolType
default:
return nil
}
return &dt
}
func getColsDefFromInsertStmt(insertStmt *pg_query.InsertStmt, tableCols tableDef, withColumns bool) tableDef {
stmtListItems := insertStmt.GetSelectStmt().GetSelectStmt().GetValuesLists()[0].GetList().GetItems()
for i, valueItem := range stmtListItems {
if !withColumns {
tableCols = append(tableCols, columnDef{name: getDefaultColumnName(i)})
}
columnVal := valueItem.GetAConst()
if tableCols[i].dataType == nil {
dt := getDataType(columnVal)
tableCols[i].dataType = dt
tableCols[i] = initDataTypeSuggestions(tableCols[i], columnVal)
} else {
tableCols[i] = updateDataTypeSuggestionsIfNeed(tableCols[i], columnVal)
}
}
return tableCols
}
func initDataTypeSuggestions(col columnDef, val *pg_query.A_Const) columnDef {
isTimeValue := canBeTime(val)
isTimestampValue := canBeTimestampWithTZ(val)
if col.dataType == nil {
col.canBeInt = true
col.canBeTime = isTimeValue
col.canBeTimestampWithTZ = isTimestampValue
return col
}
if *col.dataType == intType {
col.canBeInt = true
}
if *col.dataType == charType && isTimeValue {
col.canBeTime = true
}
if *col.dataType == charType && isTimestampValue {
col.canBeTimestampWithTZ = true
}
return col
}
func updateDataTypeSuggestionsIfNeed(col columnDef, val *pg_query.A_Const) columnDef {
if col.canBeInt && !canBeInt(val) {
col.canBeInt = false
}
if col.canBeTime && !canBeTime(val) {
col.canBeTime = false
}
if col.canBeTimestampWithTZ && !canBeTimestampWithTZ(val) {
col.canBeTimestampWithTZ = false
}
return col
}
func canBeTime(val *pg_query.A_Const) bool {
const layout = "15:04:05.999999"
value := val.GetSval()
if value == nil {
return true
}
_, err := time.Parse(layout, value.Sval)
return err == nil
}
func canBeTimestampWithTZ(val *pg_query.A_Const) bool {
const layout = "2006-01-02 15:04:05.999999+03"
value := val.GetSval()
if value == nil {
return true
}
_, err := time.Parse(layout, value.Sval)
return err == nil
}
func canBeInt(val *pg_query.A_Const) bool {
if val.GetIval() != nil {
return true
}
// for some reason pg_query parses big integers as floats
if val.GetFval() != nil {
_, err := strconv.ParseInt(val.GetFval().Fval, 10, 64)
return err == nil
}
return false
}
// SQL compilation
func compileCreateSt(tableName string, cols []columnDef) (string, error) {
st, _ := pg_query.Parse("CREATE TABLE t();")
ct := st.GetStmts()[0].GetStmt().GetCreateStmt()
ct.GetRelation().Relname = tableName
tableCols := make([]*pg_query.Node, 0, len(cols))
for _, col := range cols {
tableCols = append(tableCols, compileColDef(col))
}
ct.TableElts = tableCols
stmt, err := pg_query.Deparse(st)
if err != nil {
return "", err
}
return fmt.Sprintf("%s;\n", stmt), nil
}
func compileColDef(col columnDef) *pg_query.Node {
types := map[dataType]int{
intType: 0,
floatType: 1,
charType: 2,
timeType: 3,
timestampWithTZType: 4,
boolType: 5,
jsonType: 6,
}
st, _ := pg_query.Parse("CREATE TABLE t (col1 BIGINT, col2 REAL, col3 VARCHAR, col4 TIME, col5 TIMESTAMP WITH TIME ZONE, col6 BOOLEAN, col7 JSON)")
createStmt := st.GetStmts()[0].GetStmt().GetCreateStmt()
var ct int
if col.dataType == nil {
jt := jsonType
col.dataType = &jt
}
if *col.dataType == intType && !col.canBeInt {
ft := floatType
col.dataType = &ft
}
if *col.dataType == charType && col.canBeTime {
tt := timeType
col.dataType = &tt
}
if *col.dataType == charType && col.canBeTimestampWithTZ {
tt := timestampWithTZType
col.dataType = &tt
}
ct, _ = types[*col.dataType]
names := pg_query.TypeName{Names: createStmt.TableElts[ct].GetColumnDef().GetTypeName().Names}
constraints := createStmt.TableElts[ct].GetColumnDef().Constraints
return pg_query.MakeSimpleColumnDefNode(col.name, &names, constraints, 0)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment