mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2025-12-11 23:50:04 +01:00
Ability to specify custom migrations table name (#178)
Supported via `--migrations-table` CLI flag or `DBMATE_MIGRATIONS_TABLE` environment variable. Specified table name is quoted when necessary. For PostgreSQL specifically, it's also possible to specify a custom schema (for example: `--migrations-table=foo.migrations`). Closes #168
This commit is contained in:
parent
656dc0253a
commit
c907c3f5c6
15 changed files with 807 additions and 325 deletions
7
main.go
7
main.go
|
|
@ -52,6 +52,12 @@ func NewApp() *cli.App {
|
|||
Value: dbmate.DefaultMigrationsDir,
|
||||
Usage: "specify the directory containing migration files",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "migrations-table",
|
||||
EnvVars: []string{"DBMATE_MIGRATIONS_TABLE"},
|
||||
Value: dbmate.DefaultMigrationsTableName,
|
||||
Usage: "specify the database table to record migrations in",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "schema-file",
|
||||
Aliases: []string{"s"},
|
||||
|
|
@ -222,6 +228,7 @@ func action(f func(*dbmate.DB, *cli.Context) error) cli.ActionFunc {
|
|||
db := dbmate.New(u)
|
||||
db.AutoDumpSchema = !c.Bool("no-dump-schema")
|
||||
db.MigrationsDir = c.String("migrations-dir")
|
||||
db.MigrationsTableName = c.String("migrations-table")
|
||||
db.SchemaFile = c.String("schema-file")
|
||||
db.WaitBefore = c.Bool("wait")
|
||||
overrideTimeout := c.Duration("wait-timeout")
|
||||
|
|
|
|||
|
|
@ -13,11 +13,12 @@ import (
|
|||
)
|
||||
|
||||
func init() {
|
||||
RegisterDriver(ClickHouseDriver{}, "clickhouse")
|
||||
RegisterDriver(&ClickHouseDriver{}, "clickhouse")
|
||||
}
|
||||
|
||||
// ClickHouseDriver provides top level database functions
|
||||
type ClickHouseDriver struct {
|
||||
migrationsTableName string
|
||||
}
|
||||
|
||||
func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
|
||||
|
|
@ -52,12 +53,17 @@ func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
|
|||
return &u
|
||||
}
|
||||
|
||||
// SetMigrationsTableName sets the schema migrations table name
|
||||
func (drv *ClickHouseDriver) SetMigrationsTableName(name string) {
|
||||
drv.migrationsTableName = name
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
func (drv *ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("clickhouse", normalizeClickHouseURL(u).String())
|
||||
}
|
||||
|
||||
func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
|
||||
func (drv *ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
|
||||
// connect to clickhouse database
|
||||
clickhouseURL := normalizeClickHouseURL(u)
|
||||
values := clickhouseURL.Query()
|
||||
|
|
@ -67,7 +73,7 @@ func (drv ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
|
|||
return drv.Open(clickhouseURL)
|
||||
}
|
||||
|
||||
func (drv ClickHouseDriver) databaseName(u *url.URL) string {
|
||||
func (drv *ClickHouseDriver) databaseName(u *url.URL) string {
|
||||
name := normalizeClickHouseURL(u).Query().Get("database")
|
||||
if name == "" {
|
||||
name = "default"
|
||||
|
|
@ -77,7 +83,7 @@ func (drv ClickHouseDriver) databaseName(u *url.URL) string {
|
|||
|
||||
var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`)
|
||||
|
||||
func clickhouseQuoteIdentifier(str string) string {
|
||||
func (drv *ClickHouseDriver) quoteIdentifier(str string) string {
|
||||
if clickhouseValidIdentifier.MatchString(str) {
|
||||
return str
|
||||
}
|
||||
|
|
@ -88,7 +94,7 @@ func clickhouseQuoteIdentifier(str string) string {
|
|||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error {
|
||||
func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error {
|
||||
name := drv.databaseName(u)
|
||||
fmt.Printf("Creating: %s\n", name)
|
||||
|
||||
|
|
@ -98,13 +104,13 @@ func (drv ClickHouseDriver) CreateDatabase(u *url.URL) error {
|
|||
}
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec("create database " + clickhouseQuoteIdentifier(name))
|
||||
_, err = db.Exec("create database " + drv.quoteIdentifier(name))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv ClickHouseDriver) DropDatabase(u *url.URL) error {
|
||||
func (drv *ClickHouseDriver) DropDatabase(u *url.URL) error {
|
||||
name := drv.databaseName(u)
|
||||
fmt.Printf("Dropping: %s\n", name)
|
||||
|
||||
|
|
@ -114,15 +120,15 @@ func (drv ClickHouseDriver) DropDatabase(u *url.URL) error {
|
|||
}
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec("drop database if exists " + clickhouseQuoteIdentifier(name))
|
||||
_, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
|
||||
func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
|
||||
buf.WriteString("\n--\n-- Database schema\n--\n\n")
|
||||
|
||||
buf.WriteString("CREATE DATABASE " + clickhouseQuoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
|
||||
buf.WriteString("CREATE DATABASE " + drv.quoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
|
||||
|
||||
tables, err := queryColumn(db, "show tables")
|
||||
if err != nil {
|
||||
|
|
@ -132,7 +138,7 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er
|
|||
|
||||
for _, table := range tables {
|
||||
var clause string
|
||||
err = db.QueryRow("show create table " + clickhouseQuoteIdentifier(table)).Scan(&clause)
|
||||
err = db.QueryRow("show create table " + drv.quoteIdentifier(table)).Scan(&clause)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -141,10 +147,13 @@ func clickhouseSchemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) er
|
|||
return nil
|
||||
}
|
||||
|
||||
func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
|
||||
func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
|
||||
migrationsTable := drv.quotedMigrationsTableName()
|
||||
|
||||
// load applied migrations
|
||||
migrations, err := queryColumn(db,
|
||||
"select version from schema_migrations final where applied order by version asc",
|
||||
fmt.Sprintf("select version from %s final ", migrationsTable)+
|
||||
"where applied order by version asc",
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -155,29 +164,30 @@ func clickhouseSchemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
|
|||
migrations[i] = "'" + quoter.Replace(migrations[i]) + "'"
|
||||
}
|
||||
|
||||
// build schema_migrations table data
|
||||
// build schema migrations table data
|
||||
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n")
|
||||
|
||||
if len(migrations) > 0 {
|
||||
buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" +
|
||||
strings.Join(migrations, "),\n (") +
|
||||
");\n")
|
||||
buf.WriteString(
|
||||
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
|
||||
strings.Join(migrations, "),\n (") +
|
||||
");\n")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DumpSchema returns the current database schema
|
||||
func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
var err error
|
||||
|
||||
err = clickhouseSchemaDump(db, &buf, drv.databaseName(u))
|
||||
err = drv.schemaDump(db, &buf, drv.databaseName(u))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = clickhouseSchemaMigrationsDump(db, &buf)
|
||||
err = drv.schemaMigrationsDump(db, &buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -186,7 +196,7 @@ func (drv ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
|||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := drv.databaseName(u)
|
||||
|
||||
db, err := drv.openClickHouseDB(u)
|
||||
|
|
@ -205,24 +215,27 @@ func (drv ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
|
|||
return exists, err
|
||||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
_, err := db.Exec(`
|
||||
create table if not exists schema_migrations (
|
||||
// CreateMigrationsTable creates the schema migrations table
|
||||
func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
_, err := db.Exec(fmt.Sprintf(`
|
||||
create table if not exists %s (
|
||||
version String,
|
||||
ts DateTime default now(),
|
||||
applied UInt8 default 1
|
||||
) engine = ReplacingMergeTree(ts)
|
||||
primary key version
|
||||
order by version
|
||||
`)
|
||||
`, drv.quotedMigrationsTableName()))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations final where applied order by version desc"
|
||||
func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := fmt.Sprintf("select version from %s final where applied order by version desc",
|
||||
drv.quotedMigrationsTableName())
|
||||
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
}
|
||||
|
|
@ -251,15 +264,19 @@ func (drv ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]
|
|||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv ClickHouseDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
||||
func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec(
|
||||
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
||||
version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) error {
|
||||
func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec(
|
||||
"insert into schema_migrations (version, applied) values (?, ?)",
|
||||
fmt.Sprintf("insert into %s (version, applied) values (?, ?)",
|
||||
drv.quotedMigrationsTableName()),
|
||||
version, false,
|
||||
)
|
||||
|
||||
|
|
@ -268,7 +285,7 @@ func (drv ClickHouseDriver) DeleteMigration(db Transaction, version string) erro
|
|||
|
||||
// Ping verifies a connection to the database server. It does not verify whether the
|
||||
// specified database exists.
|
||||
func (drv ClickHouseDriver) Ping(u *url.URL) error {
|
||||
func (drv *ClickHouseDriver) Ping(u *url.URL) error {
|
||||
// attempt connection to primary database, not "clickhouse" database
|
||||
// to support servers with no "clickhouse" database
|
||||
// (see https://github.com/amacneil/dbmate/issues/78)
|
||||
|
|
@ -291,3 +308,7 @@ func (drv ClickHouseDriver) Ping(u *url.URL) error {
|
|||
|
||||
return err
|
||||
}
|
||||
|
||||
func (drv *ClickHouseDriver) quotedMigrationsTableName() string {
|
||||
return drv.quoteIdentifier(drv.migrationsTableName)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,8 +15,15 @@ func clickhouseTestURL(t *testing.T) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
func testClickHouseDriver() *ClickHouseDriver {
|
||||
drv := &ClickHouseDriver{}
|
||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
||||
|
||||
return drv
|
||||
}
|
||||
|
||||
func prepTestClickHouseDB(t *testing.T, u *url.URL) *sql.DB {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -50,7 +57,7 @@ func TestNormalizeClickHouseURLCanonical(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClickHouseCreateDropDatabase(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
u := clickhouseTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -87,7 +94,9 @@ func TestClickHouseCreateDropDatabase(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClickHouseDumpSchema(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := clickhouseTestURL(t)
|
||||
|
||||
// prepare database
|
||||
|
|
@ -113,11 +122,11 @@ func TestClickHouseDumpSchema(t *testing.T) {
|
|||
// DumpSchema should return schema
|
||||
schema, err := drv.DumpSchema(u, db)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".schema_migrations")
|
||||
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".test_migrations")
|
||||
require.Contains(t, string(schema), "--\n"+
|
||||
"-- Dbmate schema migrations\n"+
|
||||
"--\n\n"+
|
||||
"INSERT INTO schema_migrations (version) VALUES\n"+
|
||||
"INSERT INTO test_migrations (version) VALUES\n"+
|
||||
" ('abc1'),\n"+
|
||||
" ('abc2');\n")
|
||||
|
||||
|
|
@ -134,7 +143,7 @@ func TestClickHouseDumpSchema(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClickHouseDatabaseExists(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
u := clickhouseTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -157,7 +166,7 @@ func TestClickHouseDatabaseExists(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClickHouseDatabaseExists_Error(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
u := clickhouseTestURL(t)
|
||||
values := u.Query()
|
||||
values.Set("username", "invalid")
|
||||
|
|
@ -169,31 +178,61 @@ func TestClickHouseDatabaseExists_Error(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClickHouseCreateMigrationsTable(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
u := clickhouseTestURL(t)
|
||||
db := prepTestClickHouseDB(t, u)
|
||||
defer mustClose(db)
|
||||
t.Run("default table", func(t *testing.T) {
|
||||
drv := testClickHouseDriver()
|
||||
u := clickhouseTestURL(t)
|
||||
db := prepTestClickHouseDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.")
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.")
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("custom table", func(t *testing.T) {
|
||||
drv := testClickHouseDriver()
|
||||
drv.SetMigrationsTableName("testMigrations")
|
||||
|
||||
u := clickhouseTestURL(t)
|
||||
db := prepTestClickHouseDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
|
||||
require.EqualError(t, err, "code: 60, message: Table dbmate.testMigrations doesn't exist.")
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClickHouseSelectMigrations(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := clickhouseTestURL(t)
|
||||
db := prepTestClickHouseDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -203,7 +242,7 @@ func TestClickHouseSelectMigrations(t *testing.T) {
|
|||
|
||||
tx, err := db.Begin()
|
||||
require.NoError(t, err)
|
||||
stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)")
|
||||
stmt, err := tx.Prepare("insert into test_migrations (version) values (?)")
|
||||
require.NoError(t, err)
|
||||
_, err = stmt.Exec("abc2")
|
||||
require.NoError(t, err)
|
||||
|
|
@ -229,7 +268,9 @@ func TestClickHouseSelectMigrations(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestClickHouseInsertMigration(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := clickhouseTestURL(t)
|
||||
db := prepTestClickHouseDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -238,7 +279,7 @@ func TestClickHouseInsertMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
|
|
@ -250,13 +291,15 @@ func TestClickHouseInsertMigration(t *testing.T) {
|
|||
err = tx.Commit()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestClickHouseDeleteMigration(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := clickhouseTestURL(t)
|
||||
db := prepTestClickHouseDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -266,7 +309,7 @@ func TestClickHouseDeleteMigration(t *testing.T) {
|
|||
|
||||
tx, err := db.Begin()
|
||||
require.NoError(t, err)
|
||||
stmt, err := tx.Prepare("insert into schema_migrations (version) values (?)")
|
||||
stmt, err := tx.Prepare("insert into test_migrations (version) values (?)")
|
||||
require.NoError(t, err)
|
||||
_, err = stmt.Exec("abc2")
|
||||
require.NoError(t, err)
|
||||
|
|
@ -283,13 +326,13 @@ func TestClickHouseDeleteMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations final where applied").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations final where applied").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestClickHousePing(t *testing.T) {
|
||||
drv := ClickHouseDriver{}
|
||||
drv := testClickHouseDriver()
|
||||
u := clickhouseTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -306,3 +349,27 @@ func TestClickHousePing(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "connect: connection refused")
|
||||
}
|
||||
|
||||
func TestClickHouseQuotedMigrationsTableName(t *testing.T) {
|
||||
t.Run("default name", func(t *testing.T) {
|
||||
drv := testClickHouseDriver()
|
||||
name := drv.quotedMigrationsTableName()
|
||||
require.Equal(t, "schema_migrations", name)
|
||||
})
|
||||
|
||||
t.Run("custom name", func(t *testing.T) {
|
||||
drv := testClickHouseDriver()
|
||||
drv.SetMigrationsTableName("fooMigrations")
|
||||
|
||||
name := drv.quotedMigrationsTableName()
|
||||
require.Equal(t, "fooMigrations", name)
|
||||
})
|
||||
|
||||
t.Run("quoted name", func(t *testing.T) {
|
||||
drv := testClickHouseDriver()
|
||||
drv.SetMigrationsTableName("bizarre\"$name")
|
||||
|
||||
name := drv.quotedMigrationsTableName()
|
||||
require.Equal(t, `"bizarre""$name"`, name)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ import (
|
|||
// DefaultMigrationsDir specifies default directory to find migration files
|
||||
const DefaultMigrationsDir = "./db/migrations"
|
||||
|
||||
// DefaultMigrationsTableName specifies default database tables to record migraitons in
|
||||
const DefaultMigrationsTableName = "schema_migrations"
|
||||
|
||||
// DefaultSchemaFile specifies default location for schema.sql
|
||||
const DefaultSchemaFile = "./db/schema.sql"
|
||||
|
||||
|
|
@ -26,14 +29,15 @@ const DefaultWaitTimeout = 60 * time.Second
|
|||
|
||||
// DB allows dbmate actions to be performed on a specified database
|
||||
type DB struct {
|
||||
AutoDumpSchema bool
|
||||
DatabaseURL *url.URL
|
||||
MigrationsDir string
|
||||
SchemaFile string
|
||||
Verbose bool
|
||||
WaitBefore bool
|
||||
WaitInterval time.Duration
|
||||
WaitTimeout time.Duration
|
||||
AutoDumpSchema bool
|
||||
DatabaseURL *url.URL
|
||||
MigrationsDir string
|
||||
MigrationsTableName string
|
||||
SchemaFile string
|
||||
Verbose bool
|
||||
WaitBefore bool
|
||||
WaitInterval time.Duration
|
||||
WaitTimeout time.Duration
|
||||
}
|
||||
|
||||
// migrationFileRegexp pattern for valid migration files
|
||||
|
|
@ -47,19 +51,27 @@ type statusResult struct {
|
|||
// New initializes a new dbmate database
|
||||
func New(databaseURL *url.URL) *DB {
|
||||
return &DB{
|
||||
AutoDumpSchema: true,
|
||||
DatabaseURL: databaseURL,
|
||||
MigrationsDir: DefaultMigrationsDir,
|
||||
SchemaFile: DefaultSchemaFile,
|
||||
WaitBefore: false,
|
||||
WaitInterval: DefaultWaitInterval,
|
||||
WaitTimeout: DefaultWaitTimeout,
|
||||
AutoDumpSchema: true,
|
||||
DatabaseURL: databaseURL,
|
||||
MigrationsDir: DefaultMigrationsDir,
|
||||
MigrationsTableName: DefaultMigrationsTableName,
|
||||
SchemaFile: DefaultSchemaFile,
|
||||
WaitBefore: false,
|
||||
WaitInterval: DefaultWaitInterval,
|
||||
WaitTimeout: DefaultWaitTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// GetDriver loads the required database driver
|
||||
func (db *DB) GetDriver() (Driver, error) {
|
||||
return GetDriver(db.DatabaseURL.Scheme)
|
||||
drv, err := getDriver(db.DatabaseURL.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
drv.SetMigrationsTableName(db.MigrationsTableName)
|
||||
|
||||
return drv, err
|
||||
}
|
||||
|
||||
// Wait blocks until the database server is available. It does not verify that
|
||||
|
|
|
|||
|
|
@ -38,12 +38,26 @@ func TestNew(t *testing.T) {
|
|||
require.True(t, db.AutoDumpSchema)
|
||||
require.Equal(t, u.String(), db.DatabaseURL.String())
|
||||
require.Equal(t, "./db/migrations", db.MigrationsDir)
|
||||
require.Equal(t, "schema_migrations", db.MigrationsTableName)
|
||||
require.Equal(t, "./db/schema.sql", db.SchemaFile)
|
||||
require.False(t, db.WaitBefore)
|
||||
require.Equal(t, time.Second, db.WaitInterval)
|
||||
require.Equal(t, 60*time.Second, db.WaitTimeout)
|
||||
}
|
||||
|
||||
func TestGetDriver(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db := New(u)
|
||||
|
||||
drv, err := db.GetDriver()
|
||||
require.NoError(t, err)
|
||||
|
||||
// driver should have default migrations table set
|
||||
pgDrv, ok := drv.(*PostgresDriver)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "schema_migrations", pgDrv.migrationsTableName)
|
||||
}
|
||||
|
||||
func TestWait(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db := newTestDB(t, u)
|
||||
|
|
@ -242,7 +256,7 @@ func testMigrateURL(t *testing.T, u *url.URL) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// verify results
|
||||
sqlDB, err := GetDriverOpen(u)
|
||||
sqlDB, err := getDriverOpen(u)
|
||||
require.NoError(t, err)
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
|
|
@ -275,7 +289,7 @@ func testUpURL(t *testing.T, u *url.URL) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// verify results
|
||||
sqlDB, err := GetDriverOpen(u)
|
||||
sqlDB, err := getDriverOpen(u)
|
||||
require.NoError(t, err)
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
|
|
@ -308,7 +322,7 @@ func testRollbackURL(t *testing.T, u *url.URL) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// verify migration
|
||||
sqlDB, err := GetDriverOpen(u)
|
||||
sqlDB, err := getDriverOpen(u)
|
||||
require.NoError(t, err)
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
|
|
@ -351,7 +365,7 @@ func testStatusURL(t *testing.T, u *url.URL) {
|
|||
require.NoError(t, err)
|
||||
|
||||
// verify migration
|
||||
sqlDB, err := GetDriverOpen(u)
|
||||
sqlDB, err := getDriverOpen(u)
|
||||
require.NoError(t, err)
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ type Driver interface {
|
|||
CreateDatabase(*url.URL) error
|
||||
DropDatabase(*url.URL) error
|
||||
DumpSchema(*url.URL, *sql.DB) ([]byte, error)
|
||||
SetMigrationsTableName(string)
|
||||
CreateMigrationsTable(*url.URL, *sql.DB) error
|
||||
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
||||
InsertMigration(Transaction, string) error
|
||||
|
|
@ -34,18 +35,20 @@ type Transaction interface {
|
|||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// GetDriver loads a database driver by name
|
||||
func GetDriver(name string) (Driver, error) {
|
||||
if val, ok := drivers[name]; ok {
|
||||
return val, nil
|
||||
// getDriver loads a database driver by name
|
||||
func getDriver(name string) (Driver, error) {
|
||||
if drv, ok := drivers[name]; ok {
|
||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
||||
|
||||
return drv, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported driver: %s", name)
|
||||
}
|
||||
|
||||
// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
|
||||
func GetDriverOpen(u *url.URL) (*sql.DB, error) {
|
||||
drv, err := GetDriver(u.Scheme)
|
||||
// getDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
|
||||
func getDriverOpen(u *url.URL) (*sql.DB, error) {
|
||||
drv, err := getDriver(u.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,21 +7,21 @@ import (
|
|||
)
|
||||
|
||||
func TestGetDriver_Postgres(t *testing.T) {
|
||||
drv, err := GetDriver("postgres")
|
||||
drv, err := getDriver("postgres")
|
||||
require.NoError(t, err)
|
||||
_, ok := drv.(PostgresDriver)
|
||||
_, ok := drv.(*PostgresDriver)
|
||||
require.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestGetDriver_MySQL(t *testing.T) {
|
||||
drv, err := GetDriver("mysql")
|
||||
drv, err := getDriver("mysql")
|
||||
require.NoError(t, err)
|
||||
_, ok := drv.(MySQLDriver)
|
||||
_, ok := drv.(*MySQLDriver)
|
||||
require.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestGetDriver_Error(t *testing.T) {
|
||||
drv, err := GetDriver("foo")
|
||||
drv, err := getDriver("foo")
|
||||
require.EqualError(t, err, "unsupported driver: foo")
|
||||
require.Nil(t, drv)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ import (
|
|||
)
|
||||
|
||||
func init() {
|
||||
RegisterDriver(MySQLDriver{}, "mysql")
|
||||
RegisterDriver(&MySQLDriver{}, "mysql")
|
||||
}
|
||||
|
||||
// MySQLDriver provides top level database functions
|
||||
type MySQLDriver struct {
|
||||
migrationsTableName string
|
||||
}
|
||||
|
||||
func normalizeMySQLURL(u *url.URL) string {
|
||||
|
|
@ -52,12 +53,17 @@ func normalizeMySQLURL(u *url.URL) string {
|
|||
return normalizedString
|
||||
}
|
||||
|
||||
// SetMigrationsTableName sets the schema migrations table name
|
||||
func (drv *MySQLDriver) SetMigrationsTableName(name string) {
|
||||
drv.migrationsTableName = name
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
func (drv *MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("mysql", normalizeMySQLURL(u))
|
||||
}
|
||||
|
||||
func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
|
||||
func (drv *MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
|
||||
// connect to no particular database
|
||||
rootURL := *u
|
||||
rootURL.Path = "/"
|
||||
|
|
@ -65,14 +71,14 @@ func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
|
|||
return drv.Open(&rootURL)
|
||||
}
|
||||
|
||||
func mysqlQuoteIdentifier(str string) string {
|
||||
func (drv *MySQLDriver) quoteIdentifier(str string) string {
|
||||
str = strings.Replace(str, "`", "\\`", -1)
|
||||
|
||||
return fmt.Sprintf("`%s`", str)
|
||||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
|
||||
func (drv *MySQLDriver) CreateDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Creating: %s\n", name)
|
||||
|
||||
|
|
@ -83,13 +89,13 @@ func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
|
|||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec(fmt.Sprintf("create database %s",
|
||||
mysqlQuoteIdentifier(name)))
|
||||
drv.quoteIdentifier(name)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv MySQLDriver) DropDatabase(u *url.URL) error {
|
||||
func (drv *MySQLDriver) DropDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Dropping: %s\n", name)
|
||||
|
||||
|
|
@ -100,12 +106,12 @@ func (drv MySQLDriver) DropDatabase(u *url.URL) error {
|
|||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
|
||||
mysqlQuoteIdentifier(name)))
|
||||
drv.quoteIdentifier(name)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func mysqldumpArgs(u *url.URL) []string {
|
||||
func (drv *MySQLDriver) mysqldumpArgs(u *url.URL) []string {
|
||||
// generate CLI arguments
|
||||
args := []string{"--opt", "--routines", "--no-data",
|
||||
"--skip-dump-date", "--skip-add-drop-table"}
|
||||
|
|
@ -131,10 +137,12 @@ func mysqldumpArgs(u *url.URL) []string {
|
|||
return args
|
||||
}
|
||||
|
||||
func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
migrationsTable := drv.quotedMigrationsTableName()
|
||||
|
||||
// load applied migrations
|
||||
migrations, err := queryColumn(db,
|
||||
"select quote(version) from schema_migrations order by version asc")
|
||||
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -142,12 +150,13 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
|||
// build schema_migrations table data
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n" +
|
||||
"LOCK TABLES `schema_migrations` WRITE;\n")
|
||||
fmt.Sprintf("LOCK TABLES %s WRITE;\n", migrationsTable))
|
||||
|
||||
if len(migrations) > 0 {
|
||||
buf.WriteString("INSERT INTO `schema_migrations` (version) VALUES\n (" +
|
||||
strings.Join(migrations, "),\n (") +
|
||||
");\n")
|
||||
buf.WriteString(
|
||||
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
|
||||
strings.Join(migrations, "),\n (") +
|
||||
");\n")
|
||||
}
|
||||
|
||||
buf.WriteString("UNLOCK TABLES;\n")
|
||||
|
|
@ -156,13 +165,13 @@ func mysqlSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
|||
}
|
||||
|
||||
// DumpSchema returns the current database schema
|
||||
func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
schema, err := runCommand("mysqldump", mysqldumpArgs(u)...)
|
||||
func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
schema, err := runCommand("mysqldump", drv.mysqldumpArgs(u)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrations, err := mysqlSchemaMigrationsDump(db)
|
||||
migrations, err := drv.schemaMigrationsDump(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -172,7 +181,7 @@ func (drv MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
|||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := databaseName(u)
|
||||
|
||||
db, err := drv.openRootDB(u)
|
||||
|
|
@ -192,17 +201,18 @@ func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
|||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
_, err := db.Exec("create table if not exists schema_migrations " +
|
||||
"(version varchar(255) primary key) character set latin1 collate latin1_bin")
|
||||
func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
_, err := db.Exec(fmt.Sprintf("create table if not exists %s "+
|
||||
"(version varchar(255) primary key) character set latin1 collate latin1_bin",
|
||||
drv.quotedMigrationsTableName()))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations order by version desc"
|
||||
func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
}
|
||||
|
|
@ -231,22 +241,26 @@ func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool,
|
|||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv MySQLDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
||||
func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec(
|
||||
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
||||
version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from schema_migrations where version = ?", version)
|
||||
func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec(
|
||||
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
|
||||
version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Ping verifies a connection to the database server. It does not verify whether the
|
||||
// specified database exists.
|
||||
func (drv MySQLDriver) Ping(u *url.URL) error {
|
||||
func (drv *MySQLDriver) Ping(u *url.URL) error {
|
||||
db, err := drv.openRootDB(u)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -255,3 +269,7 @@ func (drv MySQLDriver) Ping(u *url.URL) error {
|
|||
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
func (drv *MySQLDriver) quotedMigrationsTableName() string {
|
||||
return drv.quoteIdentifier(drv.migrationsTableName)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,8 +15,15 @@ func mySQLTestURL(t *testing.T) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
func testMySQLDriver() *MySQLDriver {
|
||||
drv := &MySQLDriver{}
|
||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
||||
|
||||
return drv
|
||||
}
|
||||
|
||||
func prepTestMySQLDB(t *testing.T, u *url.URL) *sql.DB {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -78,7 +85,7 @@ func TestNormalizeMySQLURLSocket(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMySQLCreateDropDatabase(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -116,7 +123,9 @@ func TestMySQLCreateDropDatabase(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMySQLDumpSchema(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// prepare database
|
||||
|
|
@ -134,13 +143,13 @@ func TestMySQLDumpSchema(t *testing.T) {
|
|||
// DumpSchema should return schema
|
||||
schema, err := drv.DumpSchema(u, db)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(schema), "CREATE TABLE `schema_migrations`")
|
||||
require.Contains(t, string(schema), "CREATE TABLE `test_migrations`")
|
||||
require.Contains(t, string(schema), "\n-- Dump completed\n\n"+
|
||||
"--\n"+
|
||||
"-- Dbmate schema migrations\n"+
|
||||
"--\n\n"+
|
||||
"LOCK TABLES `schema_migrations` WRITE;\n"+
|
||||
"INSERT INTO `schema_migrations` (version) VALUES\n"+
|
||||
"LOCK TABLES `test_migrations` WRITE;\n"+
|
||||
"INSERT INTO `test_migrations` (version) VALUES\n"+
|
||||
" ('abc1'),\n"+
|
||||
" ('abc2');\n"+
|
||||
"UNLOCK TABLES;\n")
|
||||
|
|
@ -156,7 +165,7 @@ func TestMySQLDumpSchema(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMySQLDatabaseExists(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -179,7 +188,7 @@ func TestMySQLDatabaseExists(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMySQLDatabaseExists_Error(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
u := mySQLTestURL(t)
|
||||
u.User = url.User("invalid")
|
||||
|
||||
|
|
@ -189,22 +198,25 @@ func TestMySQLDatabaseExists_Error(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMySQLCreateMigrationsTable(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := mySQLTestURL(t)
|
||||
db := prepTestMySQLDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Regexp(t, "Table 'dbmate.schema_migrations' doesn't exist", err.Error())
|
||||
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.Error(t, err)
|
||||
require.Regexp(t, "Table 'dbmate.test_migrations' doesn't exist", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
|
|
@ -213,7 +225,9 @@ func TestMySQLCreateMigrationsTable(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMySQLSelectMigrations(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := mySQLTestURL(t)
|
||||
db := prepTestMySQLDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -221,7 +235,7 @@ func TestMySQLSelectMigrations(t *testing.T) {
|
|||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
_, err = db.Exec(`insert into test_migrations (version)
|
||||
values ('abc2'), ('abc1'), ('abc3')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -240,7 +254,9 @@ func TestMySQLSelectMigrations(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestMySQLInsertMigration(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := mySQLTestURL(t)
|
||||
db := prepTestMySQLDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -249,7 +265,7 @@ func TestMySQLInsertMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
|
|
@ -257,14 +273,16 @@ func TestMySQLInsertMigration(t *testing.T) {
|
|||
err = drv.InsertMigration(db, "abc1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").
|
||||
err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").
|
||||
Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestMySQLDeleteMigration(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := mySQLTestURL(t)
|
||||
db := prepTestMySQLDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -272,7 +290,7 @@ func TestMySQLDeleteMigration(t *testing.T) {
|
|||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
_, err = db.Exec(`insert into test_migrations (version)
|
||||
values ('abc1'), ('abc2')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -280,13 +298,13 @@ func TestMySQLDeleteMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestMySQLPing(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
drv := testMySQLDriver()
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -303,3 +321,19 @@ func TestMySQLPing(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "connect: connection refused")
|
||||
}
|
||||
|
||||
func TestMySQLQuotedMigrationsTableName(t *testing.T) {
|
||||
t.Run("default name", func(t *testing.T) {
|
||||
drv := testMySQLDriver()
|
||||
name := drv.quotedMigrationsTableName()
|
||||
require.Equal(t, "`schema_migrations`", name)
|
||||
})
|
||||
|
||||
t.Run("custom name", func(t *testing.T) {
|
||||
drv := testMySQLDriver()
|
||||
drv.SetMigrationsTableName("fooMigrations")
|
||||
|
||||
name := drv.quotedMigrationsTableName()
|
||||
require.Equal(t, "`fooMigrations`", name)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,12 +11,14 @@ import (
|
|||
)
|
||||
|
||||
func init() {
|
||||
RegisterDriver(PostgresDriver{}, "postgres")
|
||||
RegisterDriver(PostgresDriver{}, "postgresql")
|
||||
drv := &PostgresDriver{}
|
||||
RegisterDriver(drv, "postgres")
|
||||
RegisterDriver(drv, "postgresql")
|
||||
}
|
||||
|
||||
// PostgresDriver provides top level database functions
|
||||
type PostgresDriver struct {
|
||||
migrationsTableName string
|
||||
}
|
||||
|
||||
func normalizePostgresURL(u *url.URL) *url.URL {
|
||||
|
|
@ -78,12 +80,17 @@ func normalizePostgresURLForDump(u *url.URL) []string {
|
|||
return out
|
||||
}
|
||||
|
||||
// SetMigrationsTableName sets the schema migrations table name
|
||||
func (drv *PostgresDriver) SetMigrationsTableName(name string) {
|
||||
drv.migrationsTableName = name
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
func (drv *PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("postgres", normalizePostgresURL(u).String())
|
||||
}
|
||||
|
||||
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||
func (drv *PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||
// connect to postgres database
|
||||
postgresURL := *u
|
||||
postgresURL.Path = "postgres"
|
||||
|
|
@ -92,7 +99,7 @@ func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
|||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
|
||||
func (drv *PostgresDriver) CreateDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Creating: %s\n", name)
|
||||
|
||||
|
|
@ -109,7 +116,7 @@ func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
|
|||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv PostgresDriver) DropDatabase(u *url.URL) error {
|
||||
func (drv *PostgresDriver) DropDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Dropping: %s\n", name)
|
||||
|
||||
|
|
@ -125,8 +132,8 @@ func (drv PostgresDriver) DropDatabase(u *url.URL) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -152,7 +159,7 @@ func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, erro
|
|||
}
|
||||
|
||||
// DumpSchema returns the current database schema
|
||||
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
// load schema
|
||||
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
|
||||
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
|
||||
|
|
@ -161,7 +168,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
migrations, err := drv.postgresSchemaMigrationsDump(db)
|
||||
migrations, err := drv.schemaMigrationsDump(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -171,7 +178,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
|||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := databaseName(u)
|
||||
|
||||
db, err := drv.openPostgresDB(u)
|
||||
|
|
@ -191,48 +198,45 @@ func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
|||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
// get schema from URL search_path param
|
||||
searchPath := strings.Split(u.Query().Get("search_path"), ",")
|
||||
urlSchema := strings.TrimSpace(searchPath[0])
|
||||
if urlSchema == "" {
|
||||
urlSchema = "public"
|
||||
}
|
||||
|
||||
// get *unquoted* current schema from database
|
||||
dbSchema, err := queryRow(db, "select current_schema()")
|
||||
func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db, u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// if urlSchema and dbSchema are not equal, the most likely explanation is that the schema
|
||||
// has not yet been created
|
||||
if urlSchema != dbSchema {
|
||||
// in theory we could just execute this statement every time, but we do the comparison
|
||||
// above in case the user doesn't have permissions to create schemas and the schema
|
||||
// already exists
|
||||
fmt.Printf("Creating schema: %s\n", urlSchema)
|
||||
_, err = db.Exec("create schema if not exists " + pq.QuoteIdentifier(urlSchema))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// first attempt at creating migrations table
|
||||
createTableStmt := fmt.Sprintf("create table if not exists %s.%s", schema, migrationsTable) +
|
||||
" (version varchar(255) primary key)"
|
||||
_, err = db.Exec(createTableStmt)
|
||||
if err == nil {
|
||||
// table exists or created successfully
|
||||
return nil
|
||||
}
|
||||
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
// catch 'schema does not exist' error
|
||||
pqErr, ok := err.(*pq.Error)
|
||||
if !ok || pqErr.Code != "3F000" {
|
||||
// unknown error
|
||||
return err
|
||||
}
|
||||
|
||||
// in theory we could attempt to create the schema every time, but we avoid that
|
||||
// in case the user doesn't have permissions to create schemas
|
||||
fmt.Printf("Creating schema: %s\n", schema)
|
||||
_, err = db.Exec(fmt.Sprintf("create schema if not exists %s", schema))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec("create table if not exists " + migrationsTable +
|
||||
" (version varchar(255) primary key)")
|
||||
|
||||
// second and final attempt at creating migrations table
|
||||
_, err = db.Exec(createTableStmt)
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -266,8 +270,8 @@ func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bo
|
|||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error {
|
||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -278,8 +282,8 @@ func (drv PostgresDriver) InsertMigration(db Transaction, version string) error
|
|||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error {
|
||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -291,7 +295,7 @@ func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error
|
|||
|
||||
// Ping verifies a connection to the database server. It does not verify whether the
|
||||
// specified database exists.
|
||||
func (drv PostgresDriver) Ping(u *url.URL) error {
|
||||
func (drv *PostgresDriver) Ping(u *url.URL) error {
|
||||
// attempt connection to primary database, not "postgres" database
|
||||
// to support servers with no "postgres" database
|
||||
// (see https://github.com/amacneil/dbmate/issues/78)
|
||||
|
|
@ -306,7 +310,7 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// ignore 'database "foo" does not exist' error
|
||||
// ignore 'database does not exist' error
|
||||
pqErr, ok := err.(*pq.Error)
|
||||
if ok && pqErr.Code == "3D000" {
|
||||
return nil
|
||||
|
|
@ -315,17 +319,53 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (drv PostgresDriver) migrationsTableName(db Transaction) (string, error) {
|
||||
// get current schema
|
||||
schema, err := queryRow(db, "select quote_ident(current_schema())")
|
||||
func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, error) {
|
||||
schema, name, err := drv.quotedMigrationsTableNameParts(db, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// if the search path is empty, or does not contain a valid schema, default to public
|
||||
return schema + "." + name, nil
|
||||
}
|
||||
|
||||
func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url.URL) (string, string, error) {
|
||||
schema := ""
|
||||
tableNameParts := strings.Split(drv.migrationsTableName, ".")
|
||||
if len(tableNameParts) > 1 {
|
||||
// schema specified as part of table name
|
||||
schema, tableNameParts = tableNameParts[0], tableNameParts[1:]
|
||||
}
|
||||
|
||||
if schema == "" && u != nil {
|
||||
// no schema specified with table name, try URL search path if available
|
||||
searchPath := strings.Split(u.Query().Get("search_path"), ",")
|
||||
schema = strings.TrimSpace(searchPath[0])
|
||||
}
|
||||
|
||||
var err error
|
||||
if schema == "" {
|
||||
// if no URL available, use current schema
|
||||
// this is a hack because we don't always have the URL context available
|
||||
schema, err = queryValue(db, "select current_schema()")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
// fall back to public schema as last resort
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
|
||||
return schema + ".schema_migrations", nil
|
||||
// quote all parts
|
||||
// use server rather than client to do this to avoid unnecessary quotes
|
||||
// (which would change schema.sql diff)
|
||||
tableNameParts = append([]string{schema}, tableNameParts...)
|
||||
quotedNameParts, err := queryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
// if more than one part, we already have a schema
|
||||
return quotedNameParts[0], strings.Join(quotedNameParts[1:], "."), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,8 +15,15 @@ func postgresTestURL(t *testing.T) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
func testPostgresDriver() *PostgresDriver {
|
||||
drv := &PostgresDriver{}
|
||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
||||
|
||||
return drv
|
||||
}
|
||||
|
||||
func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -87,7 +94,7 @@ func TestNormalizePostgresURLForDump(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPostgresCreateDropDatabase(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -125,45 +132,80 @@ func TestPostgresCreateDropDatabase(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPostgresDumpSchema(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
t.Run("default migrations table", func(t *testing.T) {
|
||||
drv := testPostgresDriver()
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// prepare database
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
// prepare database
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// insert migration
|
||||
err = drv.InsertMigration(db, "abc1")
|
||||
require.NoError(t, err)
|
||||
err = drv.InsertMigration(db, "abc2")
|
||||
require.NoError(t, err)
|
||||
// insert migration
|
||||
err = drv.InsertMigration(db, "abc1")
|
||||
require.NoError(t, err)
|
||||
err = drv.InsertMigration(db, "abc2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// DumpSchema should return schema
|
||||
schema, err := drv.DumpSchema(u, db)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations")
|
||||
require.Contains(t, string(schema), "\n--\n"+
|
||||
"-- PostgreSQL database dump complete\n"+
|
||||
"--\n\n\n"+
|
||||
"--\n"+
|
||||
"-- Dbmate schema migrations\n"+
|
||||
"--\n\n"+
|
||||
"INSERT INTO public.schema_migrations (version) VALUES\n"+
|
||||
" ('abc1'),\n"+
|
||||
" ('abc2');\n")
|
||||
// DumpSchema should return schema
|
||||
schema, err := drv.DumpSchema(u, db)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations")
|
||||
require.Contains(t, string(schema), "\n--\n"+
|
||||
"-- PostgreSQL database dump complete\n"+
|
||||
"--\n\n\n"+
|
||||
"--\n"+
|
||||
"-- Dbmate schema migrations\n"+
|
||||
"--\n\n"+
|
||||
"INSERT INTO public.schema_migrations (version) VALUES\n"+
|
||||
" ('abc1'),\n"+
|
||||
" ('abc2');\n")
|
||||
|
||||
// DumpSchema should return error if command fails
|
||||
u.Path = "/fakedb"
|
||||
schema, err = drv.DumpSchema(u, db)
|
||||
require.Nil(t, schema)
|
||||
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
|
||||
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
|
||||
// DumpSchema should return error if command fails
|
||||
u.Path = "/fakedb"
|
||||
schema, err = drv.DumpSchema(u, db)
|
||||
require.Nil(t, schema)
|
||||
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
|
||||
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
|
||||
})
|
||||
|
||||
t.Run("custom migrations table with schema", func(t *testing.T) {
|
||||
drv := testPostgresDriver()
|
||||
drv.SetMigrationsTableName("camelSchema.testMigrations")
|
||||
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// prepare database
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// insert migration
|
||||
err = drv.InsertMigration(db, "abc1")
|
||||
require.NoError(t, err)
|
||||
err = drv.InsertMigration(db, "abc2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// DumpSchema should return schema
|
||||
schema, err := drv.DumpSchema(u, db)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(schema), "CREATE TABLE \"camelSchema\".\"testMigrations\"")
|
||||
require.Contains(t, string(schema), "\n--\n"+
|
||||
"-- PostgreSQL database dump complete\n"+
|
||||
"--\n\n\n"+
|
||||
"--\n"+
|
||||
"-- Dbmate schema migrations\n"+
|
||||
"--\n\n"+
|
||||
"INSERT INTO \"camelSchema\".\"testMigrations\" (version) VALUES\n"+
|
||||
" ('abc1'),\n"+
|
||||
" ('abc2');\n")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgresDatabaseExists(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -186,7 +228,7 @@ func TestPostgresDatabaseExists(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPostgresDatabaseExists_Error(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
u := postgresTestURL(t)
|
||||
u.User = url.User("invalid")
|
||||
|
||||
|
|
@ -197,9 +239,8 @@ func TestPostgresDatabaseExists_Error(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
|
||||
t.Run("default schema", func(t *testing.T) {
|
||||
drv := testPostgresDriver()
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -223,39 +264,81 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("custom schema", func(t *testing.T) {
|
||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
|
||||
t.Run("custom search path", func(t *testing.T) {
|
||||
drv := testPostgresDriver()
|
||||
drv.SetMigrationsTableName("testMigrations")
|
||||
|
||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=camelFoo")
|
||||
require.NoError(t, err)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// delete schema
|
||||
_, err = db.Exec("drop schema if exists foo")
|
||||
_, err = db.Exec("drop schema if exists \"camelFoo\"")
|
||||
require.NoError(t, err)
|
||||
|
||||
// drop any schema_migrations table in public schema
|
||||
_, err = db.Exec("drop table if exists public.schema_migrations")
|
||||
// drop any testMigrations table in public schema
|
||||
_, err = db.Exec("drop table if exists public.\"testMigrations\"")
|
||||
require.NoError(t, err)
|
||||
|
||||
// migrations table should not exist in either schema
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "pq: relation \"foo.schema_migrations\" does not exist", err.Error())
|
||||
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
|
||||
require.Equal(t, "pq: relation \"camelFoo.testMigrations\" does not exist", err.Error())
|
||||
err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())
|
||||
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// foo schema should be created, and migrations table should exist only in foo schema
|
||||
err = db.QueryRow("select count(*) from foo.schema_migrations").Scan(&count)
|
||||
// camelFoo schema should be created, and migrations table should exist only in camelFoo schema
|
||||
err = db.QueryRow("select count(*) from \"camelFoo\".\"testMigrations\"").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from public.\"testMigrations\"").Scan(&count)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())
|
||||
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("custom schema", func(t *testing.T) {
|
||||
drv := testPostgresDriver()
|
||||
drv.SetMigrationsTableName("camelSchema.testMigrations")
|
||||
|
||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
|
||||
require.NoError(t, err)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// delete schemas
|
||||
_, err = db.Exec("drop schema if exists foo")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("drop schema if exists \"camelSchema\"")
|
||||
require.NoError(t, err)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "pq: relation \"camelSchema.testMigrations\" does not exist", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// camelSchema should be created, and testMigrations table should exist
|
||||
err = db.QueryRow("select count(*) from \"camelSchema\".\"testMigrations\"").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
// testMigrations table should not exist in foo schema because
|
||||
// schema specified with migrations table name takes priority over search path
|
||||
err = db.QueryRow("select count(*) from foo.\"testMigrations\"").Scan(&count)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "pq: relation \"foo.testMigrations\" does not exist", err.Error())
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
|
|
@ -264,7 +347,9 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPostgresSelectMigrations(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -272,7 +357,7 @@ func TestPostgresSelectMigrations(t *testing.T) {
|
|||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into public.schema_migrations (version)
|
||||
_, err = db.Exec(`insert into public.test_migrations (version)
|
||||
values ('abc2'), ('abc1'), ('abc3')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -291,7 +376,9 @@ func TestPostgresSelectMigrations(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestPostgresInsertMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -300,7 +387,7 @@ func TestPostgresInsertMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
|
|
@ -308,14 +395,16 @@ func TestPostgresInsertMigration(t *testing.T) {
|
|||
err = drv.InsertMigration(db, "abc1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.QueryRow("select count(*) from public.schema_migrations where version = 'abc1'").
|
||||
err = db.QueryRow("select count(*) from public.test_migrations where version = 'abc1'").
|
||||
Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestPostgresDeleteMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -323,7 +412,7 @@ func TestPostgresDeleteMigration(t *testing.T) {
|
|||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into public.schema_migrations (version)
|
||||
_, err = db.Exec(`insert into public.test_migrations (version)
|
||||
values ('abc1'), ('abc2')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -331,13 +420,13 @@ func TestPostgresDeleteMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from public.schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from public.test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestPostgresPing(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
drv := testPostgresDriver()
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -355,15 +444,15 @@ func TestPostgresPing(t *testing.T) {
|
|||
require.Contains(t, err.Error(), "connect: connection refused")
|
||||
}
|
||||
|
||||
func TestMigrationsTableName(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
func TestPostgresQuotedMigrationsTableName(t *testing.T) {
|
||||
drv := testPostgresDriver()
|
||||
|
||||
t.Run("default schema", func(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
name, err := drv.migrationsTableName(db)
|
||||
name, err := drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.schema_migrations", name)
|
||||
})
|
||||
|
|
@ -379,14 +468,14 @@ func TestMigrationsTableName(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
_, err = db.Exec("drop schema if exists bar")
|
||||
require.NoError(t, err)
|
||||
name, err := drv.migrationsTableName(db)
|
||||
name, err := drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.schema_migrations", name)
|
||||
|
||||
// if "foo" schema exists, it should be used
|
||||
_, err = db.Exec("create schema foo")
|
||||
require.NoError(t, err)
|
||||
name, err = drv.migrationsTableName(db)
|
||||
name, err = drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foo.schema_migrations", name)
|
||||
})
|
||||
|
|
@ -401,8 +490,76 @@ func TestMigrationsTableName(t *testing.T) {
|
|||
_, err := db.Exec("select pg_catalog.set_config('search_path', '', false)")
|
||||
require.NoError(t, err)
|
||||
|
||||
name, err := drv.migrationsTableName(db)
|
||||
name, err := drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.schema_migrations", name)
|
||||
})
|
||||
|
||||
t.Run("custom table name", func(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
drv.SetMigrationsTableName("simple_name")
|
||||
name, err := drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.simple_name", name)
|
||||
})
|
||||
|
||||
t.Run("custom table name quoted", func(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// this table name will need quoting
|
||||
drv.SetMigrationsTableName("camelCase")
|
||||
name, err := drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.\"camelCase\"", name)
|
||||
})
|
||||
|
||||
t.Run("custom table name with custom schema", func(t *testing.T) {
|
||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
|
||||
require.NoError(t, err)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec("create schema if not exists foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv.SetMigrationsTableName("simple_name")
|
||||
name, err := drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foo.simple_name", name)
|
||||
})
|
||||
|
||||
t.Run("custom table name overrides schema", func(t *testing.T) {
|
||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
|
||||
require.NoError(t, err)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec("create schema if not exists foo")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("create schema if not exists bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
// if schema is specified as part of table name, it should override search_path
|
||||
drv.SetMigrationsTableName("bar.simple_name")
|
||||
name, err := drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "bar.simple_name", name)
|
||||
|
||||
// schema and table name should be quoted if necessary
|
||||
drv.SetMigrationsTableName("barName.camelTable")
|
||||
name, err = drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "\"barName\".\"camelTable\"", name)
|
||||
|
||||
// more than 2 components is unexpected but we will quote and pass it along anyway
|
||||
drv.SetMigrationsTableName("whyWould.i.doThis")
|
||||
name, err = drv.quotedMigrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "\"whyWould\".i.\"doThis\"", name)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,16 +11,19 @@ import (
|
|||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterDriver(SQLiteDriver{}, "sqlite")
|
||||
RegisterDriver(SQLiteDriver{}, "sqlite3")
|
||||
drv := &SQLiteDriver{}
|
||||
RegisterDriver(drv, "sqlite")
|
||||
RegisterDriver(drv, "sqlite3")
|
||||
}
|
||||
|
||||
// SQLiteDriver provides top level database functions
|
||||
type SQLiteDriver struct {
|
||||
migrationsTableName string
|
||||
}
|
||||
|
||||
func sqlitePath(u *url.URL) string {
|
||||
|
|
@ -31,13 +34,18 @@ func sqlitePath(u *url.URL) string {
|
|||
return str
|
||||
}
|
||||
|
||||
// SetMigrationsTableName sets the schema migrations table name
|
||||
func (drv *SQLiteDriver) SetMigrationsTableName(name string) {
|
||||
drv.migrationsTableName = name
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
func (drv *SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("sqlite3", sqlitePath(u))
|
||||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv SQLiteDriver) CreateDatabase(u *url.URL) error {
|
||||
func (drv *SQLiteDriver) CreateDatabase(u *url.URL) error {
|
||||
fmt.Printf("Creating: %s\n", sqlitePath(u))
|
||||
|
||||
db, err := drv.Open(u)
|
||||
|
|
@ -50,7 +58,7 @@ func (drv SQLiteDriver) CreateDatabase(u *url.URL) error {
|
|||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv SQLiteDriver) DropDatabase(u *url.URL) error {
|
||||
func (drv *SQLiteDriver) DropDatabase(u *url.URL) error {
|
||||
path := sqlitePath(u)
|
||||
fmt.Printf("Dropping: %s\n", path)
|
||||
|
||||
|
|
@ -65,36 +73,39 @@ func (drv SQLiteDriver) DropDatabase(u *url.URL) error {
|
|||
return os.Remove(path)
|
||||
}
|
||||
|
||||
func sqliteSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
migrationsTable := drv.quotedMigrationsTableName()
|
||||
|
||||
// load applied migrations
|
||||
migrations, err := queryColumn(db,
|
||||
"select quote(version) from schema_migrations order by version asc")
|
||||
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// build schema_migrations table data
|
||||
// build schema migrations table data
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("-- Dbmate schema migrations\n")
|
||||
|
||||
if len(migrations) > 0 {
|
||||
buf.WriteString("INSERT INTO schema_migrations (version) VALUES\n (" +
|
||||
strings.Join(migrations, "),\n (") +
|
||||
");\n")
|
||||
buf.WriteString(
|
||||
fmt.Sprintf("INSERT INTO %s (version) VALUES\n (", migrationsTable) +
|
||||
strings.Join(migrations, "),\n (") +
|
||||
");\n")
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// DumpSchema returns the current database schema
|
||||
func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
path := sqlitePath(u)
|
||||
schema, err := runCommand("sqlite3", path, ".schema")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrations, err := sqliteSchemaMigrationsDump(db)
|
||||
migrations, err := drv.schemaMigrationsDump(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -104,7 +115,7 @@ func (drv SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
|||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
_, err := os.Stat(sqlitePath(u))
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
|
|
@ -116,18 +127,19 @@ func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
|
|||
return true, nil
|
||||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
_, err := db.Exec("create table if not exists schema_migrations " +
|
||||
"(version varchar(255) primary key)")
|
||||
// CreateMigrationsTable creates the schema migrations table
|
||||
func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||
_, err := db.Exec(
|
||||
fmt.Sprintf("create table if not exists %s ", drv.quotedMigrationsTableName()) +
|
||||
"(version varchar(255) primary key)")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations order by version desc"
|
||||
func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
}
|
||||
|
|
@ -156,15 +168,19 @@ func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool
|
|||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv SQLiteDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
||||
func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec(
|
||||
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
||||
version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from schema_migrations where version = ?", version)
|
||||
func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec(
|
||||
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
|
||||
version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -172,7 +188,7 @@ func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error {
|
|||
// Ping verifies a connection to the database. Due to the way SQLite works, by
|
||||
// testing whether the database is valid, it will automatically create the database
|
||||
// if it does not already exist.
|
||||
func (drv SQLiteDriver) Ping(u *url.URL) error {
|
||||
func (drv *SQLiteDriver) Ping(u *url.URL) error {
|
||||
db, err := drv.Open(u)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -181,3 +197,14 @@ func (drv SQLiteDriver) Ping(u *url.URL) error {
|
|||
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
func (drv *SQLiteDriver) quotedMigrationsTableName() string {
|
||||
return drv.quoteIdentifier(drv.migrationsTableName)
|
||||
}
|
||||
|
||||
// quoteIdentifier quotes a table or column name
|
||||
// we fall back to lib/pq implementation since both use ansi standard (double quotes)
|
||||
// and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent
|
||||
func (drv *SQLiteDriver) quoteIdentifier(s string) string {
|
||||
return pq.QuoteIdentifier(s)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,8 +18,15 @@ func sqliteTestURL(t *testing.T) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
func testSQLiteDriver() *SQLiteDriver {
|
||||
drv := &SQLiteDriver{}
|
||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
||||
|
||||
return drv
|
||||
}
|
||||
|
||||
func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -37,7 +44,7 @@ func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
|
|||
}
|
||||
|
||||
func TestSQLiteCreateDropDatabase(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
u := sqliteTestURL(t)
|
||||
path := sqlitePath(u)
|
||||
|
||||
|
|
@ -64,7 +71,9 @@ func TestSQLiteCreateDropDatabase(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSQLiteDumpSchema(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := sqliteTestURL(t)
|
||||
|
||||
// prepare database
|
||||
|
|
@ -82,9 +91,9 @@ func TestSQLiteDumpSchema(t *testing.T) {
|
|||
// DumpSchema should return schema
|
||||
schema, err := drv.DumpSchema(u, db)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(schema), "CREATE TABLE schema_migrations")
|
||||
require.Contains(t, string(schema), "CREATE TABLE IF NOT EXISTS \"test_migrations\"")
|
||||
require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+
|
||||
"INSERT INTO schema_migrations (version) VALUES\n"+
|
||||
"INSERT INTO \"test_migrations\" (version) VALUES\n"+
|
||||
" ('abc1'),\n"+
|
||||
" ('abc2');\n")
|
||||
|
||||
|
|
@ -97,7 +106,7 @@ func TestSQLiteDumpSchema(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSQLiteDatabaseExists(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
u := sqliteTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
|
|
@ -120,31 +129,61 @@ func TestSQLiteDatabaseExists(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSQLiteCreateMigrationsTable(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
u := sqliteTestURL(t)
|
||||
db := prepTestSQLiteDB(t, u)
|
||||
defer mustClose(db)
|
||||
t.Run("default table", func(t *testing.T) {
|
||||
drv := testSQLiteDriver()
|
||||
u := sqliteTestURL(t)
|
||||
db := prepTestSQLiteDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Regexp(t, "no such table: schema_migrations", err.Error())
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Regexp(t, "no such table: schema_migrations", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("custom table", func(t *testing.T) {
|
||||
drv := testSQLiteDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := sqliteTestURL(t)
|
||||
db := prepTestSQLiteDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.Regexp(t, "no such table: test_migrations", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQLiteSelectMigrations(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := sqliteTestURL(t)
|
||||
db := prepTestSQLiteDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -152,7 +191,7 @@ func TestSQLiteSelectMigrations(t *testing.T) {
|
|||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
_, err = db.Exec(`insert into test_migrations (version)
|
||||
values ('abc2'), ('abc1'), ('abc3')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -171,7 +210,9 @@ func TestSQLiteSelectMigrations(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSQLiteInsertMigration(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := sqliteTestURL(t)
|
||||
db := prepTestSQLiteDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -180,7 +221,7 @@ func TestSQLiteInsertMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
|
|
@ -188,14 +229,16 @@ func TestSQLiteInsertMigration(t *testing.T) {
|
|||
err = drv.InsertMigration(db, "abc1")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").
|
||||
err = db.QueryRow("select count(*) from test_migrations where version = 'abc1'").
|
||||
Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestSQLiteDeleteMigration(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
drv.SetMigrationsTableName("test_migrations")
|
||||
|
||||
u := sqliteTestURL(t)
|
||||
db := prepTestSQLiteDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
|
@ -203,7 +246,7 @@ func TestSQLiteDeleteMigration(t *testing.T) {
|
|||
err := drv.CreateMigrationsTable(u, db)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
_, err = db.Exec(`insert into test_migrations (version)
|
||||
values ('abc1'), ('abc2')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -211,13 +254,13 @@ func TestSQLiteDeleteMigration(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
err = db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestSQLitePing(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
drv := testSQLiteDriver()
|
||||
u := sqliteTestURL(t)
|
||||
path := sqlitePath(u)
|
||||
|
||||
|
|
@ -249,3 +292,19 @@ func TestSQLitePing(t *testing.T) {
|
|||
err = drv.Ping(u)
|
||||
require.EqualError(t, err, "unable to open database file: is a directory")
|
||||
}
|
||||
|
||||
func TestSQLiteQuotedMigrationsTableName(t *testing.T) {
|
||||
t.Run("default name", func(t *testing.T) {
|
||||
drv := testSQLiteDriver()
|
||||
name := drv.quotedMigrationsTableName()
|
||||
require.Equal(t, `"schema_migrations"`, name)
|
||||
})
|
||||
|
||||
t.Run("custom name", func(t *testing.T) {
|
||||
drv := testSQLiteDriver()
|
||||
drv.SetMigrationsTableName("fooMigrations")
|
||||
|
||||
name := drv.quotedMigrationsTableName()
|
||||
require.Equal(t, `"fooMigrations"`, name)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,8 +104,8 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) {
|
|||
// queryColumn runs a SQL statement and returns a slice of strings
|
||||
// it is assumed that the statement returns only one column
|
||||
// e.g. schema_migrations table
|
||||
func queryColumn(db Transaction, query string) ([]string, error) {
|
||||
rows, err := db.Query(query)
|
||||
func queryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -128,12 +128,12 @@ func queryColumn(db Transaction, query string) ([]string, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// queryRow runs a SQL statement and returns a single string
|
||||
// queryValue runs a SQL statement and returns a single string
|
||||
// it is assumed that the statement returns only one row and one column
|
||||
// sql NULL is returned as empty string
|
||||
func queryRow(db Transaction, query string) (string, error) {
|
||||
func queryValue(db Transaction, query string, args ...interface{}) (string, error) {
|
||||
var result sql.NullString
|
||||
err := db.QueryRow(query).Scan(&result)
|
||||
err := db.QueryRow(query, args...).Scan(&result)
|
||||
if err != nil || !result.Valid {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
|
@ -33,3 +35,24 @@ func TestTrimLeadingSQLComments(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
require.Equal(t, "real stuff\n-- end\n", string(out))
|
||||
}
|
||||
|
||||
func TestQueryColumn(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db, err := sql.Open("postgres", u.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
val, err := queryColumn(db, "select concat('foo_', unnest($1::text[]))",
|
||||
pq.Array([]string{"hi", "there"}))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
|
||||
}
|
||||
|
||||
func TestQueryValue(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db, err := sql.Open("postgres", u.String())
|
||||
require.NoError(t, err)
|
||||
|
||||
val, err := queryValue(db, "select $1::int + $2::int", "5", 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "7", val)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue