Add tests for migrate/up/rollback commands

This commit is contained in:
Adrian Macneil 2015-11-28 23:45:23 -07:00
parent e87dd1e608
commit a2e16a66d2
2 changed files with 147 additions and 3 deletions

View file

@ -1,15 +1,34 @@
package main_test package main_test
import ( import (
"database/sql"
"flag" "flag"
"github.com/adrianmacneil/dbmate" "github.com/adrianmacneil/dbmate"
"github.com/codegangsta/cli" "github.com/codegangsta/cli"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io"
"net/url"
"os" "os"
"path/filepath"
"testing" "testing"
) )
func newContext() *cli.Context { var stubsDir string
func testContext(t *testing.T) *cli.Context {
var err error
if stubsDir == "" {
stubsDir, err = filepath.Abs("./stubs")
require.Nil(t, err)
}
err = os.Chdir(stubsDir)
require.Nil(t, err)
u := testURL(t)
err = os.Setenv("DATABASE_URL", u.String())
require.Nil(t, err)
app := main.NewApp() app := main.NewApp()
flagset := flag.NewFlagSet(app.Name, flag.ContinueOnError) flagset := flag.NewFlagSet(app.Name, flag.ContinueOnError)
for _, f := range app.Flags { for _, f := range app.Flags {
@ -19,11 +38,33 @@ func newContext() *cli.Context {
return cli.NewContext(app, flagset, nil) return cli.NewContext(app, flagset, nil)
} }
func TestGetDatabaseUrl_Default(t *testing.T) { func testURL(t *testing.T) *url.URL {
str := os.Getenv("POSTGRES_PORT")
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
u, err := url.Parse(str)
require.Nil(t, err)
u.Scheme = "postgres"
u.User = url.User("postgres")
u.Path = "/dbmate"
u.RawQuery = "sslmode=disable"
return u
}
func mustClose(c io.Closer) {
if err := c.Close(); err != nil {
panic(err)
}
}
func TestGetDatabaseUrl(t *testing.T) {
ctx := testContext(t)
err := os.Setenv("DATABASE_URL", "postgres://example.org/db") err := os.Setenv("DATABASE_URL", "postgres://example.org/db")
require.Nil(t, err) require.Nil(t, err)
ctx := newContext()
u, err := main.GetDatabaseURL(ctx) u, err := main.GetDatabaseURL(ctx)
require.Nil(t, err) require.Nil(t, err)
@ -31,3 +72,97 @@ func TestGetDatabaseUrl_Default(t *testing.T) {
require.Equal(t, "example.org", u.Host) require.Equal(t, "example.org", u.Host)
require.Equal(t, "/db", u.Path) require.Equal(t, "/db", u.Path)
} }
func TestMigrateCommand(t *testing.T) {
ctx := testContext(t)
// drop and recreate database
err := main.DropCommand(ctx)
require.Nil(t, err)
err = main.CreateCommand(ctx)
require.Nil(t, err)
// migrate
err = main.MigrateCommand(ctx)
require.Nil(t, err)
// verify results
u := testURL(t)
db, err := sql.Open("postgres", u.String())
require.Nil(t, err)
defer mustClose(db)
count := 0
err = db.QueryRow(`select count(*) from schema_migrations
where version = '20151129054053'`).Scan(&count)
require.Nil(t, err)
require.Equal(t, 1, count)
err = db.QueryRow("select count(*) from users").Scan(&count)
require.Nil(t, err)
require.Equal(t, 1, count)
}
func TestUpCommand(t *testing.T) {
ctx := testContext(t)
// drop database
err := main.DropCommand(ctx)
require.Nil(t, err)
// create and migrate
err = main.UpCommand(ctx)
require.Nil(t, err)
// verify results
u := testURL(t)
db, err := sql.Open("postgres", u.String())
require.Nil(t, err)
defer mustClose(db)
count := 0
err = db.QueryRow(`select count(*) from schema_migrations
where version = '20151129054053'`).Scan(&count)
require.Nil(t, err)
require.Equal(t, 1, count)
err = db.QueryRow("select count(*) from users").Scan(&count)
require.Nil(t, err)
require.Equal(t, 1, count)
}
func TestRollbackCommand(t *testing.T) {
ctx := testContext(t)
// drop, recreate, and migrate database
err := main.DropCommand(ctx)
require.Nil(t, err)
err = main.CreateCommand(ctx)
require.Nil(t, err)
err = main.MigrateCommand(ctx)
require.Nil(t, err)
// verify migration
u := testURL(t)
db, err := sql.Open("postgres", u.String())
require.Nil(t, err)
defer mustClose(db)
count := 0
err = db.QueryRow(`select count(*) from schema_migrations
where version = '20151129054053'`).Scan(&count)
require.Nil(t, err)
require.Equal(t, 1, count)
// rollback
err = main.RollbackCommand(ctx)
require.Nil(t, err)
// verify rollback
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.Nil(t, err)
require.Equal(t, 0, count)
err = db.QueryRow("select count(*) from users").Scan(&count)
require.Equal(t, "pq: relation \"users\" does not exist", err.Error())
}

View file

@ -0,0 +1,9 @@
-- migrate:up
CREATE TABLE users (
id integer,
name varchar
);
INSERT INTO users (id, name) VALUES (1, 'alice');
-- migrate:down
DROP TABLE users;