Ability to specify custom migrations table name (#178)

Supported via `--migrations-table` CLI flag or `DBMATE_MIGRATIONS_TABLE` environment variable.

Specified table name is quoted when necessary.

For PostgreSQL specifically, it's also possible to specify a custom schema (for example: `--migrations-table=foo.migrations`).

Closes #168
This commit is contained in:
Adrian Macneil 2020-11-17 18:11:24 +13:00 committed by GitHub
parent 656dc0253a
commit c907c3f5c6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 807 additions and 325 deletions

View file

@ -52,6 +52,12 @@ func NewApp() *cli.App {
Value: dbmate.DefaultMigrationsDir,
Usage: "specify the directory containing migration files",
},
&cli.StringFlag{
Name: "migrations-table",
EnvVars: []string{"DBMATE_MIGRATIONS_TABLE"},
Value: dbmate.DefaultMigrationsTableName,
Usage: "specify the database table to record migrations in",
},
&cli.StringFlag{
Name: "schema-file",
Aliases: []string{"s"},
@ -222,6 +228,7 @@ func action(f func(*dbmate.DB, *cli.Context) error) cli.ActionFunc {
db := dbmate.New(u)
db.AutoDumpSchema = !c.Bool("no-dump-schema")
db.MigrationsDir = c.String("migrations-dir")
db.MigrationsTableName = c.String("migrations-table")
db.SchemaFile = c.String("schema-file")
db.WaitBefore = c.Bool("wait")
overrideTimeout := c.Duration("wait-timeout")

View file

@ -13,11 +13,12 @@ import (
)
func init() {
RegisterDriver(ClickHouseDriver{}, "clickhouse")
RegisterDriver(&ClickHouseDriver{}, "clickhouse")
}
// ClickHouseDriver provides top level database functions
type ClickHouseDriver struct {
migrationsTableName string
}
func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
@ -52,12 +53,17 @@ func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
return &u
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *ClickHouseDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection
func (drv ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) {
func (drv *ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("clickhouse", normalizeClickHouseURL(u).String())
}
func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
func (drv *ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
// connect to clickhouse database
clickhouseURL := normalizeClickHouseURL(u)
values := clickhouseURL.Query()
@ -67,7 +73,7 @@ func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
return drv.Open(clickhouseURL)
}
func (drv ClickHouseDriver) databaseName(u *url.URL) string {
func (drv *ClickHouseDriver) databaseName(u *url.URL) string {
name := normalizeClickHouseURL(u).Query().Get("database")
if name == "" {
name = "default"
@ -77,7 +83,7 @@ func (drv ClickHouseDriver) databaseName(u *url.URL) string {
var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`)
func clickhouseQuoteIdentifier(str string) string {
func (drv *ClickHouseDriver) quoteIdentifier(str string) string {
if clickhouseValidIdentifier.MatchString(str) {
return str
}
@ -88,7 +94,7 @@ func clickhouseQuoteIdentifier(str string) string {
}
// CreateDatabase creates the specified database
func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error {
func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error {
name := drv.databaseName(u)
fmt.Printf("Creating: %s\n", name)
@ -98,13 +104,13 @@ func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error {
}
defer mustClose(db)
_, err = db.Exec("create database " + clickhouseQuoteIdentifier(name))
_, err = db.Exec("create database " + drv.quoteIdentifier(name))
return err
}
// DropDatabase drops the specified database (if it exists)
func (drv ClickHouseDriver) DropDatabase(u *url.URL) error {
func (drv *ClickHouseDriver) DropDatabase(u *url.URL) error {
name := drv.databaseName(u)
fmt.Printf("Dropping: %s\n", name)
@ -114,15 +120,15 @@ func (drv ClickHouseDriver) DropDatabase(u *url.URL) error {
}
defer mustClose(db)
_, err = db.Exec("drop database if exists " + clickhouseQuoteIdentifier(name))
_, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name))
return err
}
func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
buf.WriteString("\n--\n-- Database schema\n--\n\n")
buf.WriteString("CREATE DATABASE " + clickhouseQuoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
buf.WriteString("CREATE DATABASE " + drv.quoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
tables, err := queryColumn(db, "show tables")
if err != nil {
@ -132,7 +138,7 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er
for _, table := range tables {
var clause string
err = db.QueryRow("show create table " + clickhouseQuoteIdentifier(table)).Scan(&clause)
err = db.QueryRow("show create table " + drv.quoteIdentifier(table)).Scan(&clause)
if err != nil {
return err
}
@ -141,10 +147,13 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er
return nil
}
func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations
migrations, err := queryColumn(db,
"select version from schema_migrations final where applied order by version asc",
fmt.Sprintf("select version from %s final ", migrationsTable)+
"where applied order by version asc",
)
if err != nil {
return err
@ -155,29 +164,30 @@ func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
migrations[i] = "'" + quoter.Replace(migrations[i]) + "'"
}
// build schema_migrations table data
// build schema migrations table data
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n")
if len(migrations) > 0 {
buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" +
strings.Join(migrations, "),\n (") +
");\n")
buf.WriteString(
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
strings.Join(migrations, "),\n (") +
");\n")
}
return nil
}
// DumpSchema returns the current database schema
func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
var buf bytes.Buffer
var err error
err = clickhouseSchemaDump(db, &buf, drv.databaseName(u))
err = drv.schemaDump(db, &buf, drv.databaseName(u))
if err != nil {
return nil, err
}
err = clickhouseSchemaMigrationsDump(db, &buf)
err = drv.schemaMigrationsDump(db, &buf)
if err != nil {
return nil, err
}
@ -186,7 +196,7 @@ func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
}
// DatabaseExists determines whether the database exists
func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
name := drv.databaseName(u)
db, err := drv.openClickHouseDB(u)
@ -205,24 +215,27 @@ func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
return exists, err
}
// CreateMigrationsTable creates the schema_migrations table
func (drv ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec(`
create table if not exists schema_migrations (
// CreateMigrationsTable creates the schema migrations table
func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec(fmt.Sprintf(`
create table if not exists %s (
version String,
ts DateTime default now(),
applied UInt8 default 1
) engine = ReplacingMergeTree(ts)
primary key version
order by version
`)
`, drv.quotedMigrationsTableName()))
return err
}
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations final where applied order by version desc"
func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := fmt.Sprintf("select version from %s final where applied order by version desc",
drv.quotedMigrationsTableName())
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
}
@ -251,15 +264,19 @@ func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]
}
// InsertMigration adds a new migration record
func (drv ClickHouseDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
return err
}
// DeleteMigration removes a migration record
func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) error {
func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec(
"insert into schema_migrations (version, applied) values (?, ?)",
fmt.Sprintf("insert into %s (version, applied) values (?, ?)",
drv.quotedMigrationsTableName()),
version, false,
)
@ -268,7 +285,7 @@ func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) erro
// Ping verifies a connection to the database server. It does not verify whether the
// specified database exists.
func (drv ClickHouseDriver) Ping(u *url.URL) error {
func (drv *ClickHouseDriver) Ping(u *url.URL) error {
// attempt connection to primary database, not "clickhouse" database
// to support servers with no "clickhouse" database
// (see https://github.com/amacneil/dbmate/issues/78)
@ -291,3 +308,7 @@ func (drv ClickHouseDriver) Ping(u *url.URL) error {
return err
}
func (drv *ClickHouseDriver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}

View file

@ -15,8 +15,15 @@ func clickhouseTestURL(t *testing.T) *url.URL {
return u
}
func testClickHouseDriver() *ClickHouseDriver {
drv := &ClickHouseDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestClickHouseDB(t *testing.T, u *url.URL) *sql.DB {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
// drop any existing database
err := drv.DropDatabase(u)
@ -50,7 +57,7 @@ func TestNormalizeClickHouseURLCanonical(t *testing.T) {
}
func TestClickHouseCreateDropDatabase(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
// drop any existing database
@ -87,7 +94,9 @@ func TestClickHouseCreateDropDatabase(t *testing.T) {
}
func TestClickHouseDumpSchema(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t)
// prepare database
@ -113,11 +122,11 @@ func TestClickHouseDumpSchema(t *testing.T) {
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".schema_migrations")
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".test_migrations")
require.Contains(t, string(schema), "--\n"+
"-- Dbmate schema migrations\n"+
"--\n\n"+
"INSERT INTO schema_migrations (version) VALUES\n"+
"INSERT INTO test_migrations (version) VALUES\n"+
" ('abc1'),\n"+
" ('abc2');\n")
@ -134,7 +143,7 @@ func TestClickHouseDumpSchema(t *testing.T) {
}
func TestClickHouseDatabaseExists(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
// drop any existing database
@ -157,7 +166,7 @@ func TestClickHouseDatabaseExists(t *testing.T) {
}
func TestClickHouseDatabaseExists_Error(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
values := u.Query()
values.Set("username", "invalid")
@ -169,31 +178,61 @@ func TestClickHouseDatabaseExists_Error(t *testing.T) {
}
func TestClickHouseCreateMigrationsTable(t *testing.T) {
drv := ClickHouseDriver{}
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
t.Run("default table", func(t *testing.T) {
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.")
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.")
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
t.Run("custom table", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("testMigrations")
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
require.EqualError(t, err, "code: 60, message: Table dbmate.testMigrations doesn't exist.")
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
}
func TestClickHouseSelectMigrations(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
@ -203,7 +242,7 @@ func TestClickHouseSelectMigrations(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)")
stmt, err := tx.Prepare("insert into test_migrations (version) values (?)")
require.NoError(t, err)
_, err = stmt.Exec("abc2")
require.NoError(t, err)
@ -229,7 +268,9 @@ func TestClickHouseSelectMigrations(t *testing.T) {
}
func TestClickHouseInsertMigration(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
@ -238,7 +279,7 @@ func TestClickHouseInsertMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
require.Equal(t, 0, count)
@ -250,13 +291,15 @@ func TestClickHouseInsertMigration(t *testing.T) {
err = tx.Commit()
require.NoError(t, err)
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestClickHouseDeleteMigration(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
@ -266,7 +309,7 @@ func TestClickHouseDeleteMigration(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)")
stmt, err := tx.Prepare("insert into test_migrations (version) values (?)")
require.NoError(t, err)
_, err = stmt.Exec("abc2")
require.NoError(t, err)
@ -283,13 +326,13 @@ func TestClickHouseDeleteMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from schema_migrations final where applied").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations final where applied").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestClickHousePing(t *testing.T) {
drv := ClickHouseDriver{}
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
// drop any existing database
@ -306,3 +349,27 @@ func TestClickHousePing(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused")
}
func TestClickHouseQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testClickHouseDriver()
name := drv.quotedMigrationsTableName()
require.Equal(t, "schema_migrations", name)
})
t.Run("custom name", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("fooMigrations")
name := drv.quotedMigrationsTableName()
require.Equal(t, "fooMigrations", name)
})
t.Run("quoted name", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("bizarre\"$name")
name := drv.quotedMigrationsTableName()
require.Equal(t, `"bizarre""$name"`, name)
})
}

View file

@ -15,6 +15,9 @@ import (
// DefaultMigrationsDir specifies default directory to find migration files
const DefaultMigrationsDir = "./db/migrations"
// DefaultMigrationsTableName specifies default database tables to record migraitons in
const DefaultMigrationsTableName = "schema_migrations"
// DefaultSchemaFile specifies default location for schema.sql
const DefaultSchemaFile = "./db/schema.sql"
@ -26,14 +29,15 @@ const DefaultWaitTimeout = 60 * time.Second
// DB allows dbmate actions to be performed on a specified database
type DB struct {
AutoDumpSchema bool
DatabaseURL *url.URL
MigrationsDir string
SchemaFile string
Verbose bool
WaitBefore bool
WaitInterval time.Duration
WaitTimeout time.Duration
AutoDumpSchema bool
DatabaseURL *url.URL
MigrationsDir string
MigrationsTableName string
SchemaFile string
Verbose bool
WaitBefore bool
WaitInterval time.Duration
WaitTimeout time.Duration
}
// migrationFileRegexp pattern for valid migration files
@ -47,19 +51,27 @@ type statusResult struct {
// New initializes a new dbmate database
func New(databaseURL *url.URL) *DB {
return &DB{
AutoDumpSchema: true,
DatabaseURL: databaseURL,
MigrationsDir: DefaultMigrationsDir,
SchemaFile: DefaultSchemaFile,
WaitBefore: false,
WaitInterval: DefaultWaitInterval,
WaitTimeout: DefaultWaitTimeout,
AutoDumpSchema: true,
DatabaseURL: databaseURL,
MigrationsDir: DefaultMigrationsDir,
MigrationsTableName: DefaultMigrationsTableName,
SchemaFile: DefaultSchemaFile,
WaitBefore: false,
WaitInterval: DefaultWaitInterval,
WaitTimeout: DefaultWaitTimeout,
}
}
// GetDriver loads the required database driver
func (db *DB) GetDriver() (Driver, error) {
return GetDriver(db.DatabaseURL.Scheme)
drv, err := getDriver(db.DatabaseURL.Scheme)
if err != nil {
return nil, err
}
drv.SetMigrationsTableName(db.MigrationsTableName)
return drv, err
}
// Wait blocks until the database server is available. It does not verify that

View file

@ -38,12 +38,26 @@ func TestNew(t *testing.T) {
require.True(t, db.AutoDumpSchema)
require.Equal(t, u.String(), db.DatabaseURL.String())
require.Equal(t, "./db/migrations", db.MigrationsDir)
require.Equal(t, "schema_migrations", db.MigrationsTableName)
require.Equal(t, "./db/schema.sql", db.SchemaFile)
require.False(t, db.WaitBefore)
require.Equal(t, time.Second, db.WaitInterval)
require.Equal(t, 60*time.Second, db.WaitTimeout)
}
func TestGetDriver(t *testing.T) {
u := postgresTestURL(t)
db := New(u)
drv, err := db.GetDriver()
require.NoError(t, err)
// driver should have default migrations table set
pgDrv, ok := drv.(*PostgresDriver)
require.True(t, ok)
require.Equal(t, "schema_migrations", pgDrv.migrationsTableName)
}
func TestWait(t *testing.T) {
u := postgresTestURL(t)
db := newTestDB(t, u)
@ -242,7 +256,7 @@ func testMigrateURL(t *testing.T, u *url.URL) {
require.NoError(t, err)
// verify results
sqlDB, err := GetDriverOpen(u)
sqlDB, err := getDriverOpen(u)
require.NoError(t, err)
defer mustClose(sqlDB)
@ -275,7 +289,7 @@ func testUpURL(t *testing.T, u *url.URL) {
require.NoError(t, err)
// verify results
sqlDB, err := GetDriverOpen(u)
sqlDB, err := getDriverOpen(u)
require.NoError(t, err)
defer mustClose(sqlDB)
@ -308,7 +322,7 @@ func testRollbackURL(t *testing.T, u *url.URL) {
require.NoError(t, err)
// verify migration
sqlDB, err := GetDriverOpen(u)
sqlDB, err := getDriverOpen(u)
require.NoError(t, err)
defer mustClose(sqlDB)
@ -351,7 +365,7 @@ func testStatusURL(t *testing.T, u *url.URL) {
require.NoError(t, err)
// verify migration
sqlDB, err := GetDriverOpen(u)
sqlDB, err := getDriverOpen(u)
require.NoError(t, err)
defer mustClose(sqlDB)

View file

@ -13,6 +13,7 @@ type Driver interface {
CreateDatabase(*url.URL) error
DropDatabase(*url.URL) error
DumpSchema(*url.URL, *sql.DB) ([]byte, error)
SetMigrationsTableName(string)
CreateMigrationsTable(*url.URL, *sql.DB) error
SelectMigrations(*sql.DB, int) (map[string]bool, error)
InsertMigration(Transaction, string) error
@ -34,18 +35,20 @@ type Transaction interface {
QueryRow(query string, args ...interface{}) *sql.Row
}
// GetDriver loads a database driver by name
func GetDriver(name string) (Driver, error) {
if val, ok := drivers[name]; ok {
return val, nil
// getDriver loads a database driver by name
func getDriver(name string) (Driver, error) {
if drv, ok := drivers[name]; ok {
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv, nil
}
return nil, fmt.Errorf("unsupported driver: %s", name)
}
// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
func GetDriverOpen(u *url.URL) (*sql.DB, error) {
drv, err := GetDriver(u.Scheme)
// getDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
func getDriverOpen(u *url.URL) (*sql.DB, error) {
drv, err := getDriver(u.Scheme)
if err != nil {
return nil, err
}

View file

@ -7,21 +7,21 @@ import (
)
func TestGetDriver_Postgres(t *testing.T) {
drv, err := GetDriver("postgres")
drv, err := getDriver("postgres")
require.NoError(t, err)
_, ok := drv.(PostgresDriver)
_, ok := drv.(*PostgresDriver)
require.Equal(t, true, ok)
}
func TestGetDriver_MySQL(t *testing.T) {
drv, err := GetDriver("mysql")
drv, err := getDriver("mysql")
require.NoError(t, err)
_, ok := drv.(MySQLDriver)
_, ok := drv.(*MySQLDriver)
require.Equal(t, true, ok)
}
func TestGetDriver_Error(t *testing.T) {
drv, err := GetDriver("foo")
drv, err := getDriver("foo")
require.EqualError(t, err, "unsupported driver: foo")
require.Nil(t, drv)
}

View file

@ -11,11 +11,12 @@ import (
)
func init() {
RegisterDriver(MySQLDriver{}, "mysql")
RegisterDriver(&MySQLDriver{}, "mysql")
}
// MySQLDriver provides top level database functions
type MySQLDriver struct {
migrationsTableName string
}
func normalizeMySQLURL(u *url.URL) string {
@ -52,12 +53,17 @@ func normalizeMySQLURL(u *url.URL) string {
return normalizedString
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *MySQLDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection
func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
func (drv *MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("mysql", normalizeMySQLURL(u))
}
func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
func (drv *MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
// connect to no particular database
rootURL := *u
rootURL.Path = "/"
@ -65,14 +71,14 @@ func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
return drv.Open(&rootURL)
}
func mysqlQuoteIdentifier(str string) string {
func (drv *MySQLDriver) quoteIdentifier(str string) string {
str = strings.Replace(str, "`", "\\`", -1)
return fmt.Sprintf("`%s`", str)
}
// CreateDatabase creates the specified database
func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
func (drv *MySQLDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Creating: %s\n", name)
@ -83,13 +89,13 @@ func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
defer mustClose(db)
_, err = db.Exec(fmt.Sprintf("create database %s",
mysqlQuoteIdentifier(name)))
drv.quoteIdentifier(name)))
return err
}
// DropDatabase drops the specified database (if it exists)
func (drv MySQLDriver) DropDatabase(u *url.URL) error {
func (drv *MySQLDriver) DropDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Dropping: %s\n", name)
@ -100,12 +106,12 @@ func (drv MySQLDriver) DropDatabase(u *url.URL) error {
defer mustClose(db)
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
mysqlQuoteIdentifier(name)))
drv.quoteIdentifier(name)))
return err
}
func mysqldumpArgs(u *url.URL) []string {
func (drv *MySQLDriver) mysqldumpArgs(u *url.URL) []string {
// generate CLI arguments
args := []string{"--opt", "--routines", "--no-data",
"--skip-dump-date", "--skip-add-drop-table"}
@ -131,10 +137,12 @@ func mysqldumpArgs(u *url.URL) []string {
return args
}
func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations
migrations, err := queryColumn(db,
"select quote(version) from schema_migrations order by version asc")
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
if err != nil {
return nil, err
}
@ -142,12 +150,13 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
// build schema_migrations table data
var buf bytes.Buffer
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" +
"LOCK TABLES `schema_migrations` WRITE;\n")
fmt.Sprintf("LOCK TABLES %s WRITE;\n", migrationsTable))
if len(migrations) > 0 {
buf.WriteString("INSERT INTO `schema_migrations` (version) VALUES\n (" +
strings.Join(migrations, "),\n (") +
");\n")
buf.WriteString(
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
strings.Join(migrations, "),\n (") +
");\n")
}
buf.WriteString("UNLOCK TABLES;\n")
@ -156,13 +165,13 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
}
// DumpSchema returns the current database schema
func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
schema, err := runCommand("mysqldump", mysqldumpArgs(u)...)
func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
schema, err := runCommand("mysqldump", drv.mysqldumpArgs(u)...)
if err != nil {
return nil, err
}
migrations, err := mysqlSchemaMigrationsDump(db)
migrations, err := drv.schemaMigrationsDump(db)
if err != nil {
return nil, err
}
@ -172,7 +181,7 @@ func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
}
// DatabaseExists determines whether the database exists
func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u)
db, err := drv.openRootDB(u)
@ -192,17 +201,18 @@ func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema_migrations table
func (drv MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec("create table if not exists schema_migrations " +
"(version varchar(255) primary key) character set latin1 collate latin1_bin")
func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec(fmt.Sprintf("create table if not exists %s "+
"(version varchar(255) primary key) character set latin1 collate latin1_bin",
drv.quotedMigrationsTableName()))
return err
}
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations order by version desc"
func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
}
@ -231,22 +241,26 @@ func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool,
}
// InsertMigration adds a new migration record
func (drv MySQLDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
return err
}
// DeleteMigration removes a migration record
func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec("delete from schema_migrations where version = ?", version)
func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
version)
return err
}
// Ping verifies a connection to the database server. It does not verify whether the
// specified database exists.
func (drv MySQLDriver) Ping(u *url.URL) error {
func (drv *MySQLDriver) Ping(u *url.URL) error {
db, err := drv.openRootDB(u)
if err != nil {
return err
@ -255,3 +269,7 @@ func (drv MySQLDriver) Ping(u *url.URL) error {
return db.Ping()
}
func (drv *MySQLDriver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}

View file

@ -15,8 +15,15 @@ func mySQLTestURL(t *testing.T) *url.URL {
return u
}
func testMySQLDriver() *MySQLDriver {
drv := &MySQLDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestMySQLDB(t *testing.T, u *url.URL) *sql.DB {
drv := MySQLDriver{}
drv := testMySQLDriver()
// drop any existing database
err := drv.DropDatabase(u)
@ -78,7 +85,7 @@ func TestNormalizeMySQLURLSocket(t *testing.T) {
}
func TestMySQLCreateDropDatabase(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
u := mySQLTestURL(t)
// drop any existing database
@ -116,7 +123,9 @@ func TestMySQLCreateDropDatabase(t *testing.T) {
}
func TestMySQLDumpSchema(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t)
// prepare database
@ -134,13 +143,13 @@ func TestMySQLDumpSchema(t *testing.T) {
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE `schema_migrations`")
require.Contains(t, string(schema), "CREATE TABLE `test_migrations`")
require.Contains(t, string(schema), "\n-- Dump completed\n\n"+
"--\n"+
"-- Dbmate schema migrations\n"+
"--\n\n"+
"LOCK TABLES `schema_migrations` WRITE;\n"+
"INSERT INTO `schema_migrations` (version) VALUES\n"+
"LOCK TABLES `test_migrations` WRITE;\n"+
"INSERT INTO `test_migrations` (version) VALUES\n"+
" ('abc1'),\n"+
" ('abc2');\n"+
"UNLOCK TABLES;\n")
@ -156,7 +165,7 @@ func TestMySQLDumpSchema(t *testing.T) {
}
func TestMySQLDatabaseExists(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
u := mySQLTestURL(t)
// drop any existing database
@ -179,7 +188,7 @@ func TestMySQLDatabaseExists(t *testing.T) {
}
func TestMySQLDatabaseExists_Error(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
u := mySQLTestURL(t)
u.User = url.User("invalid")
@ -189,22 +198,25 @@ func TestMySQLDatabaseExists_Error(t *testing.T) {
}
func TestMySQLCreateMigrationsTable(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.Regexp(t, "Table 'dbmate.schema_migrations' doesn't exist", err.Error())
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.Error(t, err)
require.Regexp(t, "Table 'dbmate.test_migrations' doesn't exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
// create table should be idempotent
@ -213,7 +225,9 @@ func TestMySQLCreateMigrationsTable(t *testing.T) {
}
func TestMySQLSelectMigrations(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
@ -221,7 +235,7 @@ func TestMySQLSelectMigrations(t *testing.T) {
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version)
_, err = db.Exec(`insert into test_migrations (version)
values ('abc2'), ('abc1'), ('abc3')`)
require.NoError(t, err)
@ -240,7 +254,9 @@ func TestMySQLSelectMigrations(t *testing.T) {
}
func TestMySQLInsertMigration(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
@ -249,7 +265,7 @@ func TestMySQLInsertMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
require.Equal(t, 0, count)
@ -257,14 +273,16 @@ func TestMySQLInsertMigration(t *testing.T) {
err = drv.InsertMigration(db, "abc1")
require.NoError(t, err)
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").
err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").
Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestMySQLDeleteMigration(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
@ -272,7 +290,7 @@ func TestMySQLDeleteMigration(t *testing.T) {
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version)
_, err = db.Exec(`insert into test_migrations (version)
values ('abc1'), ('abc2')`)
require.NoError(t, err)
@ -280,13 +298,13 @@ func TestMySQLDeleteMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestMySQLPing(t *testing.T) {
drv := MySQLDriver{}
drv := testMySQLDriver()
u := mySQLTestURL(t)
// drop any existing database
@ -303,3 +321,19 @@ func TestMySQLPing(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused")
}
func TestMySQLQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testMySQLDriver()
name := drv.quotedMigrationsTableName()
require.Equal(t, "`schema_migrations`", name)
})
t.Run("custom name", func(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("fooMigrations")
name := drv.quotedMigrationsTableName()
require.Equal(t, "`fooMigrations`", name)
})
}

View file

@ -11,12 +11,14 @@ import (
)
func init() {
RegisterDriver(PostgresDriver{}, "postgres")
RegisterDriver(PostgresDriver{}, "postgresql")
drv := &PostgresDriver{}
RegisterDriver(drv, "postgres")
RegisterDriver(drv, "postgresql")
}
// PostgresDriver provides top level database functions
type PostgresDriver struct {
migrationsTableName string
}
func normalizePostgresURL(u *url.URL) *url.URL {
@ -78,12 +80,17 @@ func normalizePostgresURLForDump(u *url.URL) []string {
return out
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *PostgresDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
func (drv *PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("postgres", normalizePostgresURL(u).String())
}
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
func (drv *PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
// connect to postgres database
postgresURL := *u
postgresURL.Path = "postgres"
@ -92,7 +99,7 @@ func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
}
// CreateDatabase creates the specified database
func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
func (drv *PostgresDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Creating: %s\n", name)
@ -109,7 +116,7 @@ func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
}
// DropDatabase drops the specified database (if it exists)
func (drv PostgresDriver) DropDatabase(u *url.URL) error {
func (drv *PostgresDriver) DropDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Dropping: %s\n", name)
@ -125,8 +132,8 @@ func (drv PostgresDriver) DropDatabase(u *url.URL) error {
return err
}
func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable, err := drv.migrationsTableName(db)
func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return nil, err
}
@ -152,7 +159,7 @@ func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, erro
}
// DumpSchema returns the current database schema
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
// load schema
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
@ -161,7 +168,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
return nil, err
}
migrations, err := drv.postgresSchemaMigrationsDump(db)
migrations, err := drv.schemaMigrationsDump(db)
if err != nil {
return nil, err
}
@ -171,7 +178,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
}
// DatabaseExists determines whether the database exists
func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u)
db, err := drv.openPostgresDB(u)
@ -191,48 +198,45 @@ func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema_migrations table
func (drv PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
// get schema from URL search_path param
searchPath := strings.Split(u.Query().Get("search_path"), ",")
urlSchema := strings.TrimSpace(searchPath[0])
if urlSchema == "" {
urlSchema = "public"
}
// get *unquoted* current schema from database
dbSchema, err := queryRow(db, "select current_schema()")
func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db, u)
if err != nil {
return err
}
// if urlSchema and dbSchema are not equal, the most likely explanation is that the schema
// has not yet been created
if urlSchema != dbSchema {
// in theory we could just execute this statement every time, but we do the comparison
// above in case the user doesn't have permissions to create schemas and the schema
// already exists
fmt.Printf("Creating schema: %s\n", urlSchema)
_, err = db.Exec("create schema if not exists " + pq.QuoteIdentifier(urlSchema))
if err != nil {
return err
}
// first attempt at creating migrations table
createTableStmt := fmt.Sprintf("create table if not exists %s.%s", schema, migrationsTable) +
" (version varchar(255) primary key)"
_, err = db.Exec(createTableStmt)
if err == nil {
// table exists or created successfully
return nil
}
migrationsTable, err := drv.migrationsTableName(db)
// catch 'schema does not exist' error
pqErr, ok := err.(*pq.Error)
if !ok || pqErr.Code != "3F000" {
// unknown error
return err
}
// in theory we could attempt to create the schema every time, but we avoid that
// in case the user doesn't have permissions to create schemas
fmt.Printf("Creating schema: %s\n", schema)
_, err = db.Exec(fmt.Sprintf("create schema if not exists %s", schema))
if err != nil {
return err
}
_, err = db.Exec("create table if not exists " + migrationsTable +
" (version varchar(255) primary key)")
// second and final attempt at creating migrations table
_, err = db.Exec(createTableStmt)
return err
}
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
migrationsTable, err := drv.migrationsTableName(db)
func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return nil, err
}
@ -266,8 +270,8 @@ func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bo
}
// InsertMigration adds a new migration record
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
migrationsTable, err := drv.migrationsTableName(db)
func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return err
}
@ -278,8 +282,8 @@ func (drv PostgresDriver) InsertMigration(db Transaction, version string) error
}
// DeleteMigration removes a migration record
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
migrationsTable, err := drv.migrationsTableName(db)
func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return err
}
@ -291,7 +295,7 @@ func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error
// Ping verifies a connection to the database server. It does not verify whether the
// specified database exists.
func (drv PostgresDriver) Ping(u *url.URL) error {
func (drv *PostgresDriver) Ping(u *url.URL) error {
// attempt connection to primary database, not "postgres" database
// to support servers with no "postgres" database
// (see https://github.com/amacneil/dbmate/issues/78)
@ -306,7 +310,7 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
return nil
}
// ignore 'database "foo" does not exist' error
// ignore 'database does not exist' error
pqErr, ok := err.(*pq.Error)
if ok && pqErr.Code == "3D000" {
return nil
@ -315,17 +319,53 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
return err
}
func (drv PostgresDriver) migrationsTableName(db Transaction) (string, error) {
// get current schema
schema, err := queryRow(db, "select quote_ident(current_schema())")
func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, error) {
schema, name, err := drv.quotedMigrationsTableNameParts(db, nil)
if err != nil {
return "", err
}
// if the search path is empty, or does not contain a valid schema, default to public
return schema + "." + name, nil
}
func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url.URL) (string, string, error) {
schema := ""
tableNameParts := strings.Split(drv.migrationsTableName, ".")
if len(tableNameParts) > 1 {
// schema specified as part of table name
schema, tableNameParts = tableNameParts[0], tableNameParts[1:]
}
if schema == "" && u != nil {
// no schema specified with table name, try URL search path if available
searchPath := strings.Split(u.Query().Get("search_path"), ",")
schema = strings.TrimSpace(searchPath[0])
}
var err error
if schema == "" {
// if no URL available, use current schema
// this is a hack because we don't always have the URL context available
schema, err = queryValue(db, "select current_schema()")
if err != nil {
return "", "", err
}
}
// fall back to public schema as last resort
if schema == "" {
schema = "public"
}
return schema + ".schema_migrations", nil
// quote all parts
// use server rather than client to do this to avoid unnecessary quotes
// (which would change schema.sql diff)
tableNameParts = append([]string{schema}, tableNameParts...)
quotedNameParts, err := queryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
if err != nil {
return "", "", err
}
// if more than one part, we already have a schema
return quotedNameParts[0], strings.Join(quotedNameParts[1:], "."), nil
}

View file

@ -15,8 +15,15 @@ func postgresTestURL(t *testing.T) *url.URL {
return u
}
func testPostgresDriver() *PostgresDriver {
drv := &PostgresDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB {
drv := PostgresDriver{}
drv := testPostgresDriver()
// drop any existing database
err := drv.DropDatabase(u)
@ -87,7 +94,7 @@ func TestNormalizePostgresURLForDump(t *testing.T) {
}
func TestPostgresCreateDropDatabase(t *testing.T) {
drv := PostgresDriver{}
drv := testPostgresDriver()
u := postgresTestURL(t)
// drop any existing database
@ -125,45 +132,80 @@ func TestPostgresCreateDropDatabase(t *testing.T) {
}
func TestPostgresDumpSchema(t *testing.T) {
drv := PostgresDriver{}
u := postgresTestURL(t)
t.Run("default migrations table", func(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
// prepare database
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// prepare database
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// insert migration
err = drv.InsertMigration(db, "abc1")
require.NoError(t, err)
err = drv.InsertMigration(db, "abc2")
require.NoError(t, err)
// insert migration
err = drv.InsertMigration(db, "abc1")
require.NoError(t, err)
err = drv.InsertMigration(db, "abc2")
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations")
require.Contains(t, string(schema), "\n--\n"+
"-- PostgreSQL database dump complete\n"+
"--\n\n\n"+
"--\n"+
"-- Dbmate schema migrations\n"+
"--\n\n"+
"INSERT INTO public.schema_migrations (version) VALUES\n"+
" ('abc1'),\n"+
" ('abc2');\n")
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations")
require.Contains(t, string(schema), "\n--\n"+
"-- PostgreSQL database dump complete\n"+
"--\n\n\n"+
"--\n"+
"-- Dbmate schema migrations\n"+
"--\n\n"+
"INSERT INTO public.schema_migrations (version) VALUES\n"+
" ('abc1'),\n"+
" ('abc2');\n")
// DumpSchema should return error if command fails
u.Path = "/fakedb"
schema, err = drv.DumpSchema(u, db)
require.Nil(t, schema)
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
// DumpSchema should return error if command fails
u.Path = "/fakedb"
schema, err = drv.DumpSchema(u, db)
require.Nil(t, schema)
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
})
t.Run("custom migrations table with schema", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("camelSchema.testMigrations")
u := postgresTestURL(t)
// prepare database
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// insert migration
err = drv.InsertMigration(db, "abc1")
require.NoError(t, err)
err = drv.InsertMigration(db, "abc2")
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE \"camelSchema\".\"testMigrations\"")
require.Contains(t, string(schema), "\n--\n"+
"-- PostgreSQL database dump complete\n"+
"--\n\n\n"+
"--\n"+
"-- Dbmate schema migrations\n"+
"--\n\n"+
"INSERT INTO \"camelSchema\".\"testMigrations\" (version) VALUES\n"+
" ('abc1'),\n"+
" ('abc2');\n")
})
}
func TestPostgresDatabaseExists(t *testing.T) {
drv := PostgresDriver{}
drv := testPostgresDriver()
u := postgresTestURL(t)
// drop any existing database
@ -186,7 +228,7 @@ func TestPostgresDatabaseExists(t *testing.T) {
}
func TestPostgresDatabaseExists_Error(t *testing.T) {
drv := PostgresDriver{}
drv := testPostgresDriver()
u := postgresTestURL(t)
u.User = url.User("invalid")
@ -197,9 +239,8 @@ func TestPostgresDatabaseExists_Error(t *testing.T) {
}
func TestPostgresCreateMigrationsTable(t *testing.T) {
drv := PostgresDriver{}
t.Run("default schema", func(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
@ -223,39 +264,81 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.NoError(t, err)
})
t.Run("custom schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
t.Run("custom search path", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("testMigrations")
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=camelFoo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// delete schema
_, err = db.Exec("drop schema if exists foo")
_, err = db.Exec("drop schema if exists \"camelFoo\"")
require.NoError(t, err)
// drop any schema_migrations table in public schema
_, err = db.Exec("drop table if exists public.schema_migrations")
// drop any testMigrations table in public schema
_, err = db.Exec("drop table if exists public.\"testMigrations\"")
require.NoError(t, err)
// migrations table should not exist in either schema
count := 0
err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count)
require.Error(t, err)
require.Equal(t, "pq: relation \"foo.schema_migrations\" does not exist", err.Error())
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
require.Equal(t, "pq: relation \"camelFoo.testMigrations\" does not exist", err.Error())
err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count)
require.Error(t, err)
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// foo schema should be created, and migrations table should exist only in foo schema
err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count)
// camelFoo schema should be created, and migrations table should exist only in camelFoo schema
err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count)
require.NoError(t, err)
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count)
require.Error(t, err)
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
t.Run("custom schema", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("camelSchema.testMigrations")
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// delete schemas
_, err = db.Exec("drop schema if exists foo")
require.NoError(t, err)
_, err = db.Exec("drop schema if exists \"camelSchema\"")
require.NoError(t, err)
// migrations table should not exist
count := 0
err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count)
require.Error(t, err)
require.Equal(t, "pq: relation \"camelSchema.testMigrations\" does not exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// camelSchema should be created, and testMigrations table should exist
err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count)
require.NoError(t, err)
// testMigrations table should not exist in foo schema because
// schema specified with migrations table name takes priority over search path
err = db.QueryRow("select count(*) from foo.\"testMigrations\"").Scan(&count)
require.Error(t, err)
require.Equal(t, "pq: relation \"foo.testMigrations\" does not exist", err.Error())
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
@ -264,7 +347,9 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
}
func TestPostgresSelectMigrations(t *testing.T) {
drv := PostgresDriver{}
drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
@ -272,7 +357,7 @@ func TestPostgresSelectMigrations(t *testing.T) {
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
_, err = db.Exec(`insert into public.schema_migrations (version)
_, err = db.Exec(`insert into public.test_migrations (version)
values ('abc2'), ('abc1'), ('abc3')`)
require.NoError(t, err)
@ -291,7 +376,9 @@ func TestPostgresSelectMigrations(t *testing.T) {
}
func TestPostgresInsertMigration(t *testing.T) {
drv := PostgresDriver{}
drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
@ -300,7 +387,7 @@ func TestPostgresInsertMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count)
require.NoError(t, err)
require.Equal(t, 0, count)
@ -308,14 +395,16 @@ func TestPostgresInsertMigration(t *testing.T) {
err = drv.InsertMigration(db, "abc1")
require.NoError(t, err)
err = db.QueryRow("select count(*) from public.schema_migrations where version = 'abc1'").
err = db.QueryRow("select count(*) from public.test_migrations where version = 'abc1'").
Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestPostgresDeleteMigration(t *testing.T) {
drv := PostgresDriver{}
drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
@ -323,7 +412,7 @@ func TestPostgresDeleteMigration(t *testing.T) {
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
_, err = db.Exec(`insert into public.schema_migrations (version)
_, err = db.Exec(`insert into public.test_migrations (version)
values ('abc1'), ('abc2')`)
require.NoError(t, err)
@ -331,13 +420,13 @@ func TestPostgresDeleteMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestPostgresPing(t *testing.T) {
drv := PostgresDriver{}
drv := testPostgresDriver()
u := postgresTestURL(t)
// drop any existing database
@ -355,15 +444,15 @@ func TestPostgresPing(t *testing.T) {
require.Contains(t, err.Error(), "connect: connection refused")
}
func TestMigrationsTableName(t *testing.T) {
drv := PostgresDriver{}
func TestPostgresQuotedMigrationsTableName(t *testing.T) {
drv := testPostgresDriver()
t.Run("default schema", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
name, err := drv.migrationsTableName(db)
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name)
})
@ -379,14 +468,14 @@ func TestMigrationsTableName(t *testing.T) {
require.NoError(t, err)
_, err = db.Exec("drop schema if exists bar")
require.NoError(t, err)
name, err := drv.migrationsTableName(db)
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name)
// if "foo" schema exists, it should be used
_, err = db.Exec("create schema foo")
require.NoError(t, err)
name, err = drv.migrationsTableName(db)
name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "foo.schema_migrations", name)
})
@ -401,8 +490,76 @@ func TestMigrationsTableName(t *testing.T) {
_, err := db.Exec("select pg_catalog.set_config('search_path', '', false)")
require.NoError(t, err)
name, err := drv.migrationsTableName(db)
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name)
})
t.Run("custom table name", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv.SetMigrationsTableName("simple_name")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.simple_name", name)
})
t.Run("custom table name quoted", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// this table name will need quoting
drv.SetMigrationsTableName("camelCase")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.\"camelCase\"", name)
})
t.Run("custom table name with custom schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
_, err = db.Exec("create schema if not exists foo")
require.NoError(t, err)
drv.SetMigrationsTableName("simple_name")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "foo.simple_name", name)
})
t.Run("custom table name overrides schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
_, err = db.Exec("create schema if not exists foo")
require.NoError(t, err)
_, err = db.Exec("create schema if not exists bar")
require.NoError(t, err)
// if schema is specified as part of table name, it should override search_path
drv.SetMigrationsTableName("bar.simple_name")
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "bar.simple_name", name)
// schema and table name should be quoted if necessary
drv.SetMigrationsTableName("barName.camelTable")
name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "\"barName\".\"camelTable\"", name)
// more than 2 components is unexpected but we will quote and pass it along anyway
drv.SetMigrationsTableName("whyWould.i.doThis")
name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "\"whyWould\".i.\"doThis\"", name)
})
}

View file

@ -11,16 +11,19 @@ import (
"regexp"
"strings"
"github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql
)
func init() {
RegisterDriver(SQLiteDriver{}, "sqlite")
RegisterDriver(SQLiteDriver{}, "sqlite3")
drv := &SQLiteDriver{}
RegisterDriver(drv, "sqlite")
RegisterDriver(drv, "sqlite3")
}
// SQLiteDriver provides top level database functions
type SQLiteDriver struct {
migrationsTableName string
}
func sqlitePath(u *url.URL) string {
@ -31,13 +34,18 @@ func sqlitePath(u *url.URL) string {
return str
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *SQLiteDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection
func (drv SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
func (drv *SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("sqlite3", sqlitePath(u))
}
// CreateDatabase creates the specified database
func (drv SQLiteDriver) CreateDatabase(u *url.URL) error {
func (drv *SQLiteDriver) CreateDatabase(u *url.URL) error {
fmt.Printf("Creating: %s\n", sqlitePath(u))
db, err := drv.Open(u)
@ -50,7 +58,7 @@ func (drv SQLiteDriver) CreateDatabase(u *url.URL) error {
}
// DropDatabase drops the specified database (if it exists)
func (drv SQLiteDriver) DropDatabase(u *url.URL) error {
func (drv *SQLiteDriver) DropDatabase(u *url.URL) error {
path := sqlitePath(u)
fmt.Printf("Dropping: %s\n", path)
@ -65,36 +73,39 @@ func (drv SQLiteDriver) DropDatabase(u *url.URL) error {
return os.Remove(path)
}
func sqliteSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations
migrations, err := queryColumn(db,
"select quote(version) from schema_migrations order by version asc")
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
if err != nil {
return nil, err
}
// build schema_migrations table data
// build schema migrations table data
var buf bytes.Buffer
buf.WriteString("-- Dbmate schema migrations\n")
if len(migrations) > 0 {
buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" +
strings.Join(migrations, "),\n (") +
");\n")
buf.WriteString(
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
strings.Join(migrations, "),\n (") +
");\n")
}
return buf.Bytes(), nil
}
// DumpSchema returns the current database schema
func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
path := sqlitePath(u)
schema, err := runCommand("sqlite3", path, ".schema")
if err != nil {
return nil, err
}
migrations, err := sqliteSchemaMigrationsDump(db)
migrations, err := drv.schemaMigrationsDump(db)
if err != nil {
return nil, err
}
@ -104,7 +115,7 @@ func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
}
// DatabaseExists determines whether the database exists
func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
_, err := os.Stat(sqlitePath(u))
if os.IsNotExist(err) {
return false, nil
@ -116,18 +127,19 @@ func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
return true, nil
}
// CreateMigrationsTable creates the schema_migrations table
func (drv SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec("create table if not exists schema_migrations " +
"(version varchar(255) primary key)")
// CreateMigrationsTable creates the schema migrations table
func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
_, err := db.Exec(
fmt.Sprintf("create table if not exists %s ", drv.quotedMigrationsTableName()) +
"(version varchar(255) primary key)")
return err
}
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations order by version desc"
func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
}
@ -156,15 +168,19 @@ func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool
}
// InsertMigration adds a new migration record
func (drv SQLiteDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
return err
}
// DeleteMigration removes a migration record
func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec("delete from schema_migrations where version = ?", version)
func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
version)
return err
}
@ -172,7 +188,7 @@ func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error {
// Ping verifies a connection to the database. Due to the way SQLite works, by
// testing whether the database is valid, it will automatically create the database
// if it does not already exist.
func (drv SQLiteDriver) Ping(u *url.URL) error {
func (drv *SQLiteDriver) Ping(u *url.URL) error {
db, err := drv.Open(u)
if err != nil {
return err
@ -181,3 +197,14 @@ func (drv SQLiteDriver) Ping(u *url.URL) error {
return db.Ping()
}
func (drv *SQLiteDriver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}
// quoteIdentifier quotes a table or column name
// we fall back to lib/pq implementation since both use ansi standard (double quotes)
// and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent
func (drv *SQLiteDriver) quoteIdentifier(s string) string {
return pq.QuoteIdentifier(s)
}

View file

@ -18,8 +18,15 @@ func sqliteTestURL(t *testing.T) *url.URL {
return u
}
func testSQLiteDriver() *SQLiteDriver {
drv := &SQLiteDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
// drop any existing database
err := drv.DropDatabase(u)
@ -37,7 +44,7 @@ func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
}
func TestSQLiteCreateDropDatabase(t *testing.T) {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
u := sqliteTestURL(t)
path := sqlitePath(u)
@ -64,7 +71,9 @@ func TestSQLiteCreateDropDatabase(t *testing.T) {
}
func TestSQLiteDumpSchema(t *testing.T) {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t)
// prepare database
@ -82,9 +91,9 @@ func TestSQLiteDumpSchema(t *testing.T) {
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE schema_migrations")
require.Contains(t, string(schema), "CREATE TABLE IF NOT EXISTS \"test_migrations\"")
require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+
"INSERT INTO schema_migrations (version) VALUES\n"+
"INSERT INTO \"test_migrations\" (version) VALUES\n"+
" ('abc1'),\n"+
" ('abc2');\n")
@ -97,7 +106,7 @@ func TestSQLiteDumpSchema(t *testing.T) {
}
func TestSQLiteDatabaseExists(t *testing.T) {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
u := sqliteTestURL(t)
// drop any existing database
@ -120,31 +129,61 @@ func TestSQLiteDatabaseExists(t *testing.T) {
}
func TestSQLiteCreateMigrationsTable(t *testing.T) {
drv := SQLiteDriver{}
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
t.Run("default table", func(t *testing.T) {
drv := testSQLiteDriver()
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.Regexp(t, "no such table: schema_migrations", err.Error())
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.Regexp(t, "no such table: schema_migrations", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
t.Run("custom table", func(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.Regexp(t, "no such table: test_migrations", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
// migrations table should exist
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
})
}
func TestSQLiteSelectMigrations(t *testing.T) {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
@ -152,7 +191,7 @@ func TestSQLiteSelectMigrations(t *testing.T) {
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version)
_, err = db.Exec(`insert into test_migrations (version)
values ('abc2'), ('abc1'), ('abc3')`)
require.NoError(t, err)
@ -171,7 +210,9 @@ func TestSQLiteSelectMigrations(t *testing.T) {
}
func TestSQLiteInsertMigration(t *testing.T) {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
@ -180,7 +221,7 @@ func TestSQLiteInsertMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
require.Equal(t, 0, count)
@ -188,14 +229,16 @@ func TestSQLiteInsertMigration(t *testing.T) {
err = drv.InsertMigration(db, "abc1")
require.NoError(t, err)
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").
err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").
Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestSQLiteDeleteMigration(t *testing.T) {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
@ -203,7 +246,7 @@ func TestSQLiteDeleteMigration(t *testing.T) {
err := drv.CreateMigrationsTable(u, db)
require.NoError(t, err)
_, err = db.Exec(`insert into schema_migrations (version)
_, err = db.Exec(`insert into test_migrations (version)
values ('abc1'), ('abc2')`)
require.NoError(t, err)
@ -211,13 +254,13 @@ func TestSQLiteDeleteMigration(t *testing.T) {
require.NoError(t, err)
count := 0
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestSQLitePing(t *testing.T) {
drv := SQLiteDriver{}
drv := testSQLiteDriver()
u := sqliteTestURL(t)
path := sqlitePath(u)
@ -249,3 +292,19 @@ func TestSQLitePing(t *testing.T) {
err = drv.Ping(u)
require.EqualError(t, err, "unable to open database file: is a directory")
}
func TestSQLiteQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testSQLiteDriver()
name := drv.quotedMigrationsTableName()
require.Equal(t, `"schema_migrations"`, name)
})
t.Run("custom name", func(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("fooMigrations")
name := drv.quotedMigrationsTableName()
require.Equal(t, `"fooMigrations"`, name)
})
}

View file

@ -104,8 +104,8 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) {
// queryColumn runs a SQL statement and returns a slice of strings
// it is assumed that the statement returns only one column
// e.g. schema_migrations table
func queryColumn(db Transaction, query string) ([]string, error) {
rows, err := db.Query(query)
func queryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
@ -128,12 +128,12 @@ func queryColumn(db Transaction, query string) ([]string, error) {
return result, nil
}
// queryRow runs a SQL statement and returns a single string
// queryValue runs a SQL statement and returns a single string
// it is assumed that the statement returns only one row and one column
// sql NULL is returned as empty string
func queryRow(db Transaction, query string) (string, error) {
func queryValue(db Transaction, query string, args ...interface{}) (string, error) {
var result sql.NullString
err := db.QueryRow(query).Scan(&result)
err := db.QueryRow(query, args...).Scan(&result)
if err != nil || !result.Valid {
return "", err
}

View file

@ -1,9 +1,11 @@
package dbmate
import (
"database/sql"
"net/url"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
@ -33,3 +35,24 @@ func TestTrimLeadingSQLComments(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "real stuff\n-- end\n", string(out))
}
func TestQueryColumn(t *testing.T) {
u := postgresTestURL(t)
db, err := sql.Open("postgres", u.String())
require.NoError(t, err)
val, err := queryColumn(db, "select concat('foo_', unnest($1::text[]))",
pq.Array([]string{"hi", "there"}))
require.NoError(t, err)
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
}
func TestQueryValue(t *testing.T) {
u := postgresTestURL(t)
db, err := sql.Open("postgres", u.String())
require.NoError(t, err)
val, err := queryValue(db, "select $1::int + $2::int", "5", 2)
require.NoError(t, err)
require.Equal(t, "7", val)
}