postgres: Support custom schema for schema_migrations table (#167)

Instead of hardcoding `schema_migrations` table to the `public` schema, add support for specifying a schema via the `search_path` URL parameter.

**Backwards compatibility note**: If anyone was using the previously undocumented `search_path` behavior (affecting migrations themselves, but always storing the `schema_migrations` table in `public`), you will need to either prepend `public` to your `search_path`, or migrate your `schema_migrations` table to your primary schema:

```sql
ALTER TABLE public.schema_migrations SET SCHEMA myschema;
```

Closes #110
This commit is contained in:
Adrian Macneil 2020-11-01 13:30:35 +13:00 committed by GitHub
parent d4ecd0b259
commit 55a8065efe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 20 deletions

View file

@ -30,6 +30,8 @@ func RegisterDriver(drv Driver, scheme string) {
// Transaction can represent a database or open transaction
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// GetDriver loads a database driver by name

View file

@ -125,20 +125,25 @@ func (drv PostgresDriver) DropDatabase(u *url.URL) error {
return err
}
func postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
// load applied migrations
migrations, err := queryColumn(db,
"select quote_literal(version) from public.schema_migrations order by version asc")
func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable, err := drv.migrationsTableName(db)
if err != nil {
return nil, err
}
// build schema_migrations table data
// load applied migrations
migrations, err := queryColumn(db,
"select quote_literal(version) from "+migrationsTable+" order by version asc")
if err != nil {
return nil, err
}
// build migrations table data
var buf bytes.Buffer
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n")
if len(migrations) > 0 {
buf.WriteString("INSERT INTO public.schema_migrations (version) VALUES\n (" +
buf.WriteString("INSERT INTO " + migrationsTable + " (version) VALUES\n (" +
strings.Join(migrations, "),\n (") +
");\n")
}
@ -156,7 +161,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
return nil, err
}
migrations, err := postgresSchemaMigrationsDump(db)
migrations, err := drv.postgresSchemaMigrationsDump(db)
if err != nil {
return nil, err
}
@ -187,8 +192,13 @@ func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
// CreateMigrationsTable creates the schema_migrations table
func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
_, err := db.Exec("create table if not exists public.schema_migrations " +
"(version varchar(255) primary key)")
migrationsTable, err := drv.migrationsTableName(db)
if err != nil {
return err
}
_, err = db.Exec("create table if not exists " + migrationsTable +
" (version varchar(255) primary key)")
return err
}
@ -196,7 +206,12 @@ func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from public.schema_migrations order by version desc"
migrationsTable, err := drv.migrationsTableName(db)
if err != nil {
return nil, err
}
query := "select version from " + migrationsTable + " order by version desc"
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
}
@ -222,14 +237,24 @@ func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bo
// InsertMigration adds a new migration record
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into public.schema_migrations (version) values ($1)", version)
migrationsTable, err := drv.migrationsTableName(db)
if err != nil {
return err
}
_, err = db.Exec("insert into "+migrationsTable+" (version) values ($1)", version)
return err
}
// DeleteMigration removes a migration record
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec("delete from public.schema_migrations where version = $1", version)
migrationsTable, err := drv.migrationsTableName(db)
if err != nil {
return err
}
_, err = db.Exec("delete from "+migrationsTable+" where version = $1", version)
return err
}
@ -259,3 +284,18 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
return err
}
func (drv PostgresDriver) migrationsTableName(db Transaction) (string, error) {
// get current schema
schema, err := queryRow(db, "select quote_ident(current_schema())")
if err != nil {
return "", err
}
// if the search path is empty, or does not contain a valid schema, default to public
if schema == "" {
schema = "public"
}
return schema + ".schema_migrations", nil
}

View file

@ -15,9 +15,8 @@ func postgresTestURL(t *testing.T) *url.URL {
return u
}
func prepTestPostgresDB(t *testing.T) *sql.DB {
func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB {
drv := PostgresDriver{}
u := postgresTestURL(t)
// drop any existing database
err := drv.DropDatabase(u)
@ -130,7 +129,7 @@ func TestPostgresDumpSchema(t *testing.T) {
u := postgresTestURL(t)
// prepare database
db := prepTestPostgresDB(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
@ -198,7 +197,8 @@ func TestPostgresDatabaseExists_Error(t *testing.T) {
func TestPostgresCreateMigrationsTable(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// migrations table should not exist
@ -221,7 +221,8 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
func TestPostgresSelectMigrations(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
@ -247,7 +248,8 @@ func TestPostgresSelectMigrations(t *testing.T) {
func TestPostgresInsertMigration(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
@ -270,7 +272,8 @@ func TestPostgresInsertMigration(t *testing.T) {
func TestPostgresDeleteMigration(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
@ -307,3 +310,55 @@ func TestPostgresPing(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused")
}
func TestMigrationsTableName(t *testing.T) {
drv := PostgresDriver{}
t.Run("default schema", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
name, err := drv.migrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name)
})
t.Run("custom schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo,bar,public")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// if "foo" schema does not exist, current schema should be "public"
_, err = db.Exec("drop schema if exists foo")
require.NoError(t, err)
_, err = db.Exec("drop schema if exists bar")
require.NoError(t, err)
name, err := drv.migrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name)
// if "foo" schema exists, it should be used
_, err = db.Exec("create schema foo")
require.NoError(t, err)
name, err = drv.migrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "foo.schema_migrations", name)
})
t.Run("no schema", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
// this is an unlikely edge case, but if for some reason there is
// no current schema then we should default to "public"
_, err := db.Exec("select pg_catalog.set_config('search_path', '', false)")
require.NoError(t, err)
name, err := drv.migrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name)
})
}

View file

@ -104,7 +104,7 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) {
// queryColumn runs a SQL statement and returns a slice of strings
// it is assumed that the statement returns only one column
// e.g. schema_migrations table
func queryColumn(db *sql.DB, query string) ([]string, error) {
func queryColumn(db Transaction, query string) ([]string, error) {
rows, err := db.Query(query)
if err != nil {
return nil, err
@ -128,6 +128,19 @@ func queryColumn(db *sql.DB, query string) ([]string, error) {
return result, nil
}
// queryRow runs a SQL statement and returns a single string
// it is assumed that the statement returns only one row and one column
// sql NULL is returned as empty string
func queryRow(db Transaction, query string) (string, error) {
var result sql.NullString
err := db.QueryRow(query).Scan(&result)
if err != nil || !result.Valid {
return "", err
}
return result.String, nil
}
func printVerbose(result sql.Result) {
lastInsertID, err := result.LastInsertId()
if err == nil {