mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2025-12-11 23:50:04 +01:00
Flatten package layout
This commit is contained in:
parent
275f5791f4
commit
aa5757f8f9
13 changed files with 193 additions and 239 deletions
25
commands.go
25
commands.go
|
|
@ -3,10 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/adrianmacneil/dbmate/driver"
|
|
||||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
|
||||||
"github.com/codegangsta/cli"
|
"github.com/codegangsta/cli"
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
|
@ -22,7 +19,7 @@ func UpCommand(ctx *cli.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, err := driver.Get(u.Scheme)
|
drv, err := GetDriver(u.Scheme)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -48,7 +45,7 @@ func CreateCommand(ctx *cli.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, err := driver.Get(u.Scheme)
|
drv, err := GetDriver(u.Scheme)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -63,7 +60,7 @@ func DropCommand(ctx *cli.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, err := driver.Get(u.Scheme)
|
drv, err := GetDriver(u.Scheme)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -73,12 +70,6 @@ func DropCommand(ctx *cli.Context) error {
|
||||||
|
|
||||||
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
|
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
|
||||||
|
|
||||||
func mustClose(c io.Closer) {
|
|
||||||
if err := c.Close(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCommand creates a new migration file
|
// NewCommand creates a new migration file
|
||||||
func NewCommand(ctx *cli.Context) error {
|
func NewCommand(ctx *cli.Context) error {
|
||||||
// new migration name
|
// new migration name
|
||||||
|
|
@ -126,7 +117,7 @@ func GetDatabaseURL(ctx *cli.Context) (u *url.URL, err error) {
|
||||||
return url.Parse(value)
|
return url.Parse(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
|
func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -143,13 +134,13 @@ func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func openDatabaseForMigration(ctx *cli.Context) (driver.Driver, *sql.DB, error) {
|
func openDatabaseForMigration(ctx *cli.Context) (Driver, *sql.DB, error) {
|
||||||
u, err := GetDatabaseURL(ctx)
|
u, err := GetDatabaseURL(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, err := driver.Get(u.Scheme)
|
drv, err := GetDriver(u.Scheme)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -205,7 +196,7 @@ func MigrateCommand(ctx *cli.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// begin transaction
|
// begin transaction
|
||||||
err = doTransaction(db, func(tx shared.Transaction) error {
|
err = doTransaction(db, func(tx Transaction) error {
|
||||||
// run actual migration
|
// run actual migration
|
||||||
if _, err := tx.Exec(migration["up"]); err != nil {
|
if _, err := tx.Exec(migration["up"]); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -364,7 +355,7 @@ func RollbackCommand(ctx *cli.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// begin transaction
|
// begin transaction
|
||||||
err = doTransaction(db, func(tx shared.Transaction) error {
|
err = doTransaction(db, func(tx Transaction) error {
|
||||||
// rollback migration
|
// rollback migration
|
||||||
if _, err := tx.Exec(migration["down"]); err != nil {
|
if _, err := tx.Exec(migration["down"]); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"github.com/adrianmacneil/dbmate/driver"
|
|
||||||
"github.com/codegangsta/cli"
|
"github.com/codegangsta/cli"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -35,39 +34,10 @@ func testContext(t *testing.T, u *url.URL) *cli.Context {
|
||||||
return cli.NewContext(app, flagset, nil)
|
return cli.NewContext(app, flagset, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func postgresTestURL(t *testing.T) *url.URL {
|
|
||||||
str := os.Getenv("POSTGRES_PORT")
|
|
||||||
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
|
|
||||||
|
|
||||||
u, err := url.Parse(str)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
u.Scheme = "postgres"
|
|
||||||
u.User = url.User("postgres")
|
|
||||||
u.Path = "/dbmate"
|
|
||||||
u.RawQuery = "sslmode=disable"
|
|
||||||
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func mysqlTestURL(t *testing.T) *url.URL {
|
|
||||||
str := os.Getenv("MYSQL_PORT")
|
|
||||||
require.NotEmpty(t, str, "missing MYSQL_PORT environment variable")
|
|
||||||
|
|
||||||
u, err := url.Parse(str)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
u.Scheme = "mysql"
|
|
||||||
u.User = url.UserPassword("root", "root")
|
|
||||||
u.Path = "/dbmate"
|
|
||||||
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func testURLs(t *testing.T) []*url.URL {
|
func testURLs(t *testing.T) []*url.URL {
|
||||||
return []*url.URL{
|
return []*url.URL{
|
||||||
postgresTestURL(t),
|
postgresTestURL(t),
|
||||||
mysqlTestURL(t),
|
mySQLTestURL(t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -98,7 +68,7 @@ func testMigrateCommandURL(t *testing.T, u *url.URL) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
db, err := driver.Open(u)
|
db, err := GetDriverOpen(u)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
|
|
@ -131,7 +101,7 @@ func testUpCommandURL(t *testing.T, u *url.URL) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// verify results
|
// verify results
|
||||||
db, err := driver.Open(u)
|
db, err := GetDriverOpen(u)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
|
|
@ -164,7 +134,7 @@ func testRollbackCommandURL(t *testing.T, u *url.URL) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// verify migration
|
// verify migration
|
||||||
db, err := driver.Open(u)
|
db, err := GetDriverOpen(u)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
|
|
|
||||||
46
driver.go
Normal file
46
driver.go
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Driver provides top level database functions
|
||||||
|
type Driver interface {
|
||||||
|
Open(*url.URL) (*sql.DB, error)
|
||||||
|
DatabaseExists(*url.URL) (bool, error)
|
||||||
|
CreateDatabase(*url.URL) error
|
||||||
|
DropDatabase(*url.URL) error
|
||||||
|
CreateMigrationsTable(*sql.DB) error
|
||||||
|
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
||||||
|
InsertMigration(Transaction, string) error
|
||||||
|
DeleteMigration(Transaction, string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transaction can represent a database or open transaction
|
||||||
|
type Transaction interface {
|
||||||
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDriver loads a database driver by name
|
||||||
|
func GetDriver(name string) (Driver, error) {
|
||||||
|
switch name {
|
||||||
|
case "mysql":
|
||||||
|
return MySQLDriver{}, nil
|
||||||
|
case "postgres":
|
||||||
|
return PostgresDriver{}, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("Unknown driver: %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
|
||||||
|
func GetDriverOpen(u *url.URL) (*sql.DB, error) {
|
||||||
|
drv, err := GetDriver(u.Scheme)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return drv.Open(u)
|
||||||
|
}
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
package driver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"github.com/adrianmacneil/dbmate/driver/mysql"
|
|
||||||
"github.com/adrianmacneil/dbmate/driver/postgres"
|
|
||||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
|
||||||
"net/url"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Driver provides top level database functions
|
|
||||||
type Driver interface {
|
|
||||||
Open(*url.URL) (*sql.DB, error)
|
|
||||||
DatabaseExists(*url.URL) (bool, error)
|
|
||||||
CreateDatabase(*url.URL) error
|
|
||||||
DropDatabase(*url.URL) error
|
|
||||||
CreateMigrationsTable(*sql.DB) error
|
|
||||||
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
|
||||||
InsertMigration(shared.Transaction, string) error
|
|
||||||
DeleteMigration(shared.Transaction, string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get loads a database driver by name
|
|
||||||
func Get(name string) (Driver, error) {
|
|
||||||
switch name {
|
|
||||||
case "mysql":
|
|
||||||
return mysql.Driver{}, nil
|
|
||||||
case "postgres":
|
|
||||||
return postgres.Driver{}, nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("Unknown driver: %s", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open is a shortcut for driver.Get(u.Scheme).Open(u)
|
|
||||||
func Open(u *url.URL) (*sql.DB, error) {
|
|
||||||
drv, err := Get(u.Scheme)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return drv.Open(u)
|
|
||||||
}
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
package driver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/adrianmacneil/dbmate/driver/postgres"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGet_Postgres(t *testing.T) {
|
|
||||||
drv, err := Get("postgres")
|
|
||||||
require.Nil(t, err)
|
|
||||||
_, ok := drv.(postgres.Driver)
|
|
||||||
require.Equal(t, true, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGet_Error(t *testing.T) {
|
|
||||||
drv, err := Get("foo")
|
|
||||||
require.Equal(t, "Unknown driver: foo", err.Error())
|
|
||||||
require.Nil(t, drv)
|
|
||||||
}
|
|
||||||
|
|
@ -1,21 +0,0 @@
|
||||||
package shared
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"net/url"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DatabaseName returns the database name from a URL
|
|
||||||
func DatabaseName(u *url.URL) string {
|
|
||||||
name := u.Path
|
|
||||||
if len(name) > 0 && name[:1] == "/" {
|
|
||||||
name = name[1:len(name)]
|
|
||||||
}
|
|
||||||
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transaction can represent a database or open transaction
|
|
||||||
type Transaction interface {
|
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
|
||||||
}
|
|
||||||
26
driver_test.go
Normal file
26
driver_test.go
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetDriver_Postgres(t *testing.T) {
|
||||||
|
drv, err := GetDriver("postgres")
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, ok := drv.(PostgresDriver)
|
||||||
|
require.Equal(t, true, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDriver_MySQL(t *testing.T) {
|
||||||
|
drv, err := GetDriver("mysql")
|
||||||
|
require.Nil(t, err)
|
||||||
|
_, ok := drv.(MySQLDriver)
|
||||||
|
require.Equal(t, true, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDriver_Error(t *testing.T) {
|
||||||
|
drv, err := GetDriver("foo")
|
||||||
|
require.Equal(t, "Unknown driver: foo", err.Error())
|
||||||
|
require.Nil(t, drv)
|
||||||
|
}
|
||||||
|
|
@ -1,20 +1,18 @@
|
||||||
package mysql
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
|
||||||
_ "github.com/adrianmacneil/go-mysql" // mysql driver
|
_ "github.com/adrianmacneil/go-mysql" // mysql driver
|
||||||
"io"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Driver provides top level database functions
|
// MySQLDriver provides top level database functions
|
||||||
type Driver struct {
|
type MySQLDriver struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeURL(u *url.URL) string {
|
func normalizeMySQLURL(u *url.URL) string {
|
||||||
normalizedURL := *u
|
normalizedURL := *u
|
||||||
normalizedURL.Scheme = ""
|
normalizedURL.Scheme = ""
|
||||||
normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host)
|
normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host)
|
||||||
|
|
@ -28,11 +26,11 @@ func normalizeURL(u *url.URL) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
func (drv Driver) Open(u *url.URL) (*sql.DB, error) {
|
func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||||
return sql.Open("mysql", normalizeURL(u))
|
return sql.Open("mysql", normalizeMySQLURL(u))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) {
|
func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
|
||||||
// connect to no particular database
|
// connect to no particular database
|
||||||
rootURL := *u
|
rootURL := *u
|
||||||
rootURL.Path = "/"
|
rootURL.Path = "/"
|
||||||
|
|
@ -40,12 +38,6 @@ func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) {
|
||||||
return drv.Open(&rootURL)
|
return drv.Open(&rootURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustClose(c io.Closer) {
|
|
||||||
if err := c.Close(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func quoteIdentifier(str string) string {
|
func quoteIdentifier(str string) string {
|
||||||
str = strings.Replace(str, "`", "\\`", -1)
|
str = strings.Replace(str, "`", "\\`", -1)
|
||||||
|
|
||||||
|
|
@ -53,8 +45,8 @@ func quoteIdentifier(str string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDatabase creates the specified database
|
// CreateDatabase creates the specified database
|
||||||
func (drv Driver) CreateDatabase(u *url.URL) error {
|
func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
|
||||||
name := shared.DatabaseName(u)
|
name := databaseName(u)
|
||||||
fmt.Printf("Creating: %s\n", name)
|
fmt.Printf("Creating: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openRootDB(u)
|
db, err := drv.openRootDB(u)
|
||||||
|
|
@ -70,8 +62,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropDatabase drops the specified database (if it exists)
|
// DropDatabase drops the specified database (if it exists)
|
||||||
func (drv Driver) DropDatabase(u *url.URL) error {
|
func (drv MySQLDriver) DropDatabase(u *url.URL) error {
|
||||||
name := shared.DatabaseName(u)
|
name := databaseName(u)
|
||||||
fmt.Printf("Dropping: %s\n", name)
|
fmt.Printf("Dropping: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openRootDB(u)
|
db, err := drv.openRootDB(u)
|
||||||
|
|
@ -87,8 +79,8 @@ func (drv Driver) DropDatabase(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseExists determines whether the database exists
|
// DatabaseExists determines whether the database exists
|
||||||
func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
name := shared.DatabaseName(u)
|
name := databaseName(u)
|
||||||
|
|
||||||
db, err := drv.openRootDB(u)
|
db, err := drv.openRootDB(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -107,7 +99,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMigrationsTable creates the schema_migrations table
|
// CreateMigrationsTable creates the schema_migrations table
|
||||||
func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
func (drv MySQLDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(`create table if not exists schema_migrations (
|
_, err := db.Exec(`create table if not exists schema_migrations (
|
||||||
version varchar(255) primary key)`)
|
version varchar(255) primary key)`)
|
||||||
|
|
||||||
|
|
@ -116,7 +108,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
|
|
||||||
// SelectMigrations returns a list of applied migrations
|
// SelectMigrations returns a list of applied migrations
|
||||||
// with an optional limit (in descending order)
|
// with an optional limit (in descending order)
|
||||||
func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||||
query := "select version from schema_migrations order by version desc"
|
query := "select version from schema_migrations order by version desc"
|
||||||
if limit >= 0 {
|
if limit >= 0 {
|
||||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||||
|
|
@ -142,14 +134,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertMigration adds a new migration record
|
// InsertMigration adds a new migration record
|
||||||
func (drv Driver) InsertMigration(db shared.Transaction, version string) error {
|
func (drv MySQLDriver) InsertMigration(db Transaction, version string) error {
|
||||||
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMigration removes a migration record
|
// DeleteMigration removes a migration record
|
||||||
func (drv Driver) DeleteMigration(db shared.Transaction, version string) error {
|
func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error {
|
||||||
_, err := db.Exec("delete from schema_migrations where version = ?", version)
|
_, err := db.Exec("delete from schema_migrations where version = ?", version)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package mysql
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testURL(t *testing.T) *url.URL {
|
func mySQLTestURL(t *testing.T) *url.URL {
|
||||||
str := os.Getenv("MYSQL_PORT")
|
str := os.Getenv("MYSQL_PORT")
|
||||||
require.NotEmpty(t, str, "missing MYSQL_PORT environment variable")
|
require.NotEmpty(t, str, "missing MYSQL_PORT environment variable")
|
||||||
|
|
||||||
|
|
@ -22,9 +22,9 @@ func testURL(t *testing.T) *url.URL {
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepTestDB(t *testing.T) *sql.DB {
|
func prepTestMySQLDB(t *testing.T) *sql.DB {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
u := testURL(t)
|
u := mySQLTestURL(t)
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase(u)
|
||||||
|
|
@ -41,9 +41,9 @@ func prepTestDB(t *testing.T) *sql.DB {
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateDropDatabase(t *testing.T) {
|
func TestMySQLCreateDropDatabase(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
u := testURL(t)
|
u := mySQLTestURL(t)
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase(u)
|
||||||
|
|
@ -79,9 +79,9 @@ func TestCreateDropDatabase(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDatabaseExists(t *testing.T) {
|
func TestMySQLDatabaseExists(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
u := testURL(t)
|
u := mySQLTestURL(t)
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase(u)
|
||||||
|
|
@ -102,9 +102,9 @@ func TestDatabaseExists(t *testing.T) {
|
||||||
require.Equal(t, true, exists)
|
require.Equal(t, true, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDatabaseExists_Error(t *testing.T) {
|
func TestMySQLDatabaseExists_Error(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
u := testURL(t)
|
u := mySQLTestURL(t)
|
||||||
u.User = url.User("invalid")
|
u.User = url.User("invalid")
|
||||||
|
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists(u)
|
||||||
|
|
@ -112,9 +112,9 @@ func TestDatabaseExists_Error(t *testing.T) {
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateMigrationsTable(t *testing.T) {
|
func TestMySQLCreateMigrationsTable(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestMySQLDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
|
|
@ -135,9 +135,9 @@ func TestCreateMigrationsTable(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectMigrations(t *testing.T) {
|
func TestMySQLSelectMigrations(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestMySQLDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
|
|
@ -161,9 +161,9 @@ func TestSelectMigrations(t *testing.T) {
|
||||||
require.Equal(t, false, migrations["abc2"])
|
require.Equal(t, false, migrations["abc2"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertMigration(t *testing.T) {
|
func TestMySQLInsertMigration(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestMySQLDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
|
|
@ -184,9 +184,9 @@ func TestInsertMigration(t *testing.T) {
|
||||||
require.Equal(t, 1, count)
|
require.Equal(t, 1, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteMigration(t *testing.T) {
|
func TestMySQLDeleteMigration(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := MySQLDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestMySQLDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
|
|
@ -1,24 +1,22 @@
|
||||||
package postgres
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/adrianmacneil/dbmate/driver/shared"
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"io"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Driver provides top level database functions
|
// PostgresDriver provides top level database functions
|
||||||
type Driver struct {
|
type PostgresDriver struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
func (drv Driver) Open(u *url.URL) (*sql.DB, error) {
|
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||||
return sql.Open("postgres", u.String())
|
return sql.Open("postgres", u.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||||
// connect to postgres database
|
// connect to postgres database
|
||||||
postgresURL := *u
|
postgresURL := *u
|
||||||
postgresURL.Path = "postgres"
|
postgresURL.Path = "postgres"
|
||||||
|
|
@ -26,15 +24,9 @@ func (drv Driver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||||
return drv.Open(&postgresURL)
|
return drv.Open(&postgresURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mustClose(c io.Closer) {
|
|
||||||
if err := c.Close(); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateDatabase creates the specified database
|
// CreateDatabase creates the specified database
|
||||||
func (drv Driver) CreateDatabase(u *url.URL) error {
|
func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
|
||||||
name := shared.DatabaseName(u)
|
name := databaseName(u)
|
||||||
fmt.Printf("Creating: %s\n", name)
|
fmt.Printf("Creating: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openPostgresDB(u)
|
db, err := drv.openPostgresDB(u)
|
||||||
|
|
@ -50,8 +42,8 @@ func (drv Driver) CreateDatabase(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropDatabase drops the specified database (if it exists)
|
// DropDatabase drops the specified database (if it exists)
|
||||||
func (drv Driver) DropDatabase(u *url.URL) error {
|
func (drv PostgresDriver) DropDatabase(u *url.URL) error {
|
||||||
name := shared.DatabaseName(u)
|
name := databaseName(u)
|
||||||
fmt.Printf("Dropping: %s\n", name)
|
fmt.Printf("Dropping: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openPostgresDB(u)
|
db, err := drv.openPostgresDB(u)
|
||||||
|
|
@ -67,8 +59,8 @@ func (drv Driver) DropDatabase(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseExists determines whether the database exists
|
// DatabaseExists determines whether the database exists
|
||||||
func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
name := shared.DatabaseName(u)
|
name := databaseName(u)
|
||||||
|
|
||||||
db, err := drv.openPostgresDB(u)
|
db, err := drv.openPostgresDB(u)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -87,7 +79,7 @@ func (drv Driver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMigrationsTable creates the schema_migrations table
|
// CreateMigrationsTable creates the schema_migrations table
|
||||||
func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(`create table if not exists schema_migrations (
|
_, err := db.Exec(`create table if not exists schema_migrations (
|
||||||
version varchar(255) primary key)`)
|
version varchar(255) primary key)`)
|
||||||
|
|
||||||
|
|
@ -96,7 +88,7 @@ func (drv Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
|
|
||||||
// SelectMigrations returns a list of applied migrations
|
// SelectMigrations returns a list of applied migrations
|
||||||
// with an optional limit (in descending order)
|
// with an optional limit (in descending order)
|
||||||
func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||||
query := "select version from schema_migrations order by version desc"
|
query := "select version from schema_migrations order by version desc"
|
||||||
if limit >= 0 {
|
if limit >= 0 {
|
||||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||||
|
|
@ -122,14 +114,14 @@ func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, erro
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertMigration adds a new migration record
|
// InsertMigration adds a new migration record
|
||||||
func (drv Driver) InsertMigration(db shared.Transaction, version string) error {
|
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
|
||||||
_, err := db.Exec("insert into schema_migrations (version) values ($1)", version)
|
_, err := db.Exec("insert into schema_migrations (version) values ($1)", version)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMigration removes a migration record
|
// DeleteMigration removes a migration record
|
||||||
func (drv Driver) DeleteMigration(db shared.Transaction, version string) error {
|
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
|
||||||
_, err := db.Exec("delete from schema_migrations where version = $1", version)
|
_, err := db.Exec("delete from schema_migrations where version = $1", version)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package postgres
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
|
@ -8,7 +8,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testURL(t *testing.T) *url.URL {
|
func postgresTestURL(t *testing.T) *url.URL {
|
||||||
str := os.Getenv("POSTGRES_PORT")
|
str := os.Getenv("POSTGRES_PORT")
|
||||||
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
|
require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
|
||||||
|
|
||||||
|
|
@ -23,9 +23,9 @@ func testURL(t *testing.T) *url.URL {
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepTestDB(t *testing.T) *sql.DB {
|
func prepTestPostgresDB(t *testing.T) *sql.DB {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
u := testURL(t)
|
u := postgresTestURL(t)
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase(u)
|
||||||
|
|
@ -42,9 +42,9 @@ func prepTestDB(t *testing.T) *sql.DB {
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateDropDatabase(t *testing.T) {
|
func TestPostgresCreateDropDatabase(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
u := testURL(t)
|
u := postgresTestURL(t)
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase(u)
|
||||||
|
|
@ -80,9 +80,9 @@ func TestCreateDropDatabase(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDatabaseExists(t *testing.T) {
|
func TestPostgresDatabaseExists(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
u := testURL(t)
|
u := postgresTestURL(t)
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase(u)
|
||||||
|
|
@ -103,9 +103,9 @@ func TestDatabaseExists(t *testing.T) {
|
||||||
require.Equal(t, true, exists)
|
require.Equal(t, true, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDatabaseExists_Error(t *testing.T) {
|
func TestPostgresDatabaseExists_Error(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
u := testURL(t)
|
u := postgresTestURL(t)
|
||||||
u.User = url.User("invalid")
|
u.User = url.User("invalid")
|
||||||
|
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists(u)
|
||||||
|
|
@ -113,9 +113,9 @@ func TestDatabaseExists_Error(t *testing.T) {
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCreateMigrationsTable(t *testing.T) {
|
func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
|
|
@ -136,9 +136,9 @@ func TestCreateMigrationsTable(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSelectMigrations(t *testing.T) {
|
func TestPostgresSelectMigrations(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
|
|
@ -162,9 +162,9 @@ func TestSelectMigrations(t *testing.T) {
|
||||||
require.Equal(t, false, migrations["abc2"])
|
require.Equal(t, false, migrations["abc2"])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsertMigration(t *testing.T) {
|
func TestPostgresInsertMigration(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
|
|
@ -185,9 +185,9 @@ func TestInsertMigration(t *testing.T) {
|
||||||
require.Equal(t, 1, count)
|
require.Equal(t, 1, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteMigration(t *testing.T) {
|
func TestPostgresDeleteMigration(t *testing.T) {
|
||||||
drv := Driver{}
|
drv := PostgresDriver{}
|
||||||
db := prepTestDB(t)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer mustClose(db)
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
22
utils.go
Normal file
22
utils.go
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
// databaseName returns the database name from a URL
|
||||||
|
func databaseName(u *url.URL) string {
|
||||||
|
name := u.Path
|
||||||
|
if len(name) > 0 && name[:1] == "/" {
|
||||||
|
name = name[1:len(name)]
|
||||||
|
}
|
||||||
|
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustClose(c io.Closer) {
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package shared
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
@ -10,7 +10,7 @@ func TestDatabaseName(t *testing.T) {
|
||||||
u, err := url.Parse("ignore://localhost/foo?query")
|
u, err := url.Parse("ignore://localhost/foo?query")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
name := DatabaseName(u)
|
name := databaseName(u)
|
||||||
require.Equal(t, "foo", name)
|
require.Equal(t, "foo", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -18,6 +18,6 @@ func TestDatabaseName_Empty(t *testing.T) {
|
||||||
u, err := url.Parse("ignore://localhost")
|
u, err := url.Parse("ignore://localhost")
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
name := DatabaseName(u)
|
name := databaseName(u)
|
||||||
require.Equal(t, "", name)
|
require.Equal(t, "", name)
|
||||||
}
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue