mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2025-12-11 23:50:04 +01:00
postgres: Support custom schema for schema_migrations table (#167)
Instead of hardcoding `schema_migrations` table to the `public` schema, add support for specifying a schema via the `search_path` URL parameter. **Backwards compatibility note**: If anyone was using the previously undocumented `search_path` behavior (affecting migrations themselves, but always storing the `schema_migrations` table in `public`), you will need to either prepend `public` to your `search_path`, or migrate your `schema_migrations` table to your primary schema: ```sql ALTER TABLE public.schema_migrations SET SCHEMA myschema; ``` Closes #110
This commit is contained in:
parent
d4ecd0b259
commit
55a8065efe
5 changed files with 141 additions and 20 deletions
|
|
@ -30,6 +30,8 @@ func RegisterDriver(drv Driver, scheme string) {
|
|||
// Transaction can represent a database or open transaction
|
||||
type Transaction interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// GetDriver loads a database driver by name
|
||||
|
|
|
|||
|
|
@ -125,20 +125,25 @@ func (drv PostgresDriver) DropDatabase(u *url.URL) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
// load applied migrations
|
||||
migrations, err := queryColumn(db,
|
||||
"select quote_literal(version) from public.schema_migrations order by version asc")
|
||||
func (drv PostgresDriver) postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// build schema_migrations table data
|
||||
// load applied migrations
|
||||
migrations, err := queryColumn(db,
|
||||
"select quote_literal(version) from "+migrationsTable+" order by version asc")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// build migrations table data
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("\n--\n-- Dbmate schema migrations\n--\n\n")
|
||||
|
||||
if len(migrations) > 0 {
|
||||
buf.WriteString("INSERT INTO public.schema_migrations (version) VALUES\n (" +
|
||||
buf.WriteString("INSERT INTO " + migrationsTable + " (version) VALUES\n (" +
|
||||
strings.Join(migrations, "),\n (") +
|
||||
");\n")
|
||||
}
|
||||
|
|
@ -156,7 +161,7 @@ func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
migrations, err := postgresSchemaMigrationsDump(db)
|
||||
migrations, err := drv.postgresSchemaMigrationsDump(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -187,8 +192,13 @@ func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
|||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||
_, err := db.Exec("create table if not exists public.schema_migrations " +
|
||||
"(version varchar(255) primary key)")
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec("create table if not exists " + migrationsTable +
|
||||
" (version varchar(255) primary key)")
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -196,7 +206,12 @@ func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
|
|||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from public.schema_migrations order by version desc"
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := "select version from " + migrationsTable + " order by version desc"
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
}
|
||||
|
|
@ -222,14 +237,24 @@ func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bo
|
|||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into public.schema_migrations (version) values ($1)", version)
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec("insert into "+migrationsTable+" (version) values ($1)", version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from public.schema_migrations where version = $1", version)
|
||||
migrationsTable, err := drv.migrationsTableName(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec("delete from "+migrationsTable+" where version = $1", version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -259,3 +284,18 @@ func (drv PostgresDriver) Ping(u *url.URL) error {
|
|||
|
||||
return err
|
||||
}
|
||||
|
||||
func (drv PostgresDriver) migrationsTableName(db Transaction) (string, error) {
|
||||
// get current schema
|
||||
schema, err := queryRow(db, "select quote_ident(current_schema())")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// if the search path is empty, or does not contain a valid schema, default to public
|
||||
if schema == "" {
|
||||
schema = "public"
|
||||
}
|
||||
|
||||
return schema + ".schema_migrations", nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,9 +15,8 @@ func postgresTestURL(t *testing.T) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
func prepTestPostgresDB(t *testing.T) *sql.DB {
|
||||
func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -130,7 +129,7 @@ func TestPostgresDumpSchema(t *testing.T) {
|
|||
u := postgresTestURL(t)
|
||||
|
||||
// prepare database
|
||||
db := prepTestPostgresDB(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -198,7 +197,8 @@ func TestPostgresDatabaseExists_Error(t *testing.T) {
|
|||
|
||||
func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
|
|
@ -221,7 +221,8 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
|||
|
||||
func TestPostgresSelectMigrations(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -247,7 +248,8 @@ func TestPostgresSelectMigrations(t *testing.T) {
|
|||
|
||||
func TestPostgresInsertMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -270,7 +272,8 @@ func TestPostgresInsertMigration(t *testing.T) {
|
|||
|
||||
func TestPostgresDeleteMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -307,3 +310,55 @@ func TestPostgresPing(t *testing.T) {
|
|||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "connect: connection refused")
|
||||
}
|
||||
|
||||
func TestMigrationsTableName(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
|
||||
t.Run("default schema", func(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
name, err := drv.migrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.schema_migrations", name)
|
||||
})
|
||||
|
||||
t.Run("custom schema", func(t *testing.T) {
|
||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo,bar,public")
|
||||
require.NoError(t, err)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// if "foo" schema does not exist, current schema should be "public"
|
||||
_, err = db.Exec("drop schema if exists foo")
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("drop schema if exists bar")
|
||||
require.NoError(t, err)
|
||||
name, err := drv.migrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.schema_migrations", name)
|
||||
|
||||
// if "foo" schema exists, it should be used
|
||||
_, err = db.Exec("create schema foo")
|
||||
require.NoError(t, err)
|
||||
name, err = drv.migrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foo.schema_migrations", name)
|
||||
})
|
||||
|
||||
t.Run("no schema", func(t *testing.T) {
|
||||
u := postgresTestURL(t)
|
||||
db := prepTestPostgresDB(t, u)
|
||||
defer mustClose(db)
|
||||
|
||||
// this is an unlikely edge case, but if for some reason there is
|
||||
// no current schema then we should default to "public"
|
||||
_, err := db.Exec("select pg_catalog.set_config('search_path', '', false)")
|
||||
require.NoError(t, err)
|
||||
|
||||
name, err := drv.migrationsTableName(db)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "public.schema_migrations", name)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) {
|
|||
// queryColumn runs a SQL statement and returns a slice of strings
|
||||
// it is assumed that the statement returns only one column
|
||||
// e.g. schema_migrations table
|
||||
func queryColumn(db *sql.DB, query string) ([]string, error) {
|
||||
func queryColumn(db Transaction, query string) ([]string, error) {
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -128,6 +128,19 @@ func queryColumn(db *sql.DB, query string) ([]string, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// queryRow runs a SQL statement and returns a single string
|
||||
// it is assumed that the statement returns only one row and one column
|
||||
// sql NULL is returned as empty string
|
||||
func queryRow(db Transaction, query string) (string, error) {
|
||||
var result sql.NullString
|
||||
err := db.QueryRow(query).Scan(&result)
|
||||
if err != nil || !result.Valid {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return result.String, nil
|
||||
}
|
||||
|
||||
func printVerbose(result sql.Result) {
|
||||
lastInsertID, err := result.LastInsertId()
|
||||
if err == nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue