Flatten package layout

This commit is contained in:
Adrian Macneil 2015-12-01 09:43:35 -08:00
parent 275f5791f4
commit aa5757f8f9
13 changed files with 193 additions and 239 deletions

View file

@ -3,10 +3,7 @@ package main
import (
"database/sql"
"fmt"
"github.com/adrianmacneil/dbmate/driver"
"github.com/adrianmacneil/dbmate/driver/shared"
"github.com/codegangsta/cli"
"io"
"io/ioutil"
"net/url"
"os"
@ -22,7 +19,7 @@ func UpCommand(ctx *cli.Context) error {
return err
}
drv, err := driver.Get(u.Scheme)
drv, err := GetDriver(u.Scheme)
if err != nil {
return err
}
@ -48,7 +45,7 @@ func CreateCommand(ctx *cli.Context) error {
return err
}
drv, err := driver.Get(u.Scheme)
drv, err := GetDriver(u.Scheme)
if err != nil {
return err
}
@ -63,7 +60,7 @@ func DropCommand(ctx *cli.Context) error {
return err
}
drv, err := driver.Get(u.Scheme)
drv, err := GetDriver(u.Scheme)
if err != nil {
return err
}
@ -73,12 +70,6 @@ func DropCommand(ctx *cli.Context) error {
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
func mustClose(c io.Closer) {
if err := c.Close(); err != nil {
panic(err)
}
}
// NewCommand creates a new migration file
func NewCommand(ctx *cli.Context) error {
// new migration name
@ -126,7 +117,7 @@ func GetDatabaseURL(ctx *cli.Context) (u *url.URL, err error) {
return url.Parse(value)
}
func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
tx, err := db.Begin()
if err != nil {
return err
@ -143,13 +134,13 @@ func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
return tx.Commit()
}
func openDatabaseForMigration(ctx *cli.Context) (driver.Driver, *sql.DB, error) {
func openDatabaseForMigration(ctx *cli.Context) (Driver, *sql.DB, error) {
u, err := GetDatabaseURL(ctx)
if err != nil {
return nil, nil, err
}
drv, err := driver.Get(u.Scheme)
drv, err := GetDriver(u.Scheme)
if err != nil {
return nil, nil, err
}
@ -205,7 +196,7 @@ func MigrateCommand(ctx *cli.Context) error {
}
// begin transaction
err = doTransaction(db, func(tx shared.Transaction) error {
err = doTransaction(db, func(tx Transaction) error {
// run actual migration
if _, err := tx.Exec(migration["up"]); err != nil {
return err
@ -364,7 +355,7 @@ func RollbackCommand(ctx *cli.Context) error {
}
// begin transaction
err = doTransaction(db, func(tx shared.Transaction) error {
err = doTransaction(db, func(tx Transaction) error {
// rollback migration
if _, err := tx.Exec(migration["down"]); err != nil {
return err

View file

@ -2,7 +2,6 @@ package main
import (
"flag"
"github.com/adrianmacneil/dbmate/driver"
"github.com/codegangsta/cli"
"github.com/stretchr/testify/require"
"net/url"
@ -35,39 +34,10 @@ func testContext(t *testing.T, u *url.URL) *cli.Context {
return cli.NewContext(app, flagset, nil)
}
func postgresTestURL(t *testing.T) *url.URL {
str := os.Getenv("POSTGRES_PORT")
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
u, err := url.Parse(str)
require.Nil(t, err)
u.Scheme = "postgres"
u.User = url.User("postgres")
u.Path = "/dbmate"
u.RawQuery = "sslmode=disable"
return u
}
func mysqlTestURL(t *testing.T) *url.URL {
str := os.Getenv("MYSQL_PORT")
require.NotEmpty(t, str, "missing MYSQL_PORT environment variable")
u, err := url.Parse(str)
require.Nil(t, err)
u.Scheme = "mysql"
u.User = url.UserPassword("root", "root")
u.Path = "/dbmate"
return u
}
func testURLs(t *testing.T) []*url.URL {
return []*url.URL{
postgresTestURL(t),
mysqlTestURL(t),
mySQLTestURL(t),
}
}
@ -98,7 +68,7 @@ func testMigrateCommandURL(t *testing.T, u *url.URL) {
require.Nil(t, err)
// verify results
db, err := driver.Open(u)
db, err := GetDriverOpen(u)
require.Nil(t, err)
defer mustClose(db)
@ -131,7 +101,7 @@ func testUpCommandURL(t *testing.T, u *url.URL) {
require.Nil(t, err)
// verify results
db, err := driver.Open(u)
db, err := GetDriverOpen(u)
require.Nil(t, err)
defer mustClose(db)
@ -164,7 +134,7 @@ func testRollbackCommandURL(t *testing.T, u *url.URL) {
require.Nil(t, err)
// verify migration
db, err := driver.Open(u)
db, err := GetDriverOpen(u)
require.Nil(t, err)
defer mustClose(db)

46
driver.go Normal file
View file

@ -0,0 +1,46 @@
package main
import (
"database/sql"
"fmt"
"net/url"
)
// Driver provides top level database functions
type Driver interface {
Open(*url.URL) (*sql.DB, error)
DatabaseExists(*url.URL) (bool, error)
CreateDatabase(*url.URL) error
DropDatabase(*url.URL) error
CreateMigrationsTable(*sql.DB) error
SelectMigrations(*sql.DB, int) (map[string]bool, error)
InsertMigration(Transaction, string) error
DeleteMigration(Transaction, string) error
}
// Transaction can represent a database or open transaction
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
// GetDriver loads a database driver by name
func GetDriver(name string) (Driver, error) {
switch name {
case "mysql":
return MySQLDriver{}, nil
case "postgres":
return PostgresDriver{}, nil
default:
return nil, fmt.Errorf("Unknown driver: %s", name)
}
}
// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
func GetDriverOpen(u *url.URL) (*sql.DB, error) {
drv, err := GetDriver(u.Scheme)
if err != nil {
return nil, err
}
return drv.Open(u)
}

View file

@ -1,44 +0,0 @@
package driver
import (
"database/sql"
"fmt"
"github.com/adrianmacneil/dbmate/driver/mysql"
"github.com/adrianmacneil/dbmate/driver/postgres"
"github.com/adrianmacneil/dbmate/driver/shared"
"net/url"
)
// Driver provides top level database functions
type Driver interface {
Open(*url.URL) (*sql.DB, error)
DatabaseExists(*url.URL) (bool, error)
CreateDatabase(*url.URL) error
DropDatabase(*url.URL) error
CreateMigrationsTable(*sql.DB) error
SelectMigrations(*sql.DB, int) (map[string]bool, error)
InsertMigration(shared.Transaction, string) error
DeleteMigration(shared.Transaction, string) error
}
// Get loads a database driver by name
func Get(name string) (Driver, error) {
switch name {
case "mysql":
return mysql.Driver{}, nil
case "postgres":
return postgres.Driver{}, nil
default:
return nil, fmt.Errorf("Unknown driver: %s", name)
}
}
// Open is a shortcut for driver.Get(u.Scheme).Open(u)
func Open(u *url.URL) (*sql.DB, error) {
drv, err := Get(u.Scheme)
if err != nil {
return nil, err
}
return drv.Open(u)
}

View file

@ -1,20 +0,0 @@
package driver
import (
"github.com/adrianmacneil/dbmate/driver/postgres"
"github.com/stretchr/testify/require"
"testing"
)
func TestGet_Postgres(t *testing.T) {
drv, err := Get("postgres")
require.Nil(t, err)
_, ok := drv.(postgres.Driver)
require.Equal(t, true, ok)
}
func TestGet_Error(t *testing.T) {
drv, err := Get("foo")
require.Equal(t, "Unknown driver: foo", err.Error())
require.Nil(t, drv)
}

View file

@ -1,21 +0,0 @@
package shared
import (
"database/sql"
"net/url"
)
// DatabaseName returns the database name from a URL
func DatabaseName(u *url.URL) string {
name := u.Path
if len(name) > 0 && name[:1] == "/" {
name = name[1:len(name)]
}
return name
}
// Transaction can represent a database or open transaction
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}

26
driver_test.go Normal file
View file

@ -0,0 +1,26 @@
package main
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestGetDriver_Postgres(t *testing.T) {
drv, err := GetDriver("postgres")
require.Nil(t, err)
_, ok := drv.(PostgresDriver)
require.Equal(t, true, ok)
}
func TestGetDriver_MySQL(t *testing.T) {
drv, err := GetDriver("mysql")
require.Nil(t, err)
_, ok := drv.(MySQLDriver)
require.Equal(t, true, ok)
}
func TestGetDriver_Error(t *testing.T) {
drv, err := GetDriver("foo")
require.Equal(t, "Unknown driver: foo", err.Error())
require.Nil(t, drv)
}

View file

@ -1,20 +1,18 @@
package mysql
package main
import (
"database/sql"
"fmt"
"github.com/adrianmacneil/dbmate/driver/shared"
_ "github.com/adrianmacneil/go-mysql" // mysql driver
"io"
"net/url"
"strings"
)
// Driver provides top level database functions
type Driver struct {
// MySQLDriver provides top level database functions
type MySQLDriver struct {
}
func normalizeURL(u *url.URL) string {
func normalizeMySQLURL(u *url.URL) string {
normalizedURL := *u
normalizedURL.Scheme = ""
normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host)
@ -28,11 +26,11 @@ func normalizeURL(u *url.URL) string {
}
// Open creates a new database connection
func (drv Driver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("mysql", normalizeURL(u))
func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("mysql", normalizeMySQLURL(u))
}
func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) {
func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
// connect to no particular database
rootURL := *u
rootURL.Path = "/"
@ -40,12 +38,6 @@ func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) {
return drv.Open(&rootURL)
}
func mustClose(c io.Closer) {
if err := c.Close(); err != nil {
panic(err)
}
}
func quoteIdentifier(str string) string {
str = strings.Replace(str, "`", "\\`", -1)
@ -53,8 +45,8 @@ func quoteIdentifier(str string) string {
}
// CreateDatabase creates the specified database
func (drv Driver) CreateDatabase(u *url.URL) error {
name := shared.DatabaseName(u)
func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Creating: %s\n", name)
db, err := drv.openRootDB(u)
@ -70,8 +62,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error {
}
// DropDatabase drops the specified database (if it exists)
func (drv Driver) DropDatabase(u *url.URL) error {
name := shared.DatabaseName(u)
func (drv MySQLDriver) DropDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Dropping: %s\n", name)
db, err := drv.openRootDB(u)
@ -87,8 +79,8 @@ func (drv Driver) DropDatabase(u *url.URL) error {
}
// DatabaseExists determines whether the database exists
func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
name := shared.DatabaseName(u)
func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u)
db, err := drv.openRootDB(u)
if err != nil {
@ -107,7 +99,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema_migrations table
func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
func (drv MySQLDriver) CreateMigrationsTable(db *sql.DB) error {
_, err := db.Exec(`create table if not exists schema_migrations (
version varchar(255) primary key)`)
@ -116,7 +108,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations order by version desc"
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
@ -142,14 +134,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro
}
// InsertMigration adds a new migration record
func (drv Driver) InsertMigration(db shared.Transaction, version string) error {
func (drv MySQLDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
return err
}
// DeleteMigration removes a migration record
func (drv Driver) DeleteMigration(db shared.Transaction, version string) error {
func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec("delete from schema_migrations where version = ?", version)
return err

View file

@ -1,4 +1,4 @@
package mysql
package main
import (
"database/sql"
@ -8,7 +8,7 @@ import (
"testing"
)
func testURL(t *testing.T) *url.URL {
func mySQLTestURL(t *testing.T) *url.URL {
str := os.Getenv("MYSQL_PORT")
require.NotEmpty(t, str, "missing MYSQL_PORT environment variable")
@ -22,9 +22,9 @@ func testURL(t *testing.T) *url.URL {
return u
}
func prepTestDB(t *testing.T) *sql.DB {
drv := Driver{}
u := testURL(t)
func prepTestMySQLDB(t *testing.T) *sql.DB {
drv := MySQLDriver{}
u := mySQLTestURL(t)
// drop any existing database
err := drv.DropDatabase(u)
@ -41,9 +41,9 @@ func prepTestDB(t *testing.T) *sql.DB {
return db
}
func TestCreateDropDatabase(t *testing.T) {
drv := Driver{}
u := testURL(t)
func TestMySQLCreateDropDatabase(t *testing.T) {
drv := MySQLDriver{}
u := mySQLTestURL(t)
// drop any existing database
err := drv.DropDatabase(u)
@ -79,9 +79,9 @@ func TestCreateDropDatabase(t *testing.T) {
}()
}
func TestDatabaseExists(t *testing.T) {
drv := Driver{}
u := testURL(t)
func TestMySQLDatabaseExists(t *testing.T) {
drv := MySQLDriver{}
u := mySQLTestURL(t)
// drop any existing database
err := drv.DropDatabase(u)
@ -102,9 +102,9 @@ func TestDatabaseExists(t *testing.T) {
require.Equal(t, true, exists)
}
func TestDatabaseExists_Error(t *testing.T) {
drv := Driver{}
u := testURL(t)
func TestMySQLDatabaseExists_Error(t *testing.T) {
drv := MySQLDriver{}
u := mySQLTestURL(t)
u.User = url.User("invalid")
exists, err := drv.DatabaseExists(u)
@ -112,9 +112,9 @@ func TestDatabaseExists_Error(t *testing.T) {
require.Equal(t, false, exists)
}
func TestCreateMigrationsTable(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestMySQLCreateMigrationsTable(t *testing.T) {
drv := MySQLDriver{}
db := prepTestMySQLDB(t)
defer mustClose(db)
// migrations table should not exist
@ -135,9 +135,9 @@ func TestCreateMigrationsTable(t *testing.T) {
require.Nil(t, err)
}
func TestSelectMigrations(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestMySQLSelectMigrations(t *testing.T) {
drv := MySQLDriver{}
db := prepTestMySQLDB(t)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
@ -161,9 +161,9 @@ func TestSelectMigrations(t *testing.T) {
require.Equal(t, false, migrations["abc2"])
}
func TestInsertMigration(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestMySQLInsertMigration(t *testing.T) {
drv := MySQLDriver{}
db := prepTestMySQLDB(t)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
@ -184,9 +184,9 @@ func TestInsertMigration(t *testing.T) {
require.Equal(t, 1, count)
}
func TestDeleteMigration(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestMySQLDeleteMigration(t *testing.T) {
drv := MySQLDriver{}
db := prepTestMySQLDB(t)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)

View file

@ -1,24 +1,22 @@
package postgres
package main
import (
"database/sql"
"fmt"
"github.com/adrianmacneil/dbmate/driver/shared"
"github.com/lib/pq"
"io"
"net/url"
)
// Driver provides top level database functions
type Driver struct {
// PostgresDriver provides top level database functions
type PostgresDriver struct {
}
// Open creates a new database connection
func (drv Driver) Open(u *url.URL) (*sql.DB, error) {
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("postgres", u.String())
}
func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) {
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
// connect to postgres database
postgresURL := *u
postgresURL.Path = "postgres"
@ -26,15 +24,9 @@ func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) {
return drv.Open(&postgresURL)
}
func mustClose(c io.Closer) {
if err := c.Close(); err != nil {
panic(err)
}
}
// CreateDatabase creates the specified database
func (drv Driver) CreateDatabase(u *url.URL) error {
name := shared.DatabaseName(u)
func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Creating: %s\n", name)
db, err := drv.openPostgresDB(u)
@ -50,8 +42,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error {
}
// DropDatabase drops the specified database (if it exists)
func (drv Driver) DropDatabase(u *url.URL) error {
name := shared.DatabaseName(u)
func (drv PostgresDriver) DropDatabase(u *url.URL) error {
name := databaseName(u)
fmt.Printf("Dropping: %s\n", name)
db, err := drv.openPostgresDB(u)
@ -67,8 +59,8 @@ func (drv Driver) DropDatabase(u *url.URL) error {
}
// DatabaseExists determines whether the database exists
func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
name := shared.DatabaseName(u)
func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u)
db, err := drv.openPostgresDB(u)
if err != nil {
@ -87,7 +79,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema_migrations table
func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
_, err := db.Exec(`create table if not exists schema_migrations (
version varchar(255) primary key)`)
@ -96,7 +88,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := "select version from schema_migrations order by version desc"
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
@ -122,14 +114,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro
}
// InsertMigration adds a new migration record
func (drv Driver) InsertMigration(db shared.Transaction, version string) error {
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
_, err := db.Exec("insert into schema_migrations (version) values ($1)", version)
return err
}
// DeleteMigration removes a migration record
func (drv Driver) DeleteMigration(db shared.Transaction, version string) error {
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
_, err := db.Exec("delete from schema_migrations where version = $1", version)
return err

View file

@ -1,4 +1,4 @@
package postgres
package main
import (
"database/sql"
@ -8,7 +8,7 @@ import (
"testing"
)
func testURL(t *testing.T) *url.URL {
func postgresTestURL(t *testing.T) *url.URL {
str := os.Getenv("POSTGRES_PORT")
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
@ -23,9 +23,9 @@ func testURL(t *testing.T) *url.URL {
return u
}
func prepTestDB(t *testing.T) *sql.DB {
drv := Driver{}
u := testURL(t)
func prepTestPostgresDB(t *testing.T) *sql.DB {
drv := PostgresDriver{}
u := postgresTestURL(t)
// drop any existing database
err := drv.DropDatabase(u)
@ -42,9 +42,9 @@ func prepTestDB(t *testing.T) *sql.DB {
return db
}
func TestCreateDropDatabase(t *testing.T) {
drv := Driver{}
u := testURL(t)
func TestPostgresCreateDropDatabase(t *testing.T) {
drv := PostgresDriver{}
u := postgresTestURL(t)
// drop any existing database
err := drv.DropDatabase(u)
@ -80,9 +80,9 @@ func TestCreateDropDatabase(t *testing.T) {
}()
}
func TestDatabaseExists(t *testing.T) {
drv := Driver{}
u := testURL(t)
func TestPostgresDatabaseExists(t *testing.T) {
drv := PostgresDriver{}
u := postgresTestURL(t)
// drop any existing database
err := drv.DropDatabase(u)
@ -103,9 +103,9 @@ func TestDatabaseExists(t *testing.T) {
require.Equal(t, true, exists)
}
func TestDatabaseExists_Error(t *testing.T) {
drv := Driver{}
u := testURL(t)
func TestPostgresDatabaseExists_Error(t *testing.T) {
drv := PostgresDriver{}
u := postgresTestURL(t)
u.User = url.User("invalid")
exists, err := drv.DatabaseExists(u)
@ -113,9 +113,9 @@ func TestDatabaseExists_Error(t *testing.T) {
require.Equal(t, false, exists)
}
func TestCreateMigrationsTable(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestPostgresCreateMigrationsTable(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
defer mustClose(db)
// migrations table should not exist
@ -136,9 +136,9 @@ func TestCreateMigrationsTable(t *testing.T) {
require.Nil(t, err)
}
func TestSelectMigrations(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestPostgresSelectMigrations(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
@ -162,9 +162,9 @@ func TestSelectMigrations(t *testing.T) {
require.Equal(t, false, migrations["abc2"])
}
func TestInsertMigration(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestPostgresInsertMigration(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)
@ -185,9 +185,9 @@ func TestInsertMigration(t *testing.T) {
require.Equal(t, 1, count)
}
func TestDeleteMigration(t *testing.T) {
drv := Driver{}
db := prepTestDB(t)
func TestPostgresDeleteMigration(t *testing.T) {
drv := PostgresDriver{}
db := prepTestPostgresDB(t)
defer mustClose(db)
err := drv.CreateMigrationsTable(db)

22
utils.go Normal file
View file

@ -0,0 +1,22 @@
package main
import (
"io"
"net/url"
)
// databaseName returns the database name from a URL
func databaseName(u *url.URL) string {
name := u.Path
if len(name) > 0 && name[:1] == "/" {
name = name[1:len(name)]
}
return name
}
func mustClose(c io.Closer) {
if err := c.Close(); err != nil {
panic(err)
}
}

View file

@ -1,4 +1,4 @@
package shared
package main
import (
"github.com/stretchr/testify/require"
@ -10,7 +10,7 @@ func TestDatabaseName(t *testing.T) {
u, err := url.Parse("ignore://localhost/foo?query")
require.Nil(t, err)
name := DatabaseName(u)
name := databaseName(u)
require.Equal(t, "foo", name)
}
@ -18,6 +18,6 @@ func TestDatabaseName_Empty(t *testing.T) {
u, err := url.Parse("ignore://localhost")
require.Nil(t, err)
name := DatabaseName(u)
name := databaseName(u)
require.Equal(t, "", name)
}