Refactor drivers into separate packages (#179)

`dbmate` package was starting to get a bit polluted. This PR migrates each driver into a separate package, with clean separation between each.

In addition:

* Drivers are now initialized with a URL, avoiding the need to pass `*url.URL` to every method
* Sqlite supports a cleaner syntax for relative paths
* Driver tests now load their test URL from environment variables

Public API of `dbmate` package has not changed (no changes to `main` package).
This commit is contained in:
Adrian Macneil 2020-11-19 15:04:42 +13:00 committed by GitHub
parent c907c3f5c6
commit 61771e386d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1195 additions and 1078 deletions

View file

@ -26,3 +26,7 @@ linters-settings:
local-prefixes: github.com/amacneil/dbmate
misspell:
locale: US
issues:
include:
- EXC0002

View file

@ -3,14 +3,14 @@ LDFLAGS := -ldflags '-s'
# statically link binaries (to support alpine + scratch containers)
STATICLDFLAGS := -ldflags '-s -extldflags "-static"'
# avoid building code that is incompatible with static linking
TAGS := -tags netgo,osusergo,sqlite_omit_load_extension
TAGS := -tags netgo,osusergo,sqlite_omit_load_extension,sqlite_json
.PHONY: all
all: build lint test
all: build test lint
.PHONY: test
test:
go test -v $(TAGS) $(STATICLDFLAGS) ./...
go test -p 1 $(TAGS) $(STATICLDFLAGS) ./...
.PHONY: fix
fix:
@ -22,9 +22,9 @@ lint:
.PHONY: wait
wait:
dist/dbmate-linux-amd64 -e MYSQL_URL wait
dist/dbmate-linux-amd64 -e POSTGRESQL_URL wait
dist/dbmate-linux-amd64 -e CLICKHOUSE_URL wait
dist/dbmate-linux-amd64 -e CLICKHOUSE_TEST_URL wait
dist/dbmate-linux-amd64 -e MYSQL_TEST_URL wait
dist/dbmate-linux-amd64 -e POSTGRES_TEST_URL wait
.PHONY: clean
clean:

View file

@ -152,16 +152,16 @@ DATABASE_URL="postgres://username:password@127.0.0.1:5432/database_name?search_p
**SQLite**
SQLite databases are stored on the filesystem, so you do not need to specify a host. By default, files are relative to the current directory. For example, the following will create a database at `./db/database_name.sqlite3`:
SQLite databases are stored on the filesystem, so you do not need to specify a host. By default, files are relative to the current directory. For example, the following will create a database at `./db/database.sqlite3`:
```sh
DATABASE_URL="sqlite:///db/database_name.sqlite3"
DATABASE_URL="sqlite:db/database.sqlite3"
```
To specify an absolute path, add an additional forward slash to the path. The following will create a database at `/tmp/database_name.sqlite3`:
To specify an absolute path, add a forward slash to the path. The following will create a database at `/tmp/database.sqlite3`:
```sh
DATABASE_URL="sqlite:////tmp/database_name.sqlite3"
DATABASE_URL="sqlite:/tmp/database.sqlite3"
```
**ClickHouse**

View file

@ -11,9 +11,10 @@ services:
- postgres
- clickhouse
environment:
MYSQL_URL: mysql://root:root@mysql/dbmate
POSTGRESQL_URL: postgres://postgres:postgres@postgres/dbmate?sslmode=disable
CLICKHOUSE_URL: clickhouse://clickhouse:9000?database=dbmate
CLICKHOUSE_TEST_URL: clickhouse://clickhouse:9000?database=dbmate_test
MYSQL_TEST_URL: mysql://root:root@mysql/dbmate_test
POSTGRES_TEST_URL: postgres://postgres:postgres@postgres/dbmate_test?sslmode=disable
SQLITE_TEST_URL: sqlite3:/tmp/dbmate_test.sqlite3
dbmate:
build:

1
go.sum
View file

@ -1,3 +1,4 @@
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/ClickHouse/clickhouse-go v1.4.3 h1:iAFMa2UrQdR5bHJ2/yaSLffZkxpcOYQMCUuKeNXGdqc=
github.com/ClickHouse/clickhouse-go v1.4.3/go.mod h1:EaI/sW7Azgz9UATzd5ZdZHRUhHgv5+JMS9NSr2smCJI=

View file

@ -11,6 +11,10 @@ import (
"github.com/urfave/cli/v2"
"github.com/amacneil/dbmate/pkg/dbmate"
_ "github.com/amacneil/dbmate/pkg/driver/clickhouse"
_ "github.com/amacneil/dbmate/pkg/driver/mysql"
_ "github.com/amacneil/dbmate/pkg/driver/postgres"
_ "github.com/amacneil/dbmate/pkg/driver/sqlite"
)
func main() {

View file

@ -2,6 +2,7 @@ package dbmate
import (
"database/sql"
"errors"
"fmt"
"io/ioutil"
"net/url"
@ -10,6 +11,8 @@ import (
"regexp"
"sort"
"time"
"github.com/amacneil/dbmate/pkg/dbutil"
)
// DefaultMigrationsDir specifies default directory to find migration files
@ -43,9 +46,10 @@ type DB struct {
// migrationFileRegexp pattern for valid migration files
var migrationFileRegexp = regexp.MustCompile(`^\d.*\.sql$`)
type statusResult struct {
filename string
applied bool
// StatusResult represents an available migration status
type StatusResult struct {
Filename string
Applied bool
}
// New initializes a new dbmate database
@ -62,16 +66,23 @@ func New(databaseURL *url.URL) *DB {
}
}
// GetDriver loads the required database driver
// GetDriver initializes the appropriate database driver
func (db *DB) GetDriver() (Driver, error) {
drv, err := getDriver(db.DatabaseURL.Scheme)
if err != nil {
return nil, err
if db.DatabaseURL == nil || db.DatabaseURL.Scheme == "" {
return nil, errors.New("invalid url")
}
drv.SetMigrationsTableName(db.MigrationsTableName)
driverFunc := drivers[db.DatabaseURL.Scheme]
if driverFunc == nil {
return nil, fmt.Errorf("unsupported driver: %s", db.DatabaseURL.Scheme)
}
return drv, err
config := DriverConfig{
DatabaseURL: db.DatabaseURL,
MigrationsTableName: db.MigrationsTableName,
}
return driverFunc(config), nil
}
// Wait blocks until the database server is available. It does not verify that
@ -82,8 +93,12 @@ func (db *DB) Wait() error {
return err
}
return db.wait(drv)
}
func (db *DB) wait(drv Driver) error {
// attempt connection to database server
err = drv.Ping(db.DatabaseURL)
err := drv.Ping()
if err == nil {
// connection successful
return nil
@ -95,7 +110,7 @@ func (db *DB) Wait() error {
time.Sleep(db.WaitInterval)
// attempt connection to database server
err = drv.Ping(db.DatabaseURL)
err = drv.Ping()
if err == nil {
// connection successful
fmt.Print("\n")
@ -110,82 +125,91 @@ func (db *DB) Wait() error {
// CreateAndMigrate creates the database (if necessary) and runs migrations
func (db *DB) CreateAndMigrate() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, err := db.GetDriver()
if err != nil {
return err
}
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
// create database if it does not already exist
// skip this step if we cannot determine status
// (e.g. user does not have list database permission)
exists, err := drv.DatabaseExists(db.DatabaseURL)
exists, err := drv.DatabaseExists()
if err == nil && !exists {
if err := drv.CreateDatabase(db.DatabaseURL); err != nil {
if err := drv.CreateDatabase(); err != nil {
return err
}
}
// migrate
return db.Migrate()
return db.migrate(drv)
}
// Create creates the current database
func (db *DB) Create() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, err := db.GetDriver()
if err != nil {
return err
}
return drv.CreateDatabase(db.DatabaseURL)
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
return drv.CreateDatabase()
}
// Drop drops the current database (if it exists)
func (db *DB) Drop() error {
if db.WaitBefore {
err := db.Wait()
if err != nil {
return err
}
}
drv, err := db.GetDriver()
if err != nil {
return err
}
return drv.DropDatabase(db.DatabaseURL)
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
return drv.DropDatabase()
}
// DumpSchema writes the current database schema to a file
func (db *DB) DumpSchema() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
return db.dumpSchema(drv)
}
func (db *DB) dumpSchema(drv Driver) error {
if db.WaitBefore {
err := db.Wait()
err := db.wait(drv)
if err != nil {
return err
}
}
drv, sqlDB, err := db.openDatabaseForMigration()
sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return err
}
defer mustClose(sqlDB)
defer dbutil.MustClose(sqlDB)
schema, err := drv.DumpSchema(db.DatabaseURL, sqlDB)
schema, err := drv.DumpSchema(sqlDB)
if err != nil {
return err
}
@ -201,6 +225,15 @@ func (db *DB) DumpSchema() error {
return ioutil.WriteFile(db.SchemaFile, schema, 0644)
}
// ensureDir creates a directory if it does not already exist
func ensureDir(dir string) error {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("unable to create directory `%s`", dir)
}
return nil
}
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
// NewMigration creates a new migration file
@ -231,13 +264,13 @@ func (db *DB) NewMigration(name string) error {
return err
}
defer mustClose(file)
defer dbutil.MustClose(file)
_, err = file.WriteString(migrationTemplate)
return err
}
func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
tx, err := db.Begin()
func doTransaction(sqlDB *sql.DB, txFunc func(dbutil.Transaction) error) error {
tx, err := sqlDB.Begin()
if err != nil {
return err
}
@ -253,27 +286,31 @@ func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
return tx.Commit()
}
func (db *DB) openDatabaseForMigration() (Driver, *sql.DB, error) {
drv, err := db.GetDriver()
func (db *DB) openDatabaseForMigration(drv Driver) (*sql.DB, error) {
sqlDB, err := drv.Open()
if err != nil {
return nil, nil, err
return nil, err
}
sqlDB, err := drv.Open(db.DatabaseURL)
if err != nil {
return nil, nil, err
if err := drv.CreateMigrationsTable(sqlDB); err != nil {
dbutil.MustClose(sqlDB)
return nil, err
}
if err := drv.CreateMigrationsTable(db.DatabaseURL, sqlDB); err != nil {
mustClose(sqlDB)
return nil, nil, err
}
return drv, sqlDB, nil
return sqlDB, nil
}
// Migrate migrates database to the latest version
func (db *DB) Migrate() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
return db.migrate(drv)
}
func (db *DB) migrate(drv Driver) error {
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
if err != nil {
return err
@ -284,17 +321,17 @@ func (db *DB) Migrate() error {
}
if db.WaitBefore {
err := db.Wait()
err := db.wait(drv)
if err != nil {
return err
}
}
drv, sqlDB, err := db.openDatabaseForMigration()
sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return err
}
defer mustClose(sqlDB)
defer dbutil.MustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, -1)
if err != nil {
@ -315,7 +352,7 @@ func (db *DB) Migrate() error {
return err
}
execMigration := func(tx Transaction) error {
execMigration := func(tx dbutil.Transaction) error {
// run actual migration
result, err := tx.Exec(up.Contents)
if err != nil {
@ -343,12 +380,23 @@ func (db *DB) Migrate() error {
// automatically update schema file, silence errors
if db.AutoDumpSchema {
_ = db.DumpSchema()
_ = db.dumpSchema(drv)
}
return nil
}
func printVerbose(result sql.Result) {
lastInsertID, err := result.LastInsertId()
if err == nil {
fmt.Printf("Last insert ID: %d\n", lastInsertID)
}
rowsAffected, err := result.RowsAffected()
if err == nil {
fmt.Printf("Rows affected: %d\n", rowsAffected)
}
}
func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
files, err := ioutil.ReadDir(dir)
if err != nil {
@ -400,18 +448,23 @@ func migrationVersion(filename string) string {
// Rollback rolls back the most recent migration
func (db *DB) Rollback() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
if db.WaitBefore {
err := db.Wait()
err := db.wait(drv)
if err != nil {
return err
}
}
drv, sqlDB, err := db.openDatabaseForMigration()
sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return err
}
defer mustClose(sqlDB)
defer dbutil.MustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, 1)
if err != nil {
@ -439,7 +492,7 @@ func (db *DB) Rollback() error {
return err
}
execMigration := func(tx Transaction) error {
execMigration := func(tx dbutil.Transaction) error {
// rollback migration
result, err := tx.Exec(down.Contents)
if err != nil {
@ -466,53 +519,20 @@ func (db *DB) Rollback() error {
// automatically update schema file, silence errors
if db.AutoDumpSchema {
_ = db.DumpSchema()
_ = db.dumpSchema(drv)
}
return nil
}
func checkMigrationsStatus(db *DB) ([]statusResult, error) {
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
if err != nil {
return nil, err
}
if len(files) == 0 {
return nil, fmt.Errorf("no migration files found")
}
drv, sqlDB, err := db.openDatabaseForMigration()
if err != nil {
return nil, err
}
defer mustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, -1)
if err != nil {
return nil, err
}
var results []statusResult
for _, filename := range files {
ver := migrationVersion(filename)
res := statusResult{filename: filename}
if ok := applied[ver]; ok {
res.applied = true
} else {
res.applied = false
}
results = append(results, res)
}
return results, nil
}
// Status shows the status of all migrations
func (db *DB) Status(quiet bool) (int, error) {
results, err := checkMigrationsStatus(db)
drv, err := db.GetDriver()
if err != nil {
return -1, err
}
results, err := db.CheckMigrationsStatus(drv)
if err != nil {
return -1, err
}
@ -521,11 +541,11 @@ func (db *DB) Status(quiet bool) (int, error) {
var line string
for _, res := range results {
if res.applied {
line = fmt.Sprintf("[X] %s", res.filename)
if res.Applied {
line = fmt.Sprintf("[X] %s", res.Filename)
totalApplied++
} else {
line = fmt.Sprintf("[ ] %s", res.filename)
line = fmt.Sprintf("[ ] %s", res.Filename)
}
if !quiet {
fmt.Println(line)
@ -541,3 +561,42 @@ func (db *DB) Status(quiet bool) (int, error) {
return totalPending, nil
}
// CheckMigrationsStatus returns the status of all available mgirations
func (db *DB) CheckMigrationsStatus(drv Driver) ([]StatusResult, error) {
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
if err != nil {
return nil, err
}
if len(files) == 0 {
return nil, fmt.Errorf("no migration files found")
}
sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return nil, err
}
defer dbutil.MustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, -1)
if err != nil {
return nil, err
}
var results []StatusResult
for _, filename := range files {
ver := migrationVersion(filename)
res := StatusResult{Filename: filename}
if ok := applied[ver]; ok {
res.Applied = true
} else {
res.Applied = false
}
results = append(results, res)
}
return results, nil
}

View file

@ -1,4 +1,4 @@
package dbmate
package dbmate_test
import (
"io/ioutil"
@ -8,13 +8,19 @@ import (
"testing"
"time"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
_ "github.com/amacneil/dbmate/pkg/driver/mysql"
_ "github.com/amacneil/dbmate/pkg/driver/postgres"
_ "github.com/amacneil/dbmate/pkg/driver/sqlite"
"github.com/kami-zh/go-capturer"
"github.com/stretchr/testify/require"
)
var testdataDir string
func newTestDB(t *testing.T, u *url.URL) *DB {
func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
var err error
// only chdir once, because testdata is relative to current directory
@ -26,17 +32,16 @@ func newTestDB(t *testing.T, u *url.URL) *DB {
require.NoError(t, err)
}
db := New(u)
db := dbmate.New(u)
db.AutoDumpSchema = false
return db
}
func TestNew(t *testing.T) {
u := postgresTestURL(t)
db := New(u)
db := dbmate.New(dbutil.MustParseURL("foo:test"))
require.True(t, db.AutoDumpSchema)
require.Equal(t, u.String(), db.DatabaseURL.String())
require.Equal(t, "foo:test", db.DatabaseURL.String())
require.Equal(t, "./db/migrations", db.MigrationsDir)
require.Equal(t, "schema_migrations", db.MigrationsTableName)
require.Equal(t, "./db/schema.sql", db.SchemaFile)
@ -46,20 +51,30 @@ func TestNew(t *testing.T) {
}
func TestGetDriver(t *testing.T) {
u := postgresTestURL(t)
db := New(u)
t.Run("missing URL", func(t *testing.T) {
db := dbmate.New(nil)
drv, err := db.GetDriver()
require.NoError(t, err)
require.Nil(t, drv)
require.EqualError(t, err, "invalid url")
})
// driver should have default migrations table set
pgDrv, ok := drv.(*PostgresDriver)
require.True(t, ok)
require.Equal(t, "schema_migrations", pgDrv.migrationsTableName)
t.Run("missing schema", func(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("//hi"))
drv, err := db.GetDriver()
require.Nil(t, drv)
require.EqualError(t, err, "invalid url")
})
t.Run("invalid driver", func(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("foo://bar"))
drv, err := db.GetDriver()
require.EqualError(t, err, "unsupported driver: foo")
require.Nil(t, drv)
})
}
func TestWait(t *testing.T) {
u := postgresTestURL(t)
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
db := newTestDB(t, u)
// speed up our retry loop for testing
@ -83,7 +98,7 @@ func TestWait(t *testing.T) {
}
func TestDumpSchema(t *testing.T) {
u := postgresTestURL(t)
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
db := newTestDB(t, u)
// create custom schema file directory
@ -120,7 +135,7 @@ func TestDumpSchema(t *testing.T) {
}
func TestAutoDumpSchema(t *testing.T) {
u := postgresTestURL(t)
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
db := newTestDB(t, u)
db.AutoDumpSchema = true
@ -177,7 +192,7 @@ func checkWaitCalled(t *testing.T, u *url.URL, command func() error) {
}
func testWaitBefore(t *testing.T, verbose bool) {
u := postgresTestURL(t)
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
db := newTestDB(t, u)
db.Verbose = verbose
db.WaitBefore = true
@ -234,19 +249,23 @@ Rows affected: 0`)
Rows affected: 0`)
}
func testURLs(t *testing.T) []*url.URL {
func testURLs() []*url.URL {
return []*url.URL{
postgresTestURL(t),
mySQLTestURL(t),
sqliteTestURL(t),
dbutil.MustParseURL(os.Getenv("MYSQL_TEST_URL")),
dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL")),
dbutil.MustParseURL(os.Getenv("SQLITE_TEST_URL")),
}
}
func testMigrateURL(t *testing.T, u *url.URL) {
func TestMigrate(t *testing.T) {
for _, u := range testURLs() {
t.Run(u.Scheme, func(t *testing.T) {
db := newTestDB(t, u)
drv, err := db.GetDriver()
require.NoError(t, err)
// drop and recreate database
err := db.Drop()
err = db.Drop()
require.NoError(t, err)
err = db.Create()
require.NoError(t, err)
@ -256,9 +275,9 @@ func testMigrateURL(t *testing.T, u *url.URL) {
require.NoError(t, err)
// verify results
sqlDB, err := getDriverOpen(u)
sqlDB, err := drv.Open()
require.NoError(t, err)
defer mustClose(sqlDB)
defer dbutil.MustClose(sqlDB)
count := 0
err = sqlDB.QueryRow(`select count(*) from schema_migrations
@ -269,19 +288,19 @@ func testMigrateURL(t *testing.T, u *url.URL) {
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestMigrate(t *testing.T) {
for _, u := range testURLs(t) {
testMigrateURL(t, u)
})
}
}
func testUpURL(t *testing.T, u *url.URL) {
func TestUp(t *testing.T) {
for _, u := range testURLs() {
t.Run(u.Scheme, func(t *testing.T) {
db := newTestDB(t, u)
drv, err := db.GetDriver()
require.NoError(t, err)
// drop database
err := db.Drop()
err = db.Drop()
require.NoError(t, err)
// create and migrate
@ -289,9 +308,9 @@ func testUpURL(t *testing.T, u *url.URL) {
require.NoError(t, err)
// verify results
sqlDB, err := getDriverOpen(u)
sqlDB, err := drv.Open()
require.NoError(t, err)
defer mustClose(sqlDB)
defer dbutil.MustClose(sqlDB)
count := 0
err = sqlDB.QueryRow(`select count(*) from schema_migrations
@ -302,19 +321,19 @@ func testUpURL(t *testing.T, u *url.URL) {
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestUp(t *testing.T) {
for _, u := range testURLs(t) {
testUpURL(t, u)
})
}
}
func testRollbackURL(t *testing.T, u *url.URL) {
func TestRollback(t *testing.T) {
for _, u := range testURLs() {
t.Run(u.Scheme, func(t *testing.T) {
db := newTestDB(t, u)
drv, err := db.GetDriver()
require.NoError(t, err)
// drop, recreate, and migrate database
err := db.Drop()
err = db.Drop()
require.NoError(t, err)
err = db.Create()
require.NoError(t, err)
@ -322,9 +341,9 @@ func testRollbackURL(t *testing.T, u *url.URL) {
require.NoError(t, err)
// verify migration
sqlDB, err := getDriverOpen(u)
sqlDB, err := drv.Open()
require.NoError(t, err)
defer mustClose(sqlDB)
defer dbutil.MustClose(sqlDB)
count := 0
err = sqlDB.QueryRow(`select count(*) from schema_migrations
@ -347,60 +366,56 @@ func testRollbackURL(t *testing.T, u *url.URL) {
err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
require.NotNil(t, err)
require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
}
func TestRollback(t *testing.T) {
for _, u := range testURLs(t) {
testRollbackURL(t, u)
})
}
}
func testStatusURL(t *testing.T, u *url.URL) {
func TestStatus(t *testing.T) {
for _, u := range testURLs() {
t.Run(u.Scheme, func(t *testing.T) {
db := newTestDB(t, u)
drv, err := db.GetDriver()
require.NoError(t, err)
// drop, recreate, and migrate database
err := db.Drop()
err = db.Drop()
require.NoError(t, err)
err = db.Create()
require.NoError(t, err)
// verify migration
sqlDB, err := getDriverOpen(u)
sqlDB, err := drv.Open()
require.NoError(t, err)
defer mustClose(sqlDB)
defer dbutil.MustClose(sqlDB)
// two pending
results, err := checkMigrationsStatus(db)
results, err := db.CheckMigrationsStatus(drv)
require.NoError(t, err)
require.Len(t, results, 2)
require.False(t, results[0].applied)
require.False(t, results[1].applied)
require.False(t, results[0].Applied)
require.False(t, results[1].Applied)
// run migrations
err = db.Migrate()
require.NoError(t, err)
// two applied
results, err = checkMigrationsStatus(db)
results, err = db.CheckMigrationsStatus(drv)
require.NoError(t, err)
require.Len(t, results, 2)
require.True(t, results[0].applied)
require.True(t, results[1].applied)
require.True(t, results[0].Applied)
require.True(t, results[1].Applied)
// rollback last migration
err = db.Rollback()
require.NoError(t, err)
// one applied, one pending
results, err = checkMigrationsStatus(db)
results, err = db.CheckMigrationsStatus(drv)
require.NoError(t, err)
require.Len(t, results, 2)
require.True(t, results[0].applied)
require.False(t, results[1].applied)
}
func TestStatus(t *testing.T) {
for _, u := range testURLs(t) {
testStatusURL(t, u)
require.True(t, results[0].Applied)
require.False(t, results[1].Applied)
})
}
}

View file

@ -2,56 +2,37 @@ package dbmate
import (
"database/sql"
"fmt"
"net/url"
"github.com/amacneil/dbmate/pkg/dbutil"
)
// Driver provides top level database functions
type Driver interface {
Open(*url.URL) (*sql.DB, error)
DatabaseExists(*url.URL) (bool, error)
CreateDatabase(*url.URL) error
DropDatabase(*url.URL) error
DumpSchema(*url.URL, *sql.DB) ([]byte, error)
SetMigrationsTableName(string)
CreateMigrationsTable(*url.URL, *sql.DB) error
Open() (*sql.DB, error)
DatabaseExists() (bool, error)
CreateDatabase() error
DropDatabase() error
DumpSchema(*sql.DB) ([]byte, error)
CreateMigrationsTable(*sql.DB) error
SelectMigrations(*sql.DB, int) (map[string]bool, error)
InsertMigration(Transaction, string) error
DeleteMigration(Transaction, string) error
Ping(*url.URL) error
InsertMigration(dbutil.Transaction, string) error
DeleteMigration(dbutil.Transaction, string) error
Ping() error
}
var drivers = map[string]Driver{}
// RegisterDriver registers a driver for a URL scheme
func RegisterDriver(drv Driver, scheme string) {
drivers[scheme] = drv
// DriverConfig holds configuration passed to driver constructors
type DriverConfig struct {
DatabaseURL *url.URL
MigrationsTableName string
}
// Transaction can represent a database or open transaction
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// DriverFunc represents a driver constructor
type DriverFunc func(DriverConfig) Driver
// getDriver loads a database driver by name
func getDriver(name string) (Driver, error) {
if drv, ok := drivers[name]; ok {
drv.SetMigrationsTableName(DefaultMigrationsTableName)
var drivers = map[string]DriverFunc{}
return drv, nil
}
return nil, fmt.Errorf("unsupported driver: %s", name)
}
// getDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
func getDriverOpen(u *url.URL) (*sql.DB, error) {
drv, err := getDriver(u.Scheme)
if err != nil {
return nil, err
}
return drv.Open(u)
// RegisterDriver registers a driver constructor for a given URL scheme
func RegisterDriver(f DriverFunc, scheme string) {
drivers[scheme] = f
}

View file

@ -1,27 +0,0 @@
package dbmate
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestGetDriver_Postgres(t *testing.T) {
drv, err := getDriver("postgres")
require.NoError(t, err)
_, ok := drv.(*PostgresDriver)
require.Equal(t, true, ok)
}
func TestGetDriver_MySQL(t *testing.T) {
drv, err := getDriver("mysql")
require.NoError(t, err)
_, ok := drv.(*MySQLDriver)
require.Equal(t, true, ok)
}
func TestGetDriver_Error(t *testing.T) {
drv, err := getDriver("foo")
require.EqualError(t, err, "unsupported driver: foo")
require.Nil(t, drv)
}

View file

@ -1,58 +0,0 @@
package dbmate
import (
"database/sql"
"net/url"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/require"
)
func TestDatabaseName(t *testing.T) {
u, err := url.Parse("ignore://localhost/foo?query")
require.NoError(t, err)
name := databaseName(u)
require.Equal(t, "foo", name)
}
func TestDatabaseName_Empty(t *testing.T) {
u, err := url.Parse("ignore://localhost")
require.NoError(t, err)
name := databaseName(u)
require.Equal(t, "", name)
}
func TestTrimLeadingSQLComments(t *testing.T) {
in := "--\n" +
"-- foo\n\n" +
"-- bar\n\n" +
"real stuff\n" +
"-- end\n"
out, err := trimLeadingSQLComments([]byte(in))
require.NoError(t, err)
require.Equal(t, "real stuff\n-- end\n", string(out))
}
func TestQueryColumn(t *testing.T) {
u := postgresTestURL(t)
db, err := sql.Open("postgres", u.String())
require.NoError(t, err)
val, err := queryColumn(db, "select concat('foo_', unnest($1::text[]))",
pq.Array([]string{"hi", "there"}))
require.NoError(t, err)
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
}
func TestQueryValue(t *testing.T) {
u := postgresTestURL(t)
db, err := sql.Open("postgres", u.String())
require.NoError(t, err)
val, err := queryValue(db, "select $1::int + $2::int", "5", 2)
require.NoError(t, err)
require.Equal(t, "7", val)
}

View file

@ -1,21 +1,26 @@
package dbmate
package dbutil
import (
"bufio"
"bytes"
"database/sql"
"errors"
"fmt"
"io"
"net/url"
"os"
"os/exec"
"strings"
"unicode"
)
// databaseName returns the database name from a URL
func databaseName(u *url.URL) string {
// Transaction can represent a database or open transaction
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// DatabaseName returns the database name from a URL
func DatabaseName(u *url.URL) string {
name := u.Path
if len(name) > 0 && name[:1] == "/" {
name = name[1:]
@ -24,24 +29,15 @@ func databaseName(u *url.URL) string {
return name
}
// mustClose ensures a stream is closed
func mustClose(c io.Closer) {
// MustClose ensures a stream is closed
func MustClose(c io.Closer) {
if err := c.Close(); err != nil {
panic(err)
}
}
// ensureDir creates a directory if it does not already exist
func ensureDir(dir string) error {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("unable to create directory `%s`", dir)
}
return nil
}
// runCommand runs a command and returns the stdout if successful
func runCommand(name string, args ...string) ([]byte, error) {
// RunCommand runs a command and returns the stdout if successful
func RunCommand(name string, args ...string) ([]byte, error) {
var stdout, stderr bytes.Buffer
cmd := exec.Command(name, args...)
cmd.Stdout = &stdout
@ -61,10 +57,10 @@ func runCommand(name string, args ...string) ([]byte, error) {
return stdout.Bytes(), nil
}
// trimLeadingSQLComments removes sql comments and blank lines from the beginning of text
// TrimLeadingSQLComments removes sql comments and blank lines from the beginning of text
// generally when performing sql dumps these contain host-specific information such as
// client/server version numbers
func trimLeadingSQLComments(data []byte) ([]byte, error) {
func TrimLeadingSQLComments(data []byte) ([]byte, error) {
// create decent size buffer
out := bytes.NewBuffer(make([]byte, 0, len(data)))
@ -101,15 +97,15 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) {
return out.Bytes(), nil
}
// queryColumn runs a SQL statement and returns a slice of strings
// QueryColumn runs a SQL statement and returns a slice of strings
// it is assumed that the statement returns only one column
// e.g. schema_migrations table
func queryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
func QueryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer mustClose(rows)
defer MustClose(rows)
// read into slice
var result []string
@ -128,10 +124,10 @@ func queryColumn(db Transaction, query string, args ...interface{}) ([]string, e
return result, nil
}
// queryValue runs a SQL statement and returns a single string
// QueryValue runs a SQL statement and returns a single string
// it is assumed that the statement returns only one row and one column
// sql NULL is returned as empty string
func queryValue(db Transaction, query string, args ...interface{}) (string, error) {
func QueryValue(db Transaction, query string, args ...interface{}) (string, error) {
var result sql.NullString
err := db.QueryRow(query, args...).Scan(&result)
if err != nil || !result.Valid {
@ -141,13 +137,17 @@ func queryValue(db Transaction, query string, args ...interface{}) (string, erro
return result.String, nil
}
func printVerbose(result sql.Result) {
lastInsertID, err := result.LastInsertId()
if err == nil {
fmt.Printf("Last insert ID: %d\n", lastInsertID)
// MustParseURL parses a URL from string, and panics if it fails.
// It is used during testing and in cases where we are parsing a generated URL.
func MustParseURL(s string) *url.URL {
if s == "" {
panic("missing url")
}
rowsAffected, err := result.RowsAffected()
if err == nil {
fmt.Printf("Rows affected: %d\n", rowsAffected)
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}

58
pkg/dbutil/dbutil_test.go Normal file
View file

@ -0,0 +1,58 @@
package dbutil_test
import (
"database/sql"
"testing"
"github.com/amacneil/dbmate/pkg/dbutil"
_ "github.com/mattn/go-sqlite3" // database/sql driver
"github.com/stretchr/testify/require"
)
func TestDatabaseName(t *testing.T) {
t.Run("valid", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host/dbname?query")
name := dbutil.DatabaseName(u)
require.Equal(t, "dbname", name)
})
t.Run("empty", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host")
name := dbutil.DatabaseName(u)
require.Equal(t, "", name)
})
}
func TestTrimLeadingSQLComments(t *testing.T) {
in := "--\n" +
"-- foo\n\n" +
"-- bar\n\n" +
"real stuff\n" +
"-- end\n"
out, err := dbutil.TrimLeadingSQLComments([]byte(in))
require.NoError(t, err)
require.Equal(t, "real stuff\n-- end\n", string(out))
}
// connect to in-memory sqlite database for testing
const sqliteMemoryDB = "file:dbutil.sqlite3?mode=memory&cache=shared"
func TestQueryColumn(t *testing.T) {
db, err := sql.Open("sqlite3", sqliteMemoryDB)
require.NoError(t, err)
val, err := dbutil.QueryColumn(db, "select 'foo_' || val from (select ? as val union select ?)",
"hi", "there")
require.NoError(t, err)
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
}
func TestQueryValue(t *testing.T) {
db, err := sql.Open("sqlite3", sqliteMemoryDB)
require.NoError(t, err)
val, err := dbutil.QueryValue(db, "select $1 + $2", "5", 2)
require.NoError(t, err)
require.Equal(t, "7", val)
}

View file

@ -1,4 +1,4 @@
package dbmate
package clickhouse
import (
"bytes"
@ -9,19 +9,31 @@ import (
"sort"
"strings"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
"github.com/ClickHouse/clickhouse-go"
)
func init() {
RegisterDriver(&ClickHouseDriver{}, "clickhouse")
dbmate.RegisterDriver(NewDriver, "clickhouse")
}
// ClickHouseDriver provides top level database functions
type ClickHouseDriver struct {
// Driver provides top level database functions
type Driver struct {
migrationsTableName string
databaseURL *url.URL
}
func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
// NewDriver initializes the driver
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
return &Driver{
migrationsTableName: config.MigrationsTableName,
databaseURL: config.DatabaseURL,
}
}
func connectionString(initialURL *url.URL) string {
u := *initialURL
u.Scheme = "tcp"
@ -50,31 +62,31 @@ func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
}
u.RawQuery = query.Encode()
return &u
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *ClickHouseDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
return u.String()
}
// Open creates a new database connection
func (drv *ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("clickhouse", normalizeClickHouseURL(u).String())
func (drv *Driver) Open() (*sql.DB, error) {
return sql.Open("clickhouse", connectionString(drv.databaseURL))
}
func (drv *Driver) openClickHouseDB() (*sql.DB, error) {
// clone databaseURL
clickhouseURL, err := url.Parse(connectionString(drv.databaseURL))
if err != nil {
return nil, err
}
func (drv *ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
// connect to clickhouse database
clickhouseURL := normalizeClickHouseURL(u)
values := clickhouseURL.Query()
values.Set("database", "default")
clickhouseURL.RawQuery = values.Encode()
return drv.Open(clickhouseURL)
return sql.Open("clickhouse", clickhouseURL.String())
}
func (drv *ClickHouseDriver) databaseName(u *url.URL) string {
name := normalizeClickHouseURL(u).Query().Get("database")
func (drv *Driver) databaseName() string {
name := dbutil.MustParseURL(connectionString(drv.databaseURL)).Query().Get("database")
if name == "" {
name = "default"
}
@ -83,7 +95,7 @@ func (drv *ClickHouseDriver) databaseName(u *url.URL) string {
var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`)
func (drv *ClickHouseDriver) quoteIdentifier(str string) string {
func (drv *Driver) quoteIdentifier(str string) string {
if clickhouseValidIdentifier.MatchString(str) {
return str
}
@ -94,15 +106,15 @@ func (drv *ClickHouseDriver) quoteIdentifier(str string) string {
}
// CreateDatabase creates the specified database
func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error {
name := drv.databaseName(u)
func (drv *Driver) CreateDatabase() error {
name := drv.databaseName()
fmt.Printf("Creating: %s\n", name)
db, err := drv.openClickHouseDB(u)
db, err := drv.openClickHouseDB()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
_, err = db.Exec("create database " + drv.quoteIdentifier(name))
@ -110,27 +122,27 @@ func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error {
}
// DropDatabase drops the specified database (if it exists)
func (drv *ClickHouseDriver) DropDatabase(u *url.URL) error {
name := drv.databaseName(u)
func (drv *Driver) DropDatabase() error {
name := drv.databaseName()
fmt.Printf("Dropping: %s\n", name)
db, err := drv.openClickHouseDB(u)
db, err := drv.openClickHouseDB()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
_, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name))
return err
}
func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
func (drv *Driver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
buf.WriteString("\n--\n-- Database schema\n--\n\n")
buf.WriteString("CREATE DATABASE " + drv.quoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
tables, err := queryColumn(db, "show tables")
tables, err := dbutil.QueryColumn(db, "show tables")
if err != nil {
return err
}
@ -147,11 +159,11 @@ func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseN
return nil
}
func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
func (drv *Driver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations
migrations, err := queryColumn(db,
migrations, err := dbutil.QueryColumn(db,
fmt.Sprintf("select version from %s final ", migrationsTable)+
"where applied order by version asc",
)
@ -178,11 +190,11 @@ func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer)
}
// DumpSchema returns the current database schema
func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
var buf bytes.Buffer
var err error
err = drv.schemaDump(db, &buf, drv.databaseName(u))
err = drv.schemaDump(db, &buf, drv.databaseName())
if err != nil {
return nil, err
}
@ -196,14 +208,14 @@ func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error)
}
// DatabaseExists determines whether the database exists
func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
name := drv.databaseName(u)
func (drv *Driver) DatabaseExists() (bool, error) {
name := drv.databaseName()
db, err := drv.openClickHouseDB(u)
db, err := drv.openClickHouseDB()
if err != nil {
return false, err
}
defer mustClose(db)
defer dbutil.MustClose(db)
exists := false
err = db.QueryRow("SELECT 1 FROM system.databases where name = ?", name).
@ -216,7 +228,7 @@ func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema migrations table
func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
_, err := db.Exec(fmt.Sprintf(`
create table if not exists %s (
version String,
@ -232,7 +244,7 @@ func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := fmt.Sprintf("select version from %s final where applied order by version desc",
drv.quotedMigrationsTableName())
@ -244,7 +256,7 @@ func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string
return nil, err
}
defer mustClose(rows)
defer dbutil.MustClose(rows)
migrations := map[string]bool{}
for rows.Next() {
@ -264,7 +276,7 @@ func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string
}
// InsertMigration adds a new migration record
func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) error {
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
@ -273,7 +285,7 @@ func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) err
}
// DeleteMigration removes a migration record
func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) error {
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("insert into %s (version, applied) values (?, ?)",
drv.quotedMigrationsTableName()),
@ -285,15 +297,15 @@ func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) err
// Ping verifies a connection to the database server. It does not verify whether the
// specified database exists.
func (drv *ClickHouseDriver) Ping(u *url.URL) error {
func (drv *Driver) Ping() error {
// attempt connection to primary database, not "clickhouse" database
// to support servers with no "clickhouse" database
// (see https://github.com/amacneil/dbmate/issues/78)
db, err := drv.Open(u)
db, err := drv.Open()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
if err == nil {
@ -309,6 +321,6 @@ func (drv *ClickHouseDriver) Ping(u *url.URL) error {
return err
}
func (drv *ClickHouseDriver) quotedMigrationsTableName() string {
func (drv *Driver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}

View file

@ -1,108 +1,117 @@
package dbmate
package clickhouse
import (
"database/sql"
"net/url"
"os"
"testing"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
"github.com/stretchr/testify/require"
)
func clickhouseTestURL(t *testing.T) *url.URL {
u, err := url.Parse("clickhouse://clickhouse:9000?database=dbmate")
func testClickHouseDriver(t *testing.T) *Driver {
u := dbutil.MustParseURL(os.Getenv("CLICKHOUSE_TEST_URL"))
drv, err := dbmate.New(u).GetDriver()
require.NoError(t, err)
return u
return drv.(*Driver)
}
func testClickHouseDriver() *ClickHouseDriver {
drv := &ClickHouseDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestClickHouseDB(t *testing.T, u *url.URL) *sql.DB {
drv := testClickHouseDriver()
func prepTestClickHouseDB(t *testing.T) *sql.DB {
drv := testClickHouseDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// connect database
db, err := sql.Open("clickhouse", u.String())
db, err := sql.Open("clickhouse", drv.databaseURL.String())
require.NoError(t, err)
return db
}
func TestNormalizeClickHouseURLSimplified(t *testing.T) {
func TestGetDriver(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("clickhouse://"))
drvInterface, err := db.GetDriver()
require.NoError(t, err)
// driver should have URL and default migrations table set
drv, ok := drvInterface.(*Driver)
require.True(t, ok)
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
require.Equal(t, "schema_migrations", drv.migrationsTableName)
}
func TestConnectionString(t *testing.T) {
t.Run("simple", func(t *testing.T) {
u, err := url.Parse("clickhouse://user:pass@host/db")
require.NoError(t, err)
s := normalizeClickHouseURL(u).String()
s := connectionString(u)
require.Equal(t, "tcp://host:9000?database=db&password=pass&username=user", s)
}
})
func TestNormalizeClickHouseURLCanonical(t *testing.T) {
t.Run("canonical", func(t *testing.T) {
u, err := url.Parse("clickhouse://host:9000?database=db&password=pass&username=user")
require.NoError(t, err)
s := normalizeClickHouseURL(u).String()
s := connectionString(u)
require.Equal(t, "tcp://host:9000?database=db&password=pass&username=user", s)
})
}
func TestClickHouseCreateDropDatabase(t *testing.T) {
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
drv := testClickHouseDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// check that database exists and we can connect to it
func() {
db, err := sql.Open("clickhouse", u.String())
db, err := sql.Open("clickhouse", drv.databaseURL.String())
require.NoError(t, err)
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
require.NoError(t, err)
}()
// drop the database
err = drv.DropDatabase(u)
err = drv.DropDatabase()
require.NoError(t, err)
// check that database no longer exists
func() {
db, err := sql.Open("clickhouse", u.String())
db, err := sql.Open("clickhouse", drv.databaseURL.String())
require.NoError(t, err)
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
require.EqualError(t, err, "code: 81, message: Database dbmate doesn't exist")
require.EqualError(t, err, "code: 81, message: Database dbmate_test doesn't exist")
}()
}
func TestClickHouseDumpSchema(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
u := clickhouseTestURL(t)
drv := testClickHouseDriver(t)
drv.migrationsTableName = "test_migrations"
// prepare database
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
db := prepTestClickHouseDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
// insert migration
@ -120,9 +129,9 @@ func TestClickHouseDumpSchema(t *testing.T) {
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
schema, err := drv.DumpSchema(db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".test_migrations")
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName()+".test_migrations")
require.Contains(t, string(schema), "--\n"+
"-- Dbmate schema migrations\n"+
"--\n\n"+
@ -131,66 +140,63 @@ func TestClickHouseDumpSchema(t *testing.T) {
" ('abc2');\n")
// DumpSchema should return error if command fails
values := u.Query()
values := drv.databaseURL.Query()
values.Set("database", "fakedb")
u.RawQuery = values.Encode()
db, err = sql.Open("clickhouse", u.String())
drv.databaseURL.RawQuery = values.Encode()
db, err = sql.Open("clickhouse", drv.databaseURL.String())
require.NoError(t, err)
schema, err = drv.DumpSchema(u, db)
schema, err = drv.DumpSchema(db)
require.Nil(t, schema)
require.EqualError(t, err, "code: 81, message: Database fakedb doesn't exist")
}
func TestClickHouseDatabaseExists(t *testing.T) {
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
drv := testClickHouseDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// DatabaseExists should return false
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, false, exists)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// DatabaseExists should return true
exists, err = drv.DatabaseExists(u)
exists, err = drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, true, exists)
}
func TestClickHouseDatabaseExists_Error(t *testing.T) {
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
values := u.Query()
drv := testClickHouseDriver(t)
values := drv.databaseURL.Query()
values.Set("username", "invalid")
u.RawQuery = values.Encode()
drv.databaseURL.RawQuery = values.Encode()
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
require.EqualError(t, err, "code: 192, message: Unknown user invalid")
require.Equal(t, false, exists)
}
func TestClickHouseCreateMigrationsTable(t *testing.T) {
t.Run("default table", func(t *testing.T) {
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
drv := testClickHouseDriver(t)
db := prepTestClickHouseDB(t)
defer dbutil.MustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.")
require.EqualError(t, err, "code: 60, message: Table dbmate_test.schema_migrations doesn't exist.")
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// migrations table should exist
@ -198,25 +204,24 @@ func TestClickHouseCreateMigrationsTable(t *testing.T) {
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
t.Run("custom table", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("testMigrations")
drv := testClickHouseDriver(t)
drv.migrationsTableName = "testMigrations"
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
db := prepTestClickHouseDB(t)
defer dbutil.MustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
require.EqualError(t, err, "code: 60, message: Table dbmate.testMigrations doesn't exist.")
require.EqualError(t, err, "code: 60, message: Table dbmate_test.testMigrations doesn't exist.")
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// migrations table should exist
@ -224,20 +229,19 @@ func TestClickHouseCreateMigrationsTable(t *testing.T) {
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
}
func TestClickHouseSelectMigrations(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testClickHouseDriver(t)
drv.migrationsTableName = "test_migrations"
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
db := prepTestClickHouseDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
tx, err := db.Begin()
@ -268,14 +272,13 @@ func TestClickHouseSelectMigrations(t *testing.T) {
}
func TestClickHouseInsertMigration(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testClickHouseDriver(t)
drv.migrationsTableName = "test_migrations"
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
db := prepTestClickHouseDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
count := 0
@ -297,14 +300,13 @@ func TestClickHouseInsertMigration(t *testing.T) {
}
func TestClickHouseDeleteMigration(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testClickHouseDriver(t)
drv.migrationsTableName = "test_migrations"
u := clickhouseTestURL(t)
db := prepTestClickHouseDB(t, u)
defer mustClose(db)
db := prepTestClickHouseDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
tx, err := db.Begin()
@ -332,42 +334,41 @@ func TestClickHouseDeleteMigration(t *testing.T) {
}
func TestClickHousePing(t *testing.T) {
drv := testClickHouseDriver()
u := clickhouseTestURL(t)
drv := testClickHouseDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// ping database
err = drv.Ping(u)
err = drv.Ping()
require.NoError(t, err)
// ping invalid host should return error
u.Host = "clickhouse:404"
err = drv.Ping(u)
drv.databaseURL.Host = "clickhouse:404"
err = drv.Ping()
require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused")
}
func TestClickHouseQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testClickHouseDriver()
drv := testClickHouseDriver(t)
name := drv.quotedMigrationsTableName()
require.Equal(t, "schema_migrations", name)
})
t.Run("custom name", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("fooMigrations")
drv := testClickHouseDriver(t)
drv.migrationsTableName = "fooMigrations"
name := drv.quotedMigrationsTableName()
require.Equal(t, "fooMigrations", name)
})
t.Run("quoted name", func(t *testing.T) {
drv := testClickHouseDriver()
drv.SetMigrationsTableName("bizarre\"$name")
drv := testClickHouseDriver(t)
drv.migrationsTableName = "bizarre\"$name"
name := drv.quotedMigrationsTableName()
require.Equal(t, `"bizarre""$name"`, name)

View file

@ -1,4 +1,4 @@
package dbmate
package mysql
import (
"bytes"
@ -7,19 +7,31 @@ import (
"net/url"
"strings"
_ "github.com/go-sql-driver/mysql" // mysql driver for database/sql
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
_ "github.com/go-sql-driver/mysql" // database/sql driver
)
func init() {
RegisterDriver(&MySQLDriver{}, "mysql")
dbmate.RegisterDriver(NewDriver, "mysql")
}
// MySQLDriver provides top level database functions
type MySQLDriver struct {
// Driver provides top level database functions
type Driver struct {
migrationsTableName string
databaseURL *url.URL
}
func normalizeMySQLURL(u *url.URL) string {
// NewDriver initializes the driver
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
return &Driver{
migrationsTableName: config.MigrationsTableName,
databaseURL: config.DatabaseURL,
}
}
func connectionString(u *url.URL) string {
query := u.Query()
query.Set("multiStatements", "true")
@ -53,40 +65,40 @@ func normalizeMySQLURL(u *url.URL) string {
return normalizedString
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *MySQLDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection
func (drv *MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("mysql", normalizeMySQLURL(u))
func (drv *Driver) Open() (*sql.DB, error) {
return sql.Open("mysql", connectionString(drv.databaseURL))
}
func (drv *Driver) openRootDB() (*sql.DB, error) {
// clone databaseURL
rootURL, err := url.Parse(drv.databaseURL.String())
if err != nil {
return nil, err
}
func (drv *MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
// connect to no particular database
rootURL := *u
rootURL.Path = "/"
return drv.Open(&rootURL)
return sql.Open("mysql", connectionString(rootURL))
}
func (drv *MySQLDriver) quoteIdentifier(str string) string {
func (drv *Driver) quoteIdentifier(str string) string {
str = strings.Replace(str, "`", "\\`", -1)
return fmt.Sprintf("`%s`", str)
}
// CreateDatabase creates the specified database
func (drv *MySQLDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u)
func (drv *Driver) CreateDatabase() error {
name := dbutil.DatabaseName(drv.databaseURL)
fmt.Printf("Creating: %s\n", name)
db, err := drv.openRootDB(u)
db, err := drv.openRootDB()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
_, err = db.Exec(fmt.Sprintf("create database %s",
drv.quoteIdentifier(name)))
@ -95,15 +107,15 @@ func (drv *MySQLDriver) CreateDatabase(u *url.URL) error {
}
// DropDatabase drops the specified database (if it exists)
func (drv *MySQLDriver) DropDatabase(u *url.URL) error {
name := databaseName(u)
func (drv *Driver) DropDatabase() error {
name := dbutil.DatabaseName(drv.databaseURL)
fmt.Printf("Dropping: %s\n", name)
db, err := drv.openRootDB(u)
db, err := drv.openRootDB()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
drv.quoteIdentifier(name)))
@ -111,37 +123,37 @@ func (drv *MySQLDriver) DropDatabase(u *url.URL) error {
return err
}
func (drv *MySQLDriver) mysqldumpArgs(u *url.URL) []string {
func (drv *Driver) mysqldumpArgs() []string {
// generate CLI arguments
args := []string{"--opt", "--routines", "--no-data",
"--skip-dump-date", "--skip-add-drop-table"}
if hostname := u.Hostname(); hostname != "" {
if hostname := drv.databaseURL.Hostname(); hostname != "" {
args = append(args, "--host="+hostname)
}
if port := u.Port(); port != "" {
if port := drv.databaseURL.Port(); port != "" {
args = append(args, "--port="+port)
}
if username := u.User.Username(); username != "" {
if username := drv.databaseURL.User.Username(); username != "" {
args = append(args, "--user="+username)
}
// mysql recommends against using environment variables to supply password
// https://dev.mysql.com/doc/refman/5.7/en/password-security-user.html
if password, set := u.User.Password(); set {
if password, set := drv.databaseURL.User.Password(); set {
args = append(args, "--password="+password)
}
// add database name
args = append(args, strings.TrimLeft(u.Path, "/"))
args = append(args, dbutil.DatabaseName(drv.databaseURL))
return args
}
func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations
migrations, err := queryColumn(db,
migrations, err := dbutil.QueryColumn(db,
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
if err != nil {
return nil, err
@ -165,8 +177,8 @@ func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
}
// DumpSchema returns the current database schema
func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
schema, err := runCommand("mysqldump", drv.mysqldumpArgs(u)...)
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
schema, err := dbutil.RunCommand("mysqldump", drv.mysqldumpArgs()...)
if err != nil {
return nil, err
}
@ -177,18 +189,18 @@ func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
}
schema = append(schema, migrations...)
return trimLeadingSQLComments(schema)
return dbutil.TrimLeadingSQLComments(schema)
}
// DatabaseExists determines whether the database exists
func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u)
func (drv *Driver) DatabaseExists() (bool, error) {
name := dbutil.DatabaseName(drv.databaseURL)
db, err := drv.openRootDB(u)
db, err := drv.openRootDB()
if err != nil {
return false, err
}
defer mustClose(db)
defer dbutil.MustClose(db)
exists := false
err = db.QueryRow("select true from information_schema.schemata "+
@ -201,7 +213,7 @@ func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema_migrations table
func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
_, err := db.Exec(fmt.Sprintf("create table if not exists %s "+
"(version varchar(255) primary key) character set latin1 collate latin1_bin",
drv.quotedMigrationsTableName()))
@ -211,7 +223,7 @@ func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
@ -221,7 +233,7 @@ func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool
return nil, err
}
defer mustClose(rows)
defer dbutil.MustClose(rows)
migrations := map[string]bool{}
for rows.Next() {
@ -241,7 +253,7 @@ func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool
}
// InsertMigration adds a new migration record
func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error {
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
@ -250,7 +262,7 @@ func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error {
}
// DeleteMigration removes a migration record
func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error {
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
version)
@ -260,16 +272,16 @@ func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error {
// Ping verifies a connection to the database server. It does not verify whether the
// specified database exists.
func (drv *MySQLDriver) Ping(u *url.URL) error {
db, err := drv.openRootDB(u)
func (drv *Driver) Ping() error {
db, err := drv.openRootDB()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
return db.Ping()
}
func (drv *MySQLDriver) quotedMigrationsTableName() string {
func (drv *Driver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}

View file

@ -1,137 +1,146 @@
package dbmate
package mysql
import (
"database/sql"
"net/url"
"os"
"testing"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
"github.com/stretchr/testify/require"
)
func mySQLTestURL(t *testing.T) *url.URL {
u, err := url.Parse("mysql://root:root@mysql/dbmate")
func testMySQLDriver(t *testing.T) *Driver {
u := dbutil.MustParseURL(os.Getenv("MYSQL_TEST_URL"))
drv, err := dbmate.New(u).GetDriver()
require.NoError(t, err)
return u
return drv.(*Driver)
}
func testMySQLDriver() *MySQLDriver {
drv := &MySQLDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestMySQLDB(t *testing.T, u *url.URL) *sql.DB {
drv := testMySQLDriver()
func prepTestMySQLDB(t *testing.T) *sql.DB {
drv := testMySQLDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// connect database
db, err := drv.Open(u)
db, err := drv.Open()
require.NoError(t, err)
return db
}
func TestNormalizeMySQLURLDefaults(t *testing.T) {
func TestGetDriver(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("mysql://"))
drvInterface, err := db.GetDriver()
require.NoError(t, err)
// driver should have URL and default migrations table set
drv, ok := drvInterface.(*Driver)
require.True(t, ok)
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
require.Equal(t, "schema_migrations", drv.migrationsTableName)
}
func TestConnectionString(t *testing.T) {
t.Run("defaults", func(t *testing.T) {
u, err := url.Parse("mysql://host/foo")
require.NoError(t, err)
require.Equal(t, "", u.Port())
s := normalizeMySQLURL(u)
s := connectionString(u)
require.Equal(t, "tcp(host:3306)/foo?multiStatements=true", s)
}
})
func TestNormalizeMySQLURLCustom(t *testing.T) {
t.Run("custom", func(t *testing.T) {
u, err := url.Parse("mysql://bob:secret@host:123/foo?flag=on")
require.NoError(t, err)
require.Equal(t, "123", u.Port())
s := normalizeMySQLURL(u)
s := connectionString(u)
require.Equal(t, "bob:secret@tcp(host:123)/foo?flag=on&multiStatements=true", s)
}
})
func TestNormalizeMySQLURLCustomSpecialChars(t *testing.T) {
t.Run("special chars", func(t *testing.T) {
u, err := url.Parse("mysql://duhfsd7s:123!@123!@@host:123/foo?flag=on")
require.NoError(t, err)
require.Equal(t, "123", u.Port())
s := normalizeMySQLURL(u)
s := connectionString(u)
require.Equal(t, "duhfsd7s:123!@123!@@tcp(host:123)/foo?flag=on&multiStatements=true", s)
}
})
func TestNormalizeMySQLURLSocket(t *testing.T) {
t.Run("socket", func(t *testing.T) {
// test with no user/pass
u, err := url.Parse("mysql:///foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
require.NoError(t, err)
require.Equal(t, "", u.Host)
s := normalizeMySQLURL(u)
s := connectionString(u)
require.Equal(t, "unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
// test with user/pass
u, err = url.Parse("mysql://bob:secret@fakehost/foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
require.NoError(t, err)
s = normalizeMySQLURL(u)
s = connectionString(u)
require.Equal(t, "bob:secret@unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
})
}
func TestMySQLCreateDropDatabase(t *testing.T) {
drv := testMySQLDriver()
u := mySQLTestURL(t)
drv := testMySQLDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// check that database exists and we can connect to it
func() {
db, err := drv.Open(u)
db, err := drv.Open()
require.NoError(t, err)
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
require.NoError(t, err)
}()
// drop the database
err = drv.DropDatabase(u)
err = drv.DropDatabase()
require.NoError(t, err)
// check that database no longer exists
func() {
db, err := drv.Open(u)
db, err := drv.Open()
require.NoError(t, err)
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
require.NotNil(t, err)
require.Regexp(t, "Unknown database 'dbmate'", err.Error())
require.Error(t, err)
require.Regexp(t, "Unknown database 'dbmate_test'", err.Error())
}()
}
func TestMySQLDumpSchema(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
u := mySQLTestURL(t)
drv := testMySQLDriver(t)
drv.migrationsTableName = "test_migrations"
// prepare database
db := prepTestMySQLDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
db := prepTestMySQLDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
// insert migration
@ -141,7 +150,7 @@ func TestMySQLDumpSchema(t *testing.T) {
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
schema, err := drv.DumpSchema(db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE `test_migrations`")
require.Contains(t, string(schema), "\n-- Dump completed\n\n"+
@ -155,8 +164,8 @@ func TestMySQLDumpSchema(t *testing.T) {
"UNLOCK TABLES;\n")
// DumpSchema should return error if command fails
u.Path = "/fakedb"
schema, err = drv.DumpSchema(u, db)
drv.databaseURL.Path = "/fakedb"
schema, err = drv.DumpSchema(db)
require.Nil(t, schema)
require.EqualError(t, err, "mysqldump: [Warning] Using a password "+
"on the command line interface can be insecure.\n"+
@ -165,54 +174,52 @@ func TestMySQLDumpSchema(t *testing.T) {
}
func TestMySQLDatabaseExists(t *testing.T) {
drv := testMySQLDriver()
u := mySQLTestURL(t)
drv := testMySQLDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// DatabaseExists should return false
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, false, exists)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// DatabaseExists should return true
exists, err = drv.DatabaseExists(u)
exists, err = drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, true, exists)
}
func TestMySQLDatabaseExists_Error(t *testing.T) {
drv := testMySQLDriver()
u := mySQLTestURL(t)
u.User = url.User("invalid")
drv := testMySQLDriver(t)
drv.databaseURL.User = url.User("invalid")
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
require.Error(t, err)
require.Regexp(t, "Access denied for user 'invalid'@", err.Error())
require.Equal(t, false, exists)
}
func TestMySQLCreateMigrationsTable(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testMySQLDriver(t)
drv.migrationsTableName = "test_migrations"
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
db := prepTestMySQLDB(t)
defer dbutil.MustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.Error(t, err)
require.Regexp(t, "Table 'dbmate.test_migrations' doesn't exist", err.Error())
require.Regexp(t, "Table 'dbmate_test.test_migrations' doesn't exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// migrations table should exist
@ -220,19 +227,18 @@ func TestMySQLCreateMigrationsTable(t *testing.T) {
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
}
func TestMySQLSelectMigrations(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testMySQLDriver(t)
drv.migrationsTableName = "test_migrations"
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
db := prepTestMySQLDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
_, err = db.Exec(`insert into test_migrations (version)
@ -254,14 +260,13 @@ func TestMySQLSelectMigrations(t *testing.T) {
}
func TestMySQLInsertMigration(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testMySQLDriver(t)
drv.migrationsTableName = "test_migrations"
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
db := prepTestMySQLDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
count := 0
@ -280,14 +285,13 @@ func TestMySQLInsertMigration(t *testing.T) {
}
func TestMySQLDeleteMigration(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testMySQLDriver(t)
drv.migrationsTableName = "test_migrations"
u := mySQLTestURL(t)
db := prepTestMySQLDB(t, u)
defer mustClose(db)
db := prepTestMySQLDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
_, err = db.Exec(`insert into test_migrations (version)
@ -304,34 +308,33 @@ func TestMySQLDeleteMigration(t *testing.T) {
}
func TestMySQLPing(t *testing.T) {
drv := testMySQLDriver()
u := mySQLTestURL(t)
drv := testMySQLDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// ping database
err = drv.Ping(u)
err = drv.Ping()
require.NoError(t, err)
// ping invalid host should return error
u.Host = "mysql:404"
err = drv.Ping(u)
drv.databaseURL.Host = "mysql:404"
err = drv.Ping()
require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused")
}
func TestMySQLQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testMySQLDriver()
drv := testMySQLDriver(t)
name := drv.quotedMigrationsTableName()
require.Equal(t, "`schema_migrations`", name)
})
t.Run("custom name", func(t *testing.T) {
drv := testMySQLDriver()
drv.SetMigrationsTableName("fooMigrations")
drv := testMySQLDriver(t)
drv.migrationsTableName = "fooMigrations"
name := drv.quotedMigrationsTableName()
require.Equal(t, "`fooMigrations`", name)

View file

@ -1,4 +1,4 @@
package dbmate
package postgres
import (
"bytes"
@ -7,21 +7,32 @@ import (
"net/url"
"strings"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
"github.com/lib/pq"
)
func init() {
drv := &PostgresDriver{}
RegisterDriver(drv, "postgres")
RegisterDriver(drv, "postgresql")
dbmate.RegisterDriver(NewDriver, "postgres")
dbmate.RegisterDriver(NewDriver, "postgresql")
}
// PostgresDriver provides top level database functions
type PostgresDriver struct {
// Driver provides top level database functions
type Driver struct {
migrationsTableName string
databaseURL *url.URL
}
func normalizePostgresURL(u *url.URL) *url.URL {
// NewDriver initializes the driver
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
return &Driver{
migrationsTableName: config.MigrationsTableName,
databaseURL: config.DatabaseURL,
}
}
func connectionString(u *url.URL) string {
hostname := u.Hostname()
port := u.Port()
query := u.Query()
@ -56,11 +67,11 @@ func normalizePostgresURL(u *url.URL) *url.URL {
out.Host = fmt.Sprintf("%s:%s", hostname, port)
out.RawQuery = query.Encode()
return out
return out.String()
}
func normalizePostgresURLForDump(u *url.URL) []string {
u = normalizePostgresURL(u)
func connectionArgsForDump(u *url.URL) []string {
u = dbutil.MustParseURL(connectionString(u))
// find schemas from search_path
query := u.Query()
@ -80,34 +91,34 @@ func normalizePostgresURLForDump(u *url.URL) []string {
return out
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *PostgresDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection
func (drv *PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("postgres", normalizePostgresURL(u).String())
func (drv *Driver) Open() (*sql.DB, error) {
return sql.Open("postgres", connectionString(drv.databaseURL))
}
func (drv *Driver) openPostgresDB() (*sql.DB, error) {
// clone databaseURL
postgresURL, err := url.Parse(connectionString(drv.databaseURL))
if err != nil {
return nil, err
}
func (drv *PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
// connect to postgres database
postgresURL := *u
postgresURL.Path = "postgres"
return drv.Open(&postgresURL)
return sql.Open("postgres", postgresURL.String())
}
// CreateDatabase creates the specified database
func (drv *PostgresDriver) CreateDatabase(u *url.URL) error {
name := databaseName(u)
func (drv *Driver) CreateDatabase() error {
name := dbutil.DatabaseName(drv.databaseURL)
fmt.Printf("Creating: %s\n", name)
db, err := drv.openPostgresDB(u)
db, err := drv.openPostgresDB()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
_, err = db.Exec(fmt.Sprintf("create database %s",
pq.QuoteIdentifier(name)))
@ -116,15 +127,15 @@ func (drv *PostgresDriver) CreateDatabase(u *url.URL) error {
}
// DropDatabase drops the specified database (if it exists)
func (drv *PostgresDriver) DropDatabase(u *url.URL) error {
name := databaseName(u)
func (drv *Driver) DropDatabase() error {
name := dbutil.DatabaseName(drv.databaseURL)
fmt.Printf("Dropping: %s\n", name)
db, err := drv.openPostgresDB(u)
db, err := drv.openPostgresDB()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
pq.QuoteIdentifier(name)))
@ -132,14 +143,14 @@ func (drv *PostgresDriver) DropDatabase(u *url.URL) error {
return err
}
func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return nil, err
}
// load applied migrations
migrations, err := queryColumn(db,
migrations, err := dbutil.QueryColumn(db,
"select quote_literal(version) from "+migrationsTable+" order by version asc")
if err != nil {
return nil, err
@ -159,11 +170,11 @@ func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
}
// DumpSchema returns the current database schema
func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
// load schema
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
schema, err := runCommand("pg_dump", args...)
"--no-privileges", "--no-owner"}, connectionArgsForDump(drv.databaseURL)...)
schema, err := dbutil.RunCommand("pg_dump", args...)
if err != nil {
return nil, err
}
@ -174,18 +185,18 @@ func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
}
schema = append(schema, migrations...)
return trimLeadingSQLComments(schema)
return dbutil.TrimLeadingSQLComments(schema)
}
// DatabaseExists determines whether the database exists
func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
name := databaseName(u)
func (drv *Driver) DatabaseExists() (bool, error) {
name := dbutil.DatabaseName(drv.databaseURL)
db, err := drv.openPostgresDB(u)
db, err := drv.openPostgresDB()
if err != nil {
return false, err
}
defer mustClose(db)
defer dbutil.MustClose(db)
exists := false
err = db.QueryRow("select true from pg_database where datname = $1", name).
@ -198,8 +209,8 @@ func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema_migrations table
func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db, u)
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db)
if err != nil {
return err
}
@ -235,7 +246,7 @@ func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return nil, err
@ -250,7 +261,7 @@ func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]b
return nil, err
}
defer mustClose(rows)
defer dbutil.MustClose(rows)
migrations := map[string]bool{}
for rows.Next() {
@ -270,7 +281,7 @@ func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]b
}
// InsertMigration adds a new migration record
func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error {
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return err
@ -282,7 +293,7 @@ func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error
}
// DeleteMigration removes a migration record
func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error {
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
migrationsTable, err := drv.quotedMigrationsTableName(db)
if err != nil {
return err
@ -295,15 +306,15 @@ func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error
// Ping verifies a connection to the database server. It does not verify whether the
// specified database exists.
func (drv *PostgresDriver) Ping(u *url.URL) error {
func (drv *Driver) Ping() error {
// attempt connection to primary database, not "postgres" database
// to support servers with no "postgres" database
// (see https://github.com/amacneil/dbmate/issues/78)
db, err := drv.Open(u)
db, err := drv.Open()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
if err == nil {
@ -319,8 +330,8 @@ func (drv *PostgresDriver) Ping(u *url.URL) error {
return err
}
func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, error) {
schema, name, err := drv.quotedMigrationsTableNameParts(db, nil)
func (drv *Driver) quotedMigrationsTableName(db dbutil.Transaction) (string, error) {
schema, name, err := drv.quotedMigrationsTableNameParts(db)
if err != nil {
return "", err
}
@ -328,7 +339,7 @@ func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, er
return schema + "." + name, nil
}
func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url.URL) (string, string, error) {
func (drv *Driver) quotedMigrationsTableNameParts(db dbutil.Transaction) (string, string, error) {
schema := ""
tableNameParts := strings.Split(drv.migrationsTableName, ".")
if len(tableNameParts) > 1 {
@ -336,9 +347,9 @@ func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url
schema, tableNameParts = tableNameParts[0], tableNameParts[1:]
}
if schema == "" && u != nil {
if schema == "" {
// no schema specified with table name, try URL search path if available
searchPath := strings.Split(u.Query().Get("search_path"), ",")
searchPath := strings.Split(drv.databaseURL.Query().Get("search_path"), ",")
schema = strings.TrimSpace(searchPath[0])
}
@ -346,7 +357,7 @@ func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url
if schema == "" {
// if no URL available, use current schema
// this is a hack because we don't always have the URL context available
schema, err = queryValue(db, "select current_schema()")
schema, err = dbutil.QueryValue(db, "select current_schema()")
if err != nil {
return "", "", err
}
@ -361,7 +372,7 @@ func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url
// use server rather than client to do this to avoid unnecessary quotes
// (which would change schema.sql diff)
tableNameParts = append([]string{schema}, tableNameParts...)
quotedNameParts, err := queryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
quotedNameParts, err := dbutil.QueryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
if err != nil {
return "", "", err
}

View file

@ -1,46 +1,56 @@
package dbmate
package postgres
import (
"database/sql"
"net/url"
"os"
"testing"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
"github.com/stretchr/testify/require"
)
func postgresTestURL(t *testing.T) *url.URL {
u, err := url.Parse("postgres://postgres:postgres@postgres/dbmate?sslmode=disable")
func testPostgresDriver(t *testing.T) *Driver {
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
drv, err := dbmate.New(u).GetDriver()
require.NoError(t, err)
return u
return drv.(*Driver)
}
func testPostgresDriver() *PostgresDriver {
drv := &PostgresDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB {
drv := testPostgresDriver()
func prepTestPostgresDB(t *testing.T) *sql.DB {
drv := testPostgresDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// connect database
db, err := sql.Open("postgres", u.String())
db, err := sql.Open("postgres", drv.databaseURL.String())
require.NoError(t, err)
return db
}
func TestNormalizePostgresURL(t *testing.T) {
func TestGetDriver(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("postgres://"))
drvInterface, err := db.GetDriver()
require.NoError(t, err)
// driver should have URL and default migrations table set
drv, ok := drvInterface.(*Driver)
require.True(t, ok)
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
require.Equal(t, "schema_migrations", drv.migrationsTableName)
}
func TestConnectionString(t *testing.T) {
cases := []struct {
input string
expected string
@ -63,13 +73,13 @@ func TestNormalizePostgresURL(t *testing.T) {
u, err := url.Parse(c.input)
require.NoError(t, err)
actual := normalizePostgresURL(u).String()
actual := connectionString(u)
require.Equal(t, c.expected, actual)
})
}
}
func TestNormalizePostgresURLForDump(t *testing.T) {
func TestConnectionArgsForDump(t *testing.T) {
cases := []struct {
input string
expected []string
@ -87,59 +97,57 @@ func TestNormalizePostgresURLForDump(t *testing.T) {
u, err := url.Parse(c.input)
require.NoError(t, err)
actual := normalizePostgresURLForDump(u)
actual := connectionArgsForDump(u)
require.Equal(t, c.expected, actual)
})
}
}
func TestPostgresCreateDropDatabase(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
drv := testPostgresDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// check that database exists and we can connect to it
func() {
db, err := sql.Open("postgres", u.String())
db, err := sql.Open("postgres", drv.databaseURL.String())
require.NoError(t, err)
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
require.NoError(t, err)
}()
// drop the database
err = drv.DropDatabase(u)
err = drv.DropDatabase()
require.NoError(t, err)
// check that database no longer exists
func() {
db, err := sql.Open("postgres", u.String())
db, err := sql.Open("postgres", drv.databaseURL.String())
require.NoError(t, err)
defer mustClose(db)
defer dbutil.MustClose(db)
err = db.Ping()
require.Error(t, err)
require.Equal(t, "pq: database \"dbmate\" does not exist", err.Error())
require.Equal(t, "pq: database \"dbmate_test\" does not exist", err.Error())
}()
}
func TestPostgresDumpSchema(t *testing.T) {
t.Run("default migrations table", func(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
drv := testPostgresDriver(t)
// prepare database
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
// insert migration
@ -149,7 +157,7 @@ func TestPostgresDumpSchema(t *testing.T) {
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
schema, err := drv.DumpSchema(db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations")
require.Contains(t, string(schema), "\n--\n"+
@ -163,23 +171,21 @@ func TestPostgresDumpSchema(t *testing.T) {
" ('abc2');\n")
// DumpSchema should return error if command fails
u.Path = "/fakedb"
schema, err = drv.DumpSchema(u, db)
drv.databaseURL.Path = "/fakedb"
schema, err = drv.DumpSchema(db)
require.Nil(t, schema)
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
})
t.Run("custom migrations table with schema", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("camelSchema.testMigrations")
u := postgresTestURL(t)
drv := testPostgresDriver(t)
drv.migrationsTableName = "camelSchema.testMigrations"
// prepare database
db := prepTestPostgresDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
// insert migration
@ -189,7 +195,7 @@ func TestPostgresDumpSchema(t *testing.T) {
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
schema, err := drv.DumpSchema(db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE \"camelSchema\".\"testMigrations\"")
require.Contains(t, string(schema), "\n--\n"+
@ -205,34 +211,32 @@ func TestPostgresDumpSchema(t *testing.T) {
}
func TestPostgresDatabaseExists(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
drv := testPostgresDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// DatabaseExists should return false
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, false, exists)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// DatabaseExists should return true
exists, err = drv.DatabaseExists(u)
exists, err = drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, true, exists)
}
func TestPostgresDatabaseExists_Error(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
u.User = url.User("invalid")
drv := testPostgresDriver(t)
drv.databaseURL.User = url.User("invalid")
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
require.Error(t, err)
require.Equal(t, "pq: password authentication failed for user \"invalid\"", err.Error())
require.Equal(t, false, exists)
@ -240,10 +244,9 @@ func TestPostgresDatabaseExists_Error(t *testing.T) {
func TestPostgresCreateMigrationsTable(t *testing.T) {
t.Run("default schema", func(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv := testPostgresDriver(t)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
// migrations table should not exist
count := 0
@ -252,7 +255,7 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// migrations table should exist
@ -260,18 +263,20 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
t.Run("custom search path", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("testMigrations")
drv := testPostgresDriver(t)
drv.migrationsTableName = "testMigrations"
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=camelFoo")
u, err := url.Parse(drv.databaseURL.String() + "&search_path=camelFoo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv.databaseURL = u
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
// delete schema
_, err = db.Exec("drop schema if exists \"camelFoo\"")
@ -291,7 +296,7 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// camelFoo schema should be created, and migrations table should exist only in camelFoo schema
@ -302,18 +307,20 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
t.Run("custom schema", func(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("camelSchema.testMigrations")
drv := testPostgresDriver(t)
drv.migrationsTableName = "camelSchema.testMigrations"
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv.databaseURL = u
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
// delete schemas
_, err = db.Exec("drop schema if exists foo")
@ -328,7 +335,7 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.Equal(t, "pq: relation \"camelSchema.testMigrations\" does not exist", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// camelSchema should be created, and testMigrations table should exist
@ -341,20 +348,19 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
require.Equal(t, "pq: relation \"foo.testMigrations\" does not exist", err.Error())
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
}
func TestPostgresSelectMigrations(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testPostgresDriver(t)
drv.migrationsTableName = "test_migrations"
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
_, err = db.Exec(`insert into public.test_migrations (version)
@ -376,14 +382,13 @@ func TestPostgresSelectMigrations(t *testing.T) {
}
func TestPostgresInsertMigration(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testPostgresDriver(t)
drv.migrationsTableName = "test_migrations"
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
count := 0
@ -402,14 +407,13 @@ func TestPostgresInsertMigration(t *testing.T) {
}
func TestPostgresDeleteMigration(t *testing.T) {
drv := testPostgresDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testPostgresDriver(t)
drv.migrationsTableName = "test_migrations"
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
_, err = db.Exec(`insert into public.test_migrations (version)
@ -426,31 +430,28 @@ func TestPostgresDeleteMigration(t *testing.T) {
}
func TestPostgresPing(t *testing.T) {
drv := testPostgresDriver()
u := postgresTestURL(t)
drv := testPostgresDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// ping database
err = drv.Ping(u)
err = drv.Ping()
require.NoError(t, err)
// ping invalid host should return error
u.Host = "postgres:404"
err = drv.Ping(u)
drv.databaseURL.Host = "postgres:404"
err = drv.Ping()
require.Error(t, err)
require.Contains(t, err.Error(), "connect: connection refused")
}
func TestPostgresQuotedMigrationsTableName(t *testing.T) {
drv := testPostgresDriver()
t.Run("default schema", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv := testPostgresDriver(t)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
@ -458,32 +459,29 @@ func TestPostgresQuotedMigrationsTableName(t *testing.T) {
})
t.Run("custom schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo,bar,public")
drv := testPostgresDriver(t)
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo,bar,public")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv.databaseURL = u
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
// if "foo" schema does not exist, current schema should be "public"
_, err = db.Exec("drop schema if exists foo")
require.NoError(t, err)
_, err = db.Exec("drop schema if exists bar")
require.NoError(t, err)
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.schema_migrations", name)
// if "foo" schema exists, it should be used
_, err = db.Exec("create schema foo")
require.NoError(t, err)
name, err = drv.quotedMigrationsTableName(db)
// should use first schema from search path
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "foo.schema_migrations", name)
})
t.Run("no schema", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv := testPostgresDriver(t)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
// this is an unlikely edge case, but if for some reason there is
// no current schema then we should default to "public"
@ -496,48 +494,54 @@ func TestPostgresQuotedMigrationsTableName(t *testing.T) {
})
t.Run("custom table name", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv := testPostgresDriver(t)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
drv.SetMigrationsTableName("simple_name")
drv.migrationsTableName = "simple_name"
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.simple_name", name)
})
t.Run("custom table name quoted", func(t *testing.T) {
u := postgresTestURL(t)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv := testPostgresDriver(t)
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
// this table name will need quoting
drv.SetMigrationsTableName("camelCase")
drv.migrationsTableName = "camelCase"
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "public.\"camelCase\"", name)
})
t.Run("custom table name with custom schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
drv := testPostgresDriver(t)
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv.databaseURL = u
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
_, err = db.Exec("create schema if not exists foo")
require.NoError(t, err)
drv.SetMigrationsTableName("simple_name")
drv.migrationsTableName = "simple_name"
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "foo.simple_name", name)
})
t.Run("custom table name overrides schema", func(t *testing.T) {
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
drv := testPostgresDriver(t)
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo")
require.NoError(t, err)
db := prepTestPostgresDB(t, u)
defer mustClose(db)
drv.databaseURL = u
db := prepTestPostgresDB(t)
defer dbutil.MustClose(db)
_, err = db.Exec("create schema if not exists foo")
require.NoError(t, err)
@ -545,19 +549,19 @@ func TestPostgresQuotedMigrationsTableName(t *testing.T) {
require.NoError(t, err)
// if schema is specified as part of table name, it should override search_path
drv.SetMigrationsTableName("bar.simple_name")
drv.migrationsTableName = "bar.simple_name"
name, err := drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "bar.simple_name", name)
// schema and table name should be quoted if necessary
drv.SetMigrationsTableName("barName.camelTable")
drv.migrationsTableName = "barName.camelTable"
name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "\"barName\".\"camelTable\"", name)
// more than 2 components is unexpected but we will quote and pass it along anyway
drv.SetMigrationsTableName("whyWould.i.doThis")
drv.migrationsTableName = "whyWould.i.doThis"
name, err = drv.quotedMigrationsTableName(db)
require.NoError(t, err)
require.Equal(t, "\"whyWould\".i.\"doThis\"", name)

View file

@ -1,6 +1,6 @@
// +build cgo
package dbmate
package sqlite
import (
"bytes"
@ -11,58 +11,68 @@ import (
"regexp"
"strings"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
"github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql
_ "github.com/mattn/go-sqlite3" // database/sql driver
)
func init() {
drv := &SQLiteDriver{}
RegisterDriver(drv, "sqlite")
RegisterDriver(drv, "sqlite3")
dbmate.RegisterDriver(NewDriver, "sqlite")
dbmate.RegisterDriver(NewDriver, "sqlite3")
}
// SQLiteDriver provides top level database functions
type SQLiteDriver struct {
// Driver provides top level database functions
type Driver struct {
migrationsTableName string
databaseURL *url.URL
}
func sqlitePath(u *url.URL) string {
// strip one leading slash
// absolute URLs can be specified as sqlite:////tmp/foo.sqlite3
str := regexp.MustCompile("^/").ReplaceAllString(u.Path, "")
// NewDriver initializes the driver
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
return &Driver{
migrationsTableName: config.MigrationsTableName,
databaseURL: config.DatabaseURL,
}
}
// ConnectionString converts a URL into a valid connection string
func ConnectionString(u *url.URL) string {
// duplicate URL and remove scheme
newURL := *u
newURL.Scheme = ""
// trim duplicate leading slashes
str := regexp.MustCompile("^//+").ReplaceAllString(newURL.String(), "/")
return str
}
// SetMigrationsTableName sets the schema migrations table name
func (drv *SQLiteDriver) SetMigrationsTableName(name string) {
drv.migrationsTableName = name
}
// Open creates a new database connection
func (drv *SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("sqlite3", sqlitePath(u))
func (drv *Driver) Open() (*sql.DB, error) {
return sql.Open("sqlite3", ConnectionString(drv.databaseURL))
}
// CreateDatabase creates the specified database
func (drv *SQLiteDriver) CreateDatabase(u *url.URL) error {
fmt.Printf("Creating: %s\n", sqlitePath(u))
func (drv *Driver) CreateDatabase() error {
fmt.Printf("Creating: %s\n", ConnectionString(drv.databaseURL))
db, err := drv.Open(u)
db, err := drv.Open()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
return db.Ping()
}
// DropDatabase drops the specified database (if it exists)
func (drv *SQLiteDriver) DropDatabase(u *url.URL) error {
path := sqlitePath(u)
func (drv *Driver) DropDatabase() error {
path := ConnectionString(drv.databaseURL)
fmt.Printf("Dropping: %s\n", path)
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
if err != nil {
return err
}
@ -73,11 +83,11 @@ func (drv *SQLiteDriver) DropDatabase(u *url.URL) error {
return os.Remove(path)
}
func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
migrationsTable := drv.quotedMigrationsTableName()
// load applied migrations
migrations, err := queryColumn(db,
migrations, err := dbutil.QueryColumn(db,
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
if err != nil {
return nil, err
@ -98,9 +108,9 @@ func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
}
// DumpSchema returns the current database schema
func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
path := sqlitePath(u)
schema, err := runCommand("sqlite3", path, ".schema")
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
path := ConnectionString(drv.databaseURL)
schema, err := dbutil.RunCommand("sqlite3", path, ".schema")
if err != nil {
return nil, err
}
@ -111,12 +121,12 @@ func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
}
schema = append(schema, migrations...)
return trimLeadingSQLComments(schema)
return dbutil.TrimLeadingSQLComments(schema)
}
// DatabaseExists determines whether the database exists
func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
_, err := os.Stat(sqlitePath(u))
func (drv *Driver) DatabaseExists() (bool, error) {
_, err := os.Stat(ConnectionString(drv.databaseURL))
if os.IsNotExist(err) {
return false, nil
}
@ -128,7 +138,7 @@ func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
}
// CreateMigrationsTable creates the schema migrations table
func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
_, err := db.Exec(
fmt.Sprintf("create table if not exists %s ", drv.quotedMigrationsTableName()) +
"(version varchar(255) primary key)")
@ -138,7 +148,7 @@ func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
// SelectMigrations returns a list of applied migrations
// with an optional limit (in descending order)
func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
if limit >= 0 {
query = fmt.Sprintf("%s limit %d", query, limit)
@ -148,7 +158,7 @@ func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]boo
return nil, err
}
defer mustClose(rows)
defer dbutil.MustClose(rows)
migrations := map[string]bool{}
for rows.Next() {
@ -168,7 +178,7 @@ func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]boo
}
// InsertMigration adds a new migration record
func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error {
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
version)
@ -177,7 +187,7 @@ func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error {
}
// DeleteMigration removes a migration record
func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error {
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
_, err := db.Exec(
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
version)
@ -188,23 +198,23 @@ func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error {
// Ping verifies a connection to the database. Due to the way SQLite works, by
// testing whether the database is valid, it will automatically create the database
// if it does not already exist.
func (drv *SQLiteDriver) Ping(u *url.URL) error {
db, err := drv.Open(u)
func (drv *Driver) Ping() error {
db, err := drv.Open()
if err != nil {
return err
}
defer mustClose(db)
defer dbutil.MustClose(db)
return db.Ping()
}
func (drv *SQLiteDriver) quotedMigrationsTableName() string {
func (drv *Driver) quotedMigrationsTableName() string {
return drv.quoteIdentifier(drv.migrationsTableName)
}
// quoteIdentifier quotes a table or column name
// we fall back to lib/pq implementation since both use ansi standard (double quotes)
// and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent
func (drv *SQLiteDriver) quoteIdentifier(s string) string {
func (drv *Driver) quoteIdentifier(s string) string {
return pq.QuoteIdentifier(s)
}

View file

@ -1,59 +1,91 @@
// +build cgo
package dbmate
package sqlite
import (
"database/sql"
"net/url"
"os"
"testing"
"github.com/amacneil/dbmate/pkg/dbmate"
"github.com/amacneil/dbmate/pkg/dbutil"
"github.com/stretchr/testify/require"
)
func sqliteTestURL(t *testing.T) *url.URL {
u, err := url.Parse("sqlite3:////tmp/dbmate.sqlite3")
func testSQLiteDriver(t *testing.T) *Driver {
u := dbutil.MustParseURL(os.Getenv("SQLITE_TEST_URL"))
drv, err := dbmate.New(u).GetDriver()
require.NoError(t, err)
return u
return drv.(*Driver)
}
func testSQLiteDriver() *SQLiteDriver {
drv := &SQLiteDriver{}
drv.SetMigrationsTableName(DefaultMigrationsTableName)
return drv
}
func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
drv := testSQLiteDriver()
func prepTestSQLiteDB(t *testing.T) *sql.DB {
drv := testSQLiteDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// connect database
db, err := drv.Open(u)
db, err := drv.Open()
require.NoError(t, err)
return db
}
func TestGetDriver(t *testing.T) {
db := dbmate.New(dbutil.MustParseURL("sqlite://"))
drvInterface, err := db.GetDriver()
require.NoError(t, err)
// driver should have URL and default migrations table set
drv, ok := drvInterface.(*Driver)
require.True(t, ok)
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
require.Equal(t, "schema_migrations", drv.migrationsTableName)
}
func TestConnectionString(t *testing.T) {
t.Run("relative", func(t *testing.T) {
u := dbutil.MustParseURL("sqlite:foo/bar.sqlite3?mode=ro")
require.Equal(t, "foo/bar.sqlite3?mode=ro", ConnectionString(u))
})
t.Run("absolute", func(t *testing.T) {
u := dbutil.MustParseURL("sqlite:/tmp/foo.sqlite3?mode=ro")
require.Equal(t, "/tmp/foo.sqlite3?mode=ro", ConnectionString(u))
})
t.Run("three slashes", func(t *testing.T) {
// interpreted as absolute path
u := dbutil.MustParseURL("sqlite:///tmp/foo.sqlite3?mode=ro")
require.Equal(t, "/tmp/foo.sqlite3?mode=ro", ConnectionString(u))
})
t.Run("four slashes", func(t *testing.T) {
// interpreted as absolute path
// supported for backwards compatibility
u := dbutil.MustParseURL("sqlite:////tmp/foo.sqlite3?mode=ro")
require.Equal(t, "/tmp/foo.sqlite3?mode=ro", ConnectionString(u))
})
}
func TestSQLiteCreateDropDatabase(t *testing.T) {
drv := testSQLiteDriver()
u := sqliteTestURL(t)
path := sqlitePath(u)
drv := testSQLiteDriver(t)
path := ConnectionString(drv.databaseURL)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// check that database exists
@ -61,7 +93,7 @@ func TestSQLiteCreateDropDatabase(t *testing.T) {
require.NoError(t, err)
// drop the database
err = drv.DropDatabase(u)
err = drv.DropDatabase()
require.NoError(t, err)
// check that database no longer exists
@ -71,15 +103,13 @@ func TestSQLiteCreateDropDatabase(t *testing.T) {
}
func TestSQLiteDumpSchema(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
u := sqliteTestURL(t)
drv := testSQLiteDriver(t)
drv.migrationsTableName = "test_migrations"
// prepare database
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
err := drv.CreateMigrationsTable(u, db)
db := prepTestSQLiteDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
// insert migration
@ -89,7 +119,7 @@ func TestSQLiteDumpSchema(t *testing.T) {
require.NoError(t, err)
// DumpSchema should return schema
schema, err := drv.DumpSchema(u, db)
schema, err := drv.DumpSchema(db)
require.NoError(t, err)
require.Contains(t, string(schema), "CREATE TABLE IF NOT EXISTS \"test_migrations\"")
require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+
@ -98,50 +128,50 @@ func TestSQLiteDumpSchema(t *testing.T) {
" ('abc2');\n")
// DumpSchema should return error if command fails
u.Path = "/."
schema, err = drv.DumpSchema(u, db)
drv.databaseURL = dbutil.MustParseURL(".")
schema, err = drv.DumpSchema(db)
require.Nil(t, schema)
require.Error(t, err)
require.EqualError(t, err, "Error: unable to open database \".\": "+
"unable to open database file")
}
func TestSQLiteDatabaseExists(t *testing.T) {
drv := testSQLiteDriver()
u := sqliteTestURL(t)
drv := testSQLiteDriver(t)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// DatabaseExists should return false
exists, err := drv.DatabaseExists(u)
exists, err := drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, false, exists)
// create database
err = drv.CreateDatabase(u)
err = drv.CreateDatabase()
require.NoError(t, err)
// DatabaseExists should return true
exists, err = drv.DatabaseExists(u)
exists, err = drv.DatabaseExists()
require.NoError(t, err)
require.Equal(t, true, exists)
}
func TestSQLiteCreateMigrationsTable(t *testing.T) {
t.Run("default table", func(t *testing.T) {
drv := testSQLiteDriver()
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
drv := testSQLiteDriver(t)
db := prepTestSQLiteDB(t)
defer dbutil.MustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
require.Error(t, err)
require.Regexp(t, "no such table: schema_migrations", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// migrations table should exist
@ -149,25 +179,25 @@ func TestSQLiteCreateMigrationsTable(t *testing.T) {
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
t.Run("custom table", func(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testSQLiteDriver(t)
drv.migrationsTableName = "test_migrations"
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
db := prepTestSQLiteDB(t)
defer dbutil.MustClose(db)
// migrations table should not exist
count := 0
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
require.Error(t, err)
require.Regexp(t, "no such table: test_migrations", err.Error())
// create table
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
// migrations table should exist
@ -175,20 +205,19 @@ func TestSQLiteCreateMigrationsTable(t *testing.T) {
require.NoError(t, err)
// create table should be idempotent
err = drv.CreateMigrationsTable(u, db)
err = drv.CreateMigrationsTable(db)
require.NoError(t, err)
})
}
func TestSQLiteSelectMigrations(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testSQLiteDriver(t)
drv.migrationsTableName = "test_migrations"
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
db := prepTestSQLiteDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
_, err = db.Exec(`insert into test_migrations (version)
@ -210,14 +239,13 @@ func TestSQLiteSelectMigrations(t *testing.T) {
}
func TestSQLiteInsertMigration(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testSQLiteDriver(t)
drv.migrationsTableName = "test_migrations"
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
db := prepTestSQLiteDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
count := 0
@ -236,14 +264,13 @@ func TestSQLiteInsertMigration(t *testing.T) {
}
func TestSQLiteDeleteMigration(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("test_migrations")
drv := testSQLiteDriver(t)
drv.migrationsTableName = "test_migrations"
u := sqliteTestURL(t)
db := prepTestSQLiteDB(t, u)
defer mustClose(db)
db := prepTestSQLiteDB(t)
defer dbutil.MustClose(db)
err := drv.CreateMigrationsTable(u, db)
err := drv.CreateMigrationsTable(db)
require.NoError(t, err)
_, err = db.Exec(`insert into test_migrations (version)
@ -260,16 +287,15 @@ func TestSQLiteDeleteMigration(t *testing.T) {
}
func TestSQLitePing(t *testing.T) {
drv := testSQLiteDriver()
u := sqliteTestURL(t)
path := sqlitePath(u)
drv := testSQLiteDriver(t)
path := ConnectionString(drv.databaseURL)
// drop any existing database
err := drv.DropDatabase(u)
err := drv.DropDatabase()
require.NoError(t, err)
// ping database
err = drv.Ping(u)
err = drv.Ping()
require.NoError(t, err)
// check that the database was created (sqlite-only behavior)
@ -277,7 +303,7 @@ func TestSQLitePing(t *testing.T) {
require.NoError(t, err)
// drop the database
err = drv.DropDatabase(u)
err = drv.DropDatabase()
require.NoError(t, err)
// create directory where database file is expected
@ -289,20 +315,20 @@ func TestSQLitePing(t *testing.T) {
}()
// ping database should fail
err = drv.Ping(u)
err = drv.Ping()
require.EqualError(t, err, "unable to open database file: is a directory")
}
func TestSQLiteQuotedMigrationsTableName(t *testing.T) {
t.Run("default name", func(t *testing.T) {
drv := testSQLiteDriver()
drv := testSQLiteDriver(t)
name := drv.quotedMigrationsTableName()
require.Equal(t, `"schema_migrations"`, name)
})
t.Run("custom name", func(t *testing.T) {
drv := testSQLiteDriver()
drv.SetMigrationsTableName("fooMigrations")
drv := testSQLiteDriver(t)
drv.migrationsTableName = "fooMigrations"
name := drv.quotedMigrationsTableName()
require.Equal(t, `"fooMigrations"`, name)