241 lines
6.0 KiB
Go
241 lines
6.0 KiB
Go
|
package goose
|
||
|
|
||
|
import (
|
||
|
"bufio"
|
||
|
"bytes"
|
||
|
"database/sql"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"os"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
)
|
||
|
|
||
|
const sqlCmdPrefix = "-- +goose "
|
||
|
const scanBufSize = 4 * 1024 * 1024
|
||
|
|
||
|
var bufferPool = sync.Pool{
|
||
|
New: func() interface{} {
|
||
|
return make([]byte, scanBufSize)
|
||
|
},
|
||
|
}
|
||
|
|
||
|
// Checks the line to see if the line has a statement-ending semicolon
|
||
|
// or if the line contains a double-dash comment.
|
||
|
func endsWithSemicolon(line string) bool {
|
||
|
scanBuf := bufferPool.Get().([]byte)
|
||
|
defer bufferPool.Put(scanBuf)
|
||
|
|
||
|
prev := ""
|
||
|
scanner := bufio.NewScanner(strings.NewReader(line))
|
||
|
scanner.Buffer(scanBuf, scanBufSize)
|
||
|
scanner.Split(bufio.ScanWords)
|
||
|
|
||
|
for scanner.Scan() {
|
||
|
word := scanner.Text()
|
||
|
if strings.HasPrefix(word, "--") {
|
||
|
break
|
||
|
}
|
||
|
prev = word
|
||
|
}
|
||
|
|
||
|
return strings.HasSuffix(prev, ";")
|
||
|
}
|
||
|
|
||
|
// Split the given sql script into individual statements.
|
||
|
//
|
||
|
// The base case is to simply split on semicolons, as these
|
||
|
// naturally terminate a statement.
|
||
|
//
|
||
|
// However, more complex cases like pl/pgsql can have semicolons
|
||
|
// within a statement. For these cases, we provide the explicit annotations
|
||
|
// 'StatementBegin' and 'StatementEnd' to allow the script to
|
||
|
// tell us to ignore semicolons.
|
||
|
func getSQLStatements(r io.Reader, direction bool) ([]string, bool, error) {
|
||
|
var buf bytes.Buffer
|
||
|
scanBuf := bufferPool.Get().([]byte)
|
||
|
defer bufferPool.Put(scanBuf)
|
||
|
|
||
|
scanner := bufio.NewScanner(r)
|
||
|
scanner.Buffer(scanBuf, scanBufSize)
|
||
|
|
||
|
// track the count of each section
|
||
|
// so we can diagnose scripts with no annotations
|
||
|
upSections := 0
|
||
|
downSections := 0
|
||
|
|
||
|
statementEnded := false
|
||
|
ignoreSemicolons := false
|
||
|
directionIsActive := false
|
||
|
tx := true
|
||
|
stmts := []string{}
|
||
|
|
||
|
for scanner.Scan() {
|
||
|
|
||
|
line := scanner.Text()
|
||
|
|
||
|
// handle any goose-specific commands
|
||
|
if strings.HasPrefix(line, sqlCmdPrefix) {
|
||
|
cmd := strings.TrimSpace(line[len(sqlCmdPrefix):])
|
||
|
switch cmd {
|
||
|
case "Up":
|
||
|
directionIsActive = (direction == true)
|
||
|
upSections++
|
||
|
break
|
||
|
|
||
|
case "Down":
|
||
|
directionIsActive = (direction == false)
|
||
|
downSections++
|
||
|
break
|
||
|
|
||
|
case "StatementBegin":
|
||
|
if directionIsActive {
|
||
|
ignoreSemicolons = true
|
||
|
}
|
||
|
break
|
||
|
|
||
|
case "StatementEnd":
|
||
|
if directionIsActive {
|
||
|
statementEnded = (ignoreSemicolons == true)
|
||
|
ignoreSemicolons = false
|
||
|
}
|
||
|
break
|
||
|
|
||
|
case "NO TRANSACTION":
|
||
|
tx = false
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if !directionIsActive {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
if _, err := buf.WriteString(line + "\n"); err != nil {
|
||
|
return nil, false, fmt.Errorf("io err: %v", err)
|
||
|
}
|
||
|
|
||
|
// Wrap up the two supported cases: 1) basic with semicolon; 2) psql statement
|
||
|
// Lines that end with semicolon that are in a statement block
|
||
|
// do not conclude statement.
|
||
|
if (!ignoreSemicolons && endsWithSemicolon(line)) || statementEnded {
|
||
|
statementEnded = false
|
||
|
stmts = append(stmts, buf.String())
|
||
|
buf.Reset()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err := scanner.Err(); err != nil {
|
||
|
return nil, false, fmt.Errorf("scanning migration: %v", err)
|
||
|
}
|
||
|
|
||
|
// diagnose likely migration script errors
|
||
|
if ignoreSemicolons {
|
||
|
return nil, false, fmt.Errorf("parsing migration: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'")
|
||
|
}
|
||
|
|
||
|
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
|
||
|
return nil, false, fmt.Errorf("parsing migration: unexpected unfinished SQL query: %s. potential missing semicolon", bufferRemaining)
|
||
|
}
|
||
|
|
||
|
if upSections == 0 && downSections == 0 {
|
||
|
return nil, false, fmt.Errorf("parsing migration: no Up/Down annotations found, so no statements were executed. See https://bitbucket.org/liamstask/goose/overview for details")
|
||
|
}
|
||
|
|
||
|
return stmts, tx, nil
|
||
|
}
|
||
|
|
||
|
// Run a migration specified in raw SQL.
|
||
|
//
|
||
|
// Sections of the script can be annotated with a special comment,
|
||
|
// starting with "-- +goose" to specify whether the section should
|
||
|
// be applied during an Up or Down migration
|
||
|
//
|
||
|
// All statements following an Up or Down directive are grouped together
|
||
|
// until another direction directive is found.
|
||
|
func runSQLMigration(db *sql.DB, sqlFile string, v int64, direction bool) error {
|
||
|
f, err := os.Open(sqlFile)
|
||
|
if err != nil {
|
||
|
return errors.Wrap(err, "failed to open SQL migration file")
|
||
|
}
|
||
|
defer f.Close()
|
||
|
|
||
|
statements, useTx, err := getSQLStatements(f, direction)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if useTx {
|
||
|
// TRANSACTION.
|
||
|
|
||
|
printInfo("Begin transaction\n")
|
||
|
|
||
|
tx, err := db.Begin()
|
||
|
if err != nil {
|
||
|
errors.Wrap(err, "failed to begin transaction")
|
||
|
}
|
||
|
|
||
|
for _, query := range statements {
|
||
|
printInfo("Executing statement: %s\n", clearStatement(query))
|
||
|
if _, err = tx.Exec(query); err != nil {
|
||
|
printInfo("Rollback transaction\n")
|
||
|
tx.Rollback()
|
||
|
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if direction {
|
||
|
if _, err := tx.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
|
||
|
printInfo("Rollback transaction\n")
|
||
|
tx.Rollback()
|
||
|
return errors.Wrap(err, "failed to insert new goose version")
|
||
|
}
|
||
|
} else {
|
||
|
if _, err := tx.Exec(GetDialect().deleteVersionSQL(), v); err != nil {
|
||
|
printInfo("Rollback transaction\n")
|
||
|
tx.Rollback()
|
||
|
return errors.Wrap(err, "failed to delete goose version")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
printInfo("Commit transaction\n")
|
||
|
if err := tx.Commit(); err != nil {
|
||
|
return errors.Wrap(err, "failed to commit transaction")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// NO TRANSACTION.
|
||
|
for _, query := range statements {
|
||
|
printInfo("Executing statement: %s\n", clearStatement(query))
|
||
|
if _, err := db.Exec(query); err != nil {
|
||
|
return errors.Wrapf(err, "failed to execute SQL query %q", clearStatement(query))
|
||
|
}
|
||
|
}
|
||
|
if _, err := db.Exec(GetDialect().insertVersionSQL(), v, direction); err != nil {
|
||
|
return errors.Wrap(err, "failed to insert new goose version")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func printInfo(s string, args ...interface{}) {
|
||
|
if verbose {
|
||
|
log.Printf(s, args...)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
matchSQLComments = regexp.MustCompile(`(?m)^--.*$[\r\n]*`)
|
||
|
matchEmptyLines = regexp.MustCompile(`(?m)^$[\r\n]*`)
|
||
|
)
|
||
|
|
||
|
func clearStatement(s string) string {
|
||
|
s = matchSQLComments.ReplaceAllString(s, ``)
|
||
|
return matchEmptyLines.ReplaceAllString(s, ``)
|
||
|
}
|