Ability to specify custom migrations table name (#178)

Supported via `--migrations-table` CLI flag or `DBMATE_MIGRATIONS_TABLE` environment variable.

Specified table name is quoted when necessary.

For PostgreSQL specifically, it's also possible to specify a custom schema (for example: `--migrations-table=foo.migrations`).

Closes #168
This commit is contained in:
Adrian Macneil 2020-11-17 18:11:24 +13:00 committed by GitHub
parent 656dc0253a
commit c907c3f5c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 807 additions and 325 deletions

View file

@ -52,6 +52,12 @@ func NewApp() *cli.App {
Value: dbmate.DefaultMigrationsDir, Value: dbmate.DefaultMigrationsDir,
Usage: "specify the directory containing migration files", Usage: "specify the directory containing migration files",
}, },
&cli.StringFlag{
Name: "migrations-table",
EnvVars: []string{"DBMATE_MIGRATIONS_TABLE"},
Value: dbmate.DefaultMigrationsTableName,
Usage: "specify the database table to record migrations in",
},
&cli.StringFlag{ &cli.StringFlag{
Name: "schema-file", Name: "schema-file",
Aliases: []string{"s"}, Aliases: []string{"s"},
@ -222,6 +228,7 @@ func action(f func(*dbmate.DB, *cli.Context) error) cli.ActionFunc {
db := dbmate.New(u) db := dbmate.New(u)
db.AutoDumpSchema = !c.Bool("no-dump-schema") db.AutoDumpSchema = !c.Bool("no-dump-schema")
db.MigrationsDir = c.String("migrations-dir") db.MigrationsDir = c.String("migrations-dir")
db.MigrationsTableName = c.String("migrations-table")
db.SchemaFile = c.String("schema-file") db.SchemaFile = c.String("schema-file")
db.WaitBefore = c.Bool("wait") db.WaitBefore = c.Bool("wait")
overrideTimeout := c.Duration("wait-timeout") overrideTimeout := c.Duration("wait-timeout")

View file

@ -13,11 +13,12 @@ import (
) )
func init() { func init() {
RegisterDriver(ClickHouseDriver{}, "clickhouse") RegisterDriver(&ClickHouseDriver{}, "clickhouse")
} }
// ClickHouseDriver provides top level database functions // ClickHouseDriver provides top level database functions
type ClickHouseDriver struct { type ClickHouseDriver struct {
migrationsTableName string
} }
func normalizeClickHouseURL(initialURL *url.URL) *url.URL { func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
@ -52,12 +53,17 @@ func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
return &u return &u
} }
// SetMigrationsTableName sets the schema migrations table name
func (drv *ClickHouseDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection // Open creates a new database connection
func (drv ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) { func (drv *ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("clickhouse", normalizeClickHouseURL(u).String()) return sql.Open("clickhouse", normalizeClickHouseURL(u).String())
} }
func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) { func (drv *ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
// connect to clickhouse database // connect to clickhouse database
clickhouseURL := normalizeClickHouseURL(u) clickhouseURL := normalizeClickHouseURL(u)
values := clickhouseURL.Query() values := clickhouseURL.Query()
@ -67,7 +73,7 @@ func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
return drv.Open(clickhouseURL) return drv.Open(clickhouseURL)
} }
func (drv ClickHouseDriver) databaseName(u *url.URL) string { func (drv *ClickHouseDriver) databaseName(u *url.URL) string {
name := normalizeClickHouseURL(u).Query().Get("database") name := normalizeClickHouseURL(u).Query().Get("database")
if name == "" { if name == "" {
name = "default" name = "default"
@ -77,7 +83,7 @@ func (drv ClickHouseDriver) databaseName(u *url.URL) string {
var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`) var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`)
func clickhouseQuoteIdentifier(str string) string { func (drv *ClickHouseDriver) quoteIdentifier(str string) string {
if clickhouseValidIdentifier.MatchString(str) { if clickhouseValidIdentifier.MatchString(str) {
return str return str
} }
@ -88,7 +94,7 @@ func clickhouseQuoteIdentifier(str string) string {
} }
// CreateDatabase creates the specified database // CreateDatabase creates the specified database
func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error { func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error {
name := drv.databaseName(u) name := drv.databaseName(u)
fmt.Printf("Creating: %s\n", name) fmt.Printf("Creating: %s\n", name)
@ -98,13 +104,13 @@ func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error {
} }
defer mustClose(db) defer mustClose(db)
_, err = db.Exec("create database " + clickhouseQuoteIdentifier(name)) _, err = db.Exec("create database " + drv.quoteIdentifier(name))
return err return err
} }
// DropDatabase drops the specified database (if it exists) // DropDatabase drops the specified database (if it exists)
func (drv ClickHouseDriver) DropDatabase(u *url.URL) error { func (drv *ClickHouseDriver) DropDatabase(u *url.URL) error {
name := drv.databaseName(u) name := drv.databaseName(u)
fmt.Printf("Dropping: %s\n", name) fmt.Printf("Dropping: %s\n", name)
@ -114,15 +120,15 @@ func (drv ClickHouseDriver) DropDatabase(u *url.URL) error {
} }
defer mustClose(db) defer mustClose(db)
_, err = db.Exec("drop database if exists " + clickhouseQuoteIdentifier(name)) _, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name))
return err return err
} }
func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error { func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
buf.WriteString("\n--\n-- Database schema\n--\n\n") buf.WriteString("\n--\n-- Database schema\n--\n\n")
buf.WriteString("CREATE DATABASE " + clickhouseQuoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n") buf.WriteString("CREATE DATABASE " + drv.quoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
tables, err := queryColumn(db, "show tables") tables, err := queryColumn(db, "show tables")
if err != nil { if err != nil {
@ -132,7 +138,7 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er
for _, table := range tables { for _, table := range tables {
var clause string var clause string
err = db.QueryRow("show create table " + clickhouseQuoteIdentifier(table)).Scan(&clause) err = db.QueryRow("show create table " + drv.quoteIdentifier(table)).Scan(&clause)
if err != nil { if err != nil {
return err return err
} }
@ -141,10 +147,13 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er
return nil return nil
} }
func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error { func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations // load applied migrations
migrations, err := queryColumn(db, migrations, err := queryColumn(db,
"select version from schema_migrations final where applied order by version asc", fmt.Sprintf("select version from %s final ", migrationsTable)+
"where applied order by version asc",
) )
if err != nil { if err != nil {
return err return err
@ -155,11 +164,12 @@ func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
migrations[i] = "'" + quoter.Replace(migrations[i]) + "'" migrations[i] = "'" + quoter.Replace(migrations[i]) + "'"
} }
// build schema_migrations table data // build schema migrations table data
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n") buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n")
if len(migrations) > 0 { if len(migrations) > 0 {
buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" + buf.WriteString(
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
strings.Join(migrations, "),\n (") + strings.Join(migrations, "),\n (") +
");\n") ");\n")
} }
@ -168,16 +178,16 @@ func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
} }
// DumpSchema returns the current database schema // DumpSchema returns the current database schema
func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
var buf bytes.Buffer var buf bytes.Buffer
var err error var err error
err = clickhouseSchemaDump(db, &buf, drv.databaseName(u)) err = drv.schemaDump(db, &buf, drv.databaseName(u))
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = clickhouseSchemaMigrationsDump(db, &buf) err = drv.schemaMigrationsDump(db, &buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -186,7 +196,7 @@ func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
} }
// DatabaseExists determines whether the database exists // DatabaseExists determines whether the database exists
func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) { func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
name := drv.databaseName(u) name := drv.databaseName(u)
db, err := drv.openClickHouseDB(u) db, err := drv.openClickHouseDB(u)
@ -205,24 +215,27 @@ func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
return exists, err return exists, err
} }
// CreateMigrationsTable creates the schema_migrations table // CreateMigrationsTable creates the schema migrations table
func (drv ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec(` _, err := db.Exec(fmt.Sprintf(`
create table if not exists schema_migrations ( create table if not exists %s (
version String, version String,
ts DateTime default now(), ts DateTime default now(),
applied UInt8 default 1 applied UInt8 default 1
) engine = ReplacingMergeTree(ts) ) engine = ReplacingMergeTree(ts)
primary key version primary key version
order by version order by version
`) `, drv.quotedMigrationsTableName()))
return err return err
} }
// SelectMigrations returns a list of applied migrations // SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order) // with an optional limit (in descending order)
func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations final where applied order by version desc" query := fmt.Sprintf("select version from %s final where applied order by version desc",
drv.quotedMigrationsTableName())
if limit >= 0 { if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit) query = fmt.Sprintf("%s limit %d", query, limit)
} }
@ -251,15 +264,19 @@ func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]
} }
// InsertMigration adds a new migration record // InsertMigration adds a new migration record
func (drv ClickHouseDriver) InsertMigration(db Transaction, version string) error { func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values (?)", version) _, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
return err return err
} }
// DeleteMigration removes a migration record // DeleteMigration removes a migration record
func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) error { func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec( _, err := db.Exec(
"insert into schema_migrations (version, applied) values (?, ?)", fmt.Sprintf("insert into %s (version, applied) values (?, ?)",
drv.quotedMigrationsTableName()),
version, false, version, false,
) )
@ -268,7 +285,7 @@ func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) erro
// Ping verifies a connection to the database server. It does not verify whether the // Ping verifies a connection to the database server. It does not verify whether the
// specified database exists. // specified database exists.
func (drv ClickHouseDriver) Ping(u *url.URL) error { func (drv *ClickHouseDriver) Ping(u *url.URL) error {
// attempt connection to primary database, not "clickhouse" database // attempt connection to primary database, not "clickhouse" database
// to support servers with no "clickhouse" database // to support servers with no "clickhouse" database
// (see https://github.com/amacneil/dbmate/issues/78) // (see https://github.com/amacneil/dbmate/issues/78)
@ -291,3 +308,7 @@ func (drv ClickHouseDriver) Ping(u *url.URL) error {
return err return err
} }
func (drv *ClickHouseDriver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}

View file

@ -15,8 +15,15 @@ func clickhouseTestURL(t *testing.T) *url.URL {
return u return u
} }
func testClickHouseDriver() *ClickHouseDriver {
drv := &ClickHouseDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestClickHouseDB(t *testing.T, u *url.URL) *sql.DB { func prepTestClickHouseDB(t *testing.T, u *url.URL) *sql.DB {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
// drop any existing database // drop any existing database
err := drv.DropDatabase(u) err := drv.DropDatabase(u)
@ -50,7 +57,7 @@ func TestNormalizeClickHouseURLCanonical(t *testing.T) {
} }
func TestClickHouseCreateDropDatabase(t *testing.T) { func TestClickHouseCreateDropDatabase(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
// drop any existing database // drop any existing database
@ -87,7 +94,9 @@ func TestClickHouseCreateDropDatabase(t *testing.T) {
} }
func TestClickHouseDumpSchema(t *testing.T) { func TestClickHouseDumpSchema(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
// prepare database // prepare database
@ -113,11 +122,11 @@ func TestClickHouseDumpSchema(t *testing.T) {
// DumpSchema should return schema // DumpSchema should return schema
schema, err := drv.DumpSchema(u, db) schema, err := drv.DumpSchema(u, db)
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".schema_migrations") require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".test_migrations")
require.Contains(t, string(schema), "--\n"+ require.Contains(t, string(schema), "--\n"+
"-- Dbmate schema migrations\n"+ "-- Dbmate schema migrations\n"+
"--\n\n"+ "--\n\n"+
"INSERT INTO schema_migrations (version) VALUES\n"+ "INSERT INTO test_migrations (version) VALUES\n"+
" ('abc1'),\n"+ " ('abc1'),\n"+
" ('abc2');\n") " ('abc2');\n")
@ -134,7 +143,7 @@ func TestClickHouseDumpSchema(t *testing.T) {
} }
func TestClickHouseDatabaseExists(t *testing.T) { func TestClickHouseDatabaseExists(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
// drop any existing database // drop any existing database
@ -157,7 +166,7 @@ func TestClickHouseDatabaseExists(t *testing.T) {
} }
func TestClickHouseDatabaseExists_Error(t *testing.T) { func TestClickHouseDatabaseExists_Error(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
values := u.Query() values := u.Query()
values.Set("username", "invalid") values.Set("username", "invalid")
@ -169,7 +178,8 @@ func TestClickHouseDatabaseExists_Error(t *testing.T) {
} }
func TestClickHouseCreateMigrationsTable(t *testing.T) { func TestClickHouseCreateMigrationsTable(t *testing.T) {
drv := ClickHouseDriver{} t.Run("default table", func(t *testing.T) {
drv := testClickHouseDriver()
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u) db := prepTestClickHouseDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -190,10 +200,39 @@ func TestClickHouseCreateMigrationsTable(t *testing.T) {
// create table should be idempotent // create table should be idempotent
err = drv.CreateMigrationsTable(u, db) err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
})
t.Run("custom table", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("testMigrations")
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
require.EqualError(t, err, "code: 60, message: Table dbmate.testMigrations doesn't exist.")
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
} }
func TestClickHouseSelectMigrations(t *testing.T) { func TestClickHouseSelectMigrations(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u) db := prepTestClickHouseDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -203,7 +242,7 @@ func TestClickHouseSelectMigrations(t *testing.T) {
tx, err := db.Begin() tx, err := db.Begin()
require.NoError(t, err) require.NoError(t, err)
stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)") stmt, err := tx.Prepare("insert into test_migrations (version) values (?)")
require.NoError(t, err) require.NoError(t, err)
_, err = stmt.Exec("abc2") _, err = stmt.Exec("abc2")
require.NoError(t, err) require.NoError(t, err)
@ -229,7 +268,9 @@ func TestClickHouseSelectMigrations(t *testing.T) {
} }
func TestClickHouseInsertMigration(t *testing.T) { func TestClickHouseInsertMigration(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u) db := prepTestClickHouseDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -238,7 +279,7 @@ func TestClickHouseInsertMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, count) require.Equal(t, 0, count)
@ -250,13 +291,15 @@ func TestClickHouseInsertMigration(t *testing.T) {
err = tx.Commit() err = tx.Commit()
require.NoError(t, err) require.NoError(t, err)
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").Scan(&count) err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestClickHouseDeleteMigration(t *testing.T) { func TestClickHouseDeleteMigration(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u) db := prepTestClickHouseDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -266,7 +309,7 @@ func TestClickHouseDeleteMigration(t *testing.T) {
tx, err := db.Begin() tx, err := db.Begin()
require.NoError(t, err) require.NoError(t, err)
stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)") stmt, err := tx.Prepare("insert into test_migrations (version) values (?)")
require.NoError(t, err) require.NoError(t, err)
_, err = stmt.Exec("abc2") _, err = stmt.Exec("abc2")
require.NoError(t, err) require.NoError(t, err)
@ -283,13 +326,13 @@ func TestClickHouseDeleteMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from schema_migrations final where applied").Scan(&count) err = db.QueryRow("select count(*) from test_migrations final where applied").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestClickHousePing(t *testing.T) { func TestClickHousePing(t *testing.T) {
drv := ClickHouseDriver{} drv := testClickHouseDriver()
u := clickhouseTestURL(t) u := clickhouseTestURL(t)
// drop any existing database // drop any existing database
@ -306,3 +349,27 @@ func TestClickHousePing(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused") require.Contains(t, err.Error(), "connect: connection refused")
} }
func TestClickHouseQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testClickHouseDriver()
name := drv.quotedMigrationsTableName()
require.Equal(t, "schema_migrations", name)
})
t.Run("custom name", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("fooMigrations")
name := drv.quotedMigrationsTableName()
require.Equal(t, "fooMigrations", name)
})
t.Run("quoted name", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("bizarre\"$name")
name := drv.quotedMigrationsTableName()
require.Equal(t, `"bizarre""$name"`, name)
})
}

View file

@ -15,6 +15,9 @@ import (
// DefaultMigrationsDir specifies default directory to find migration files // DefaultMigrationsDir specifies default directory to find migration files
const DefaultMigrationsDir = "./db/migrations" const DefaultMigrationsDir = "./db/migrations"
// DefaultMigrationsTableName specifies default database tables to record migraitons in
const DefaultMigrationsTableName = "schema_migrations"
// DefaultSchemaFile specifies default location for schema.sql // DefaultSchemaFile specifies default location for schema.sql
const DefaultSchemaFile = "./db/schema.sql" const DefaultSchemaFile = "./db/schema.sql"
@ -29,6 +32,7 @@ type DB struct {
AutoDumpSchema bool AutoDumpSchema bool
DatabaseURL *url.URL DatabaseURL *url.URL
MigrationsDir string MigrationsDir string
MigrationsTableName string
SchemaFile string SchemaFile string
Verbose bool Verbose bool
WaitBefore bool WaitBefore bool
@ -50,6 +54,7 @@ func New(databaseURL *url.URL) *DB {
AutoDumpSchema: true, AutoDumpSchema: true,
DatabaseURL: databaseURL, DatabaseURL: databaseURL,
MigrationsDir: DefaultMigrationsDir, MigrationsDir: DefaultMigrationsDir,
MigrationsTableName: DefaultMigrationsTableName,
SchemaFile: DefaultSchemaFile, SchemaFile: DefaultSchemaFile,
WaitBefore: false, WaitBefore: false,
WaitInterval: DefaultWaitInterval, WaitInterval: DefaultWaitInterval,
@ -59,7 +64,14 @@ func New(databaseURL *url.URL) *DB {
// GetDriver loads the required database driver // GetDriver loads the required database driver
func (db *DB) GetDriver() (Driver, error) { func (db *DB) GetDriver() (Driver, error) {
return GetDriver(db.DatabaseURL.Scheme) drv, err := getDriver(db.DatabaseURL.Scheme)
if err != nil {
return nil, err
}
drv.SetMigrationsTableName(db.MigrationsTableName)
return drv, err
} }
// Wait blocks until the database server is available. It does not verify that // Wait blocks until the database server is available. It does not verify that

View file

@ -38,12 +38,26 @@ func TestNew(t *testing.T) {
require.True(t, db.AutoDumpSchema) require.True(t, db.AutoDumpSchema)
require.Equal(t, u.String(), db.DatabaseURL.String()) require.Equal(t, u.String(), db.DatabaseURL.String())
require.Equal(t, "./db/migrations", db.MigrationsDir) require.Equal(t, "./db/migrations", db.MigrationsDir)
require.Equal(t, "schema_migrations", db.MigrationsTableName)
require.Equal(t, "./db/schema.sql", db.SchemaFile) require.Equal(t, "./db/schema.sql", db.SchemaFile)
require.False(t, db.WaitBefore) require.False(t, db.WaitBefore)
require.Equal(t, time.Second, db.WaitInterval) require.Equal(t, time.Second, db.WaitInterval)
require.Equal(t, 60*time.Second, db.WaitTimeout) require.Equal(t, 60*time.Second, db.WaitTimeout)
} }
func TestGetDriver(t *testing.T) {
u := postgresTestURL(t)
db := New(u)
drv, err := db.GetDriver()
require.NoError(t, err)
// driver should have default migrations table set
pgDrv, ok := drv.(*PostgresDriver)
require.True(t, ok)
require.Equal(t, "schema_migrations", pgDrv.migrationsTableName)
}
func TestWait(t *testing.T) { func TestWait(t *testing.T) {
u := postgresTestURL(t) u := postgresTestURL(t)
db := newTestDB(t, u) db := newTestDB(t, u)
@ -242,7 +256,7 @@ func testMigrateURL(t *testing.T, u *url.URL) {
require.NoError(t, err) require.NoError(t, err)
// verify results // verify results
sqlDB, err := GetDriverOpen(u) sqlDB, err := getDriverOpen(u)
require.NoError(t, err) require.NoError(t, err)
defer mustClose(sqlDB) defer mustClose(sqlDB)
@ -275,7 +289,7 @@ func testUpURL(t *testing.T, u *url.URL) {
require.NoError(t, err) require.NoError(t, err)
// verify results // verify results
sqlDB, err := GetDriverOpen(u) sqlDB, err := getDriverOpen(u)
require.NoError(t, err) require.NoError(t, err)
defer mustClose(sqlDB) defer mustClose(sqlDB)
@ -308,7 +322,7 @@ func testRollbackURL(t *testing.T, u *url.URL) {
require.NoError(t, err) require.NoError(t, err)
// verify migration // verify migration
sqlDB, err := GetDriverOpen(u) sqlDB, err := getDriverOpen(u)
require.NoError(t, err) require.NoError(t, err)
defer mustClose(sqlDB) defer mustClose(sqlDB)
@ -351,7 +365,7 @@ func testStatusURL(t *testing.T, u *url.URL) {
require.NoError(t, err) require.NoError(t, err)
// verify migration // verify migration
sqlDB, err := GetDriverOpen(u) sqlDB, err := getDriverOpen(u)
require.NoError(t, err) require.NoError(t, err)
defer mustClose(sqlDB) defer mustClose(sqlDB)

View file

@ -13,6 +13,7 @@ type Driver interface {
CreateDatabase(*url.URL) error CreateDatabase(*url.URL) error
DropDatabase(*url.URL) error DropDatabase(*url.URL) error
DumpSchema(*url.URL, *sql.DB) ([]byte, error) DumpSchema(*url.URL, *sql.DB) ([]byte, error)
SetMigrationsTableName(string)
CreateMigrationsTable(*url.URL, *sql.DB) error CreateMigrationsTable(*url.URL, *sql.DB) error
SelectMigrations(*sql.DB, int) (map[string]bool, error) SelectMigrations(*sql.DB, int) (map[string]bool, error)
InsertMigration(Transaction, string) error InsertMigration(Transaction, string) error
@ -34,18 +35,20 @@ type Transaction interface {
QueryRow(query string, args ...interface{}) *sql.Row QueryRow(query string, args ...interface{}) *sql.Row
} }
// GetDriver loads a database driver by name // getDriver loads a database driver by name
func GetDriver(name string) (Driver, error) { func getDriver(name string) (Driver, error) {
if val, ok := drivers[name]; ok { if drv, ok := drivers[name]; ok {
return val, nil drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv, nil
} }
return nil, fmt.Errorf("unsupported driver: %s", name) return nil, fmt.Errorf("unsupported driver: %s", name)
} }
// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u) // getDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
func GetDriverOpen(u *url.URL) (*sql.DB, error) { func getDriverOpen(u *url.URL) (*sql.DB, error) {
drv, err := GetDriver(u.Scheme) drv, err := getDriver(u.Scheme)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -7,21 +7,21 @@ import (
) )
func TestGetDriver_Postgres(t *testing.T) { func TestGetDriver_Postgres(t *testing.T) {
drv, err := GetDriver("postgres") drv, err := getDriver("postgres")
require.NoError(t, err) require.NoError(t, err)
_, ok := drv.(PostgresDriver) _, ok := drv.(*PostgresDriver)
require.Equal(t, true, ok) require.Equal(t, true, ok)
} }
func TestGetDriver_MySQL(t *testing.T) { func TestGetDriver_MySQL(t *testing.T) {
drv, err := GetDriver("mysql") drv, err := getDriver("mysql")
require.NoError(t, err) require.NoError(t, err)
_, ok := drv.(MySQLDriver) _, ok := drv.(*MySQLDriver)
require.Equal(t, true, ok) require.Equal(t, true, ok)
} }
func TestGetDriver_Error(t *testing.T) { func TestGetDriver_Error(t *testing.T) {
drv, err := GetDriver("foo") drv, err := getDriver("foo")
require.EqualError(t, err, "unsupported driver: foo") require.EqualError(t, err, "unsupported driver: foo")
require.Nil(t, drv) require.Nil(t, drv)
} }

View file

@ -11,11 +11,12 @@ import (
) )
func init() { func init() {
RegisterDriver(MySQLDriver{}, "mysql") RegisterDriver(&MySQLDriver{}, "mysql")
} }
// MySQLDriver provides top level database functions // MySQLDriver provides top level database functions
type MySQLDriver struct { type MySQLDriver struct {
migrationsTableName string
} }
func normalizeMySQLURL(u *url.URL) string { func normalizeMySQLURL(u *url.URL) string {
@ -52,12 +53,17 @@ func normalizeMySQLURL(u *url.URL) string {
return normalizedString return normalizedString
} }
// SetMigrationsTableName sets the schema migrations table name
func (drv *MySQLDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection // Open creates a new database connection
func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) { func (drv *MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("mysql", normalizeMySQLURL(u)) return sql.Open("mysql", normalizeMySQLURL(u))
} }
func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) { func (drv *MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
// connect to no particular database // connect to no particular database
rootURL := *u rootURL := *u
rootURL.Path = "/" rootURL.Path = "/"
@ -65,14 +71,14 @@ func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
return drv.Open(&rootURL) return drv.Open(&rootURL)
} }
func mysqlQuoteIdentifier(str string) string { func (drv *MySQLDriver) quoteIdentifier(str string) string {
str = strings.Replace(str, "`", "\\`", -1) str = strings.Replace(str, "`", "\\`", -1)
return fmt.Sprintf("`%s`", str) return fmt.Sprintf("`%s`", str)
} }
// CreateDatabase creates the specified database // CreateDatabase creates the specified database
func (drv MySQLDriver) CreateDatabase(u *url.URL) error { func (drv *MySQLDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u) name := databaseName(u)
fmt.Printf("Creating: %s\n", name) fmt.Printf("Creating: %s\n", name)
@ -83,13 +89,13 @@ func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
defer mustClose(db) defer mustClose(db)
_, err = db.Exec(fmt.Sprintf("create database %s", _, err = db.Exec(fmt.Sprintf("create database %s",
mysqlQuoteIdentifier(name))) drv.quoteIdentifier(name)))
return err return err
} }
// DropDatabase drops the specified database (if it exists) // DropDatabase drops the specified database (if it exists)
func (drv MySQLDriver) DropDatabase(u *url.URL) error { func (drv *MySQLDriver) DropDatabase(u *url.URL) error {
name := databaseName(u) name := databaseName(u)
fmt.Printf("Dropping: %s\n", name) fmt.Printf("Dropping: %s\n", name)
@ -100,12 +106,12 @@ func (drv MySQLDriver) DropDatabase(u *url.URL) error {
defer mustClose(db) defer mustClose(db)
_, err = db.Exec(fmt.Sprintf("drop database if exists %s", _, err = db.Exec(fmt.Sprintf("drop database if exists %s",
mysqlQuoteIdentifier(name))) drv.quoteIdentifier(name)))
return err return err
} }
func mysqldumpArgs(u *url.URL) []string { func (drv *MySQLDriver) mysqldumpArgs(u *url.URL) []string {
// generate CLI arguments // generate CLI arguments
args := []string{"--opt", "--routines", "--no-data", args := []string{"--opt", "--routines", "--no-data",
"--skip-dump-date", "--skip-add-drop-table"} "--skip-dump-date", "--skip-add-drop-table"}
@ -131,10 +137,12 @@ func mysqldumpArgs(u *url.URL) []string {
return args return args
} }
func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) { func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations // load applied migrations
migrations, err := queryColumn(db, migrations, err := queryColumn(db,
"select quote(version) from schema_migrations order by version asc") fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -142,10 +150,11 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
// build schema_migrations table data // build schema_migrations table data
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" + buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" +
"LOCK TABLES `schema_migrations` WRITE;\n") fmt.Sprintf("LOCK TABLES %s WRITE;\n", migrationsTable))
if len(migrations) > 0 { if len(migrations) > 0 {
buf.WriteString("INSERT INTO `schema_migrations` (version) VALUES\n (" + buf.WriteString(
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
strings.Join(migrations, "),\n (") + strings.Join(migrations, "),\n (") +
");\n") ");\n")
} }
@ -156,13 +165,13 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
} }
// DumpSchema returns the current database schema // DumpSchema returns the current database schema
func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
schema, err := runCommand("mysqldump", mysqldumpArgs(u)...) schema, err := runCommand("mysqldump", drv.mysqldumpArgs(u)...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
migrations, err := mysqlSchemaMigrationsDump(db) migrations, err := drv.schemaMigrationsDump(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -172,7 +181,7 @@ func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
} }
// DatabaseExists determines whether the database exists // DatabaseExists determines whether the database exists
func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) { func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u) name := databaseName(u)
db, err := drv.openRootDB(u) db, err := drv.openRootDB(u)
@ -192,17 +201,18 @@ func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
} }
// CreateMigrationsTable creates the schema_migrations table // CreateMigrationsTable creates the schema_migrations table
func (drv MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec("create table if not exists schema_migrations " + _, err := db.Exec(fmt.Sprintf("create table if not exists %s "+
"(version varchar(255) primary key) character set latin1 collate latin1_bin") "(version varchar(255) primary key) character set latin1 collate latin1_bin",
drv.quotedMigrationsTableName()))
return err return err
} }
// SelectMigrations returns a list of applied migrations // SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order) // with an optional limit (in descending order)
func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations order by version desc" query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
if limit >= 0 { if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit) query = fmt.Sprintf("%s limit %d", query, limit)
} }
@ -231,22 +241,26 @@ func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool,
} }
// InsertMigration adds a new migration record // InsertMigration adds a new migration record
func (drv MySQLDriver) InsertMigration(db Transaction, version string) error { func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values (?)", version) _, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
return err return err
} }
// DeleteMigration removes a migration record // DeleteMigration removes a migration record
func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error { func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec("delete from schema_migrations where version = ?", version) _, err := db.Exec(
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
version)
return err return err
} }
// Ping verifies a connection to the database server. It does not verify whether the // Ping verifies a connection to the database server. It does not verify whether the
// specified database exists. // specified database exists.
func (drv MySQLDriver) Ping(u *url.URL) error { func (drv *MySQLDriver) Ping(u *url.URL) error {
db, err := drv.openRootDB(u) db, err := drv.openRootDB(u)
if err != nil { if err != nil {
return err return err
@ -255,3 +269,7 @@ func (drv MySQLDriver) Ping(u *url.URL) error {
return db.Ping() return db.Ping()
} }
func (drv *MySQLDriver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}

View file

@ -15,8 +15,15 @@ func mySQLTestURL(t *testing.T) *url.URL {
return u return u
} }
func testMySQLDriver() *MySQLDriver {
drv := &MySQLDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestMySQLDB(t *testing.T, u *url.URL) *sql.DB { func prepTestMySQLDB(t *testing.T, u *url.URL) *sql.DB {
drv := MySQLDriver{} drv := testMySQLDriver()
// drop any existing database // drop any existing database
err := drv.DropDatabase(u) err := drv.DropDatabase(u)
@ -78,7 +85,7 @@ func TestNormalizeMySQLURLSocket(t *testing.T) {
} }
func TestMySQLCreateDropDatabase(t *testing.T) { func TestMySQLCreateDropDatabase(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
u := mySQLTestURL(t) u := mySQLTestURL(t)
// drop any existing database // drop any existing database
@ -116,7 +123,9 @@ func TestMySQLCreateDropDatabase(t *testing.T) {
} }
func TestMySQLDumpSchema(t *testing.T) { func TestMySQLDumpSchema(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t) u := mySQLTestURL(t)
// prepare database // prepare database
@ -134,13 +143,13 @@ func TestMySQLDumpSchema(t *testing.T) {
// DumpSchema should return schema // DumpSchema should return schema
schema, err := drv.DumpSchema(u, db) schema, err := drv.DumpSchema(u, db)
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE `schema_migrations`") require.Contains(t, string(schema), "CREATE TABLE `test_migrations`")
require.Contains(t, string(schema), "\n-- Dump completed\n\n"+ require.Contains(t, string(schema), "\n-- Dump completed\n\n"+
"--\n"+ "--\n"+
"-- Dbmate schema migrations\n"+ "-- Dbmate schema migrations\n"+
"--\n\n"+ "--\n\n"+
"LOCK TABLES `schema_migrations` WRITE;\n"+ "LOCK TABLES `test_migrations` WRITE;\n"+
"INSERT INTO `schema_migrations` (version) VALUES\n"+ "INSERT INTO `test_migrations` (version) VALUES\n"+
" ('abc1'),\n"+ " ('abc1'),\n"+
" ('abc2');\n"+ " ('abc2');\n"+
"UNLOCK TABLES;\n") "UNLOCK TABLES;\n")
@ -156,7 +165,7 @@ func TestMySQLDumpSchema(t *testing.T) {
} }
func TestMySQLDatabaseExists(t *testing.T) { func TestMySQLDatabaseExists(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
u := mySQLTestURL(t) u := mySQLTestURL(t)
// drop any existing database // drop any existing database
@ -179,7 +188,7 @@ func TestMySQLDatabaseExists(t *testing.T) {
} }
func TestMySQLDatabaseExists_Error(t *testing.T) { func TestMySQLDatabaseExists_Error(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
u := mySQLTestURL(t) u := mySQLTestURL(t)
u.User = url.User("invalid") u.User = url.User("invalid")
@ -189,22 +198,25 @@ func TestMySQLDatabaseExists_Error(t *testing.T) {
} }
func TestMySQLCreateMigrationsTable(t *testing.T) { func TestMySQLCreateMigrationsTable(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t) u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u) db := prepTestMySQLDB(t, u)
defer mustClose(db) defer mustClose(db)
// migrations table should not exist // migrations table should not exist
count := 0 count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.Regexp(t, "Table 'dbmate.schema_migrations' doesn't exist", err.Error()) require.Error(t, err)
require.Regexp(t, "Table 'dbmate.test_migrations' doesn't exist", err.Error())
// create table // create table
err = drv.CreateMigrationsTable(u, db) err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
// migrations table should exist // migrations table should exist
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
// create table should be idempotent // create table should be idempotent
@ -213,7 +225,9 @@ func TestMySQLCreateMigrationsTable(t *testing.T) {
} }
func TestMySQLSelectMigrations(t *testing.T) { func TestMySQLSelectMigrations(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t) u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u) db := prepTestMySQLDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -221,7 +235,7 @@ func TestMySQLSelectMigrations(t *testing.T) {
err := drv.CreateMigrationsTable(u, db) err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version) _, err = db.Exec(`insert into test_migrations (version)
values ('abc2'), ('abc1'), ('abc3')`) values ('abc2'), ('abc1'), ('abc3')`)
require.NoError(t, err) require.NoError(t, err)
@ -240,7 +254,9 @@ func TestMySQLSelectMigrations(t *testing.T) {
} }
func TestMySQLInsertMigration(t *testing.T) { func TestMySQLInsertMigration(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t) u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u) db := prepTestMySQLDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -249,7 +265,7 @@ func TestMySQLInsertMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, count) require.Equal(t, 0, count)
@ -257,14 +273,16 @@ func TestMySQLInsertMigration(t *testing.T) {
err = drv.InsertMigration(db, "abc1") err = drv.InsertMigration(db, "abc1")
require.NoError(t, err) require.NoError(t, err)
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'"). err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").
Scan(&count) Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestMySQLDeleteMigration(t *testing.T) { func TestMySQLDeleteMigration(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t) u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u) db := prepTestMySQLDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -272,7 +290,7 @@ func TestMySQLDeleteMigration(t *testing.T) {
err := drv.CreateMigrationsTable(u, db) err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version) _, err = db.Exec(`insert into test_migrations (version)
values ('abc1'), ('abc2')`) values ('abc1'), ('abc2')`)
require.NoError(t, err) require.NoError(t, err)
@ -280,13 +298,13 @@ func TestMySQLDeleteMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestMySQLPing(t *testing.T) { func TestMySQLPing(t *testing.T) {
drv := MySQLDriver{} drv := testMySQLDriver()
u := mySQLTestURL(t) u := mySQLTestURL(t)
// drop any existing database // drop any existing database
@ -303,3 +321,19 @@ func TestMySQLPing(t *testing.T) {
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused") require.Contains(t, err.Error(), "connect: connection refused")
} }
func TestMySQLQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testMySQLDriver()
name := drv.quotedMigrationsTableName()
require.Equal(t, "`schema_migrations`", name)
})
t.Run("custom name", func(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("fooMigrations")
name := drv.quotedMigrationsTableName()
require.Equal(t, "`fooMigrations`", name)
})
}

View file

@ -11,12 +11,14 @@ import (
) )
func init() { func init() {
RegisterDriver(PostgresDriver{}, "postgres") drv := &PostgresDriver{}
RegisterDriver(PostgresDriver{}, "postgresql") RegisterDriver(drv, "postgres")
RegisterDriver(drv, "postgresql")
} }
// PostgresDriver provides top level database functions // PostgresDriver provides top level database functions
type PostgresDriver struct { type PostgresDriver struct {
migrationsTableName string
} }
func normalizePostgresURL(u *url.URL) *url.URL { func normalizePostgresURL(u *url.URL) *url.URL {
@ -78,12 +80,17 @@ func normalizePostgresURLForDump(u *url.URL) []string {
return out return out
} }
// SetMigrationsTableName sets the schema migrations table name
func (drv *PostgresDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection // Open creates a new database connection
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) { func (drv *PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("postgres", normalizePostgresURL(u).String()) return sql.Open("postgres", normalizePostgresURL(u).String())
} }
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) { func (drv *PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
// connect to postgres database // connect to postgres database
postgresURL := *u postgresURL := *u
postgresURL.Path = "postgres" postgresURL.Path = "postgres"
@ -92,7 +99,7 @@ func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
} }
// CreateDatabase creates the specified database // CreateDatabase creates the specified database
func (drv PostgresDriver) CreateDatabase(u *url.URL) error { func (drv *PostgresDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u) name := databaseName(u)
fmt.Printf("Creating: %s\n", name) fmt.Printf("Creating: %s\n", name)
@ -109,7 +116,7 @@ func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
} }
// DropDatabase drops the specified database (if it exists) // DropDatabase drops the specified database (if it exists)
func (drv PostgresDriver) DropDatabase(u *url.URL) error { func (drv *PostgresDriver) DropDatabase(u *url.URL) error {
name := databaseName(u) name := databaseName(u)
fmt.Printf("Dropping: %s\n", name) fmt.Printf("Dropping: %s\n", name)
@ -125,8 +132,8 @@ func (drv PostgresDriver) DropDatabase(u *url.URL) error {
return err return err
} }
func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) { func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable, err := drv.migrationsTableName(db) migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -152,7 +159,7 @@ func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, erro
} }
// DumpSchema returns the current database schema // DumpSchema returns the current database schema
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
// load schema // load schema
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only", args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...) "--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
@ -161,7 +168,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
return nil, err return nil, err
} }
migrations, err := drv.postgresSchemaMigrationsDump(db) migrations, err := drv.schemaMigrationsDump(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -171,7 +178,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
} }
// DatabaseExists determines whether the database exists // DatabaseExists determines whether the database exists
func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) { func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u) name := databaseName(u)
db, err := drv.openPostgresDB(u) db, err := drv.openPostgresDB(u)
@ -191,48 +198,45 @@ func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
} }
// CreateMigrationsTable creates the schema_migrations table // CreateMigrationsTable creates the schema_migrations table
func (drv PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
// get schema from URL search_path param schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db, u)
searchPath := strings.Split(u.Query().Get("search_path"), ",")
urlSchema := strings.TrimSpace(searchPath[0])
if urlSchema == "" {
urlSchema = "public"
}
// get *unquoted* current schema from database
dbSchema, err := queryRow(db, "select current_schema()")
if err != nil { if err != nil {
return err return err
} }
// if urlSchema and dbSchema are not equal, the most likely explanation is that the schema // first attempt at creating migrations table
// has not yet been created createTableStmt := fmt.Sprintf("create table if not exists %s.%s", schema, migrationsTable) +
if urlSchema != dbSchema { " (version varchar(255) primary key)"
// in theory we could just execute this statement every time, but we do the comparison _, err = db.Exec(createTableStmt)
// above in case the user doesn't have permissions to create schemas and the schema if err == nil {
// already exists // table exists or created successfully
fmt.Printf("Creating schema: %s\n", urlSchema) return nil
_, err = db.Exec("create schema if not exists " + pq.QuoteIdentifier(urlSchema))
if err != nil {
return err
}
} }
migrationsTable, err := drv.migrationsTableName(db) // catch 'schema does not exist' error
pqErr, ok := err.(*pq.Error)
if !ok || pqErr.Code != "3F000" {
// unknown error
return err
}
// in theory we could attempt to create the schema every time, but we avoid that
// in case the user doesn't have permissions to create schemas
fmt.Printf("Creating schema: %s\n", schema)
_, err = db.Exec(fmt.Sprintf("create schema if not exists %s", schema))
if err != nil { if err != nil {
return err return err
} }
_, err = db.Exec("create table if not exists " + migrationsTable + // second and final attempt at creating migrations table
" (version varchar(255) primary key)") _, err = db.Exec(createTableStmt)
return err return err
} }
// SelectMigrations returns a list of applied migrations // SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order) // with an optional limit (in descending order)
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
migrationsTable, err := drv.migrationsTableName(db) migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -266,8 +270,8 @@ func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bo
} }
// InsertMigration adds a new migration record // InsertMigration adds a new migration record
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error { func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error {
migrationsTable, err := drv.migrationsTableName(db) migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil { if err != nil {
return err return err
} }
@ -278,8 +282,8 @@ func (drv PostgresDriver) InsertMigration(db Transaction, version string) error
} }
// DeleteMigration removes a migration record // DeleteMigration removes a migration record
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error { func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error {
migrationsTable, err := drv.migrationsTableName(db) migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil { if err != nil {
return err return err
} }
@ -291,7 +295,7 @@ func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error
// Ping verifies a connection to the database server. It does not verify whether the // Ping verifies a connection to the database server. It does not verify whether the
// specified database exists. // specified database exists.
func (drv PostgresDriver) Ping(u *url.URL) error { func (drv *PostgresDriver) Ping(u *url.URL) error {
// attempt connection to primary database, not "postgres" database // attempt connection to primary database, not "postgres" database
// to support servers with no "postgres" database // to support servers with no "postgres" database
// (see https://github.com/amacneil/dbmate/issues/78) // (see https://github.com/amacneil/dbmate/issues/78)
@ -306,7 +310,7 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
return nil return nil
} }
// ignore 'database "foo" does not exist' error // ignore 'database does not exist' error
pqErr, ok := err.(*pq.Error) pqErr, ok := err.(*pq.Error)
if ok && pqErr.Code == "3D000" { if ok && pqErr.Code == "3D000" {
return nil return nil
@ -315,17 +319,53 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
return err return err
} }
func (drv PostgresDriver) migrationsTableName(db Transaction) (string, error) { func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, error) {
// get current schema schema, name, err := drv.quotedMigrationsTableNameParts(db, nil)
schema, err := queryRow(db, "select quote_ident(current_schema())")
if err != nil { if err != nil {
return "", err return "", err
} }
// if the search path is empty, or does not contain a valid schema, default to public return schema + "." + name, nil
}
func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url.URL) (string, string, error) {
schema := ""
tableNameParts := strings.Split(drv.migrationsTableName, ".")
if len(tableNameParts) > 1 {
// schema specified as part of table name
schema, tableNameParts = tableNameParts[0], tableNameParts[1:]
}
if schema == "" && u != nil {
// no schema specified with table name, try URL search path if available
searchPath := strings.Split(u.Query().Get("search_path"), ",")
schema = strings.TrimSpace(searchPath[0])
}
var err error
if schema == "" {
// if no URL available, use current schema
// this is a hack because we don't always have the URL context available
schema, err = queryValue(db, "select current_schema()")
if err != nil {
return "", "", err
}
}
// fall back to public schema as last resort
if schema == "" { if schema == "" {
schema = "public" schema = "public"
} }
return schema + ".schema_migrations", nil // quote all parts
// use server rather than client to do this to avoid unnecessary quotes
// (which would change schema.sql diff)
tableNameParts = append([]string{schema}, tableNameParts...)
quotedNameParts, err := queryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
if err != nil {
return "", "", err
}
// if more than one part, we already have a schema
return quotedNameParts[0], strings.Join(quotedNameParts[1:], "."), nil
} }

View file

@ -15,8 +15,15 @@ func postgresTestURL(t *testing.T) *url.URL {
return u return u
} }
func testPostgresDriver() *PostgresDriver {
drv := &PostgresDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB { func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB {
drv := PostgresDriver{} drv := testPostgresDriver()
// drop any existing database // drop any existing database
err := drv.DropDatabase(u) err := drv.DropDatabase(u)
@ -87,7 +94,7 @@ func TestNormalizePostgresURLForDump(t *testing.T) {
} }
func TestPostgresCreateDropDatabase(t *testing.T) { func TestPostgresCreateDropDatabase(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
u := postgresTestURL(t) u := postgresTestURL(t)
// drop any existing database // drop any existing database
@ -125,7 +132,8 @@ func TestPostgresCreateDropDatabase(t *testing.T) {
} }
func TestPostgresDumpSchema(t *testing.T) { func TestPostgresDumpSchema(t *testing.T) {
drv := PostgresDriver{} t.Run("default migrations table", func(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t) u := postgresTestURL(t)
// prepare database // prepare database
@ -160,10 +168,44 @@ func TestPostgresDumpSchema(t *testing.T) {
require.Nil(t, schema) require.Nil(t, schema)
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+ require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist") "\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
})
t.Run("custom migrations table with schema", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("camelSchema.testMigrations")
u := postgresTestURL(t)
// prepare database
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// insert migration
err = drv.InsertMigration(db, "abc1")
require.NoError(t, err)
err = drv.InsertMigration(db, "abc2")
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE \"camelSchema\".\"testMigrations\"")
require.Contains(t, string(schema), "\n--\n"+
"-- PostgreSQL database dump complete\n"+
"--\n\n\n"+
"--\n"+
"-- Dbmate schema migrations\n"+
"--\n\n"+
"INSERT INTO \"camelSchema\".\"testMigrations\" (version) VALUES\n"+
" ('abc1'),\n"+
" ('abc2');\n")
})
} }
func TestPostgresDatabaseExists(t *testing.T) { func TestPostgresDatabaseExists(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
u := postgresTestURL(t) u := postgresTestURL(t)
// drop any existing database // drop any existing database
@ -186,7 +228,7 @@ func TestPostgresDatabaseExists(t *testing.T) {
} }
func TestPostgresDatabaseExists_Error(t *testing.T) { func TestPostgresDatabaseExists_Error(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
u := postgresTestURL(t) u := postgresTestURL(t)
u.User = url.User("invalid") u.User = url.User("invalid")
@ -197,9 +239,8 @@ func TestPostgresDatabaseExists_Error(t *testing.T) {
} }
func TestPostgresCreateMigrationsTable(t *testing.T) { func TestPostgresCreateMigrationsTable(t *testing.T) {
drv := PostgresDriver{}
t.Run("default schema", func(t *testing.T) { t.Run("default schema", func(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t) u := postgresTestURL(t)
db := prepTestPostgresDB(t, u) db := prepTestPostgresDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -223,39 +264,81 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("custom schema", func(t *testing.T) { t.Run("custom search path", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo") drv := testPostgresDriver()
drv.SetMigrationsTableName("testMigrations")
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=camelFoo")
require.NoError(t, err) require.NoError(t, err)
db := prepTestPostgresDB(t, u) db := prepTestPostgresDB(t, u)
defer mustClose(db) defer mustClose(db)
// delete schema // delete schema
_, err = db.Exec("drop schema if exists foo") _, err = db.Exec("drop schema if exists \"camelFoo\"")
require.NoError(t, err) require.NoError(t, err)
// drop any schema_migrations table in public schema // drop any testMigrations table in public schema
_, err = db.Exec("drop table if exists public.schema_migrations") _, err = db.Exec("drop table if exists public.\"testMigrations\"")
require.NoError(t, err) require.NoError(t, err)
// migrations table should not exist in either schema // migrations table should not exist in either schema
count := 0 count := 0
err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count)
require.Error(t, err) require.Error(t, err)
require.Equal(t, "pq: relation \"foo.schema_migrations\" does not exist", err.Error()) require.Equal(t, "pq: relation \"camelFoo.testMigrations\" does not exist", err.Error())
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count)
require.Error(t, err) require.Error(t, err)
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error()) require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
// create table // create table
err = drv.CreateMigrationsTable(u, db) err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
// foo schema should be created, and migrations table should exist only in foo schema // camelFoo schema should be created, and migrations table should exist only in camelFoo schema
err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count)
require.Error(t, err) require.Error(t, err)
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error()) require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
t.Run("custom schema", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("camelSchema.testMigrations")
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// delete schemas
_, err = db.Exec("drop schema if exists foo")
require.NoError(t, err)
_, err = db.Exec("drop schema if exists \"camelSchema\"")
require.NoError(t, err)
// migrations table should not exist
count := 0
err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count)
require.Error(t, err)
require.Equal(t, "pq: relation \"camelSchema.testMigrations\" does not exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// camelSchema should be created, and testMigrations table should exist
err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count)
require.NoError(t, err)
// testMigrations table should not exist in foo schema because
// schema specified with migrations table name takes priority over search path
err = db.QueryRow("select count(*) from foo.\"testMigrations\"").Scan(&count)
require.Error(t, err)
require.Equal(t, "pq: relation \"foo.testMigrations\" does not exist", err.Error())
// create table should be idempotent // create table should be idempotent
err = drv.CreateMigrationsTable(u, db) err = drv.CreateMigrationsTable(u, db)
@ -264,7 +347,9 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
} }
func TestPostgresSelectMigrations(t *testing.T) { func TestPostgresSelectMigrations(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
u := postgresTestURL(t) u := postgresTestURL(t)
db := prepTestPostgresDB(t, u) db := prepTestPostgresDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -272,7 +357,7 @@ func TestPostgresSelectMigrations(t *testing.T) {
err := drv.CreateMigrationsTable(u, db) err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(`insert into public.schema_migrations (version) _, err = db.Exec(`insert into public.test_migrations (version)
values ('abc2'), ('abc1'), ('abc3')`) values ('abc2'), ('abc1'), ('abc3')`)
require.NoError(t, err) require.NoError(t, err)
@ -291,7 +376,9 @@ func TestPostgresSelectMigrations(t *testing.T) {
} }
func TestPostgresInsertMigration(t *testing.T) { func TestPostgresInsertMigration(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
u := postgresTestURL(t) u := postgresTestURL(t)
db := prepTestPostgresDB(t, u) db := prepTestPostgresDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -300,7 +387,7 @@ func TestPostgresInsertMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, count) require.Equal(t, 0, count)
@ -308,14 +395,16 @@ func TestPostgresInsertMigration(t *testing.T) {
err = drv.InsertMigration(db, "abc1") err = drv.InsertMigration(db, "abc1")
require.NoError(t, err) require.NoError(t, err)
err = db.QueryRow("select count(*) from public.schema_migrations where version = 'abc1'"). err = db.QueryRow("select count(*) from public.test_migrations where version = 'abc1'").
Scan(&count) Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestPostgresDeleteMigration(t *testing.T) { func TestPostgresDeleteMigration(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
u := postgresTestURL(t) u := postgresTestURL(t)
db := prepTestPostgresDB(t, u) db := prepTestPostgresDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -323,7 +412,7 @@ func TestPostgresDeleteMigration(t *testing.T) {
err := drv.CreateMigrationsTable(u, db) err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(`insert into public.schema_migrations (version) _, err = db.Exec(`insert into public.test_migrations (version)
values ('abc1'), ('abc2')`) values ('abc1'), ('abc2')`)
require.NoError(t, err) require.NoError(t, err)
@ -331,13 +420,13 @@ func TestPostgresDeleteMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestPostgresPing(t *testing.T) { func TestPostgresPing(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
u := postgresTestURL(t) u := postgresTestURL(t)
// drop any existing database // drop any existing database
@ -355,15 +444,15 @@ func TestPostgresPing(t *testing.T) {
require.Contains(t, err.Error(), "connect: connection refused") require.Contains(t, err.Error(), "connect: connection refused")
} }
func TestMigrationsTableName(t *testing.T) { func TestPostgresQuotedMigrationsTableName(t *testing.T) {
drv := PostgresDriver{} drv := testPostgresDriver()
t.Run("default schema", func(t *testing.T) { t.Run("default schema", func(t *testing.T) {
u := postgresTestURL(t) u := postgresTestURL(t)
db := prepTestPostgresDB(t, u) db := prepTestPostgresDB(t, u)
defer mustClose(db) defer mustClose(db)
name, err := drv.migrationsTableName(db) name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name) require.Equal(t, "public.schema_migrations", name)
}) })
@ -379,14 +468,14 @@ func TestMigrationsTableName(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec("drop schema if exists bar") _, err = db.Exec("drop schema if exists bar")
require.NoError(t, err) require.NoError(t, err)
name, err := drv.migrationsTableName(db) name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name) require.Equal(t, "public.schema_migrations", name)
// if "foo" schema exists, it should be used // if "foo" schema exists, it should be used
_, err = db.Exec("create schema foo") _, err = db.Exec("create schema foo")
require.NoError(t, err) require.NoError(t, err)
name, err = drv.migrationsTableName(db) name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "foo.schema_migrations", name) require.Equal(t, "foo.schema_migrations", name)
}) })
@ -401,8 +490,76 @@ func TestMigrationsTableName(t *testing.T) {
_, err := db.Exec("select pg_catalog.set_config('search_path', '', false)") _, err := db.Exec("select pg_catalog.set_config('search_path', '', false)")
require.NoError(t, err) require.NoError(t, err)
name, err := drv.migrationsTableName(db) name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name) require.Equal(t, "public.schema_migrations", name)
}) })
t.Run("custom table name", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv.SetMigrationsTableName("simple_name")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.simple_name", name)
})
t.Run("custom table name quoted", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// this table name will need quoting
drv.SetMigrationsTableName("camelCase")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.\"camelCase\"", name)
})
t.Run("custom table name with custom schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
_, err = db.Exec("create schema if not exists foo")
require.NoError(t, err)
drv.SetMigrationsTableName("simple_name")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "foo.simple_name", name)
})
t.Run("custom table name overrides schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
_, err = db.Exec("create schema if not exists foo")
require.NoError(t, err)
_, err = db.Exec("create schema if not exists bar")
require.NoError(t, err)
// if schema is specified as part of table name, it should override search_path
drv.SetMigrationsTableName("bar.simple_name")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "bar.simple_name", name)
// schema and table name should be quoted if necessary
drv.SetMigrationsTableName("barName.camelTable")
name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "\"barName\".\"camelTable\"", name)
// more than 2 components is unexpected but we will quote and pass it along anyway
drv.SetMigrationsTableName("whyWould.i.doThis")
name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "\"whyWould\".i.\"doThis\"", name)
})
} }

View file

@ -11,16 +11,19 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql _ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql
) )
func init() { func init() {
RegisterDriver(SQLiteDriver{}, "sqlite") drv := &SQLiteDriver{}
RegisterDriver(SQLiteDriver{}, "sqlite3") RegisterDriver(drv, "sqlite")
RegisterDriver(drv, "sqlite3")
} }
// SQLiteDriver provides top level database functions // SQLiteDriver provides top level database functions
type SQLiteDriver struct { type SQLiteDriver struct {
migrationsTableName string
} }
func sqlitePath(u *url.URL) string { func sqlitePath(u *url.URL) string {
@ -31,13 +34,18 @@ func sqlitePath(u *url.URL) string {
return str return str
} }
// SetMigrationsTableName sets the schema migrations table name
func (drv *SQLiteDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection // Open creates a new database connection
func (drv SQLiteDriver) Open(u *url.URL) (*sql.DB, error) { func (drv *SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("sqlite3", sqlitePath(u)) return sql.Open("sqlite3", sqlitePath(u))
} }
// CreateDatabase creates the specified database // CreateDatabase creates the specified database
func (drv SQLiteDriver) CreateDatabase(u *url.URL) error { func (drv *SQLiteDriver) CreateDatabase(u *url.URL) error {
fmt.Printf("Creating: %s\n", sqlitePath(u)) fmt.Printf("Creating: %s\n", sqlitePath(u))
db, err := drv.Open(u) db, err := drv.Open(u)
@ -50,7 +58,7 @@ func (drv SQLiteDriver) CreateDatabase(u *url.URL) error {
} }
// DropDatabase drops the specified database (if it exists) // DropDatabase drops the specified database (if it exists)
func (drv SQLiteDriver) DropDatabase(u *url.URL) error { func (drv *SQLiteDriver) DropDatabase(u *url.URL) error {
path := sqlitePath(u) path := sqlitePath(u)
fmt.Printf("Dropping: %s\n", path) fmt.Printf("Dropping: %s\n", path)
@ -65,20 +73,23 @@ func (drv SQLiteDriver) DropDatabase(u *url.URL) error {
return os.Remove(path) return os.Remove(path)
} }
func sqliteSchemaMigrationsDump(db *sql.DB) ([]byte, error) { func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations // load applied migrations
migrations, err := queryColumn(db, migrations, err := queryColumn(db,
"select quote(version) from schema_migrations order by version asc") fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// build schema_migrations table data // build schema migrations table data
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString("-- Dbmate schema migrations\n") buf.WriteString("-- Dbmate schema migrations\n")
if len(migrations) > 0 { if len(migrations) > 0 {
buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" + buf.WriteString(
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
strings.Join(migrations, "),\n (") + strings.Join(migrations, "),\n (") +
");\n") ");\n")
} }
@ -87,14 +98,14 @@ func sqliteSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
} }
// DumpSchema returns the current database schema // DumpSchema returns the current database schema
func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
path := sqlitePath(u) path := sqlitePath(u)
schema, err := runCommand("sqlite3", path, ".schema") schema, err := runCommand("sqlite3", path, ".schema")
if err != nil { if err != nil {
return nil, err return nil, err
} }
migrations, err := sqliteSchemaMigrationsDump(db) migrations, err := drv.schemaMigrationsDump(db)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -104,7 +115,7 @@ func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
} }
// DatabaseExists determines whether the database exists // DatabaseExists determines whether the database exists
func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) { func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
_, err := os.Stat(sqlitePath(u)) _, err := os.Stat(sqlitePath(u))
if os.IsNotExist(err) { if os.IsNotExist(err) {
return false, nil return false, nil
@ -116,9 +127,10 @@ func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
return true, nil return true, nil
} }
// CreateMigrationsTable creates the schema_migrations table // CreateMigrationsTable creates the schema migrations table
func (drv SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec("create table if not exists schema_migrations " + _, err := db.Exec(
fmt.Sprintf("create table if not exists %s ", drv.quotedMigrationsTableName()) +
"(version varchar(255) primary key)") "(version varchar(255) primary key)")
return err return err
@ -126,8 +138,8 @@ func (drv SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
// SelectMigrations returns a list of applied migrations // SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order) // with an optional limit (in descending order)
func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations order by version desc" query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
if limit >= 0 { if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit) query = fmt.Sprintf("%s limit %d", query, limit)
} }
@ -156,15 +168,19 @@ func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool
} }
// InsertMigration adds a new migration record // InsertMigration adds a new migration record
func (drv SQLiteDriver) InsertMigration(db Transaction, version string) error { func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values (?)", version) _, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
return err return err
} }
// DeleteMigration removes a migration record // DeleteMigration removes a migration record
func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error { func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec("delete from schema_migrations where version = ?", version) _, err := db.Exec(
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
version)
return err return err
} }
@ -172,7 +188,7 @@ func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error {
// Ping verifies a connection to the database. Due to the way SQLite works, by // Ping verifies a connection to the database. Due to the way SQLite works, by
// testing whether the database is valid, it will automatically create the database // testing whether the database is valid, it will automatically create the database
// if it does not already exist. // if it does not already exist.
func (drv SQLiteDriver) Ping(u *url.URL) error { func (drv *SQLiteDriver) Ping(u *url.URL) error {
db, err := drv.Open(u) db, err := drv.Open(u)
if err != nil { if err != nil {
return err return err
@ -181,3 +197,14 @@ func (drv SQLiteDriver) Ping(u *url.URL) error {
return db.Ping() return db.Ping()
} }
func (drv *SQLiteDriver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}
// quoteIdentifier quotes a table or column name
// we fall back to lib/pq implementation since both use ansi standard (double quotes)
// and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent
func (drv *SQLiteDriver) quoteIdentifier(s string) string {
return pq.QuoteIdentifier(s)
}

View file

@ -18,8 +18,15 @@ func sqliteTestURL(t *testing.T) *url.URL {
return u return u
} }
func testSQLiteDriver() *SQLiteDriver {
drv := &SQLiteDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB { func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
drv := SQLiteDriver{} drv := testSQLiteDriver()
// drop any existing database // drop any existing database
err := drv.DropDatabase(u) err := drv.DropDatabase(u)
@ -37,7 +44,7 @@ func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
} }
func TestSQLiteCreateDropDatabase(t *testing.T) { func TestSQLiteCreateDropDatabase(t *testing.T) {
drv := SQLiteDriver{} drv := testSQLiteDriver()
u := sqliteTestURL(t) u := sqliteTestURL(t)
path := sqlitePath(u) path := sqlitePath(u)
@ -64,7 +71,9 @@ func TestSQLiteCreateDropDatabase(t *testing.T) {
} }
func TestSQLiteDumpSchema(t *testing.T) { func TestSQLiteDumpSchema(t *testing.T) {
drv := SQLiteDriver{} drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t) u := sqliteTestURL(t)
// prepare database // prepare database
@ -82,9 +91,9 @@ func TestSQLiteDumpSchema(t *testing.T) {
// DumpSchema should return schema // DumpSchema should return schema
schema, err := drv.DumpSchema(u, db) schema, err := drv.DumpSchema(u, db)
require.NoError(t, err) require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE schema_migrations") require.Contains(t, string(schema), "CREATE TABLE IF NOT EXISTS \"test_migrations\"")
require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+ require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+
"INSERT INTO schema_migrations (version) VALUES\n"+ "INSERT INTO \"test_migrations\" (version) VALUES\n"+
" ('abc1'),\n"+ " ('abc1'),\n"+
" ('abc2');\n") " ('abc2');\n")
@ -97,7 +106,7 @@ func TestSQLiteDumpSchema(t *testing.T) {
} }
func TestSQLiteDatabaseExists(t *testing.T) { func TestSQLiteDatabaseExists(t *testing.T) {
drv := SQLiteDriver{} drv := testSQLiteDriver()
u := sqliteTestURL(t) u := sqliteTestURL(t)
// drop any existing database // drop any existing database
@ -120,7 +129,8 @@ func TestSQLiteDatabaseExists(t *testing.T) {
} }
func TestSQLiteCreateMigrationsTable(t *testing.T) { func TestSQLiteCreateMigrationsTable(t *testing.T) {
drv := SQLiteDriver{} t.Run("default table", func(t *testing.T) {
drv := testSQLiteDriver()
u := sqliteTestURL(t) u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u) db := prepTestSQLiteDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -141,10 +151,39 @@ func TestSQLiteCreateMigrationsTable(t *testing.T) {
// create table should be idempotent // create table should be idempotent
err = drv.CreateMigrationsTable(u, db) err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
})
t.Run("custom table", func(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.Regexp(t, "no such table: test_migrations", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
} }
func TestSQLiteSelectMigrations(t *testing.T) { func TestSQLiteSelectMigrations(t *testing.T) {
drv := SQLiteDriver{} drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t) u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u) db := prepTestSQLiteDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -152,7 +191,7 @@ func TestSQLiteSelectMigrations(t *testing.T) {
err := drv.CreateMigrationsTable(u, db) err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version) _, err = db.Exec(`insert into test_migrations (version)
values ('abc2'), ('abc1'), ('abc3')`) values ('abc2'), ('abc1'), ('abc3')`)
require.NoError(t, err) require.NoError(t, err)
@ -171,7 +210,9 @@ func TestSQLiteSelectMigrations(t *testing.T) {
} }
func TestSQLiteInsertMigration(t *testing.T) { func TestSQLiteInsertMigration(t *testing.T) {
drv := SQLiteDriver{} drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t) u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u) db := prepTestSQLiteDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -180,7 +221,7 @@ func TestSQLiteInsertMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, count) require.Equal(t, 0, count)
@ -188,14 +229,16 @@ func TestSQLiteInsertMigration(t *testing.T) {
err = drv.InsertMigration(db, "abc1") err = drv.InsertMigration(db, "abc1")
require.NoError(t, err) require.NoError(t, err)
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'"). err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").
Scan(&count) Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestSQLiteDeleteMigration(t *testing.T) { func TestSQLiteDeleteMigration(t *testing.T) {
drv := SQLiteDriver{} drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t) u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u) db := prepTestSQLiteDB(t, u)
defer mustClose(db) defer mustClose(db)
@ -203,7 +246,7 @@ func TestSQLiteDeleteMigration(t *testing.T) {
err := drv.CreateMigrationsTable(u, db) err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err) require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version) _, err = db.Exec(`insert into test_migrations (version)
values ('abc1'), ('abc2')`) values ('abc1'), ('abc2')`)
require.NoError(t, err) require.NoError(t, err)
@ -211,13 +254,13 @@ func TestSQLiteDeleteMigration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
count := 0 count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, count) require.Equal(t, 1, count)
} }
func TestSQLitePing(t *testing.T) { func TestSQLitePing(t *testing.T) {
drv := SQLiteDriver{} drv := testSQLiteDriver()
u := sqliteTestURL(t) u := sqliteTestURL(t)
path := sqlitePath(u) path := sqlitePath(u)
@ -249,3 +292,19 @@ func TestSQLitePing(t *testing.T) {
err = drv.Ping(u) err = drv.Ping(u)
require.EqualError(t, err, "unable to open database file: is a directory") require.EqualError(t, err, "unable to open database file: is a directory")
} }
func TestSQLiteQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testSQLiteDriver()
name := drv.quotedMigrationsTableName()
require.Equal(t, `"schema_migrations"`, name)
})
t.Run("custom name", func(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("fooMigrations")
name := drv.quotedMigrationsTableName()
require.Equal(t, `"fooMigrations"`, name)
})
}

View file

@ -104,8 +104,8 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) {
// queryColumn runs a SQL statement and returns a slice of strings // queryColumn runs a SQL statement and returns a slice of strings
// it is assumed that the statement returns only one column // it is assumed that the statement returns only one column
// e.g. schema_migrations table // e.g. schema_migrations table
func queryColumn(db Transaction, query string) ([]string, error) { func queryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query) rows, err := db.Query(query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -128,12 +128,12 @@ func queryColumn(db Transaction, query string) ([]string, error) {
return result, nil return result, nil
} }
// queryRow runs a SQL statement and returns a single string // queryValue runs a SQL statement and returns a single string
// it is assumed that the statement returns only one row and one column // it is assumed that the statement returns only one row and one column
// sql NULL is returned as empty string // sql NULL is returned as empty string
func queryRow(db Transaction, query string) (string, error) { func queryValue(db Transaction, query string, args ...interface{}) (string, error) {
var result sql.NullString var result sql.NullString
err := db.QueryRow(query).Scan(&result) err := db.QueryRow(query, args...).Scan(&result)
if err != nil || !result.Valid { if err != nil || !result.Valid {
return "", err return "", err
} }

View file

@ -1,9 +1,11 @@
package dbmate package dbmate
import ( import (
"database/sql"
"net/url" "net/url"
"testing" "testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -33,3 +35,24 @@ func TestTrimLeadingSQLComments(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "real stuff\n-- end\n", string(out)) require.Equal(t, "real stuff\n-- end\n", string(out))
} }
func TestQueryColumn(t *testing.T) {
u := postgresTestURL(t)
db, err := sql.Open("postgres", u.String())
require.NoError(t, err)
val, err := queryColumn(db, "select concat('foo_', unnest($1::text[]))",
pq.Array([]string{"hi", "there"}))
require.NoError(t, err)
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
}
func TestQueryValue(t *testing.T) {
u := postgresTestURL(t)
db, err := sql.Open("postgres", u.String())
require.NoError(t, err)
val, err := queryValue(db, "select $1::int + $2::int", "5", 2)
require.NoError(t, err)
require.Equal(t, "7", val)
}