Add --wait flag (#112)

 When using dbmate as a container, since the base image is distroless, we can't do `dbmate up && dbmate wait`, which makes config a bit more cumbersome (e.g. in kubernetes an extra initContainer).

Closes #111
Closes #112
This commit is contained in:
Reuben Thomas-Davis 2019-12-31 20:47:27 +00:00 committed by Adrian Macneil
parent 98066fadaa
commit 1e45bd774c
4 changed files with 107 additions and 0 deletions

View file

@ -30,6 +30,7 @@ type DB struct {
DatabaseURL *url.URL
MigrationsDir string
SchemaFile string
WaitBefore bool
WaitInterval time.Duration
WaitTimeout time.Duration
}
@ -41,6 +42,7 @@ func New(databaseURL *url.URL) *DB {
DatabaseURL: databaseURL,
MigrationsDir: DefaultMigrationsDir,
SchemaFile: DefaultSchemaFile,
WaitBefore: false,
WaitInterval: DefaultWaitInterval,
WaitTimeout: DefaultWaitTimeout,
}
@ -87,6 +89,13 @@ func (db *DB) Wait() error {
// CreateAndMigrate creates the database (if necessary) and runs migrations
func (db *DB) CreateAndMigrate() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, err := db.GetDriver()
if err != nil {
return err
@ -108,6 +117,13 @@ func (db *DB) CreateAndMigrate() error {
// Create creates the current database
func (db *DB) Create() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, err := db.GetDriver()
if err != nil {
return err
@ -118,6 +134,13 @@ func (db *DB) Create() error {
// Drop drops the current database (if it exists)
func (db *DB) Drop() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, err := db.GetDriver()
if err != nil {
return err
@ -128,6 +151,13 @@ func (db *DB) Drop() error {
// DumpSchema writes the current database schema to a file
func (db *DB) DumpSchema() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, sqlDB, err := db.openDatabaseForMigration()
if err != nil {
return err
@ -233,6 +263,13 @@ func (db *DB) Migrate() error {
return fmt.Errorf("no migration files found")
}
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, sqlDB, err := db.openDatabaseForMigration()
if err != nil {
return err
@ -340,6 +377,13 @@ func migrationVersion(filename string) string {
// Rollback rolls back the most recent migration
func (db *DB) Rollback() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, sqlDB, err := db.openDatabaseForMigration()
if err != nil {
return err

View file

@ -38,6 +38,7 @@ func TestNew(t *testing.T) {
require.Equal(t, u.String(), db.DatabaseURL.String())
require.Equal(t, "./db/migrations", db.MigrationsDir)
require.Equal(t, "./db/schema.sql", db.SchemaFile)
require.False(t, db.WaitBefore)
require.Equal(t, time.Second, db.WaitInterval)
require.Equal(t, 60*time.Second, db.WaitTimeout)
}
@ -150,6 +151,55 @@ func TestAutoDumpSchema(t *testing.T) {
require.Contains(t, string(schema), "-- PostgreSQL database dump")
}
func checkWaitCalled(t *testing.T, u *url.URL, command func() error) {
oldHost := u.Host
u.Host = "postgres:404"
err := command()
require.Error(t, err)
require.Contains(t, err.Error(), "unable to connect to database: dial tcp")
require.Contains(t, err.Error(), "connect: connection refused")
u.Host = oldHost
}
func TestWaitBefore(t *testing.T) {
u := postgresTestURL(t)
db := newTestDB(t, u)
db.WaitBefore = true
// so that checkWaitCalled returns quickly
db.WaitInterval = time.Millisecond
db.WaitTimeout = 5 * time.Millisecond
// drop database
err := db.Drop()
require.NoError(t, err)
checkWaitCalled(t, u, db.Drop)
// create
err = db.Create()
require.NoError(t, err)
checkWaitCalled(t, u, db.Create)
// create and migrate
err = db.CreateAndMigrate()
require.NoError(t, err)
checkWaitCalled(t, u, db.CreateAndMigrate)
// migrate
err = db.Migrate()
require.NoError(t, err)
checkWaitCalled(t, u, db.Migrate)
// rollback
err = db.Rollback()
require.NoError(t, err)
checkWaitCalled(t, u, db.Rollback)
// dump
err = db.DumpSchema()
require.NoError(t, err)
checkWaitCalled(t, u, db.DumpSchema)
}
func testURLs(t *testing.T) []*url.URL {
return []*url.URL{
postgresTestURL(t),