From 8cb676158eb1b78eb464a4a80c95163d872c3be9 Mon Sep 17 00:00:00 2001 From: Adrian Macneil Date: Mon, 30 Nov 2015 22:06:12 -0800 Subject: [PATCH] Abstract command tests to support multiple databases --- commands_test.go | 100 ++++++++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 40 deletions(-) diff --git a/commands_test.go b/commands_test.go index 64d65fb..72701ad 100644 --- a/commands_test.go +++ b/commands_test.go @@ -15,7 +15,7 @@ import ( var stubsDir string -func testContext(t *testing.T) *cli.Context { +func testContext(t *testing.T, u *url.URL) *cli.Context { var err error if stubsDir == "" { stubsDir, err = filepath.Abs("./stubs") @@ -25,7 +25,6 @@ func testContext(t *testing.T) *cli.Context { err = os.Chdir(stubsDir) require.Nil(t, err) - u := testURL(t) err = os.Setenv("DATABASE_URL", u.String()) require.Nil(t, err) @@ -38,7 +37,7 @@ func testContext(t *testing.T) *cli.Context { return cli.NewContext(app, flagset, nil) } -func testURL(t *testing.T) *url.URL { +func postgresTestURL(t *testing.T) *url.URL { str := os.Getenv("POSTGRES_PORT") require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable") @@ -53,6 +52,13 @@ func testURL(t *testing.T) *url.URL { return u } +func testURLs(t *testing.T) []*url.URL { + return []*url.URL{ + postgresTestURL(t), + } + +} + func mustClose(c io.Closer) { if err := c.Close(); err != nil { panic(err) @@ -60,21 +66,20 @@ func mustClose(c io.Closer) { } func TestGetDatabaseUrl(t *testing.T) { - ctx := testContext(t) - - err := os.Setenv("DATABASE_URL", "postgres://example.org/db") + envURL, err := url.Parse("foo://example.org/db") require.Nil(t, err) + ctx := testContext(t, envURL) u, err := main.GetDatabaseURL(ctx) require.Nil(t, err) - require.Equal(t, "postgres", u.Scheme) + require.Equal(t, "foo", u.Scheme) require.Equal(t, "example.org", u.Host) require.Equal(t, "/db", u.Path) } -func TestMigrateCommand(t *testing.T) { - ctx := testContext(t) +func testMigrateCommandURL(t *testing.T, u *url.URL) { + ctx := testContext(t, u) // drop and recreate database err := main.DropCommand(ctx) @@ -87,8 +92,40 @@ func TestMigrateCommand(t *testing.T) { require.Nil(t, err) // verify results - u := testURL(t) - db, err := sql.Open("postgres", u.String()) + db, err := sql.Open(u.Scheme, u.String()) + require.Nil(t, err) + defer mustClose(db) + + count := 0 + err = db.QueryRow(`select count(*) from schema_migrations + where version = '20151129054053'`).Scan(&count) + require.Nil(t, err) + require.Equal(t, 1, count) + + err = db.QueryRow("select count(*) from users").Scan(&count) + require.Nil(t, err) + require.Equal(t, 1, count) +} + +func TestMigrateCommand(t *testing.T) { + for _, u := range testURLs(t) { + testMigrateCommandURL(t, u) + } +} + +func testUpCommandURL(t *testing.T, u *url.URL) { + ctx := testContext(t, u) + + // drop database + err := main.DropCommand(ctx) + require.Nil(t, err) + + // create and migrate + err = main.UpCommand(ctx) + require.Nil(t, err) + + // verify results + db, err := sql.Open(u.Scheme, u.String()) require.Nil(t, err) defer mustClose(db) @@ -104,35 +141,13 @@ func TestMigrateCommand(t *testing.T) { } func TestUpCommand(t *testing.T) { - ctx := testContext(t) - - // drop database - err := main.DropCommand(ctx) - require.Nil(t, err) - - // create and migrate - err = main.UpCommand(ctx) - require.Nil(t, err) - - // verify results - u := testURL(t) - db, err := sql.Open("postgres", u.String()) - require.Nil(t, err) - defer mustClose(db) - - count := 0 - err = db.QueryRow(`select count(*) from schema_migrations - where version = '20151129054053'`).Scan(&count) - require.Nil(t, err) - require.Equal(t, 1, count) - - err = db.QueryRow("select count(*) from users").Scan(&count) - require.Nil(t, err) - require.Equal(t, 1, count) + for _, u := range testURLs(t) { + testUpCommandURL(t, u) + } } -func TestRollbackCommand(t *testing.T) { - ctx := testContext(t) +func testRollbackCommandURL(t *testing.T, u *url.URL) { + ctx := testContext(t, u) // drop, recreate, and migrate database err := main.DropCommand(ctx) @@ -143,8 +158,7 @@ func TestRollbackCommand(t *testing.T) { require.Nil(t, err) // verify migration - u := testURL(t) - db, err := sql.Open("postgres", u.String()) + db, err := sql.Open(u.Scheme, u.String()) require.Nil(t, err) defer mustClose(db) @@ -166,3 +180,9 @@ func TestRollbackCommand(t *testing.T) { err = db.QueryRow("select count(*) from users").Scan(&count) require.Equal(t, "pq: relation \"users\" does not exist", err.Error()) } + +func TestRollbackCommand(t *testing.T) { + for _, u := range testURLs(t) { + testRollbackCommandURL(t, u) + } +}