diff --git a/commands.go b/commands.go index 7ee9c82..48955df 100644 --- a/commands.go +++ b/commands.go @@ -193,7 +193,7 @@ func MigrateCommand(ctx *cli.Context) error { for filename := range available { ver := migrationVersion(filename) - if _, ok := applied[ver]; ok { + if ok := applied[ver]; ok { // migration already applied continue } diff --git a/driver/driver.go b/driver/driver.go index 6c88ffb..f621896 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -15,7 +15,7 @@ type Driver interface { CreateDatabase(*url.URL) error DropDatabase(*url.URL) error CreateMigrationsTable(*sql.DB) error - SelectMigrations(*sql.DB, int) (map[string]struct{}, error) + SelectMigrations(*sql.DB, int) (map[string]bool, error) InsertMigration(shared.Transaction, string) error DeleteMigration(shared.Transaction, string) error } diff --git a/driver/driver_test.go b/driver/driver_test.go new file mode 100644 index 0000000..5c4c576 --- /dev/null +++ b/driver/driver_test.go @@ -0,0 +1,21 @@ +package driver_test + +import ( + "github.com/adrianmacneil/dbmate/driver" + "github.com/adrianmacneil/dbmate/driver/postgres" + "github.com/stretchr/testify/require" + "testing" +) + +func TestGet_Postgres(t *testing.T) { + drv, err := driver.Get("postgres") + require.Nil(t, err) + _, ok := drv.(postgres.Driver) + require.Equal(t, true, ok) +} + +func TestGet_Error(t *testing.T) { + drv, err := driver.Get("foo") + require.Equal(t, "Unknown driver: foo", err.Error()) + require.Nil(t, drv) +} diff --git a/driver/postgres/postgres.go b/driver/postgres/postgres.go index 077552b..1002708 100644 --- a/driver/postgres/postgres.go +++ b/driver/postgres/postgres.go @@ -4,7 +4,7 @@ import ( "database/sql" "fmt" "github.com/adrianmacneil/dbmate/driver/shared" - pq "github.com/lib/pq" + "github.com/lib/pq" "io" "net/url" ) @@ -14,17 +14,16 @@ type Driver struct { } // Open creates a new database connection -func (postgres Driver) Open(u *url.URL) (*sql.DB, error) { +func (drv Driver) Open(u *url.URL) (*sql.DB, error) { return sql.Open("postgres", u.String()) } -// postgresExec runs a sql statement on the "postgres" database -func (postgres Driver) openPostgresDB(u *url.URL) (*sql.DB, error) { +func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) { // connect to postgres database postgresURL := *u postgresURL.Path = "postgres" - return postgres.Open(&postgresURL) + return drv.Open(&postgresURL) } func mustClose(c io.Closer) { @@ -34,11 +33,11 @@ func mustClose(c io.Closer) { } // CreateDatabase creates the specified database -func (postgres Driver) CreateDatabase(u *url.URL) error { +func (drv Driver) CreateDatabase(u *url.URL) error { name := shared.DatabaseName(u) fmt.Printf("Creating: %s\n", name) - db, err := postgres.openPostgresDB(u) + db, err := drv.openPostgresDB(u) if err != nil { return err } @@ -51,11 +50,11 @@ func (postgres Driver) CreateDatabase(u *url.URL) error { } // DropDatabase drops the specified database (if it exists) -func (postgres Driver) DropDatabase(u *url.URL) error { +func (drv Driver) DropDatabase(u *url.URL) error { name := shared.DatabaseName(u) fmt.Printf("Dropping: %s\n", name) - db, err := postgres.openPostgresDB(u) + db, err := drv.openPostgresDB(u) if err != nil { return err } @@ -68,10 +67,10 @@ func (postgres Driver) DropDatabase(u *url.URL) error { } // DatabaseExists determines whether the database exists -func (postgres Driver) DatabaseExists(u *url.URL) (bool, error) { +func (drv Driver) DatabaseExists(u *url.URL) (bool, error) { name := shared.DatabaseName(u) - db, err := postgres.openPostgresDB(u) + db, err := drv.openPostgresDB(u) if err != nil { return false, err } @@ -87,13 +86,8 @@ func (postgres Driver) DatabaseExists(u *url.URL) (bool, error) { return exists, err } -// HasMigrationsTable returns true if the schema_migrations table exists -func (postgres Driver) HasMigrationsTable(db *sql.DB) (bool, error) { - return false, fmt.Errorf("not implemented") -} - // CreateMigrationsTable creates the schema_migrations table -func (postgres Driver) CreateMigrationsTable(db *sql.DB) error { +func (drv Driver) CreateMigrationsTable(db *sql.DB) error { _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations ( version varchar(255) PRIMARY KEY)`) @@ -102,7 +96,7 @@ func (postgres Driver) CreateMigrationsTable(db *sql.DB) error { // SelectMigrations returns a list of applied migrations // with an optional limit (in descending order) -func (postgres Driver) SelectMigrations(db *sql.DB, limit int) (map[string]struct{}, error) { +func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { query := "SELECT version FROM schema_migrations ORDER BY version DESC" if limit >= 0 { query = fmt.Sprintf("%s LIMIT %d", query, limit) @@ -114,28 +108,28 @@ func (postgres Driver) SelectMigrations(db *sql.DB, limit int) (map[string]struc defer mustClose(rows) - migrations := map[string]struct{}{} + migrations := map[string]bool{} for rows.Next() { var version string if err := rows.Scan(&version); err != nil { return nil, err } - migrations[version] = struct{}{} + migrations[version] = true } return migrations, nil } // InsertMigration adds a new migration record -func (postgres Driver) InsertMigration(db shared.Transaction, version string) error { +func (drv Driver) InsertMigration(db shared.Transaction, version string) error { _, err := db.Exec("INSERT INTO schema_migrations (version) VALUES ($1)", version) return err } // DeleteMigration removes a migration record -func (postgres Driver) DeleteMigration(db shared.Transaction, version string) error { +func (drv Driver) DeleteMigration(db shared.Transaction, version string) error { _, err := db.Exec("DELETE FROM schema_migrations WHERE version = $1", version) return err diff --git a/driver/postgres/postgres_test.go b/driver/postgres/postgres_test.go index 2b299f2..cd672c0 100644 --- a/driver/postgres/postgres_test.go +++ b/driver/postgres/postgres_test.go @@ -31,16 +31,35 @@ func mustClose(c io.Closer) { } } -func TestCreateDropDatabase(t *testing.T) { - d := postgres.Driver{} +func prepTestDB(t *testing.T) *sql.DB { + drv := postgres.Driver{} u := testURL(t) // drop any existing database - err := d.DropDatabase(u) + err := drv.DropDatabase(u) require.Nil(t, err) // create database - err = d.CreateDatabase(u) + err = drv.CreateDatabase(u) + require.Nil(t, err) + + // connect database + db, err := sql.Open("postgres", u.String()) + require.Nil(t, err) + + return db +} + +func TestCreateDropDatabase(t *testing.T) { + drv := postgres.Driver{} + u := testURL(t) + + // drop any existing database + err := drv.DropDatabase(u) + require.Nil(t, err) + + // create database + err = drv.CreateDatabase(u) require.Nil(t, err) // check that database exists and we can connect to it @@ -54,7 +73,7 @@ func TestCreateDropDatabase(t *testing.T) { }() // drop the database - err = d.DropDatabase(u) + err = drv.DropDatabase(u) require.Nil(t, err) // check that database no longer exists @@ -70,34 +89,127 @@ func TestCreateDropDatabase(t *testing.T) { } func TestDatabaseExists(t *testing.T) { - d := postgres.Driver{} + drv := postgres.Driver{} u := testURL(t) // drop any existing database - err := d.DropDatabase(u) + err := drv.DropDatabase(u) require.Nil(t, err) // DatabaseExists should return false - exists, err := d.DatabaseExists(u) + exists, err := drv.DatabaseExists(u) require.Nil(t, err) require.Equal(t, false, exists) // create database - err = d.CreateDatabase(u) + err = drv.CreateDatabase(u) require.Nil(t, err) // DatabaseExists should return true - exists, err = d.DatabaseExists(u) + exists, err = drv.DatabaseExists(u) require.Nil(t, err) require.Equal(t, true, exists) } -func TestDatabaseExists_error(t *testing.T) { - d := postgres.Driver{} +func TestDatabaseExists_Error(t *testing.T) { + drv := postgres.Driver{} u := testURL(t) u.User = url.User("invalid") - exists, err := d.DatabaseExists(u) + exists, err := drv.DatabaseExists(u) require.Equal(t, "pq: role \"invalid\" does not exist", err.Error()) require.Equal(t, false, exists) } + +func TestCreateMigrationsTable(t *testing.T) { + drv := postgres.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + // migrations table should not exist + count := 0 + err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Equal(t, "pq: relation \"schema_migrations\" does not exist", err.Error()) + + // create table + err = drv.CreateMigrationsTable(db) + require.Nil(t, err) + + // migrations table should exist + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Nil(t, err) + + // create table should be idempotent + err = drv.CreateMigrationsTable(db) + require.Nil(t, err) +} + +func TestSelectMigrations(t *testing.T) { + drv := postgres.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + err := drv.CreateMigrationsTable(db) + require.Nil(t, err) + + _, err = db.Exec(`insert into schema_migrations (version) + values ('abc2'), ('abc1'), ('abc3')`) + require.Nil(t, err) + + migrations, err := drv.SelectMigrations(db, -1) + require.Nil(t, err) + require.Equal(t, true, migrations["abc1"]) + require.Equal(t, true, migrations["abc2"]) + require.Equal(t, true, migrations["abc2"]) + + // test limit param + migrations, err = drv.SelectMigrations(db, 1) + require.Nil(t, err) + require.Equal(t, true, migrations["abc3"]) + require.Equal(t, false, migrations["abc1"]) + require.Equal(t, false, migrations["abc2"]) +} + +func TestInsertMigration(t *testing.T) { + drv := postgres.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + err := drv.CreateMigrationsTable(db) + require.Nil(t, err) + + count := 0 + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Nil(t, err) + require.Equal(t, 0, count) + + // insert migration + err = drv.InsertMigration(db, "abc1") + require.Nil(t, err) + + err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'"). + Scan(&count) + require.Nil(t, err) + require.Equal(t, 1, count) +} + +func TestDeleteMigration(t *testing.T) { + drv := postgres.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + err := drv.CreateMigrationsTable(db) + require.Nil(t, err) + + _, err = db.Exec(`insert into schema_migrations (version) + values ('abc1'), ('abc2')`) + require.Nil(t, err) + + err = drv.DeleteMigration(db, "abc2") + require.Nil(t, err) + + count := 0 + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Nil(t, err) + require.Equal(t, 1, count) +}