Allow database URL to be specified via command line

This commit is contained in:
Adrian Macneil 2015-11-25 11:32:13 -08:00
parent 99a4f266e6
commit ece5d3cf0e
4 changed files with 36 additions and 9 deletions

View file

@ -16,7 +16,7 @@ import (
// CreateCommand creates the current database // CreateCommand creates the current database
func CreateCommand(ctx *cli.Context) error { func CreateCommand(ctx *cli.Context) error {
u, err := GetDatabaseURL() u, err := GetDatabaseURL(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -31,7 +31,7 @@ func CreateCommand(ctx *cli.Context) error {
// DropCommand drops the current database (if it exists) // DropCommand drops the current database (if it exists)
func DropCommand(ctx *cli.Context) error { func DropCommand(ctx *cli.Context) error {
u, err := GetDatabaseURL() u, err := GetDatabaseURL(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -86,8 +86,11 @@ func NewCommand(ctx *cli.Context) error {
} }
// GetDatabaseURL returns the current environment database url // GetDatabaseURL returns the current environment database url
func GetDatabaseURL() (u *url.URL, err error) { func GetDatabaseURL(ctx *cli.Context) (u *url.URL, err error) {
return url.Parse(os.Getenv("DATABASE_URL")) env := ctx.GlobalString("env")
value := os.Getenv(env)
return url.Parse(value)
} }
func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error { func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
@ -116,7 +119,7 @@ func MigrateCommand(ctx *cli.Context) error {
return fmt.Errorf("No migration files found.") return fmt.Errorf("No migration files found.")
} }
u, err := GetDatabaseURL() u, err := GetDatabaseURL(ctx)
if err != nil { if err != nil {
return err return err
} }

View file

@ -1,16 +1,29 @@
package main_test package main_test
import ( import (
"flag"
"github.com/adrianmacneil/dbmate" "github.com/adrianmacneil/dbmate"
"github.com/codegangsta/cli"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"os" "os"
"testing" "testing"
) )
func TestGetDatabaseUrl(t *testing.T) { func newContext() *cli.Context {
app := main.NewApp()
flagset := flag.NewFlagSet(app.Name, flag.ContinueOnError)
for _, f := range app.Flags {
f.Apply(flagset)
}
return cli.NewContext(app, flagset, nil)
}
func TestGetDatabaseUrl_Default(t *testing.T) {
os.Setenv("DATABASE_URL", "postgres://example.org/db") os.Setenv("DATABASE_URL", "postgres://example.org/db")
u, err := main.GetDatabaseURL() ctx := newContext()
u, err := main.GetDatabaseURL(ctx)
require.Nil(t, err) require.Nil(t, err)
require.Equal(t, "postgres", u.Scheme) require.Equal(t, "postgres", u.Scheme)

View file

@ -11,7 +11,7 @@ import (
func testURL(t *testing.T) *url.URL { func testURL(t *testing.T) *url.URL {
str := os.Getenv("POSTGRES_PORT") str := os.Getenv("POSTGRES_PORT")
require.NotEmpty(t, str) require.NotEmpty(t, str, "missing POSTGRES_PORT environment variable")
u, err := url.Parse(str) u, err := url.Parse(str)
require.Nil(t, err) require.Nil(t, err)

13
main.go
View file

@ -11,6 +11,12 @@ import (
func main() { func main() {
loadDotEnv() loadDotEnv()
app := NewApp()
app.Run(os.Args)
}
// NewApp creates a new command line app
func NewApp() *cli.App {
app := cli.NewApp() app := cli.NewApp()
app.Name = "dbmate" app.Name = "dbmate"
app.Usage = "A lightweight, framework-independent database migration tool." app.Usage = "A lightweight, framework-independent database migration tool."
@ -21,6 +27,11 @@ func main() {
Value: "./db/migrations", Value: "./db/migrations",
Usage: "specify the directory containing migration files", Usage: "specify the directory containing migration files",
}, },
cli.StringFlag{
Name: "env, e",
Value: "DATABASE_URL",
Usage: "specify an environment variable containing the database URL",
},
} }
app.Commands = []cli.Command{ app.Commands = []cli.Command{
@ -54,7 +65,7 @@ func main() {
}, },
} }
app.Run(os.Args) return app
} }
type command func(*cli.Context) error type command func(*cli.Context) error