mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2025-12-11 23:50:04 +01:00
Flatten package layout
This commit is contained in:
parent
275f5791f4
commit
aa5757f8f9
13 changed files with 193 additions and 239 deletions
25
commands.go
25
commands.go
|
|
@ -3,10 +3,7 @@ package main
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/adrianmacneil/dbmate/driver"
|
||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
||||
"github.com/codegangsta/cli"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
|
|
@ -22,7 +19,7 @@ func UpCommand(ctx *cli.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
drv, err := driver.Get(u.Scheme)
|
||||
drv, err := GetDriver(u.Scheme)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -48,7 +45,7 @@ func CreateCommand(ctx *cli.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
drv, err := driver.Get(u.Scheme)
|
||||
drv, err := GetDriver(u.Scheme)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -63,7 +60,7 @@ func DropCommand(ctx *cli.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
drv, err := driver.Get(u.Scheme)
|
||||
drv, err := GetDriver(u.Scheme)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -73,12 +70,6 @@ func DropCommand(ctx *cli.Context) error {
|
|||
|
||||
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
|
||||
|
||||
func mustClose(c io.Closer) {
|
||||
if err := c.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// NewCommand creates a new migration file
|
||||
func NewCommand(ctx *cli.Context) error {
|
||||
// new migration name
|
||||
|
|
@ -126,7 +117,7 @@ func GetDatabaseURL(ctx *cli.Context) (u *url.URL, err error) {
|
|||
return url.Parse(value)
|
||||
}
|
||||
|
||||
func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
|
||||
func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -143,13 +134,13 @@ func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
|
|||
return tx.Commit()
|
||||
}
|
||||
|
||||
func openDatabaseForMigration(ctx *cli.Context) (driver.Driver, *sql.DB, error) {
|
||||
func openDatabaseForMigration(ctx *cli.Context) (Driver, *sql.DB, error) {
|
||||
u, err := GetDatabaseURL(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
drv, err := driver.Get(u.Scheme)
|
||||
drv, err := GetDriver(u.Scheme)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
|
@ -205,7 +196,7 @@ func MigrateCommand(ctx *cli.Context) error {
|
|||
}
|
||||
|
||||
// begin transaction
|
||||
err = doTransaction(db, func(tx shared.Transaction) error {
|
||||
err = doTransaction(db, func(tx Transaction) error {
|
||||
// run actual migration
|
||||
if _, err := tx.Exec(migration["up"]); err != nil {
|
||||
return err
|
||||
|
|
@ -364,7 +355,7 @@ func RollbackCommand(ctx *cli.Context) error {
|
|||
}
|
||||
|
||||
// begin transaction
|
||||
err = doTransaction(db, func(tx shared.Transaction) error {
|
||||
err = doTransaction(db, func(tx Transaction) error {
|
||||
// rollback migration
|
||||
if _, err := tx.Exec(migration["down"]); err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"flag"
|
||||
"github.com/adrianmacneil/dbmate/driver"
|
||||
"github.com/codegangsta/cli"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net/url"
|
||||
|
|
@ -35,39 +34,10 @@ func testContext(t *testing.T, u *url.URL) *cli.Context {
|
|||
return cli.NewContext(app, flagset, nil)
|
||||
}
|
||||
|
||||
func postgresTestURL(t *testing.T) *url.URL {
|
||||
str := os.Getenv("POSTGRES_PORT")
|
||||
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
|
||||
|
||||
u, err := url.Parse(str)
|
||||
require.Nil(t, err)
|
||||
|
||||
u.Scheme = "postgres"
|
||||
u.User = url.User("postgres")
|
||||
u.Path = "/dbmate"
|
||||
u.RawQuery = "sslmode=disable"
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func mysqlTestURL(t *testing.T) *url.URL {
|
||||
str := os.Getenv("MYSQL_PORT")
|
||||
require.NotEmpty(t, str, "missing MYSQL_PORT environment variable")
|
||||
|
||||
u, err := url.Parse(str)
|
||||
require.Nil(t, err)
|
||||
|
||||
u.Scheme = "mysql"
|
||||
u.User = url.UserPassword("root", "root")
|
||||
u.Path = "/dbmate"
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func testURLs(t *testing.T) []*url.URL {
|
||||
return []*url.URL{
|
||||
postgresTestURL(t),
|
||||
mysqlTestURL(t),
|
||||
mySQLTestURL(t),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -98,7 +68,7 @@ func testMigrateCommandURL(t *testing.T, u *url.URL) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// verify results
|
||||
db, err := driver.Open(u)
|
||||
db, err := GetDriverOpen(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(db)
|
||||
|
||||
|
|
@ -131,7 +101,7 @@ func testUpCommandURL(t *testing.T, u *url.URL) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// verify results
|
||||
db, err := driver.Open(u)
|
||||
db, err := GetDriverOpen(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(db)
|
||||
|
||||
|
|
@ -164,7 +134,7 @@ func testRollbackCommandURL(t *testing.T, u *url.URL) {
|
|||
require.Nil(t, err)
|
||||
|
||||
// verify migration
|
||||
db, err := driver.Open(u)
|
||||
db, err := GetDriverOpen(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(db)
|
||||
|
||||
|
|
|
|||
46
driver.go
Normal file
46
driver.go
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Driver provides top level database functions
|
||||
type Driver interface {
|
||||
Open(*url.URL) (*sql.DB, error)
|
||||
DatabaseExists(*url.URL) (bool, error)
|
||||
CreateDatabase(*url.URL) error
|
||||
DropDatabase(*url.URL) error
|
||||
CreateMigrationsTable(*sql.DB) error
|
||||
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
||||
InsertMigration(Transaction, string) error
|
||||
DeleteMigration(Transaction, string) error
|
||||
}
|
||||
|
||||
// Transaction can represent a database or open transaction
|
||||
type Transaction interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
// GetDriver loads a database driver by name
|
||||
func GetDriver(name string) (Driver, error) {
|
||||
switch name {
|
||||
case "mysql":
|
||||
return MySQLDriver{}, nil
|
||||
case "postgres":
|
||||
return PostgresDriver{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown driver: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
|
||||
func GetDriverOpen(u *url.URL) (*sql.DB, error) {
|
||||
drv, err := GetDriver(u.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return drv.Open(u)
|
||||
}
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
package driver
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/adrianmacneil/dbmate/driver/mysql"
|
||||
"github.com/adrianmacneil/dbmate/driver/postgres"
|
||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Driver provides top level database functions
|
||||
type Driver interface {
|
||||
Open(*url.URL) (*sql.DB, error)
|
||||
DatabaseExists(*url.URL) (bool, error)
|
||||
CreateDatabase(*url.URL) error
|
||||
DropDatabase(*url.URL) error
|
||||
CreateMigrationsTable(*sql.DB) error
|
||||
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
||||
InsertMigration(shared.Transaction, string) error
|
||||
DeleteMigration(shared.Transaction, string) error
|
||||
}
|
||||
|
||||
// Get loads a database driver by name
|
||||
func Get(name string) (Driver, error) {
|
||||
switch name {
|
||||
case "mysql":
|
||||
return mysql.Driver{}, nil
|
||||
case "postgres":
|
||||
return postgres.Driver{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown driver: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// Open is a shortcut for driver.Get(u.Scheme).Open(u)
|
||||
func Open(u *url.URL) (*sql.DB, error) {
|
||||
drv, err := Get(u.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return drv.Open(u)
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
package driver
|
||||
|
||||
import (
|
||||
"github.com/adrianmacneil/dbmate/driver/postgres"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGet_Postgres(t *testing.T) {
|
||||
drv, err := Get("postgres")
|
||||
require.Nil(t, err)
|
||||
_, ok := drv.(postgres.Driver)
|
||||
require.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestGet_Error(t *testing.T) {
|
||||
drv, err := Get("foo")
|
||||
require.Equal(t, "Unknown driver: foo", err.Error())
|
||||
require.Nil(t, drv)
|
||||
}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// DatabaseName returns the database name from a URL
|
||||
func DatabaseName(u *url.URL) string {
|
||||
name := u.Path
|
||||
if len(name) > 0 && name[:1] == "/" {
|
||||
name = name[1:len(name)]
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
// Transaction can represent a database or open transaction
|
||||
type Transaction interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
}
|
||||
26
driver_test.go
Normal file
26
driver_test.go
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetDriver_Postgres(t *testing.T) {
|
||||
drv, err := GetDriver("postgres")
|
||||
require.Nil(t, err)
|
||||
_, ok := drv.(PostgresDriver)
|
||||
require.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestGetDriver_MySQL(t *testing.T) {
|
||||
drv, err := GetDriver("mysql")
|
||||
require.Nil(t, err)
|
||||
_, ok := drv.(MySQLDriver)
|
||||
require.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestGetDriver_Error(t *testing.T) {
|
||||
drv, err := GetDriver("foo")
|
||||
require.Equal(t, "Unknown driver: foo", err.Error())
|
||||
require.Nil(t, drv)
|
||||
}
|
||||
|
|
@ -1,20 +1,18 @@
|
|||
package mysql
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
||||
_ "github.com/adrianmacneil/go-mysql" // mysql driver
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Driver provides top level database functions
|
||||
type Driver struct {
|
||||
// MySQLDriver provides top level database functions
|
||||
type MySQLDriver struct {
|
||||
}
|
||||
|
||||
func normalizeURL(u *url.URL) string {
|
||||
func normalizeMySQLURL(u *url.URL) string {
|
||||
normalizedURL := *u
|
||||
normalizedURL.Scheme = ""
|
||||
normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host)
|
||||
|
|
@ -28,11 +26,11 @@ func normalizeURL(u *url.URL) string {
|
|||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv Driver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("mysql", normalizeURL(u))
|
||||
func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("mysql", normalizeMySQLURL(u))
|
||||
}
|
||||
|
||||
func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) {
|
||||
func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
|
||||
// connect to no particular database
|
||||
rootURL := *u
|
||||
rootURL.Path = "/"
|
||||
|
|
@ -40,12 +38,6 @@ func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) {
|
|||
return drv.Open(&rootURL)
|
||||
}
|
||||
|
||||
func mustClose(c io.Closer) {
|
||||
if err := c.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func quoteIdentifier(str string) string {
|
||||
str = strings.Replace(str, "`", "\\`", -1)
|
||||
|
||||
|
|
@ -53,8 +45,8 @@ func quoteIdentifier(str string) string {
|
|||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv Driver) CreateDatabase(u *url.URL) error {
|
||||
name := shared.DatabaseName(u)
|
||||
func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Creating: %s\n", name)
|
||||
|
||||
db, err := drv.openRootDB(u)
|
||||
|
|
@ -70,8 +62,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error {
|
|||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv Driver) DropDatabase(u *url.URL) error {
|
||||
name := shared.DatabaseName(u)
|
||||
func (drv MySQLDriver) DropDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Dropping: %s\n", name)
|
||||
|
||||
db, err := drv.openRootDB(u)
|
||||
|
|
@ -87,8 +79,8 @@ func (drv Driver) DropDatabase(u *url.URL) error {
|
|||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := shared.DatabaseName(u)
|
||||
func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := databaseName(u)
|
||||
|
||||
db, err := drv.openRootDB(u)
|
||||
if err != nil {
|
||||
|
|
@ -107,7 +99,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
|||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||
func (drv MySQLDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||
_, err := db.Exec(`create table if not exists schema_migrations (
|
||||
version varchar(255) primary key)`)
|
||||
|
||||
|
|
@ -116,7 +108,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
|||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations order by version desc"
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
|
|
@ -142,14 +134,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro
|
|||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv Driver) InsertMigration(db shared.Transaction, version string) error {
|
||||
func (drv MySQLDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv Driver) DeleteMigration(db shared.Transaction, version string) error {
|
||||
func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from schema_migrations where version = ?", version)
|
||||
|
||||
return err
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package mysql
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
|
@ -8,7 +8,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func testURL(t *testing.T) *url.URL {
|
||||
func mySQLTestURL(t *testing.T) *url.URL {
|
||||
str := os.Getenv("MYSQL_PORT")
|
||||
require.NotEmpty(t, str, "missing MYSQL_PORT environment variable")
|
||||
|
||||
|
|
@ -22,9 +22,9 @@ func testURL(t *testing.T) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
func prepTestDB(t *testing.T) *sql.DB {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func prepTestMySQLDB(t *testing.T) *sql.DB {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -41,9 +41,9 @@ func prepTestDB(t *testing.T) *sql.DB {
|
|||
return db
|
||||
}
|
||||
|
||||
func TestCreateDropDatabase(t *testing.T) {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func TestMySQLCreateDropDatabase(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -79,9 +79,9 @@ func TestCreateDropDatabase(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
|
||||
func TestDatabaseExists(t *testing.T) {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func TestMySQLDatabaseExists(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -102,9 +102,9 @@ func TestDatabaseExists(t *testing.T) {
|
|||
require.Equal(t, true, exists)
|
||||
}
|
||||
|
||||
func TestDatabaseExists_Error(t *testing.T) {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func TestMySQLDatabaseExists_Error(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
u.User = url.User("invalid")
|
||||
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
|
|
@ -112,9 +112,9 @@ func TestDatabaseExists_Error(t *testing.T) {
|
|||
require.Equal(t, false, exists)
|
||||
}
|
||||
|
||||
func TestCreateMigrationsTable(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestMySQLCreateMigrationsTable(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
|
|
@ -135,9 +135,9 @@ func TestCreateMigrationsTable(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestSelectMigrations(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestMySQLSelectMigrations(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -161,9 +161,9 @@ func TestSelectMigrations(t *testing.T) {
|
|||
require.Equal(t, false, migrations["abc2"])
|
||||
}
|
||||
|
||||
func TestInsertMigration(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestMySQLInsertMigration(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -184,9 +184,9 @@ func TestInsertMigration(t *testing.T) {
|
|||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestDeleteMigration(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestMySQLDeleteMigration(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -1,24 +1,22 @@
|
|||
package postgres
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
||||
"github.com/lib/pq"
|
||||
"io"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Driver provides top level database functions
|
||||
type Driver struct {
|
||||
// PostgresDriver provides top level database functions
|
||||
type PostgresDriver struct {
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv Driver) Open(u *url.URL) (*sql.DB, error) {
|
||||
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("postgres", u.String())
|
||||
}
|
||||
|
||||
func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||
// connect to postgres database
|
||||
postgresURL := *u
|
||||
postgresURL.Path = "postgres"
|
||||
|
|
@ -26,15 +24,9 @@ func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
|||
return drv.Open(&postgresURL)
|
||||
}
|
||||
|
||||
func mustClose(c io.Closer) {
|
||||
if err := c.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv Driver) CreateDatabase(u *url.URL) error {
|
||||
name := shared.DatabaseName(u)
|
||||
func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Creating: %s\n", name)
|
||||
|
||||
db, err := drv.openPostgresDB(u)
|
||||
|
|
@ -50,8 +42,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error {
|
|||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv Driver) DropDatabase(u *url.URL) error {
|
||||
name := shared.DatabaseName(u)
|
||||
func (drv PostgresDriver) DropDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Dropping: %s\n", name)
|
||||
|
||||
db, err := drv.openPostgresDB(u)
|
||||
|
|
@ -67,8 +59,8 @@ func (drv Driver) DropDatabase(u *url.URL) error {
|
|||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := shared.DatabaseName(u)
|
||||
func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := databaseName(u)
|
||||
|
||||
db, err := drv.openPostgresDB(u)
|
||||
if err != nil {
|
||||
|
|
@ -87,7 +79,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
|||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||
func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||
_, err := db.Exec(`create table if not exists schema_migrations (
|
||||
version varchar(255) primary key)`)
|
||||
|
||||
|
|
@ -96,7 +88,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
|||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations order by version desc"
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
|
|
@ -122,14 +114,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro
|
|||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv Driver) InsertMigration(db shared.Transaction, version string) error {
|
||||
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values ($1)", version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv Driver) DeleteMigration(db shared.Transaction, version string) error {
|
||||
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from schema_migrations where version = $1", version)
|
||||
|
||||
return err
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package postgres
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
|
@ -8,7 +8,7 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
func testURL(t *testing.T) *url.URL {
|
||||
func postgresTestURL(t *testing.T) *url.URL {
|
||||
str := os.Getenv("POSTGRES_PORT")
|
||||
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
|
||||
|
||||
|
|
@ -23,9 +23,9 @@ func testURL(t *testing.T) *url.URL {
|
|||
return u
|
||||
}
|
||||
|
||||
func prepTestDB(t *testing.T) *sql.DB {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func prepTestPostgresDB(t *testing.T) *sql.DB {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -42,9 +42,9 @@ func prepTestDB(t *testing.T) *sql.DB {
|
|||
return db
|
||||
}
|
||||
|
||||
func TestCreateDropDatabase(t *testing.T) {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func TestPostgresCreateDropDatabase(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -80,9 +80,9 @@ func TestCreateDropDatabase(t *testing.T) {
|
|||
}()
|
||||
}
|
||||
|
||||
func TestDatabaseExists(t *testing.T) {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func TestPostgresDatabaseExists(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
|
|
@ -103,9 +103,9 @@ func TestDatabaseExists(t *testing.T) {
|
|||
require.Equal(t, true, exists)
|
||||
}
|
||||
|
||||
func TestDatabaseExists_Error(t *testing.T) {
|
||||
drv := Driver{}
|
||||
u := testURL(t)
|
||||
func TestPostgresDatabaseExists_Error(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
u.User = url.User("invalid")
|
||||
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
|
|
@ -113,9 +113,9 @@ func TestDatabaseExists_Error(t *testing.T) {
|
|||
require.Equal(t, false, exists)
|
||||
}
|
||||
|
||||
func TestCreateMigrationsTable(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
|
|
@ -136,9 +136,9 @@ func TestCreateMigrationsTable(t *testing.T) {
|
|||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestSelectMigrations(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestPostgresSelectMigrations(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -162,9 +162,9 @@ func TestSelectMigrations(t *testing.T) {
|
|||
require.Equal(t, false, migrations["abc2"])
|
||||
}
|
||||
|
||||
func TestInsertMigration(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestPostgresInsertMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
|
|
@ -185,9 +185,9 @@ func TestInsertMigration(t *testing.T) {
|
|||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestDeleteMigration(t *testing.T) {
|
||||
drv := Driver{}
|
||||
db := prepTestDB(t)
|
||||
func TestPostgresDeleteMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
22
utils.go
Normal file
22
utils.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// databaseName returns the database name from a URL
|
||||
func databaseName(u *url.URL) string {
|
||||
name := u.Path
|
||||
if len(name) > 0 && name[:1] == "/" {
|
||||
name = name[1:len(name)]
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
func mustClose(c io.Closer) {
|
||||
if err := c.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package shared
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -10,7 +10,7 @@ func TestDatabaseName(t *testing.T) {
|
|||
u, err := url.Parse("ignore://localhost/foo?query")
|
||||
require.Nil(t, err)
|
||||
|
||||
name := DatabaseName(u)
|
||||
name := databaseName(u)
|
||||
require.Equal(t, "foo", name)
|
||||
}
|
||||
|
||||
|
|
@ -18,6 +18,6 @@ func TestDatabaseName_Empty(t *testing.T) {
|
|||
u, err := url.Parse("ignore://localhost")
|
||||
require.Nil(t, err)
|
||||
|
||||
name := DatabaseName(u)
|
||||
name := databaseName(u)
|
||||
require.Equal(t, "", name)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue