diff --git a/README.md b/README.md index f5c235b..32453ab 100644 --- a/README.md +++ b/README.md @@ -271,6 +271,13 @@ Waiting for database.... Creating: myapp_development ``` +Alternatively you can use the `--wait` flag: +```sh +$ dbmate --wait up +Waiting for database.... +Creating: myapp_development +``` + If the database is still not available after 60 seconds, the command will return an error: ```sh @@ -289,6 +296,7 @@ The following command line options are available with all commands. You must use * `--migrations-dir, -d "./db/migrations"` - where to keep the migration files. * `--schema-file, -s "./db/schema.sql"` - a path to keep the schema.sql file. * `--no-dump-schema` - don't auto-update the schema.sql file on migrate/rollback +* `--wait` - wait for the db to become available before executing the subsequent command For example, before running your test suite, you may wish to drop and recreate the test database. One easy way to do this is to store your test database connection URL in the `TEST_DATABASE_URL` environment variable: diff --git a/main.go b/main.go index 61d24f0..807ce87 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,10 @@ func NewApp() *cli.App { Name: "no-dump-schema", Usage: "don't update the schema file on migrate/rollback", }, + cli.BoolFlag{ + Name: "wait", + Usage: "wait for the db to become available before executing the subsequent command", + }, } app.Commands = []cli.Command{ @@ -139,6 +143,7 @@ func action(f func(*dbmate.DB, *cli.Context) error) cli.ActionFunc { db.AutoDumpSchema = !c.GlobalBool("no-dump-schema") db.MigrationsDir = c.GlobalString("migrations-dir") db.SchemaFile = c.GlobalString("schema-file") + db.WaitBefore = c.GlobalBool("wait") return f(db, c) } diff --git a/pkg/dbmate/db.go b/pkg/dbmate/db.go index 5917b21..fd5ba53 100644 --- a/pkg/dbmate/db.go +++ b/pkg/dbmate/db.go @@ -30,6 +30,7 @@ type DB struct { DatabaseURL *url.URL MigrationsDir string SchemaFile string + WaitBefore bool WaitInterval time.Duration WaitTimeout time.Duration } @@ -41,6 +42,7 @@ func New(databaseURL *url.URL) *DB { DatabaseURL: databaseURL, MigrationsDir: DefaultMigrationsDir, SchemaFile: DefaultSchemaFile, + WaitBefore: false, WaitInterval: DefaultWaitInterval, WaitTimeout: DefaultWaitTimeout, } @@ -87,6 +89,13 @@ func (db *DB) Wait() error { // CreateAndMigrate creates the database (if necessary) and runs migrations func (db *DB) CreateAndMigrate() error { + if db.WaitBefore { + err := db.Wait() + if err != nil { + return err + } + } + drv, err := db.GetDriver() if err != nil { return err @@ -108,6 +117,13 @@ func (db *DB) CreateAndMigrate() error { // Create creates the current database func (db *DB) Create() error { + if db.WaitBefore { + err := db.Wait() + if err != nil { + return err + } + } + drv, err := db.GetDriver() if err != nil { return err @@ -118,6 +134,13 @@ func (db *DB) Create() error { // Drop drops the current database (if it exists) func (db *DB) Drop() error { + if db.WaitBefore { + err := db.Wait() + if err != nil { + return err + } + } + drv, err := db.GetDriver() if err != nil { return err @@ -128,6 +151,13 @@ func (db *DB) Drop() error { // DumpSchema writes the current database schema to a file func (db *DB) DumpSchema() error { + if db.WaitBefore { + err := db.Wait() + if err != nil { + return err + } + } + drv, sqlDB, err := db.openDatabaseForMigration() if err != nil { return err @@ -233,6 +263,13 @@ func (db *DB) Migrate() error { return fmt.Errorf("no migration files found") } + if db.WaitBefore { + err := db.Wait() + if err != nil { + return err + } + } + drv, sqlDB, err := db.openDatabaseForMigration() if err != nil { return err @@ -340,6 +377,13 @@ func migrationVersion(filename string) string { // Rollback rolls back the most recent migration func (db *DB) Rollback() error { + if db.WaitBefore { + err := db.Wait() + if err != nil { + return err + } + } + drv, sqlDB, err := db.openDatabaseForMigration() if err != nil { return err diff --git a/pkg/dbmate/db_test.go b/pkg/dbmate/db_test.go index 406f0e0..61691fb 100644 --- a/pkg/dbmate/db_test.go +++ b/pkg/dbmate/db_test.go @@ -38,6 +38,7 @@ func TestNew(t *testing.T) { require.Equal(t, u.String(), db.DatabaseURL.String()) require.Equal(t, "./db/migrations", db.MigrationsDir) require.Equal(t, "./db/schema.sql", db.SchemaFile) + require.False(t, db.WaitBefore) require.Equal(t, time.Second, db.WaitInterval) require.Equal(t, 60*time.Second, db.WaitTimeout) } @@ -150,6 +151,55 @@ func TestAutoDumpSchema(t *testing.T) { require.Contains(t, string(schema), "-- PostgreSQL database dump") } +func checkWaitCalled(t *testing.T, u *url.URL, command func() error) { + oldHost := u.Host + u.Host = "postgres:404" + err := command() + require.Error(t, err) + require.Contains(t, err.Error(), "unable to connect to database: dial tcp") + require.Contains(t, err.Error(), "connect: connection refused") + u.Host = oldHost +} + +func TestWaitBefore(t *testing.T) { + u := postgresTestURL(t) + db := newTestDB(t, u) + db.WaitBefore = true + // so that checkWaitCalled returns quickly + db.WaitInterval = time.Millisecond + db.WaitTimeout = 5 * time.Millisecond + + // drop database + err := db.Drop() + require.NoError(t, err) + checkWaitCalled(t, u, db.Drop) + + // create + err = db.Create() + require.NoError(t, err) + checkWaitCalled(t, u, db.Create) + + // create and migrate + err = db.CreateAndMigrate() + require.NoError(t, err) + checkWaitCalled(t, u, db.CreateAndMigrate) + + // migrate + err = db.Migrate() + require.NoError(t, err) + checkWaitCalled(t, u, db.Migrate) + + // rollback + err = db.Rollback() + require.NoError(t, err) + checkWaitCalled(t, u, db.Rollback) + + // dump + err = db.DumpSchema() + require.NoError(t, err) + checkWaitCalled(t, u, db.DumpSchema) +} + func testURLs(t *testing.T) []*url.URL { return []*url.URL{ postgresTestURL(t),