package goose import ( "database/sql" "fmt" ) // SQLDialect abstracts the details of specific SQL dialects // for goose's few SQL specific statements type SQLDialect interface { createVersionTableSQL() string // sql string to create the db version table insertVersionSQL() string // sql string to insert the initial version table row deleteVersionSQL() string // sql string to delete version dbVersionQuery(db *sql.DB) (*sql.Rows, error) } var dialect SQLDialect = &PostgresDialect{} // GetDialect gets the SQLDialect func GetDialect() SQLDialect { return dialect } // SetDialect sets the SQLDialect func SetDialect(d string) error { switch d { case "postgres": dialect = &PostgresDialect{} case "mysql": dialect = &MySQLDialect{} case "sqlite3": dialect = &Sqlite3Dialect{} case "redshift": dialect = &RedshiftDialect{} case "tidb": dialect = &TiDBDialect{} default: return fmt.Errorf("%q: unknown dialect", d) } return nil } //////////////////////////// // Postgres //////////////////////////// // PostgresDialect struct. type PostgresDialect struct{} func (pg PostgresDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) );`, TableName()) } func (pg PostgresDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) } func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { return nil, err } return rows, err } func (pg PostgresDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) } //////////////////////////// // MySQL //////////////////////////// // MySQLDialect struct. type MySQLDialect struct{} func (m MySQLDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id serial NOT NULL, version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) );`, TableName()) } func (m MySQLDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) } func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { return nil, err } return rows, err } func (m MySQLDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) } //////////////////////////// // sqlite3 //////////////////////////// // Sqlite3Dialect struct. type Sqlite3Dialect struct{} func (m Sqlite3Dialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id INTEGER PRIMARY KEY AUTOINCREMENT, version_id INTEGER NOT NULL, is_applied INTEGER NOT NULL, tstamp TIMESTAMP DEFAULT (datetime('now')) );`, TableName()) } func (m Sqlite3Dialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) } func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { return nil, err } return rows, err } func (m Sqlite3Dialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) } //////////////////////////// // Redshift //////////////////////////// // RedshiftDialect struct. type RedshiftDialect struct{} func (rs RedshiftDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id integer NOT NULL identity(1, 1), version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default sysdate, PRIMARY KEY(id) );`, TableName()) } func (rs RedshiftDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) } func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { return nil, err } return rows, err } func (rs RedshiftDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) } //////////////////////////// // TiDB //////////////////////////// // TiDBDialect struct. type TiDBDialect struct{} func (m TiDBDialect) createVersionTableSQL() string { return fmt.Sprintf(`CREATE TABLE %s ( id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, version_id bigint NOT NULL, is_applied boolean NOT NULL, tstamp timestamp NULL default now(), PRIMARY KEY(id) );`, TableName()) } func (m TiDBDialect) insertVersionSQL() string { return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) } func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) if err != nil { return nil, err } return rows, err } func (m TiDBDialect) deleteVersionSQL() string { return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) }