From c907c3f5c60f85165835a27c0c833610297870c6 Mon Sep 17 00:00:00 2001 From: Adrian Macneil Date: Tue, 17 Nov 2020 18:11:24 +1300 Subject: [PATCH] Ability to specify custom migrations table name (#178) Supported via `--migrations-table` CLI flag or `DBMATE_MIGRATIONS_TABLE` environment variable. Specified table name is quoted when necessary. For PostgreSQL specifically, it's also possible to specify a custom schema (for example: `--migrations-table=foo.migrations`). Closes #168 --- main.go | 7 + pkg/dbmate/clickhouse.go | 89 +++++++---- pkg/dbmate/clickhouse_test.go | 133 ++++++++++++---- pkg/dbmate/db.go | 44 ++++-- pkg/dbmate/db_test.go | 22 ++- pkg/dbmate/driver.go | 17 +- pkg/dbmate/driver_test.go | 10 +- pkg/dbmate/mysql.go | 76 +++++---- pkg/dbmate/mysql_test.go | 76 ++++++--- pkg/dbmate/postgres.go | 138 ++++++++++------ pkg/dbmate/postgres_test.go | 287 ++++++++++++++++++++++++++-------- pkg/dbmate/sqlite.go | 77 ++++++--- pkg/dbmate/sqlite_test.go | 123 +++++++++++---- pkg/dbmate/utils.go | 10 +- pkg/dbmate/utils_test.go | 23 +++ 15 files changed, 807 insertions(+), 325 deletions(-) diff --git a/main.go b/main.go index 7d70497..04a777f 100644 --- a/main.go +++ b/main.go @@ -52,6 +52,12 @@ func NewApp() *cli.App { Value: dbmate.DefaultMigrationsDir, Usage: "specify the directory containing migration files", }, + &cli.StringFlag{ + Name: "migrations-table", + EnvVars: []string{"DBMATE_MIGRATIONS_TABLE"}, + Value: dbmate.DefaultMigrationsTableName, + Usage: "specify the database table to record migrations in", + }, &cli.StringFlag{ Name: "schema-file", Aliases: []string{"s"}, @@ -222,6 +228,7 @@ func action(f func(*dbmate.DB, *cli.Context) error) cli.ActionFunc { db := dbmate.New(u) db.AutoDumpSchema = !c.Bool("no-dump-schema") db.MigrationsDir = c.String("migrations-dir") + db.MigrationsTableName = c.String("migrations-table") db.SchemaFile = c.String("schema-file") db.WaitBefore = c.Bool("wait") overrideTimeout := c.Duration("wait-timeout") diff --git a/pkg/dbmate/clickhouse.go b/pkg/dbmate/clickhouse.go index 63070e3..f0acd3b 100644 --- a/pkg/dbmate/clickhouse.go +++ b/pkg/dbmate/clickhouse.go @@ -13,11 +13,12 @@ import ( ) func init() { - RegisterDriver(ClickHouseDriver{}, "clickhouse") + RegisterDriver(&ClickHouseDriver{}, "clickhouse") } // ClickHouseDriver provides top level database functions type ClickHouseDriver struct { + migrationsTableName string } func normalizeClickHouseURL(initialURL *url.URL) *url.URL { @@ -52,12 +53,17 @@ func normalizeClickHouseURL(initialURL *url.URL) *url.URL { return &u } +// SetMigrationsTableName sets the schema migrations table name +func (drv *ClickHouseDriver) SetMigrationsTableName(name string) { + drv.migrationsTableName = name +} + // Open creates a new database connection -func (drv ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) { +func (drv *ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) { return sql.Open("clickhouse", normalizeClickHouseURL(u).String()) } -func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) { +func (drv *ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) { // connect to clickhouse database clickhouseURL := normalizeClickHouseURL(u) values := clickhouseURL.Query() @@ -67,7 +73,7 @@ func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) { return drv.Open(clickhouseURL) } -func (drv ClickHouseDriver) databaseName(u *url.URL) string { +func (drv *ClickHouseDriver) databaseName(u *url.URL) string { name := normalizeClickHouseURL(u).Query().Get("database") if name == "" { name = "default" @@ -77,7 +83,7 @@ func (drv ClickHouseDriver) databaseName(u *url.URL) string { var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`) -func clickhouseQuoteIdentifier(str string) string { +func (drv *ClickHouseDriver) quoteIdentifier(str string) string { if clickhouseValidIdentifier.MatchString(str) { return str } @@ -88,7 +94,7 @@ func clickhouseQuoteIdentifier(str string) string { } // CreateDatabase creates the specified database -func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error { +func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error { name := drv.databaseName(u) fmt.Printf("Creating: %s\n", name) @@ -98,13 +104,13 @@ func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error { } defer mustClose(db) - _, err = db.Exec("create database " + clickhouseQuoteIdentifier(name)) + _, err = db.Exec("create database " + drv.quoteIdentifier(name)) return err } // DropDatabase drops the specified database (if it exists) -func (drv ClickHouseDriver) DropDatabase(u *url.URL) error { +func (drv *ClickHouseDriver) DropDatabase(u *url.URL) error { name := drv.databaseName(u) fmt.Printf("Dropping: %s\n", name) @@ -114,15 +120,15 @@ func (drv ClickHouseDriver) DropDatabase(u *url.URL) error { } defer mustClose(db) - _, err = db.Exec("drop database if exists " + clickhouseQuoteIdentifier(name)) + _, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name)) return err } -func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error { +func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error { buf.WriteString("\n--\n-- Database schema\n--\n\n") - buf.WriteString("CREATE DATABASE " + clickhouseQuoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n") + buf.WriteString("CREATE DATABASE " + drv.quoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n") tables, err := queryColumn(db, "show tables") if err != nil { @@ -132,7 +138,7 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er for _, table := range tables { var clause string - err = db.QueryRow("show create table " + clickhouseQuoteIdentifier(table)).Scan(&clause) + err = db.QueryRow("show create table " + drv.quoteIdentifier(table)).Scan(&clause) if err != nil { return err } @@ -141,10 +147,13 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er return nil } -func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error { +func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error { + migrationsTable := drv.quotedMigrationsTableName() + // load applied migrations migrations, err := queryColumn(db, - "select version from schema_migrations final where applied order by version asc", + fmt.Sprintf("select version from %s final ", migrationsTable)+ + "where applied order by version asc", ) if err != nil { return err @@ -155,29 +164,30 @@ func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error { migrations[i] = "'" + quoter.Replace(migrations[i]) + "'" } - // build schema_migrations table data + // build schema migrations table data buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n") if len(migrations) > 0 { - buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" + - strings.Join(migrations, "),\n (") + - ");\n") + buf.WriteString( + fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) + + strings.Join(migrations, "),\n (") + + ");\n") } return nil } // DumpSchema returns the current database schema -func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { +func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { var buf bytes.Buffer var err error - err = clickhouseSchemaDump(db, &buf, drv.databaseName(u)) + err = drv.schemaDump(db, &buf, drv.databaseName(u)) if err != nil { return nil, err } - err = clickhouseSchemaMigrationsDump(db, &buf) + err = drv.schemaMigrationsDump(db, &buf) if err != nil { return nil, err } @@ -186,7 +196,7 @@ func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { } // DatabaseExists determines whether the database exists -func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) { +func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) { name := drv.databaseName(u) db, err := drv.openClickHouseDB(u) @@ -205,24 +215,27 @@ func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) { return exists, err } -// CreateMigrationsTable creates the schema_migrations table -func (drv ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { - _, err := db.Exec(` - create table if not exists schema_migrations ( +// CreateMigrationsTable creates the schema migrations table +func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { + _, err := db.Exec(fmt.Sprintf(` + create table if not exists %s ( version String, ts DateTime default now(), applied UInt8 default 1 ) engine = ReplacingMergeTree(ts) primary key version order by version - `) + `, drv.quotedMigrationsTableName())) + return err } // SelectMigrations returns a list of applied migrations // with an optional limit (in descending order) -func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { - query := "select version from schema_migrations final where applied order by version desc" +func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { + query := fmt.Sprintf("select version from %s final where applied order by version desc", + drv.quotedMigrationsTableName()) + if limit >= 0 { query = fmt.Sprintf("%s limit %d", query, limit) } @@ -251,15 +264,19 @@ func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string] } // InsertMigration adds a new migration record -func (drv ClickHouseDriver) InsertMigration(db Transaction, version string) error { - _, err := db.Exec("insert into schema_migrations (version) values (?)", version) +func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) error { + _, err := db.Exec( + fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()), + version) + return err } // DeleteMigration removes a migration record -func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) error { +func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) error { _, err := db.Exec( - "insert into schema_migrations (version, applied) values (?, ?)", + fmt.Sprintf("insert into %s (version, applied) values (?, ?)", + drv.quotedMigrationsTableName()), version, false, ) @@ -268,7 +285,7 @@ func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) erro // Ping verifies a connection to the database server. It does not verify whether the // specified database exists. -func (drv ClickHouseDriver) Ping(u *url.URL) error { +func (drv *ClickHouseDriver) Ping(u *url.URL) error { // attempt connection to primary database, not "clickhouse" database // to support servers with no "clickhouse" database // (see https://github.com/amacneil/dbmate/issues/78) @@ -291,3 +308,7 @@ func (drv ClickHouseDriver) Ping(u *url.URL) error { return err } + +func (drv *ClickHouseDriver) quotedMigrationsTableName() string { + return drv.quoteIdentifier(drv.migrationsTableName) +} diff --git a/pkg/dbmate/clickhouse_test.go b/pkg/dbmate/clickhouse_test.go index 4503535..c1aeefd 100644 --- a/pkg/dbmate/clickhouse_test.go +++ b/pkg/dbmate/clickhouse_test.go @@ -15,8 +15,15 @@ func clickhouseTestURL(t *testing.T) *url.URL { return u } +func testClickHouseDriver() *ClickHouseDriver { + drv := &ClickHouseDriver{} + drv.SetMigrationsTableName(DefaultMigrationsTableName) + + return drv +} + func prepTestClickHouseDB(t *testing.T, u *url.URL) *sql.DB { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() // drop any existing database err := drv.DropDatabase(u) @@ -50,7 +57,7 @@ func TestNormalizeClickHouseURLCanonical(t *testing.T) { } func TestClickHouseCreateDropDatabase(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() u := clickhouseTestURL(t) // drop any existing database @@ -87,7 +94,9 @@ func TestClickHouseCreateDropDatabase(t *testing.T) { } func TestClickHouseDumpSchema(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() + drv.SetMigrationsTableName("test_migrations") + u := clickhouseTestURL(t) // prepare database @@ -113,11 +122,11 @@ func TestClickHouseDumpSchema(t *testing.T) { // DumpSchema should return schema schema, err := drv.DumpSchema(u, db) require.NoError(t, err) - require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".schema_migrations") + require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".test_migrations") require.Contains(t, string(schema), "--\n"+ "-- Dbmate schema migrations\n"+ "--\n\n"+ - "INSERT INTO schema_migrations (version) VALUES\n"+ + "INSERT INTO test_migrations (version) VALUES\n"+ " ('abc1'),\n"+ " ('abc2');\n") @@ -134,7 +143,7 @@ func TestClickHouseDumpSchema(t *testing.T) { } func TestClickHouseDatabaseExists(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() u := clickhouseTestURL(t) // drop any existing database @@ -157,7 +166,7 @@ func TestClickHouseDatabaseExists(t *testing.T) { } func TestClickHouseDatabaseExists_Error(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() u := clickhouseTestURL(t) values := u.Query() values.Set("username", "invalid") @@ -169,31 +178,61 @@ func TestClickHouseDatabaseExists_Error(t *testing.T) { } func TestClickHouseCreateMigrationsTable(t *testing.T) { - drv := ClickHouseDriver{} - u := clickhouseTestURL(t) - db := prepTestClickHouseDB(t, u) - defer mustClose(db) + t.Run("default table", func(t *testing.T) { + drv := testClickHouseDriver() + u := clickhouseTestURL(t) + db := prepTestClickHouseDB(t, u) + defer mustClose(db) - // migrations table should not exist - count := 0 - err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) - require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.") + // migrations table should not exist + count := 0 + err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.") - // create table - err = drv.CreateMigrationsTable(u, db) - require.NoError(t, err) + // create table + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) - // migrations table should exist - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) - require.NoError(t, err) + // migrations table should exist + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.NoError(t, err) - // create table should be idempotent - err = drv.CreateMigrationsTable(u, db) - require.NoError(t, err) + // create table should be idempotent + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + }) + + t.Run("custom table", func(t *testing.T) { + drv := testClickHouseDriver() + drv.SetMigrationsTableName("testMigrations") + + u := clickhouseTestURL(t) + db := prepTestClickHouseDB(t, u) + defer mustClose(db) + + // migrations table should not exist + count := 0 + err := db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count) + require.EqualError(t, err, "code: 60, message: Table dbmate.testMigrations doesn't exist.") + + // create table + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + + // migrations table should exist + err = db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count) + require.NoError(t, err) + + // create table should be idempotent + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + }) } func TestClickHouseSelectMigrations(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() + drv.SetMigrationsTableName("test_migrations") + u := clickhouseTestURL(t) db := prepTestClickHouseDB(t, u) defer mustClose(db) @@ -203,7 +242,7 @@ func TestClickHouseSelectMigrations(t *testing.T) { tx, err := db.Begin() require.NoError(t, err) - stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)") + stmt, err := tx.Prepare("insert into test_migrations (version) values (?)") require.NoError(t, err) _, err = stmt.Exec("abc2") require.NoError(t, err) @@ -229,7 +268,9 @@ func TestClickHouseSelectMigrations(t *testing.T) { } func TestClickHouseInsertMigration(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() + drv.SetMigrationsTableName("test_migrations") + u := clickhouseTestURL(t) db := prepTestClickHouseDB(t, u) defer mustClose(db) @@ -238,7 +279,7 @@ func TestClickHouseInsertMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations").Scan(&count) require.NoError(t, err) require.Equal(t, 0, count) @@ -250,13 +291,15 @@ func TestClickHouseInsertMigration(t *testing.T) { err = tx.Commit() require.NoError(t, err) - err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestClickHouseDeleteMigration(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() + drv.SetMigrationsTableName("test_migrations") + u := clickhouseTestURL(t) db := prepTestClickHouseDB(t, u) defer mustClose(db) @@ -266,7 +309,7 @@ func TestClickHouseDeleteMigration(t *testing.T) { tx, err := db.Begin() require.NoError(t, err) - stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)") + stmt, err := tx.Prepare("insert into test_migrations (version) values (?)") require.NoError(t, err) _, err = stmt.Exec("abc2") require.NoError(t, err) @@ -283,13 +326,13 @@ func TestClickHouseDeleteMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from schema_migrations final where applied").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations final where applied").Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestClickHousePing(t *testing.T) { - drv := ClickHouseDriver{} + drv := testClickHouseDriver() u := clickhouseTestURL(t) // drop any existing database @@ -306,3 +349,27 @@ func TestClickHousePing(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "connect: connection refused") } + +func TestClickHouseQuotedMigrationsTableName(t *testing.T) { + t.Run("default name", func(t *testing.T) { + drv := testClickHouseDriver() + name := drv.quotedMigrationsTableName() + require.Equal(t, "schema_migrations", name) + }) + + t.Run("custom name", func(t *testing.T) { + drv := testClickHouseDriver() + drv.SetMigrationsTableName("fooMigrations") + + name := drv.quotedMigrationsTableName() + require.Equal(t, "fooMigrations", name) + }) + + t.Run("quoted name", func(t *testing.T) { + drv := testClickHouseDriver() + drv.SetMigrationsTableName("bizarre\"$name") + + name := drv.quotedMigrationsTableName() + require.Equal(t, `"bizarre""$name"`, name) + }) +} diff --git a/pkg/dbmate/db.go b/pkg/dbmate/db.go index bd9ec6b..2ddc0be 100644 --- a/pkg/dbmate/db.go +++ b/pkg/dbmate/db.go @@ -15,6 +15,9 @@ import ( // DefaultMigrationsDir specifies default directory to find migration files const DefaultMigrationsDir = "./db/migrations" +// DefaultMigrationsTableName specifies default database tables to record migraitons in +const DefaultMigrationsTableName = "schema_migrations" + // DefaultSchemaFile specifies default location for schema.sql const DefaultSchemaFile = "./db/schema.sql" @@ -26,14 +29,15 @@ const DefaultWaitTimeout = 60 * time.Second // DB allows dbmate actions to be performed on a specified database type DB struct { - AutoDumpSchema bool - DatabaseURL *url.URL - MigrationsDir string - SchemaFile string - Verbose bool - WaitBefore bool - WaitInterval time.Duration - WaitTimeout time.Duration + AutoDumpSchema bool + DatabaseURL *url.URL + MigrationsDir string + MigrationsTableName string + SchemaFile string + Verbose bool + WaitBefore bool + WaitInterval time.Duration + WaitTimeout time.Duration } // migrationFileRegexp pattern for valid migration files @@ -47,19 +51,27 @@ type statusResult struct { // New initializes a new dbmate database func New(databaseURL *url.URL) *DB { return &DB{ - AutoDumpSchema: true, - DatabaseURL: databaseURL, - MigrationsDir: DefaultMigrationsDir, - SchemaFile: DefaultSchemaFile, - WaitBefore: false, - WaitInterval: DefaultWaitInterval, - WaitTimeout: DefaultWaitTimeout, + AutoDumpSchema: true, + DatabaseURL: databaseURL, + MigrationsDir: DefaultMigrationsDir, + MigrationsTableName: DefaultMigrationsTableName, + SchemaFile: DefaultSchemaFile, + WaitBefore: false, + WaitInterval: DefaultWaitInterval, + WaitTimeout: DefaultWaitTimeout, } } // GetDriver loads the required database driver func (db *DB) GetDriver() (Driver, error) { - return GetDriver(db.DatabaseURL.Scheme) + drv, err := getDriver(db.DatabaseURL.Scheme) + if err != nil { + return nil, err + } + + drv.SetMigrationsTableName(db.MigrationsTableName) + + return drv, err } // Wait blocks until the database server is available. It does not verify that diff --git a/pkg/dbmate/db_test.go b/pkg/dbmate/db_test.go index 4b6a890..f3c8c5e 100644 --- a/pkg/dbmate/db_test.go +++ b/pkg/dbmate/db_test.go @@ -38,12 +38,26 @@ func TestNew(t *testing.T) { require.True(t, db.AutoDumpSchema) require.Equal(t, u.String(), db.DatabaseURL.String()) require.Equal(t, "./db/migrations", db.MigrationsDir) + require.Equal(t, "schema_migrations", db.MigrationsTableName) require.Equal(t, "./db/schema.sql", db.SchemaFile) require.False(t, db.WaitBefore) require.Equal(t, time.Second, db.WaitInterval) require.Equal(t, 60*time.Second, db.WaitTimeout) } +func TestGetDriver(t *testing.T) { + u := postgresTestURL(t) + db := New(u) + + drv, err := db.GetDriver() + require.NoError(t, err) + + // driver should have default migrations table set + pgDrv, ok := drv.(*PostgresDriver) + require.True(t, ok) + require.Equal(t, "schema_migrations", pgDrv.migrationsTableName) +} + func TestWait(t *testing.T) { u := postgresTestURL(t) db := newTestDB(t, u) @@ -242,7 +256,7 @@ func testMigrateURL(t *testing.T, u *url.URL) { require.NoError(t, err) // verify results - sqlDB, err := GetDriverOpen(u) + sqlDB, err := getDriverOpen(u) require.NoError(t, err) defer mustClose(sqlDB) @@ -275,7 +289,7 @@ func testUpURL(t *testing.T, u *url.URL) { require.NoError(t, err) // verify results - sqlDB, err := GetDriverOpen(u) + sqlDB, err := getDriverOpen(u) require.NoError(t, err) defer mustClose(sqlDB) @@ -308,7 +322,7 @@ func testRollbackURL(t *testing.T, u *url.URL) { require.NoError(t, err) // verify migration - sqlDB, err := GetDriverOpen(u) + sqlDB, err := getDriverOpen(u) require.NoError(t, err) defer mustClose(sqlDB) @@ -351,7 +365,7 @@ func testStatusURL(t *testing.T, u *url.URL) { require.NoError(t, err) // verify migration - sqlDB, err := GetDriverOpen(u) + sqlDB, err := getDriverOpen(u) require.NoError(t, err) defer mustClose(sqlDB) diff --git a/pkg/dbmate/driver.go b/pkg/dbmate/driver.go index 8a007b4..f5674ad 100644 --- a/pkg/dbmate/driver.go +++ b/pkg/dbmate/driver.go @@ -13,6 +13,7 @@ type Driver interface { CreateDatabase(*url.URL) error DropDatabase(*url.URL) error DumpSchema(*url.URL, *sql.DB) ([]byte, error) + SetMigrationsTableName(string) CreateMigrationsTable(*url.URL, *sql.DB) error SelectMigrations(*sql.DB, int) (map[string]bool, error) InsertMigration(Transaction, string) error @@ -34,18 +35,20 @@ type Transaction interface { QueryRow(query string, args ...interface{}) *sql.Row } -// GetDriver loads a database driver by name -func GetDriver(name string) (Driver, error) { - if val, ok := drivers[name]; ok { - return val, nil +// getDriver loads a database driver by name +func getDriver(name string) (Driver, error) { + if drv, ok := drivers[name]; ok { + drv.SetMigrationsTableName(DefaultMigrationsTableName) + + return drv, nil } return nil, fmt.Errorf("unsupported driver: %s", name) } -// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u) -func GetDriverOpen(u *url.URL) (*sql.DB, error) { - drv, err := GetDriver(u.Scheme) +// getDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u) +func getDriverOpen(u *url.URL) (*sql.DB, error) { + drv, err := getDriver(u.Scheme) if err != nil { return nil, err } diff --git a/pkg/dbmate/driver_test.go b/pkg/dbmate/driver_test.go index d8f5f3e..0a23eee 100644 --- a/pkg/dbmate/driver_test.go +++ b/pkg/dbmate/driver_test.go @@ -7,21 +7,21 @@ import ( ) func TestGetDriver_Postgres(t *testing.T) { - drv, err := GetDriver("postgres") + drv, err := getDriver("postgres") require.NoError(t, err) - _, ok := drv.(PostgresDriver) + _, ok := drv.(*PostgresDriver) require.Equal(t, true, ok) } func TestGetDriver_MySQL(t *testing.T) { - drv, err := GetDriver("mysql") + drv, err := getDriver("mysql") require.NoError(t, err) - _, ok := drv.(MySQLDriver) + _, ok := drv.(*MySQLDriver) require.Equal(t, true, ok) } func TestGetDriver_Error(t *testing.T) { - drv, err := GetDriver("foo") + drv, err := getDriver("foo") require.EqualError(t, err, "unsupported driver: foo") require.Nil(t, drv) } diff --git a/pkg/dbmate/mysql.go b/pkg/dbmate/mysql.go index dd80e60..6443351 100644 --- a/pkg/dbmate/mysql.go +++ b/pkg/dbmate/mysql.go @@ -11,11 +11,12 @@ import ( ) func init() { - RegisterDriver(MySQLDriver{}, "mysql") + RegisterDriver(&MySQLDriver{}, "mysql") } // MySQLDriver provides top level database functions type MySQLDriver struct { + migrationsTableName string } func normalizeMySQLURL(u *url.URL) string { @@ -52,12 +53,17 @@ func normalizeMySQLURL(u *url.URL) string { return normalizedString } +// SetMigrationsTableName sets the schema migrations table name +func (drv *MySQLDriver) SetMigrationsTableName(name string) { + drv.migrationsTableName = name +} + // Open creates a new database connection -func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) { +func (drv *MySQLDriver) Open(u *url.URL) (*sql.DB, error) { return sql.Open("mysql", normalizeMySQLURL(u)) } -func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) { +func (drv *MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) { // connect to no particular database rootURL := *u rootURL.Path = "/" @@ -65,14 +71,14 @@ func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) { return drv.Open(&rootURL) } -func mysqlQuoteIdentifier(str string) string { +func (drv *MySQLDriver) quoteIdentifier(str string) string { str = strings.Replace(str, "`", "\\`", -1) return fmt.Sprintf("`%s`", str) } // CreateDatabase creates the specified database -func (drv MySQLDriver) CreateDatabase(u *url.URL) error { +func (drv *MySQLDriver) CreateDatabase(u *url.URL) error { name := databaseName(u) fmt.Printf("Creating: %s\n", name) @@ -83,13 +89,13 @@ func (drv MySQLDriver) CreateDatabase(u *url.URL) error { defer mustClose(db) _, err = db.Exec(fmt.Sprintf("create database %s", - mysqlQuoteIdentifier(name))) + drv.quoteIdentifier(name))) return err } // DropDatabase drops the specified database (if it exists) -func (drv MySQLDriver) DropDatabase(u *url.URL) error { +func (drv *MySQLDriver) DropDatabase(u *url.URL) error { name := databaseName(u) fmt.Printf("Dropping: %s\n", name) @@ -100,12 +106,12 @@ func (drv MySQLDriver) DropDatabase(u *url.URL) error { defer mustClose(db) _, err = db.Exec(fmt.Sprintf("drop database if exists %s", - mysqlQuoteIdentifier(name))) + drv.quoteIdentifier(name))) return err } -func mysqldumpArgs(u *url.URL) []string { +func (drv *MySQLDriver) mysqldumpArgs(u *url.URL) []string { // generate CLI arguments args := []string{"--opt", "--routines", "--no-data", "--skip-dump-date", "--skip-add-drop-table"} @@ -131,10 +137,12 @@ func mysqldumpArgs(u *url.URL) []string { return args } -func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) { +func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) { + migrationsTable := drv.quotedMigrationsTableName() + // load applied migrations migrations, err := queryColumn(db, - "select quote(version) from schema_migrations order by version asc") + fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable)) if err != nil { return nil, err } @@ -142,12 +150,13 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) { // build schema_migrations table data var buf bytes.Buffer buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" + - "LOCK TABLES `schema_migrations` WRITE;\n") + fmt.Sprintf("LOCK TABLES %s WRITE;\n", migrationsTable)) if len(migrations) > 0 { - buf.WriteString("INSERT INTO `schema_migrations` (version) VALUES\n (" + - strings.Join(migrations, "),\n (") + - ");\n") + buf.WriteString( + fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) + + strings.Join(migrations, "),\n (") + + ");\n") } buf.WriteString("UNLOCK TABLES;\n") @@ -156,13 +165,13 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) { } // DumpSchema returns the current database schema -func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { - schema, err := runCommand("mysqldump", mysqldumpArgs(u)...) +func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { + schema, err := runCommand("mysqldump", drv.mysqldumpArgs(u)...) if err != nil { return nil, err } - migrations, err := mysqlSchemaMigrationsDump(db) + migrations, err := drv.schemaMigrationsDump(db) if err != nil { return nil, err } @@ -172,7 +181,7 @@ func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { } // DatabaseExists determines whether the database exists -func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) { +func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) { name := databaseName(u) db, err := drv.openRootDB(u) @@ -192,17 +201,18 @@ func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) { } // CreateMigrationsTable creates the schema_migrations table -func (drv MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { - _, err := db.Exec("create table if not exists schema_migrations " + - "(version varchar(255) primary key) character set latin1 collate latin1_bin") +func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { + _, err := db.Exec(fmt.Sprintf("create table if not exists %s "+ + "(version varchar(255) primary key) character set latin1 collate latin1_bin", + drv.quotedMigrationsTableName())) return err } // SelectMigrations returns a list of applied migrations // with an optional limit (in descending order) -func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { - query := "select version from schema_migrations order by version desc" +func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { + query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName()) if limit >= 0 { query = fmt.Sprintf("%s limit %d", query, limit) } @@ -231,22 +241,26 @@ func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, } // InsertMigration adds a new migration record -func (drv MySQLDriver) InsertMigration(db Transaction, version string) error { - _, err := db.Exec("insert into schema_migrations (version) values (?)", version) +func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error { + _, err := db.Exec( + fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()), + version) return err } // DeleteMigration removes a migration record -func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error { - _, err := db.Exec("delete from schema_migrations where version = ?", version) +func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error { + _, err := db.Exec( + fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()), + version) return err } // Ping verifies a connection to the database server. It does not verify whether the // specified database exists. -func (drv MySQLDriver) Ping(u *url.URL) error { +func (drv *MySQLDriver) Ping(u *url.URL) error { db, err := drv.openRootDB(u) if err != nil { return err @@ -255,3 +269,7 @@ func (drv MySQLDriver) Ping(u *url.URL) error { return db.Ping() } + +func (drv *MySQLDriver) quotedMigrationsTableName() string { + return drv.quoteIdentifier(drv.migrationsTableName) +} diff --git a/pkg/dbmate/mysql_test.go b/pkg/dbmate/mysql_test.go index 8e34d6e..c5483cc 100644 --- a/pkg/dbmate/mysql_test.go +++ b/pkg/dbmate/mysql_test.go @@ -15,8 +15,15 @@ func mySQLTestURL(t *testing.T) *url.URL { return u } +func testMySQLDriver() *MySQLDriver { + drv := &MySQLDriver{} + drv.SetMigrationsTableName(DefaultMigrationsTableName) + + return drv +} + func prepTestMySQLDB(t *testing.T, u *url.URL) *sql.DB { - drv := MySQLDriver{} + drv := testMySQLDriver() // drop any existing database err := drv.DropDatabase(u) @@ -78,7 +85,7 @@ func TestNormalizeMySQLURLSocket(t *testing.T) { } func TestMySQLCreateDropDatabase(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() u := mySQLTestURL(t) // drop any existing database @@ -116,7 +123,9 @@ func TestMySQLCreateDropDatabase(t *testing.T) { } func TestMySQLDumpSchema(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() + drv.SetMigrationsTableName("test_migrations") + u := mySQLTestURL(t) // prepare database @@ -134,13 +143,13 @@ func TestMySQLDumpSchema(t *testing.T) { // DumpSchema should return schema schema, err := drv.DumpSchema(u, db) require.NoError(t, err) - require.Contains(t, string(schema), "CREATE TABLE `schema_migrations`") + require.Contains(t, string(schema), "CREATE TABLE `test_migrations`") require.Contains(t, string(schema), "\n-- Dump completed\n\n"+ "--\n"+ "-- Dbmate schema migrations\n"+ "--\n\n"+ - "LOCK TABLES `schema_migrations` WRITE;\n"+ - "INSERT INTO `schema_migrations` (version) VALUES\n"+ + "LOCK TABLES `test_migrations` WRITE;\n"+ + "INSERT INTO `test_migrations` (version) VALUES\n"+ " ('abc1'),\n"+ " ('abc2');\n"+ "UNLOCK TABLES;\n") @@ -156,7 +165,7 @@ func TestMySQLDumpSchema(t *testing.T) { } func TestMySQLDatabaseExists(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() u := mySQLTestURL(t) // drop any existing database @@ -179,7 +188,7 @@ func TestMySQLDatabaseExists(t *testing.T) { } func TestMySQLDatabaseExists_Error(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() u := mySQLTestURL(t) u.User = url.User("invalid") @@ -189,22 +198,25 @@ func TestMySQLDatabaseExists_Error(t *testing.T) { } func TestMySQLCreateMigrationsTable(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() + drv.SetMigrationsTableName("test_migrations") + u := mySQLTestURL(t) db := prepTestMySQLDB(t, u) defer mustClose(db) // migrations table should not exist count := 0 - err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) - require.Regexp(t, "Table 'dbmate.schema_migrations' doesn't exist", err.Error()) + err := db.QueryRow("select count(*) from test_migrations").Scan(&count) + require.Error(t, err) + require.Regexp(t, "Table 'dbmate.test_migrations' doesn't exist", err.Error()) // create table err = drv.CreateMigrationsTable(u, db) require.NoError(t, err) // migrations table should exist - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations").Scan(&count) require.NoError(t, err) // create table should be idempotent @@ -213,7 +225,9 @@ func TestMySQLCreateMigrationsTable(t *testing.T) { } func TestMySQLSelectMigrations(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() + drv.SetMigrationsTableName("test_migrations") + u := mySQLTestURL(t) db := prepTestMySQLDB(t, u) defer mustClose(db) @@ -221,7 +235,7 @@ func TestMySQLSelectMigrations(t *testing.T) { err := drv.CreateMigrationsTable(u, db) require.NoError(t, err) - _, err = db.Exec(`insert into schema_migrations (version) + _, err = db.Exec(`insert into test_migrations (version) values ('abc2'), ('abc1'), ('abc3')`) require.NoError(t, err) @@ -240,7 +254,9 @@ func TestMySQLSelectMigrations(t *testing.T) { } func TestMySQLInsertMigration(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() + drv.SetMigrationsTableName("test_migrations") + u := mySQLTestURL(t) db := prepTestMySQLDB(t, u) defer mustClose(db) @@ -249,7 +265,7 @@ func TestMySQLInsertMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations").Scan(&count) require.NoError(t, err) require.Equal(t, 0, count) @@ -257,14 +273,16 @@ func TestMySQLInsertMigration(t *testing.T) { err = drv.InsertMigration(db, "abc1") require.NoError(t, err) - err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'"). + err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'"). Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestMySQLDeleteMigration(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() + drv.SetMigrationsTableName("test_migrations") + u := mySQLTestURL(t) db := prepTestMySQLDB(t, u) defer mustClose(db) @@ -272,7 +290,7 @@ func TestMySQLDeleteMigration(t *testing.T) { err := drv.CreateMigrationsTable(u, db) require.NoError(t, err) - _, err = db.Exec(`insert into schema_migrations (version) + _, err = db.Exec(`insert into test_migrations (version) values ('abc1'), ('abc2')`) require.NoError(t, err) @@ -280,13 +298,13 @@ func TestMySQLDeleteMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations").Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestMySQLPing(t *testing.T) { - drv := MySQLDriver{} + drv := testMySQLDriver() u := mySQLTestURL(t) // drop any existing database @@ -303,3 +321,19 @@ func TestMySQLPing(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "connect: connection refused") } + +func TestMySQLQuotedMigrationsTableName(t *testing.T) { + t.Run("default name", func(t *testing.T) { + drv := testMySQLDriver() + name := drv.quotedMigrationsTableName() + require.Equal(t, "`schema_migrations`", name) + }) + + t.Run("custom name", func(t *testing.T) { + drv := testMySQLDriver() + drv.SetMigrationsTableName("fooMigrations") + + name := drv.quotedMigrationsTableName() + require.Equal(t, "`fooMigrations`", name) + }) +} diff --git a/pkg/dbmate/postgres.go b/pkg/dbmate/postgres.go index 7cbc8a0..172ccc3 100644 --- a/pkg/dbmate/postgres.go +++ b/pkg/dbmate/postgres.go @@ -11,12 +11,14 @@ import ( ) func init() { - RegisterDriver(PostgresDriver{}, "postgres") - RegisterDriver(PostgresDriver{}, "postgresql") + drv := &PostgresDriver{} + RegisterDriver(drv, "postgres") + RegisterDriver(drv, "postgresql") } // PostgresDriver provides top level database functions type PostgresDriver struct { + migrationsTableName string } func normalizePostgresURL(u *url.URL) *url.URL { @@ -78,12 +80,17 @@ func normalizePostgresURLForDump(u *url.URL) []string { return out } +// SetMigrationsTableName sets the schema migrations table name +func (drv *PostgresDriver) SetMigrationsTableName(name string) { + drv.migrationsTableName = name +} + // Open creates a new database connection -func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) { +func (drv *PostgresDriver) Open(u *url.URL) (*sql.DB, error) { return sql.Open("postgres", normalizePostgresURL(u).String()) } -func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) { +func (drv *PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) { // connect to postgres database postgresURL := *u postgresURL.Path = "postgres" @@ -92,7 +99,7 @@ func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) { } // CreateDatabase creates the specified database -func (drv PostgresDriver) CreateDatabase(u *url.URL) error { +func (drv *PostgresDriver) CreateDatabase(u *url.URL) error { name := databaseName(u) fmt.Printf("Creating: %s\n", name) @@ -109,7 +116,7 @@ func (drv PostgresDriver) CreateDatabase(u *url.URL) error { } // DropDatabase drops the specified database (if it exists) -func (drv PostgresDriver) DropDatabase(u *url.URL) error { +func (drv *PostgresDriver) DropDatabase(u *url.URL) error { name := databaseName(u) fmt.Printf("Dropping: %s\n", name) @@ -125,8 +132,8 @@ func (drv PostgresDriver) DropDatabase(u *url.URL) error { return err } -func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) { - migrationsTable, err := drv.migrationsTableName(db) +func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) { + migrationsTable, err := drv.quotedMigrationsTableName(db) if err != nil { return nil, err } @@ -152,7 +159,7 @@ func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, erro } // DumpSchema returns the current database schema -func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { +func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { // load schema args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only", "--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...) @@ -161,7 +168,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { return nil, err } - migrations, err := drv.postgresSchemaMigrationsDump(db) + migrations, err := drv.schemaMigrationsDump(db) if err != nil { return nil, err } @@ -171,7 +178,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { } // DatabaseExists determines whether the database exists -func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) { +func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) { name := databaseName(u) db, err := drv.openPostgresDB(u) @@ -191,48 +198,45 @@ func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) { } // CreateMigrationsTable creates the schema_migrations table -func (drv PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { - // get schema from URL search_path param - searchPath := strings.Split(u.Query().Get("search_path"), ",") - urlSchema := strings.TrimSpace(searchPath[0]) - if urlSchema == "" { - urlSchema = "public" - } - - // get *unquoted* current schema from database - dbSchema, err := queryRow(db, "select current_schema()") +func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { + schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db, u) if err != nil { return err } - // if urlSchema and dbSchema are not equal, the most likely explanation is that the schema - // has not yet been created - if urlSchema != dbSchema { - // in theory we could just execute this statement every time, but we do the comparison - // above in case the user doesn't have permissions to create schemas and the schema - // already exists - fmt.Printf("Creating schema: %s\n", urlSchema) - _, err = db.Exec("create schema if not exists " + pq.QuoteIdentifier(urlSchema)) - if err != nil { - return err - } + // first attempt at creating migrations table + createTableStmt := fmt.Sprintf("create table if not exists %s.%s", schema, migrationsTable) + + " (version varchar(255) primary key)" + _, err = db.Exec(createTableStmt) + if err == nil { + // table exists or created successfully + return nil } - migrationsTable, err := drv.migrationsTableName(db) + // catch 'schema does not exist' error + pqErr, ok := err.(*pq.Error) + if !ok || pqErr.Code != "3F000" { + // unknown error + return err + } + + // in theory we could attempt to create the schema every time, but we avoid that + // in case the user doesn't have permissions to create schemas + fmt.Printf("Creating schema: %s\n", schema) + _, err = db.Exec(fmt.Sprintf("create schema if not exists %s", schema)) if err != nil { return err } - _, err = db.Exec("create table if not exists " + migrationsTable + - " (version varchar(255) primary key)") - + // second and final attempt at creating migrations table + _, err = db.Exec(createTableStmt) return err } // SelectMigrations returns a list of applied migrations // with an optional limit (in descending order) -func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { - migrationsTable, err := drv.migrationsTableName(db) +func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { + migrationsTable, err := drv.quotedMigrationsTableName(db) if err != nil { return nil, err } @@ -266,8 +270,8 @@ func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bo } // InsertMigration adds a new migration record -func (drv PostgresDriver) InsertMigration(db Transaction, version string) error { - migrationsTable, err := drv.migrationsTableName(db) +func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error { + migrationsTable, err := drv.quotedMigrationsTableName(db) if err != nil { return err } @@ -278,8 +282,8 @@ func (drv PostgresDriver) InsertMigration(db Transaction, version string) error } // DeleteMigration removes a migration record -func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error { - migrationsTable, err := drv.migrationsTableName(db) +func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error { + migrationsTable, err := drv.quotedMigrationsTableName(db) if err != nil { return err } @@ -291,7 +295,7 @@ func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error // Ping verifies a connection to the database server. It does not verify whether the // specified database exists. -func (drv PostgresDriver) Ping(u *url.URL) error { +func (drv *PostgresDriver) Ping(u *url.URL) error { // attempt connection to primary database, not "postgres" database // to support servers with no "postgres" database // (see https://github.com/amacneil/dbmate/issues/78) @@ -306,7 +310,7 @@ func (drv PostgresDriver) Ping(u *url.URL) error { return nil } - // ignore 'database "foo" does not exist' error + // ignore 'database does not exist' error pqErr, ok := err.(*pq.Error) if ok && pqErr.Code == "3D000" { return nil @@ -315,17 +319,53 @@ func (drv PostgresDriver) Ping(u *url.URL) error { return err } -func (drv PostgresDriver) migrationsTableName(db Transaction) (string, error) { - // get current schema - schema, err := queryRow(db, "select quote_ident(current_schema())") +func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, error) { + schema, name, err := drv.quotedMigrationsTableNameParts(db, nil) if err != nil { return "", err } - // if the search path is empty, or does not contain a valid schema, default to public + return schema + "." + name, nil +} + +func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url.URL) (string, string, error) { + schema := "" + tableNameParts := strings.Split(drv.migrationsTableName, ".") + if len(tableNameParts) > 1 { + // schema specified as part of table name + schema, tableNameParts = tableNameParts[0], tableNameParts[1:] + } + + if schema == "" && u != nil { + // no schema specified with table name, try URL search path if available + searchPath := strings.Split(u.Query().Get("search_path"), ",") + schema = strings.TrimSpace(searchPath[0]) + } + + var err error + if schema == "" { + // if no URL available, use current schema + // this is a hack because we don't always have the URL context available + schema, err = queryValue(db, "select current_schema()") + if err != nil { + return "", "", err + } + } + + // fall back to public schema as last resort if schema == "" { schema = "public" } - return schema + ".schema_migrations", nil + // quote all parts + // use server rather than client to do this to avoid unnecessary quotes + // (which would change schema.sql diff) + tableNameParts = append([]string{schema}, tableNameParts...) + quotedNameParts, err := queryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts)) + if err != nil { + return "", "", err + } + + // if more than one part, we already have a schema + return quotedNameParts[0], strings.Join(quotedNameParts[1:], "."), nil } diff --git a/pkg/dbmate/postgres_test.go b/pkg/dbmate/postgres_test.go index aafb177..a2e63d0 100644 --- a/pkg/dbmate/postgres_test.go +++ b/pkg/dbmate/postgres_test.go @@ -15,8 +15,15 @@ func postgresTestURL(t *testing.T) *url.URL { return u } +func testPostgresDriver() *PostgresDriver { + drv := &PostgresDriver{} + drv.SetMigrationsTableName(DefaultMigrationsTableName) + + return drv +} + func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB { - drv := PostgresDriver{} + drv := testPostgresDriver() // drop any existing database err := drv.DropDatabase(u) @@ -87,7 +94,7 @@ func TestNormalizePostgresURLForDump(t *testing.T) { } func TestPostgresCreateDropDatabase(t *testing.T) { - drv := PostgresDriver{} + drv := testPostgresDriver() u := postgresTestURL(t) // drop any existing database @@ -125,45 +132,80 @@ func TestPostgresCreateDropDatabase(t *testing.T) { } func TestPostgresDumpSchema(t *testing.T) { - drv := PostgresDriver{} - u := postgresTestURL(t) + t.Run("default migrations table", func(t *testing.T) { + drv := testPostgresDriver() + u := postgresTestURL(t) - // prepare database - db := prepTestPostgresDB(t, u) - defer mustClose(db) - err := drv.CreateMigrationsTable(u, db) - require.NoError(t, err) + // prepare database + db := prepTestPostgresDB(t, u) + defer mustClose(db) + err := drv.CreateMigrationsTable(u, db) + require.NoError(t, err) - // insert migration - err = drv.InsertMigration(db, "abc1") - require.NoError(t, err) - err = drv.InsertMigration(db, "abc2") - require.NoError(t, err) + // insert migration + err = drv.InsertMigration(db, "abc1") + require.NoError(t, err) + err = drv.InsertMigration(db, "abc2") + require.NoError(t, err) - // DumpSchema should return schema - schema, err := drv.DumpSchema(u, db) - require.NoError(t, err) - require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations") - require.Contains(t, string(schema), "\n--\n"+ - "-- PostgreSQL database dump complete\n"+ - "--\n\n\n"+ - "--\n"+ - "-- Dbmate schema migrations\n"+ - "--\n\n"+ - "INSERT INTO public.schema_migrations (version) VALUES\n"+ - " ('abc1'),\n"+ - " ('abc2');\n") + // DumpSchema should return schema + schema, err := drv.DumpSchema(u, db) + require.NoError(t, err) + require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations") + require.Contains(t, string(schema), "\n--\n"+ + "-- PostgreSQL database dump complete\n"+ + "--\n\n\n"+ + "--\n"+ + "-- Dbmate schema migrations\n"+ + "--\n\n"+ + "INSERT INTO public.schema_migrations (version) VALUES\n"+ + " ('abc1'),\n"+ + " ('abc2');\n") - // DumpSchema should return error if command fails - u.Path = "/fakedb" - schema, err = drv.DumpSchema(u, db) - require.Nil(t, schema) - require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+ - "\"fakedb\" failed: FATAL: database \"fakedb\" does not exist") + // DumpSchema should return error if command fails + u.Path = "/fakedb" + schema, err = drv.DumpSchema(u, db) + require.Nil(t, schema) + require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+ + "\"fakedb\" failed: FATAL: database \"fakedb\" does not exist") + }) + + t.Run("custom migrations table with schema", func(t *testing.T) { + drv := testPostgresDriver() + drv.SetMigrationsTableName("camelSchema.testMigrations") + + u := postgresTestURL(t) + + // prepare database + db := prepTestPostgresDB(t, u) + defer mustClose(db) + err := drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + + // insert migration + err = drv.InsertMigration(db, "abc1") + require.NoError(t, err) + err = drv.InsertMigration(db, "abc2") + require.NoError(t, err) + + // DumpSchema should return schema + schema, err := drv.DumpSchema(u, db) + require.NoError(t, err) + require.Contains(t, string(schema), "CREATE TABLE \"camelSchema\".\"testMigrations\"") + require.Contains(t, string(schema), "\n--\n"+ + "-- PostgreSQL database dump complete\n"+ + "--\n\n\n"+ + "--\n"+ + "-- Dbmate schema migrations\n"+ + "--\n\n"+ + "INSERT INTO \"camelSchema\".\"testMigrations\" (version) VALUES\n"+ + " ('abc1'),\n"+ + " ('abc2');\n") + }) } func TestPostgresDatabaseExists(t *testing.T) { - drv := PostgresDriver{} + drv := testPostgresDriver() u := postgresTestURL(t) // drop any existing database @@ -186,7 +228,7 @@ func TestPostgresDatabaseExists(t *testing.T) { } func TestPostgresDatabaseExists_Error(t *testing.T) { - drv := PostgresDriver{} + drv := testPostgresDriver() u := postgresTestURL(t) u.User = url.User("invalid") @@ -197,9 +239,8 @@ func TestPostgresDatabaseExists_Error(t *testing.T) { } func TestPostgresCreateMigrationsTable(t *testing.T) { - drv := PostgresDriver{} - t.Run("default schema", func(t *testing.T) { + drv := testPostgresDriver() u := postgresTestURL(t) db := prepTestPostgresDB(t, u) defer mustClose(db) @@ -223,39 +264,81 @@ func TestPostgresCreateMigrationsTable(t *testing.T) { require.NoError(t, err) }) - t.Run("custom schema", func(t *testing.T) { - u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo") + t.Run("custom search path", func(t *testing.T) { + drv := testPostgresDriver() + drv.SetMigrationsTableName("testMigrations") + + u, err := url.Parse(postgresTestURL(t).String() + "&search_path=camelFoo") require.NoError(t, err) db := prepTestPostgresDB(t, u) defer mustClose(db) // delete schema - _, err = db.Exec("drop schema if exists foo") + _, err = db.Exec("drop schema if exists \"camelFoo\"") require.NoError(t, err) - // drop any schema_migrations table in public schema - _, err = db.Exec("drop table if exists public.schema_migrations") + // drop any testMigrations table in public schema + _, err = db.Exec("drop table if exists public.\"testMigrations\"") require.NoError(t, err) // migrations table should not exist in either schema count := 0 - err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count) require.Error(t, err) - require.Equal(t, "pq: relation \"foo.schema_migrations\" does not exist", err.Error()) - err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) + require.Equal(t, "pq: relation \"camelFoo.testMigrations\" does not exist", err.Error()) + err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count) require.Error(t, err) - require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error()) + require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error()) // create table err = drv.CreateMigrationsTable(u, db) require.NoError(t, err) - // foo schema should be created, and migrations table should exist only in foo schema - err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count) + // camelFoo schema should be created, and migrations table should exist only in camelFoo schema + err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count) require.NoError(t, err) - err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count) require.Error(t, err) - require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error()) + require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error()) + + // create table should be idempotent + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + }) + + t.Run("custom schema", func(t *testing.T) { + drv := testPostgresDriver() + drv.SetMigrationsTableName("camelSchema.testMigrations") + + u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo") + require.NoError(t, err) + db := prepTestPostgresDB(t, u) + defer mustClose(db) + + // delete schemas + _, err = db.Exec("drop schema if exists foo") + require.NoError(t, err) + _, err = db.Exec("drop schema if exists \"camelSchema\"") + require.NoError(t, err) + + // migrations table should not exist + count := 0 + err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count) + require.Error(t, err) + require.Equal(t, "pq: relation \"camelSchema.testMigrations\" does not exist", err.Error()) + + // create table + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + + // camelSchema should be created, and testMigrations table should exist + err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count) + require.NoError(t, err) + // testMigrations table should not exist in foo schema because + // schema specified with migrations table name takes priority over search path + err = db.QueryRow("select count(*) from foo.\"testMigrations\"").Scan(&count) + require.Error(t, err) + require.Equal(t, "pq: relation \"foo.testMigrations\" does not exist", err.Error()) // create table should be idempotent err = drv.CreateMigrationsTable(u, db) @@ -264,7 +347,9 @@ func TestPostgresCreateMigrationsTable(t *testing.T) { } func TestPostgresSelectMigrations(t *testing.T) { - drv := PostgresDriver{} + drv := testPostgresDriver() + drv.SetMigrationsTableName("test_migrations") + u := postgresTestURL(t) db := prepTestPostgresDB(t, u) defer mustClose(db) @@ -272,7 +357,7 @@ func TestPostgresSelectMigrations(t *testing.T) { err := drv.CreateMigrationsTable(u, db) require.NoError(t, err) - _, err = db.Exec(`insert into public.schema_migrations (version) + _, err = db.Exec(`insert into public.test_migrations (version) values ('abc2'), ('abc1'), ('abc3')`) require.NoError(t, err) @@ -291,7 +376,9 @@ func TestPostgresSelectMigrations(t *testing.T) { } func TestPostgresInsertMigration(t *testing.T) { - drv := PostgresDriver{} + drv := testPostgresDriver() + drv.SetMigrationsTableName("test_migrations") + u := postgresTestURL(t) db := prepTestPostgresDB(t, u) defer mustClose(db) @@ -300,7 +387,7 @@ func TestPostgresInsertMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count) require.NoError(t, err) require.Equal(t, 0, count) @@ -308,14 +395,16 @@ func TestPostgresInsertMigration(t *testing.T) { err = drv.InsertMigration(db, "abc1") require.NoError(t, err) - err = db.QueryRow("select count(*) from public.schema_migrations where version = 'abc1'"). + err = db.QueryRow("select count(*) from public.test_migrations where version = 'abc1'"). Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestPostgresDeleteMigration(t *testing.T) { - drv := PostgresDriver{} + drv := testPostgresDriver() + drv.SetMigrationsTableName("test_migrations") + u := postgresTestURL(t) db := prepTestPostgresDB(t, u) defer mustClose(db) @@ -323,7 +412,7 @@ func TestPostgresDeleteMigration(t *testing.T) { err := drv.CreateMigrationsTable(u, db) require.NoError(t, err) - _, err = db.Exec(`insert into public.schema_migrations (version) + _, err = db.Exec(`insert into public.test_migrations (version) values ('abc1'), ('abc2')`) require.NoError(t, err) @@ -331,13 +420,13 @@ func TestPostgresDeleteMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestPostgresPing(t *testing.T) { - drv := PostgresDriver{} + drv := testPostgresDriver() u := postgresTestURL(t) // drop any existing database @@ -355,15 +444,15 @@ func TestPostgresPing(t *testing.T) { require.Contains(t, err.Error(), "connect: connection refused") } -func TestMigrationsTableName(t *testing.T) { - drv := PostgresDriver{} +func TestPostgresQuotedMigrationsTableName(t *testing.T) { + drv := testPostgresDriver() t.Run("default schema", func(t *testing.T) { u := postgresTestURL(t) db := prepTestPostgresDB(t, u) defer mustClose(db) - name, err := drv.migrationsTableName(db) + name, err := drv.quotedMigrationsTableName(db) require.NoError(t, err) require.Equal(t, "public.schema_migrations", name) }) @@ -379,14 +468,14 @@ func TestMigrationsTableName(t *testing.T) { require.NoError(t, err) _, err = db.Exec("drop schema if exists bar") require.NoError(t, err) - name, err := drv.migrationsTableName(db) + name, err := drv.quotedMigrationsTableName(db) require.NoError(t, err) require.Equal(t, "public.schema_migrations", name) // if "foo" schema exists, it should be used _, err = db.Exec("create schema foo") require.NoError(t, err) - name, err = drv.migrationsTableName(db) + name, err = drv.quotedMigrationsTableName(db) require.NoError(t, err) require.Equal(t, "foo.schema_migrations", name) }) @@ -401,8 +490,76 @@ func TestMigrationsTableName(t *testing.T) { _, err := db.Exec("select pg_catalog.set_config('search_path', '', false)") require.NoError(t, err) - name, err := drv.migrationsTableName(db) + name, err := drv.quotedMigrationsTableName(db) require.NoError(t, err) require.Equal(t, "public.schema_migrations", name) }) + + t.Run("custom table name", func(t *testing.T) { + u := postgresTestURL(t) + db := prepTestPostgresDB(t, u) + defer mustClose(db) + + drv.SetMigrationsTableName("simple_name") + name, err := drv.quotedMigrationsTableName(db) + require.NoError(t, err) + require.Equal(t, "public.simple_name", name) + }) + + t.Run("custom table name quoted", func(t *testing.T) { + u := postgresTestURL(t) + db := prepTestPostgresDB(t, u) + defer mustClose(db) + + // this table name will need quoting + drv.SetMigrationsTableName("camelCase") + name, err := drv.quotedMigrationsTableName(db) + require.NoError(t, err) + require.Equal(t, "public.\"camelCase\"", name) + }) + + t.Run("custom table name with custom schema", func(t *testing.T) { + u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo") + require.NoError(t, err) + db := prepTestPostgresDB(t, u) + defer mustClose(db) + + _, err = db.Exec("create schema if not exists foo") + require.NoError(t, err) + + drv.SetMigrationsTableName("simple_name") + name, err := drv.quotedMigrationsTableName(db) + require.NoError(t, err) + require.Equal(t, "foo.simple_name", name) + }) + + t.Run("custom table name overrides schema", func(t *testing.T) { + u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo") + require.NoError(t, err) + db := prepTestPostgresDB(t, u) + defer mustClose(db) + + _, err = db.Exec("create schema if not exists foo") + require.NoError(t, err) + _, err = db.Exec("create schema if not exists bar") + require.NoError(t, err) + + // if schema is specified as part of table name, it should override search_path + drv.SetMigrationsTableName("bar.simple_name") + name, err := drv.quotedMigrationsTableName(db) + require.NoError(t, err) + require.Equal(t, "bar.simple_name", name) + + // schema and table name should be quoted if necessary + drv.SetMigrationsTableName("barName.camelTable") + name, err = drv.quotedMigrationsTableName(db) + require.NoError(t, err) + require.Equal(t, "\"barName\".\"camelTable\"", name) + + // more than 2 components is unexpected but we will quote and pass it along anyway + drv.SetMigrationsTableName("whyWould.i.doThis") + name, err = drv.quotedMigrationsTableName(db) + require.NoError(t, err) + require.Equal(t, "\"whyWould\".i.\"doThis\"", name) + }) } diff --git a/pkg/dbmate/sqlite.go b/pkg/dbmate/sqlite.go index 3ef10dd..f762d4c 100644 --- a/pkg/dbmate/sqlite.go +++ b/pkg/dbmate/sqlite.go @@ -11,16 +11,19 @@ import ( "regexp" "strings" + "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql ) func init() { - RegisterDriver(SQLiteDriver{}, "sqlite") - RegisterDriver(SQLiteDriver{}, "sqlite3") + drv := &SQLiteDriver{} + RegisterDriver(drv, "sqlite") + RegisterDriver(drv, "sqlite3") } // SQLiteDriver provides top level database functions type SQLiteDriver struct { + migrationsTableName string } func sqlitePath(u *url.URL) string { @@ -31,13 +34,18 @@ func sqlitePath(u *url.URL) string { return str } +// SetMigrationsTableName sets the schema migrations table name +func (drv *SQLiteDriver) SetMigrationsTableName(name string) { + drv.migrationsTableName = name +} + // Open creates a new database connection -func (drv SQLiteDriver) Open(u *url.URL) (*sql.DB, error) { +func (drv *SQLiteDriver) Open(u *url.URL) (*sql.DB, error) { return sql.Open("sqlite3", sqlitePath(u)) } // CreateDatabase creates the specified database -func (drv SQLiteDriver) CreateDatabase(u *url.URL) error { +func (drv *SQLiteDriver) CreateDatabase(u *url.URL) error { fmt.Printf("Creating: %s\n", sqlitePath(u)) db, err := drv.Open(u) @@ -50,7 +58,7 @@ func (drv SQLiteDriver) CreateDatabase(u *url.URL) error { } // DropDatabase drops the specified database (if it exists) -func (drv SQLiteDriver) DropDatabase(u *url.URL) error { +func (drv *SQLiteDriver) DropDatabase(u *url.URL) error { path := sqlitePath(u) fmt.Printf("Dropping: %s\n", path) @@ -65,36 +73,39 @@ func (drv SQLiteDriver) DropDatabase(u *url.URL) error { return os.Remove(path) } -func sqliteSchemaMigrationsDump(db *sql.DB) ([]byte, error) { +func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) { + migrationsTable := drv.quotedMigrationsTableName() + // load applied migrations migrations, err := queryColumn(db, - "select quote(version) from schema_migrations order by version asc") + fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable)) if err != nil { return nil, err } - // build schema_migrations table data + // build schema migrations table data var buf bytes.Buffer buf.WriteString("-- Dbmate schema migrations\n") if len(migrations) > 0 { - buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" + - strings.Join(migrations, "),\n (") + - ");\n") + buf.WriteString( + fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) + + strings.Join(migrations, "),\n (") + + ");\n") } return buf.Bytes(), nil } // DumpSchema returns the current database schema -func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { +func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { path := sqlitePath(u) schema, err := runCommand("sqlite3", path, ".schema") if err != nil { return nil, err } - migrations, err := sqliteSchemaMigrationsDump(db) + migrations, err := drv.schemaMigrationsDump(db) if err != nil { return nil, err } @@ -104,7 +115,7 @@ func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { } // DatabaseExists determines whether the database exists -func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) { +func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) { _, err := os.Stat(sqlitePath(u)) if os.IsNotExist(err) { return false, nil @@ -116,18 +127,19 @@ func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) { return true, nil } -// CreateMigrationsTable creates the schema_migrations table -func (drv SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { - _, err := db.Exec("create table if not exists schema_migrations " + - "(version varchar(255) primary key)") +// CreateMigrationsTable creates the schema migrations table +func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error { + _, err := db.Exec( + fmt.Sprintf("create table if not exists %s ", drv.quotedMigrationsTableName()) + + "(version varchar(255) primary key)") return err } // SelectMigrations returns a list of applied migrations // with an optional limit (in descending order) -func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { - query := "select version from schema_migrations order by version desc" +func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { + query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName()) if limit >= 0 { query = fmt.Sprintf("%s limit %d", query, limit) } @@ -156,15 +168,19 @@ func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool } // InsertMigration adds a new migration record -func (drv SQLiteDriver) InsertMigration(db Transaction, version string) error { - _, err := db.Exec("insert into schema_migrations (version) values (?)", version) +func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error { + _, err := db.Exec( + fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()), + version) return err } // DeleteMigration removes a migration record -func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error { - _, err := db.Exec("delete from schema_migrations where version = ?", version) +func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error { + _, err := db.Exec( + fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()), + version) return err } @@ -172,7 +188,7 @@ func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error { // Ping verifies a connection to the database. Due to the way SQLite works, by // testing whether the database is valid, it will automatically create the database // if it does not already exist. -func (drv SQLiteDriver) Ping(u *url.URL) error { +func (drv *SQLiteDriver) Ping(u *url.URL) error { db, err := drv.Open(u) if err != nil { return err @@ -181,3 +197,14 @@ func (drv SQLiteDriver) Ping(u *url.URL) error { return db.Ping() } + +func (drv *SQLiteDriver) quotedMigrationsTableName() string { + return drv.quoteIdentifier(drv.migrationsTableName) +} + +// quoteIdentifier quotes a table or column name +// we fall back to lib/pq implementation since both use ansi standard (double quotes) +// and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent +func (drv *SQLiteDriver) quoteIdentifier(s string) string { + return pq.QuoteIdentifier(s) +} diff --git a/pkg/dbmate/sqlite_test.go b/pkg/dbmate/sqlite_test.go index 3d32889..5602188 100644 --- a/pkg/dbmate/sqlite_test.go +++ b/pkg/dbmate/sqlite_test.go @@ -18,8 +18,15 @@ func sqliteTestURL(t *testing.T) *url.URL { return u } +func testSQLiteDriver() *SQLiteDriver { + drv := &SQLiteDriver{} + drv.SetMigrationsTableName(DefaultMigrationsTableName) + + return drv +} + func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB { - drv := SQLiteDriver{} + drv := testSQLiteDriver() // drop any existing database err := drv.DropDatabase(u) @@ -37,7 +44,7 @@ func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB { } func TestSQLiteCreateDropDatabase(t *testing.T) { - drv := SQLiteDriver{} + drv := testSQLiteDriver() u := sqliteTestURL(t) path := sqlitePath(u) @@ -64,7 +71,9 @@ func TestSQLiteCreateDropDatabase(t *testing.T) { } func TestSQLiteDumpSchema(t *testing.T) { - drv := SQLiteDriver{} + drv := testSQLiteDriver() + drv.SetMigrationsTableName("test_migrations") + u := sqliteTestURL(t) // prepare database @@ -82,9 +91,9 @@ func TestSQLiteDumpSchema(t *testing.T) { // DumpSchema should return schema schema, err := drv.DumpSchema(u, db) require.NoError(t, err) - require.Contains(t, string(schema), "CREATE TABLE schema_migrations") + require.Contains(t, string(schema), "CREATE TABLE IF NOT EXISTS \"test_migrations\"") require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+ - "INSERT INTO schema_migrations (version) VALUES\n"+ + "INSERT INTO \"test_migrations\" (version) VALUES\n"+ " ('abc1'),\n"+ " ('abc2');\n") @@ -97,7 +106,7 @@ func TestSQLiteDumpSchema(t *testing.T) { } func TestSQLiteDatabaseExists(t *testing.T) { - drv := SQLiteDriver{} + drv := testSQLiteDriver() u := sqliteTestURL(t) // drop any existing database @@ -120,31 +129,61 @@ func TestSQLiteDatabaseExists(t *testing.T) { } func TestSQLiteCreateMigrationsTable(t *testing.T) { - drv := SQLiteDriver{} - u := sqliteTestURL(t) - db := prepTestSQLiteDB(t, u) - defer mustClose(db) + t.Run("default table", func(t *testing.T) { + drv := testSQLiteDriver() + u := sqliteTestURL(t) + db := prepTestSQLiteDB(t, u) + defer mustClose(db) - // migrations table should not exist - count := 0 - err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) - require.Regexp(t, "no such table: schema_migrations", err.Error()) + // migrations table should not exist + count := 0 + err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Regexp(t, "no such table: schema_migrations", err.Error()) - // create table - err = drv.CreateMigrationsTable(u, db) - require.NoError(t, err) + // create table + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) - // migrations table should exist - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) - require.NoError(t, err) + // migrations table should exist + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.NoError(t, err) - // create table should be idempotent - err = drv.CreateMigrationsTable(u, db) - require.NoError(t, err) + // create table should be idempotent + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + }) + + t.Run("custom table", func(t *testing.T) { + drv := testSQLiteDriver() + drv.SetMigrationsTableName("test_migrations") + + u := sqliteTestURL(t) + db := prepTestSQLiteDB(t, u) + defer mustClose(db) + + // migrations table should not exist + count := 0 + err := db.QueryRow("select count(*) from test_migrations").Scan(&count) + require.Regexp(t, "no such table: test_migrations", err.Error()) + + // create table + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + + // migrations table should exist + err = db.QueryRow("select count(*) from test_migrations").Scan(&count) + require.NoError(t, err) + + // create table should be idempotent + err = drv.CreateMigrationsTable(u, db) + require.NoError(t, err) + }) } func TestSQLiteSelectMigrations(t *testing.T) { - drv := SQLiteDriver{} + drv := testSQLiteDriver() + drv.SetMigrationsTableName("test_migrations") + u := sqliteTestURL(t) db := prepTestSQLiteDB(t, u) defer mustClose(db) @@ -152,7 +191,7 @@ func TestSQLiteSelectMigrations(t *testing.T) { err := drv.CreateMigrationsTable(u, db) require.NoError(t, err) - _, err = db.Exec(`insert into schema_migrations (version) + _, err = db.Exec(`insert into test_migrations (version) values ('abc2'), ('abc1'), ('abc3')`) require.NoError(t, err) @@ -171,7 +210,9 @@ func TestSQLiteSelectMigrations(t *testing.T) { } func TestSQLiteInsertMigration(t *testing.T) { - drv := SQLiteDriver{} + drv := testSQLiteDriver() + drv.SetMigrationsTableName("test_migrations") + u := sqliteTestURL(t) db := prepTestSQLiteDB(t, u) defer mustClose(db) @@ -180,7 +221,7 @@ func TestSQLiteInsertMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations").Scan(&count) require.NoError(t, err) require.Equal(t, 0, count) @@ -188,14 +229,16 @@ func TestSQLiteInsertMigration(t *testing.T) { err = drv.InsertMigration(db, "abc1") require.NoError(t, err) - err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'"). + err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'"). Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestSQLiteDeleteMigration(t *testing.T) { - drv := SQLiteDriver{} + drv := testSQLiteDriver() + drv.SetMigrationsTableName("test_migrations") + u := sqliteTestURL(t) db := prepTestSQLiteDB(t, u) defer mustClose(db) @@ -203,7 +246,7 @@ func TestSQLiteDeleteMigration(t *testing.T) { err := drv.CreateMigrationsTable(u, db) require.NoError(t, err) - _, err = db.Exec(`insert into schema_migrations (version) + _, err = db.Exec(`insert into test_migrations (version) values ('abc1'), ('abc2')`) require.NoError(t, err) @@ -211,13 +254,13 @@ func TestSQLiteDeleteMigration(t *testing.T) { require.NoError(t, err) count := 0 - err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + err = db.QueryRow("select count(*) from test_migrations").Scan(&count) require.NoError(t, err) require.Equal(t, 1, count) } func TestSQLitePing(t *testing.T) { - drv := SQLiteDriver{} + drv := testSQLiteDriver() u := sqliteTestURL(t) path := sqlitePath(u) @@ -249,3 +292,19 @@ func TestSQLitePing(t *testing.T) { err = drv.Ping(u) require.EqualError(t, err, "unable to open database file: is a directory") } + +func TestSQLiteQuotedMigrationsTableName(t *testing.T) { + t.Run("default name", func(t *testing.T) { + drv := testSQLiteDriver() + name := drv.quotedMigrationsTableName() + require.Equal(t, `"schema_migrations"`, name) + }) + + t.Run("custom name", func(t *testing.T) { + drv := testSQLiteDriver() + drv.SetMigrationsTableName("fooMigrations") + + name := drv.quotedMigrationsTableName() + require.Equal(t, `"fooMigrations"`, name) + }) +} diff --git a/pkg/dbmate/utils.go b/pkg/dbmate/utils.go index cee8476..41913fc 100644 --- a/pkg/dbmate/utils.go +++ b/pkg/dbmate/utils.go @@ -104,8 +104,8 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) { // queryColumn runs a SQL statement and returns a slice of strings // it is assumed that the statement returns only one column // e.g. schema_migrations table -func queryColumn(db Transaction, query string) ([]string, error) { - rows, err := db.Query(query) +func queryColumn(db Transaction, query string, args ...interface{}) ([]string, error) { + rows, err := db.Query(query, args...) if err != nil { return nil, err } @@ -128,12 +128,12 @@ func queryColumn(db Transaction, query string) ([]string, error) { return result, nil } -// queryRow runs a SQL statement and returns a single string +// queryValue runs a SQL statement and returns a single string // it is assumed that the statement returns only one row and one column // sql NULL is returned as empty string -func queryRow(db Transaction, query string) (string, error) { +func queryValue(db Transaction, query string, args ...interface{}) (string, error) { var result sql.NullString - err := db.QueryRow(query).Scan(&result) + err := db.QueryRow(query, args...).Scan(&result) if err != nil || !result.Valid { return "", err } diff --git a/pkg/dbmate/utils_test.go b/pkg/dbmate/utils_test.go index e8a86a7..63db33c 100644 --- a/pkg/dbmate/utils_test.go +++ b/pkg/dbmate/utils_test.go @@ -1,9 +1,11 @@ package dbmate import ( + "database/sql" "net/url" "testing" + "github.com/lib/pq" "github.com/stretchr/testify/require" ) @@ -33,3 +35,24 @@ func TestTrimLeadingSQLComments(t *testing.T) { require.NoError(t, err) require.Equal(t, "real stuff\n-- end\n", string(out)) } + +func TestQueryColumn(t *testing.T) { + u := postgresTestURL(t) + db, err := sql.Open("postgres", u.String()) + require.NoError(t, err) + + val, err := queryColumn(db, "select concat('foo_', unnest($1::text[]))", + pq.Array([]string{"hi", "there"})) + require.NoError(t, err) + require.Equal(t, []string{"foo_hi", "foo_there"}, val) +} + +func TestQueryValue(t *testing.T) { + u := postgresTestURL(t) + db, err := sql.Open("postgres", u.String()) + require.NoError(t, err) + + val, err := queryValue(db, "select $1::int + $2::int", "5", 2) + require.NoError(t, err) + require.Equal(t, "7", val) +}