From aa5757f8f9dcd2cb3d081c4fa9643d98f68bf4d5 Mon Sep 17 00:00:00 2001 From: Adrian Macneil Date: Tue, 1 Dec 2015 09:43:35 -0800 Subject: [PATCH] Flatten package layout --- commands.go | 25 +++------ commands_test.go | 38 ++------------ driver.go | 46 ++++++++++++++++ driver/driver.go | 44 ---------------- driver/driver_test.go | 20 ------- driver/shared/shared.go | 21 -------- driver_test.go | 26 ++++++++++ driver/mysql/mysql.go => mysql.go | 42 ++++++--------- driver/mysql/mysql_test.go => mysql_test.go | 52 +++++++++---------- driver/postgres/postgres.go => postgres.go | 38 ++++++-------- .../postgres_test.go => postgres_test.go | 52 +++++++++---------- utils.go | 22 ++++++++ driver/shared/shared_test.go => utils_test.go | 6 +-- 13 files changed, 193 insertions(+), 239 deletions(-) create mode 100644 driver.go delete mode 100644 driver/driver.go delete mode 100644 driver/driver_test.go delete mode 100644 driver/shared/shared.go create mode 100644 driver_test.go rename driver/mysql/mysql.go => mysql.go (72%) rename driver/mysql/mysql_test.go => mysql_test.go (82%) rename driver/postgres/postgres.go => postgres.go (70%) rename driver/postgres/postgres_test.go => postgres_test.go (81%) create mode 100644 utils.go rename driver/shared/shared_test.go => utils_test.go (84%) diff --git a/commands.go b/commands.go index c4f9a48..bebd230 100644 --- a/commands.go +++ b/commands.go @@ -3,10 +3,7 @@ package main import ( "database/sql" "fmt" - "github.com/adrianmacneil/dbmate/driver" - "github.com/adrianmacneil/dbmate/driver/shared" "github.com/codegangsta/cli" - "io" "io/ioutil" "net/url" "os" @@ -22,7 +19,7 @@ func UpCommand(ctx *cli.Context) error { return err } - drv, err := driver.Get(u.Scheme) + drv, err := GetDriver(u.Scheme) if err != nil { return err } @@ -48,7 +45,7 @@ func CreateCommand(ctx *cli.Context) error { return err } - drv, err := driver.Get(u.Scheme) + drv, err := GetDriver(u.Scheme) if err != nil { return err } @@ -63,7 +60,7 @@ func DropCommand(ctx *cli.Context) error { return err } - drv, err := driver.Get(u.Scheme) + drv, err := GetDriver(u.Scheme) if err != nil { return err } @@ -73,12 +70,6 @@ func DropCommand(ctx *cli.Context) error { const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n" -func mustClose(c io.Closer) { - if err := c.Close(); err != nil { - panic(err) - } -} - // NewCommand creates a new migration file func NewCommand(ctx *cli.Context) error { // new migration name @@ -126,7 +117,7 @@ func GetDatabaseURL(ctx *cli.Context) (u *url.URL, err error) { return url.Parse(value) } -func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error { +func doTransaction(db *sql.DB, txFunc func(Transaction) error) error { tx, err := db.Begin() if err != nil { return err @@ -143,13 +134,13 @@ func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error { return tx.Commit() } -func openDatabaseForMigration(ctx *cli.Context) (driver.Driver, *sql.DB, error) { +func openDatabaseForMigration(ctx *cli.Context) (Driver, *sql.DB, error) { u, err := GetDatabaseURL(ctx) if err != nil { return nil, nil, err } - drv, err := driver.Get(u.Scheme) + drv, err := GetDriver(u.Scheme) if err != nil { return nil, nil, err } @@ -205,7 +196,7 @@ func MigrateCommand(ctx *cli.Context) error { } // begin transaction - err = doTransaction(db, func(tx shared.Transaction) error { + err = doTransaction(db, func(tx Transaction) error { // run actual migration if _, err := tx.Exec(migration["up"]); err != nil { return err @@ -364,7 +355,7 @@ func RollbackCommand(ctx *cli.Context) error { } // begin transaction - err = doTransaction(db, func(tx shared.Transaction) error { + err = doTransaction(db, func(tx Transaction) error { // rollback migration if _, err := tx.Exec(migration["down"]); err != nil { return err diff --git a/commands_test.go b/commands_test.go index 3dd6fb5..508d4ae 100644 --- a/commands_test.go +++ b/commands_test.go @@ -2,7 +2,6 @@ package main import ( "flag" - "github.com/adrianmacneil/dbmate/driver" "github.com/codegangsta/cli" "github.com/stretchr/testify/require" "net/url" @@ -35,39 +34,10 @@ func testContext(t *testing.T, u *url.URL) *cli.Context { return cli.NewContext(app, flagset, nil) } -func postgresTestURL(t *testing.T) *url.URL { - str := os.Getenv("POSTGRES_PORT") - require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable") - - u, err := url.Parse(str) - require.Nil(t, err) - - u.Scheme = "postgres" - u.User = url.User("postgres") - u.Path = "/dbmate" - u.RawQuery = "sslmode=disable" - - return u -} - -func mysqlTestURL(t *testing.T) *url.URL { - str := os.Getenv("MYSQL_PORT") - require.NotEmpty(t, str, "missing MYSQL_PORT environment variable") - - u, err := url.Parse(str) - require.Nil(t, err) - - u.Scheme = "mysql" - u.User = url.UserPassword("root", "root") - u.Path = "/dbmate" - - return u -} - func testURLs(t *testing.T) []*url.URL { return []*url.URL{ postgresTestURL(t), - mysqlTestURL(t), + mySQLTestURL(t), } } @@ -98,7 +68,7 @@ func testMigrateCommandURL(t *testing.T, u *url.URL) { require.Nil(t, err) // verify results - db, err := driver.Open(u) + db, err := GetDriverOpen(u) require.Nil(t, err) defer mustClose(db) @@ -131,7 +101,7 @@ func testUpCommandURL(t *testing.T, u *url.URL) { require.Nil(t, err) // verify results - db, err := driver.Open(u) + db, err := GetDriverOpen(u) require.Nil(t, err) defer mustClose(db) @@ -164,7 +134,7 @@ func testRollbackCommandURL(t *testing.T, u *url.URL) { require.Nil(t, err) // verify migration - db, err := driver.Open(u) + db, err := GetDriverOpen(u) require.Nil(t, err) defer mustClose(db) diff --git a/driver.go b/driver.go new file mode 100644 index 0000000..ad3d619 --- /dev/null +++ b/driver.go @@ -0,0 +1,46 @@ +package main + +import ( + "database/sql" + "fmt" + "net/url" +) + +// Driver provides top level database functions +type Driver interface { + Open(*url.URL) (*sql.DB, error) + DatabaseExists(*url.URL) (bool, error) + CreateDatabase(*url.URL) error + DropDatabase(*url.URL) error + CreateMigrationsTable(*sql.DB) error + SelectMigrations(*sql.DB, int) (map[string]bool, error) + InsertMigration(Transaction, string) error + DeleteMigration(Transaction, string) error +} + +// Transaction can represent a database or open transaction +type Transaction interface { + Exec(query string, args ...interface{}) (sql.Result, error) +} + +// GetDriver loads a database driver by name +func GetDriver(name string) (Driver, error) { + switch name { + case "mysql": + return MySQLDriver{}, nil + case "postgres": + return PostgresDriver{}, nil + default: + return nil, fmt.Errorf("Unknown driver: %s", name) + } +} + +// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u) +func GetDriverOpen(u *url.URL) (*sql.DB, error) { + drv, err := GetDriver(u.Scheme) + if err != nil { + return nil, err + } + + return drv.Open(u) +} diff --git a/driver/driver.go b/driver/driver.go deleted file mode 100644 index ef69446..0000000 --- a/driver/driver.go +++ /dev/null @@ -1,44 +0,0 @@ -package driver - -import ( - "database/sql" - "fmt" - "github.com/adrianmacneil/dbmate/driver/mysql" - "github.com/adrianmacneil/dbmate/driver/postgres" - "github.com/adrianmacneil/dbmate/driver/shared" - "net/url" -) - -// Driver provides top level database functions -type Driver interface { - Open(*url.URL) (*sql.DB, error) - DatabaseExists(*url.URL) (bool, error) - CreateDatabase(*url.URL) error - DropDatabase(*url.URL) error - CreateMigrationsTable(*sql.DB) error - SelectMigrations(*sql.DB, int) (map[string]bool, error) - InsertMigration(shared.Transaction, string) error - DeleteMigration(shared.Transaction, string) error -} - -// Get loads a database driver by name -func Get(name string) (Driver, error) { - switch name { - case "mysql": - return mysql.Driver{}, nil - case "postgres": - return postgres.Driver{}, nil - default: - return nil, fmt.Errorf("Unknown driver: %s", name) - } -} - -// Open is a shortcut for driver.Get(u.Scheme).Open(u) -func Open(u *url.URL) (*sql.DB, error) { - drv, err := Get(u.Scheme) - if err != nil { - return nil, err - } - - return drv.Open(u) -} diff --git a/driver/driver_test.go b/driver/driver_test.go deleted file mode 100644 index d4e43fd..0000000 --- a/driver/driver_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package driver - -import ( - "github.com/adrianmacneil/dbmate/driver/postgres" - "github.com/stretchr/testify/require" - "testing" -) - -func TestGet_Postgres(t *testing.T) { - drv, err := Get("postgres") - require.Nil(t, err) - _, ok := drv.(postgres.Driver) - require.Equal(t, true, ok) -} - -func TestGet_Error(t *testing.T) { - drv, err := Get("foo") - require.Equal(t, "Unknown driver: foo", err.Error()) - require.Nil(t, drv) -} diff --git a/driver/shared/shared.go b/driver/shared/shared.go deleted file mode 100644 index d0a9c40..0000000 --- a/driver/shared/shared.go +++ /dev/null @@ -1,21 +0,0 @@ -package shared - -import ( - "database/sql" - "net/url" -) - -// DatabaseName returns the database name from a URL -func DatabaseName(u *url.URL) string { - name := u.Path - if len(name) > 0 && name[:1] == "/" { - name = name[1:len(name)] - } - - return name -} - -// Transaction can represent a database or open transaction -type Transaction interface { - Exec(query string, args ...interface{}) (sql.Result, error) -} diff --git a/driver_test.go b/driver_test.go new file mode 100644 index 0000000..c3519b5 --- /dev/null +++ b/driver_test.go @@ -0,0 +1,26 @@ +package main + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestGetDriver_Postgres(t *testing.T) { + drv, err := GetDriver("postgres") + require.Nil(t, err) + _, ok := drv.(PostgresDriver) + require.Equal(t, true, ok) +} + +func TestGetDriver_MySQL(t *testing.T) { + drv, err := GetDriver("mysql") + require.Nil(t, err) + _, ok := drv.(MySQLDriver) + require.Equal(t, true, ok) +} + +func TestGetDriver_Error(t *testing.T) { + drv, err := GetDriver("foo") + require.Equal(t, "Unknown driver: foo", err.Error()) + require.Nil(t, drv) +} diff --git a/driver/mysql/mysql.go b/mysql.go similarity index 72% rename from driver/mysql/mysql.go rename to mysql.go index adfce26..6545851 100644 --- a/driver/mysql/mysql.go +++ b/mysql.go @@ -1,20 +1,18 @@ -package mysql +package main import ( "database/sql" "fmt" - "github.com/adrianmacneil/dbmate/driver/shared" _ "github.com/adrianmacneil/go-mysql" // mysql driver - "io" "net/url" "strings" ) -// Driver provides top level database functions -type Driver struct { +// MySQLDriver provides top level database functions +type MySQLDriver struct { } -func normalizeURL(u *url.URL) string { +func normalizeMySQLURL(u *url.URL) string { normalizedURL := *u normalizedURL.Scheme = "" normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host) @@ -28,11 +26,11 @@ func normalizeURL(u *url.URL) string { } // Open creates a new database connection -func (drv Driver) Open(u *url.URL) (*sql.DB, error) { - return sql.Open("mysql", normalizeURL(u)) +func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) { + return sql.Open("mysql", normalizeMySQLURL(u)) } -func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) { +func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) { // connect to no particular database rootURL := *u rootURL.Path = "/" @@ -40,12 +38,6 @@ func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) { return drv.Open(&rootURL) } -func mustClose(c io.Closer) { - if err := c.Close(); err != nil { - panic(err) - } -} - func quoteIdentifier(str string) string { str = strings.Replace(str, "`", "\\`", -1) @@ -53,8 +45,8 @@ func quoteIdentifier(str string) string { } // CreateDatabase creates the specified database -func (drv Driver) CreateDatabase(u *url.URL) error { - name := shared.DatabaseName(u) +func (drv MySQLDriver) CreateDatabase(u *url.URL) error { + name := databaseName(u) fmt.Printf("Creating: %s\n", name) db, err := drv.openRootDB(u) @@ -70,8 +62,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error { } // DropDatabase drops the specified database (if it exists) -func (drv Driver) DropDatabase(u *url.URL) error { - name := shared.DatabaseName(u) +func (drv MySQLDriver) DropDatabase(u *url.URL) error { + name := databaseName(u) fmt.Printf("Dropping: %s\n", name) db, err := drv.openRootDB(u) @@ -87,8 +79,8 @@ func (drv Driver) DropDatabase(u *url.URL) error { } // DatabaseExists determines whether the database exists -func (drv Driver) DatabaseExists(u *url.URL) (bool, error) { - name := shared.DatabaseName(u) +func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) { + name := databaseName(u) db, err := drv.openRootDB(u) if err != nil { @@ -107,7 +99,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) { } // CreateMigrationsTable creates the schema_migrations table -func (drv Driver) CreateMigrationsTable(db *sql.DB) error { +func (drv MySQLDriver) CreateMigrationsTable(db *sql.DB) error { _, err := db.Exec(`create table if not exists schema_migrations ( version varchar(255) primary key)`) @@ -116,7 +108,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error { // SelectMigrations returns a list of applied migrations // with an optional limit (in descending order) -func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { +func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { query := "select version from schema_migrations order by version desc" if limit >= 0 { query = fmt.Sprintf("%s limit %d", query, limit) @@ -142,14 +134,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro } // InsertMigration adds a new migration record -func (drv Driver) InsertMigration(db shared.Transaction, version string) error { +func (drv MySQLDriver) InsertMigration(db Transaction, version string) error { _, err := db.Exec("insert into schema_migrations (version) values (?)", version) return err } // DeleteMigration removes a migration record -func (drv Driver) DeleteMigration(db shared.Transaction, version string) error { +func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error { _, err := db.Exec("delete from schema_migrations where version = ?", version) return err diff --git a/driver/mysql/mysql_test.go b/mysql_test.go similarity index 82% rename from driver/mysql/mysql_test.go rename to mysql_test.go index 7cf9cef..4c5ac55 100644 --- a/driver/mysql/mysql_test.go +++ b/mysql_test.go @@ -1,4 +1,4 @@ -package mysql +package main import ( "database/sql" @@ -8,7 +8,7 @@ import ( "testing" ) -func testURL(t *testing.T) *url.URL { +func mySQLTestURL(t *testing.T) *url.URL { str := os.Getenv("MYSQL_PORT") require.NotEmpty(t, str, "missing MYSQL_PORT environment variable") @@ -22,9 +22,9 @@ func testURL(t *testing.T) *url.URL { return u } -func prepTestDB(t *testing.T) *sql.DB { - drv := Driver{} - u := testURL(t) +func prepTestMySQLDB(t *testing.T) *sql.DB { + drv := MySQLDriver{} + u := mySQLTestURL(t) // drop any existing database err := drv.DropDatabase(u) @@ -41,9 +41,9 @@ func prepTestDB(t *testing.T) *sql.DB { return db } -func TestCreateDropDatabase(t *testing.T) { - drv := Driver{} - u := testURL(t) +func TestMySQLCreateDropDatabase(t *testing.T) { + drv := MySQLDriver{} + u := mySQLTestURL(t) // drop any existing database err := drv.DropDatabase(u) @@ -79,9 +79,9 @@ func TestCreateDropDatabase(t *testing.T) { }() } -func TestDatabaseExists(t *testing.T) { - drv := Driver{} - u := testURL(t) +func TestMySQLDatabaseExists(t *testing.T) { + drv := MySQLDriver{} + u := mySQLTestURL(t) // drop any existing database err := drv.DropDatabase(u) @@ -102,9 +102,9 @@ func TestDatabaseExists(t *testing.T) { require.Equal(t, true, exists) } -func TestDatabaseExists_Error(t *testing.T) { - drv := Driver{} - u := testURL(t) +func TestMySQLDatabaseExists_Error(t *testing.T) { + drv := MySQLDriver{} + u := mySQLTestURL(t) u.User = url.User("invalid") exists, err := drv.DatabaseExists(u) @@ -112,9 +112,9 @@ func TestDatabaseExists_Error(t *testing.T) { require.Equal(t, false, exists) } -func TestCreateMigrationsTable(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestMySQLCreateMigrationsTable(t *testing.T) { + drv := MySQLDriver{} + db := prepTestMySQLDB(t) defer mustClose(db) // migrations table should not exist @@ -135,9 +135,9 @@ func TestCreateMigrationsTable(t *testing.T) { require.Nil(t, err) } -func TestSelectMigrations(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestMySQLSelectMigrations(t *testing.T) { + drv := MySQLDriver{} + db := prepTestMySQLDB(t) defer mustClose(db) err := drv.CreateMigrationsTable(db) @@ -161,9 +161,9 @@ func TestSelectMigrations(t *testing.T) { require.Equal(t, false, migrations["abc2"]) } -func TestInsertMigration(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestMySQLInsertMigration(t *testing.T) { + drv := MySQLDriver{} + db := prepTestMySQLDB(t) defer mustClose(db) err := drv.CreateMigrationsTable(db) @@ -184,9 +184,9 @@ func TestInsertMigration(t *testing.T) { require.Equal(t, 1, count) } -func TestDeleteMigration(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestMySQLDeleteMigration(t *testing.T) { + drv := MySQLDriver{} + db := prepTestMySQLDB(t) defer mustClose(db) err := drv.CreateMigrationsTable(db) diff --git a/driver/postgres/postgres.go b/postgres.go similarity index 70% rename from driver/postgres/postgres.go rename to postgres.go index 139eb2b..901bd0d 100644 --- a/driver/postgres/postgres.go +++ b/postgres.go @@ -1,24 +1,22 @@ -package postgres +package main import ( "database/sql" "fmt" - "github.com/adrianmacneil/dbmate/driver/shared" "github.com/lib/pq" - "io" "net/url" ) -// Driver provides top level database functions -type Driver struct { +// PostgresDriver provides top level database functions +type PostgresDriver struct { } // Open creates a new database connection -func (drv Driver) Open(u *url.URL) (*sql.DB, error) { +func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) { return sql.Open("postgres", u.String()) } -func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) { +func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) { // connect to postgres database postgresURL := *u postgresURL.Path = "postgres" @@ -26,15 +24,9 @@ func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) { return drv.Open(&postgresURL) } -func mustClose(c io.Closer) { - if err := c.Close(); err != nil { - panic(err) - } -} - // CreateDatabase creates the specified database -func (drv Driver) CreateDatabase(u *url.URL) error { - name := shared.DatabaseName(u) +func (drv PostgresDriver) CreateDatabase(u *url.URL) error { + name := databaseName(u) fmt.Printf("Creating: %s\n", name) db, err := drv.openPostgresDB(u) @@ -50,8 +42,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error { } // DropDatabase drops the specified database (if it exists) -func (drv Driver) DropDatabase(u *url.URL) error { - name := shared.DatabaseName(u) +func (drv PostgresDriver) DropDatabase(u *url.URL) error { + name := databaseName(u) fmt.Printf("Dropping: %s\n", name) db, err := drv.openPostgresDB(u) @@ -67,8 +59,8 @@ func (drv Driver) DropDatabase(u *url.URL) error { } // DatabaseExists determines whether the database exists -func (drv Driver) DatabaseExists(u *url.URL) (bool, error) { - name := shared.DatabaseName(u) +func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) { + name := databaseName(u) db, err := drv.openPostgresDB(u) if err != nil { @@ -87,7 +79,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) { } // CreateMigrationsTable creates the schema_migrations table -func (drv Driver) CreateMigrationsTable(db *sql.DB) error { +func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error { _, err := db.Exec(`create table if not exists schema_migrations ( version varchar(255) primary key)`) @@ -96,7 +88,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error { // SelectMigrations returns a list of applied migrations // with an optional limit (in descending order) -func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { +func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { query := "select version from schema_migrations order by version desc" if limit >= 0 { query = fmt.Sprintf("%s limit %d", query, limit) @@ -122,14 +114,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro } // InsertMigration adds a new migration record -func (drv Driver) InsertMigration(db shared.Transaction, version string) error { +func (drv PostgresDriver) InsertMigration(db Transaction, version string) error { _, err := db.Exec("insert into schema_migrations (version) values ($1)", version) return err } // DeleteMigration removes a migration record -func (drv Driver) DeleteMigration(db shared.Transaction, version string) error { +func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error { _, err := db.Exec("delete from schema_migrations where version = $1", version) return err diff --git a/driver/postgres/postgres_test.go b/postgres_test.go similarity index 81% rename from driver/postgres/postgres_test.go rename to postgres_test.go index 2d29fa3..05cd1a6 100644 --- a/driver/postgres/postgres_test.go +++ b/postgres_test.go @@ -1,4 +1,4 @@ -package postgres +package main import ( "database/sql" @@ -8,7 +8,7 @@ import ( "testing" ) -func testURL(t *testing.T) *url.URL { +func postgresTestURL(t *testing.T) *url.URL { str := os.Getenv("POSTGRES_PORT") require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable") @@ -23,9 +23,9 @@ func testURL(t *testing.T) *url.URL { return u } -func prepTestDB(t *testing.T) *sql.DB { - drv := Driver{} - u := testURL(t) +func prepTestPostgresDB(t *testing.T) *sql.DB { + drv := PostgresDriver{} + u := postgresTestURL(t) // drop any existing database err := drv.DropDatabase(u) @@ -42,9 +42,9 @@ func prepTestDB(t *testing.T) *sql.DB { return db } -func TestCreateDropDatabase(t *testing.T) { - drv := Driver{} - u := testURL(t) +func TestPostgresCreateDropDatabase(t *testing.T) { + drv := PostgresDriver{} + u := postgresTestURL(t) // drop any existing database err := drv.DropDatabase(u) @@ -80,9 +80,9 @@ func TestCreateDropDatabase(t *testing.T) { }() } -func TestDatabaseExists(t *testing.T) { - drv := Driver{} - u := testURL(t) +func TestPostgresDatabaseExists(t *testing.T) { + drv := PostgresDriver{} + u := postgresTestURL(t) // drop any existing database err := drv.DropDatabase(u) @@ -103,9 +103,9 @@ func TestDatabaseExists(t *testing.T) { require.Equal(t, true, exists) } -func TestDatabaseExists_Error(t *testing.T) { - drv := Driver{} - u := testURL(t) +func TestPostgresDatabaseExists_Error(t *testing.T) { + drv := PostgresDriver{} + u := postgresTestURL(t) u.User = url.User("invalid") exists, err := drv.DatabaseExists(u) @@ -113,9 +113,9 @@ func TestDatabaseExists_Error(t *testing.T) { require.Equal(t, false, exists) } -func TestCreateMigrationsTable(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestPostgresCreateMigrationsTable(t *testing.T) { + drv := PostgresDriver{} + db := prepTestPostgresDB(t) defer mustClose(db) // migrations table should not exist @@ -136,9 +136,9 @@ func TestCreateMigrationsTable(t *testing.T) { require.Nil(t, err) } -func TestSelectMigrations(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestPostgresSelectMigrations(t *testing.T) { + drv := PostgresDriver{} + db := prepTestPostgresDB(t) defer mustClose(db) err := drv.CreateMigrationsTable(db) @@ -162,9 +162,9 @@ func TestSelectMigrations(t *testing.T) { require.Equal(t, false, migrations["abc2"]) } -func TestInsertMigration(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestPostgresInsertMigration(t *testing.T) { + drv := PostgresDriver{} + db := prepTestPostgresDB(t) defer mustClose(db) err := drv.CreateMigrationsTable(db) @@ -185,9 +185,9 @@ func TestInsertMigration(t *testing.T) { require.Equal(t, 1, count) } -func TestDeleteMigration(t *testing.T) { - drv := Driver{} - db := prepTestDB(t) +func TestPostgresDeleteMigration(t *testing.T) { + drv := PostgresDriver{} + db := prepTestPostgresDB(t) defer mustClose(db) err := drv.CreateMigrationsTable(db) diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..a5cbe20 --- /dev/null +++ b/utils.go @@ -0,0 +1,22 @@ +package main + +import ( + "io" + "net/url" +) + +// databaseName returns the database name from a URL +func databaseName(u *url.URL) string { + name := u.Path + if len(name) > 0 && name[:1] == "/" { + name = name[1:len(name)] + } + + return name +} + +func mustClose(c io.Closer) { + if err := c.Close(); err != nil { + panic(err) + } +} diff --git a/driver/shared/shared_test.go b/utils_test.go similarity index 84% rename from driver/shared/shared_test.go rename to utils_test.go index f0e7a1a..8d6a075 100644 --- a/driver/shared/shared_test.go +++ b/utils_test.go @@ -1,4 +1,4 @@ -package shared +package main import ( "github.com/stretchr/testify/require" @@ -10,7 +10,7 @@ func TestDatabaseName(t *testing.T) { u, err := url.Parse("ignore://localhost/foo?query") require.Nil(t, err) - name := DatabaseName(u) + name := databaseName(u) require.Equal(t, "foo", name) } @@ -18,6 +18,6 @@ func TestDatabaseName_Empty(t *testing.T) { u, err := url.Parse("ignore://localhost") require.Nil(t, err) - name := DatabaseName(u) + name := databaseName(u) require.Equal(t, "", name) }