mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2025-12-11 23:50:04 +01:00
Refactor drivers into separate packages (#179)
`dbmate` package was starting to get a bit polluted. This PR migrates each driver into a separate package, with clean separation between each. In addition: * Drivers are now initialized with a URL, avoiding the need to pass `*url.URL` to every method * Sqlite supports a cleaner syntax for relative paths * Driver tests now load their test URL from environment variables Public API of `dbmate` package has not changed (no changes to `main` package).
This commit is contained in:
parent
c907c3f5c6
commit
61771e386d
23 changed files with 1195 additions and 1078 deletions
|
|
@ -26,3 +26,7 @@ linters-settings:
|
||||||
local-prefixes: github.com/amacneil/dbmate
|
local-prefixes: github.com/amacneil/dbmate
|
||||||
misspell:
|
misspell:
|
||||||
locale: US
|
locale: US
|
||||||
|
|
||||||
|
issues:
|
||||||
|
include:
|
||||||
|
- EXC0002
|
||||||
|
|
|
||||||
12
Makefile
12
Makefile
|
|
@ -3,14 +3,14 @@ LDFLAGS := -ldflags '-s'
|
||||||
# statically link binaries (to support alpine + scratch containers)
|
# statically link binaries (to support alpine + scratch containers)
|
||||||
STATICLDFLAGS := -ldflags '-s -extldflags "-static"'
|
STATICLDFLAGS := -ldflags '-s -extldflags "-static"'
|
||||||
# avoid building code that is incompatible with static linking
|
# avoid building code that is incompatible with static linking
|
||||||
TAGS := -tags netgo,osusergo,sqlite_omit_load_extension
|
TAGS := -tags netgo,osusergo,sqlite_omit_load_extension,sqlite_json
|
||||||
|
|
||||||
.PHONY: all
|
.PHONY: all
|
||||||
all: build lint test
|
all: build test lint
|
||||||
|
|
||||||
.PHONY: test
|
.PHONY: test
|
||||||
test:
|
test:
|
||||||
go test -v $(TAGS) $(STATICLDFLAGS) ./...
|
go test -p 1 $(TAGS) $(STATICLDFLAGS) ./...
|
||||||
|
|
||||||
.PHONY: fix
|
.PHONY: fix
|
||||||
fix:
|
fix:
|
||||||
|
|
@ -22,9 +22,9 @@ lint:
|
||||||
|
|
||||||
.PHONY: wait
|
.PHONY: wait
|
||||||
wait:
|
wait:
|
||||||
dist/dbmate-linux-amd64 -e MYSQL_URL wait
|
dist/dbmate-linux-amd64 -e CLICKHOUSE_TEST_URL wait
|
||||||
dist/dbmate-linux-amd64 -e POSTGRESQL_URL wait
|
dist/dbmate-linux-amd64 -e MYSQL_TEST_URL wait
|
||||||
dist/dbmate-linux-amd64 -e CLICKHOUSE_URL wait
|
dist/dbmate-linux-amd64 -e POSTGRES_TEST_URL wait
|
||||||
|
|
||||||
.PHONY: clean
|
.PHONY: clean
|
||||||
clean:
|
clean:
|
||||||
|
|
|
||||||
|
|
@ -152,16 +152,16 @@ DATABASE_URL="postgres://username:password@127.0.0.1:5432/database_name?search_p
|
||||||
|
|
||||||
**SQLite**
|
**SQLite**
|
||||||
|
|
||||||
SQLite databases are stored on the filesystem, so you do not need to specify a host. By default, files are relative to the current directory. For example, the following will create a database at `./db/database_name.sqlite3`:
|
SQLite databases are stored on the filesystem, so you do not need to specify a host. By default, files are relative to the current directory. For example, the following will create a database at `./db/database.sqlite3`:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
DATABASE_URL="sqlite:///db/database_name.sqlite3"
|
DATABASE_URL="sqlite:db/database.sqlite3"
|
||||||
```
|
```
|
||||||
|
|
||||||
To specify an absolute path, add an additional forward slash to the path. The following will create a database at `/tmp/database_name.sqlite3`:
|
To specify an absolute path, add a forward slash to the path. The following will create a database at `/tmp/database.sqlite3`:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
DATABASE_URL="sqlite:////tmp/database_name.sqlite3"
|
DATABASE_URL="sqlite:/tmp/database.sqlite3"
|
||||||
```
|
```
|
||||||
|
|
||||||
**ClickHouse**
|
**ClickHouse**
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,10 @@ services:
|
||||||
- postgres
|
- postgres
|
||||||
- clickhouse
|
- clickhouse
|
||||||
environment:
|
environment:
|
||||||
MYSQL_URL: mysql://root:root@mysql/dbmate
|
CLICKHOUSE_TEST_URL: clickhouse://clickhouse:9000?database=dbmate_test
|
||||||
POSTGRESQL_URL: postgres://postgres:postgres@postgres/dbmate?sslmode=disable
|
MYSQL_TEST_URL: mysql://root:root@mysql/dbmate_test
|
||||||
CLICKHOUSE_URL: clickhouse://clickhouse:9000?database=dbmate
|
POSTGRES_TEST_URL: postgres://postgres:postgres@postgres/dbmate_test?sslmode=disable
|
||||||
|
SQLITE_TEST_URL: sqlite3:/tmp/dbmate_test.sqlite3
|
||||||
|
|
||||||
dbmate:
|
dbmate:
|
||||||
build:
|
build:
|
||||||
|
|
|
||||||
1
go.sum
1
go.sum
|
|
@ -1,3 +1,4 @@
|
||||||
|
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
|
||||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||||
github.com/ClickHouse/clickhouse-go v1.4.3 h1:iAFMa2UrQdR5bHJ2/yaSLffZkxpcOYQMCUuKeNXGdqc=
|
github.com/ClickHouse/clickhouse-go v1.4.3 h1:iAFMa2UrQdR5bHJ2/yaSLffZkxpcOYQMCUuKeNXGdqc=
|
||||||
github.com/ClickHouse/clickhouse-go v1.4.3/go.mod h1:EaI/sW7Azgz9UATzd5ZdZHRUhHgv5+JMS9NSr2smCJI=
|
github.com/ClickHouse/clickhouse-go v1.4.3/go.mod h1:EaI/sW7Azgz9UATzd5ZdZHRUhHgv5+JMS9NSr2smCJI=
|
||||||
|
|
|
||||||
4
main.go
4
main.go
|
|
@ -11,6 +11,10 @@ import (
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
"github.com/amacneil/dbmate/pkg/dbmate"
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
_ "github.com/amacneil/dbmate/pkg/driver/clickhouse"
|
||||||
|
_ "github.com/amacneil/dbmate/pkg/driver/mysql"
|
||||||
|
_ "github.com/amacneil/dbmate/pkg/driver/postgres"
|
||||||
|
_ "github.com/amacneil/dbmate/pkg/driver/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
|
||||||
283
pkg/dbmate/db.go
283
pkg/dbmate/db.go
|
|
@ -2,6 +2,7 @@ package dbmate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
@ -10,6 +11,8 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"sort"
|
"sort"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultMigrationsDir specifies default directory to find migration files
|
// DefaultMigrationsDir specifies default directory to find migration files
|
||||||
|
|
@ -43,9 +46,10 @@ type DB struct {
|
||||||
// migrationFileRegexp pattern for valid migration files
|
// migrationFileRegexp pattern for valid migration files
|
||||||
var migrationFileRegexp = regexp.MustCompile(`^\d.*\.sql$`)
|
var migrationFileRegexp = regexp.MustCompile(`^\d.*\.sql$`)
|
||||||
|
|
||||||
type statusResult struct {
|
// StatusResult represents an available migration status
|
||||||
filename string
|
type StatusResult struct {
|
||||||
applied bool
|
Filename string
|
||||||
|
Applied bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New initializes a new dbmate database
|
// New initializes a new dbmate database
|
||||||
|
|
@ -62,16 +66,23 @@ func New(databaseURL *url.URL) *DB {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDriver loads the required database driver
|
// GetDriver initializes the appropriate database driver
|
||||||
func (db *DB) GetDriver() (Driver, error) {
|
func (db *DB) GetDriver() (Driver, error) {
|
||||||
drv, err := getDriver(db.DatabaseURL.Scheme)
|
if db.DatabaseURL == nil || db.DatabaseURL.Scheme == "" {
|
||||||
if err != nil {
|
return nil, errors.New("invalid url")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
drv.SetMigrationsTableName(db.MigrationsTableName)
|
driverFunc := drivers[db.DatabaseURL.Scheme]
|
||||||
|
if driverFunc == nil {
|
||||||
|
return nil, fmt.Errorf("unsupported driver: %s", db.DatabaseURL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
return drv, err
|
config := DriverConfig{
|
||||||
|
DatabaseURL: db.DatabaseURL,
|
||||||
|
MigrationsTableName: db.MigrationsTableName,
|
||||||
|
}
|
||||||
|
|
||||||
|
return driverFunc(config), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait blocks until the database server is available. It does not verify that
|
// Wait blocks until the database server is available. It does not verify that
|
||||||
|
|
@ -82,8 +93,12 @@ func (db *DB) Wait() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return db.wait(drv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) wait(drv Driver) error {
|
||||||
// attempt connection to database server
|
// attempt connection to database server
|
||||||
err = drv.Ping(db.DatabaseURL)
|
err := drv.Ping()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// connection successful
|
// connection successful
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -95,7 +110,7 @@ func (db *DB) Wait() error {
|
||||||
time.Sleep(db.WaitInterval)
|
time.Sleep(db.WaitInterval)
|
||||||
|
|
||||||
// attempt connection to database server
|
// attempt connection to database server
|
||||||
err = drv.Ping(db.DatabaseURL)
|
err = drv.Ping()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// connection successful
|
// connection successful
|
||||||
fmt.Print("\n")
|
fmt.Print("\n")
|
||||||
|
|
@ -110,82 +125,91 @@ func (db *DB) Wait() error {
|
||||||
|
|
||||||
// CreateAndMigrate creates the database (if necessary) and runs migrations
|
// CreateAndMigrate creates the database (if necessary) and runs migrations
|
||||||
func (db *DB) CreateAndMigrate() error {
|
func (db *DB) CreateAndMigrate() error {
|
||||||
if db.WaitBefore {
|
|
||||||
err := db.Wait()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
drv, err := db.GetDriver()
|
drv, err := db.GetDriver()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if db.WaitBefore {
|
||||||
|
err := db.wait(drv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// create database if it does not already exist
|
// create database if it does not already exist
|
||||||
// skip this step if we cannot determine status
|
// skip this step if we cannot determine status
|
||||||
// (e.g. user does not have list database permission)
|
// (e.g. user does not have list database permission)
|
||||||
exists, err := drv.DatabaseExists(db.DatabaseURL)
|
exists, err := drv.DatabaseExists()
|
||||||
if err == nil && !exists {
|
if err == nil && !exists {
|
||||||
if err := drv.CreateDatabase(db.DatabaseURL); err != nil {
|
if err := drv.CreateDatabase(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrate
|
// migrate
|
||||||
return db.Migrate()
|
return db.migrate(drv)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create creates the current database
|
// Create creates the current database
|
||||||
func (db *DB) Create() error {
|
func (db *DB) Create() error {
|
||||||
if db.WaitBefore {
|
|
||||||
err := db.Wait()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
drv, err := db.GetDriver()
|
drv, err := db.GetDriver()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return drv.CreateDatabase(db.DatabaseURL)
|
if db.WaitBefore {
|
||||||
|
err := db.wait(drv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return drv.CreateDatabase()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Drop drops the current database (if it exists)
|
// Drop drops the current database (if it exists)
|
||||||
func (db *DB) Drop() error {
|
func (db *DB) Drop() error {
|
||||||
if db.WaitBefore {
|
|
||||||
err := db.Wait()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
drv, err := db.GetDriver()
|
drv, err := db.GetDriver()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return drv.DropDatabase(db.DatabaseURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DumpSchema writes the current database schema to a file
|
|
||||||
func (db *DB) DumpSchema() error {
|
|
||||||
if db.WaitBefore {
|
if db.WaitBefore {
|
||||||
err := db.Wait()
|
err := db.wait(drv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, sqlDB, err := db.openDatabaseForMigration()
|
return drv.DropDatabase()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DumpSchema writes the current database schema to a file
|
||||||
|
func (db *DB) DumpSchema() error {
|
||||||
|
drv, err := db.GetDriver()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(sqlDB)
|
|
||||||
|
|
||||||
schema, err := drv.DumpSchema(db.DatabaseURL, sqlDB)
|
return db.dumpSchema(drv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) dumpSchema(drv Driver) error {
|
||||||
|
if db.WaitBefore {
|
||||||
|
err := db.wait(drv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.openDatabaseForMigration(drv)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
|
schema, err := drv.DumpSchema(sqlDB)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -201,6 +225,15 @@ func (db *DB) DumpSchema() error {
|
||||||
return ioutil.WriteFile(db.SchemaFile, schema, 0644)
|
return ioutil.WriteFile(db.SchemaFile, schema, 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ensureDir creates a directory if it does not already exist
|
||||||
|
func ensureDir(dir string) error {
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("unable to create directory `%s`", dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
|
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
|
||||||
|
|
||||||
// NewMigration creates a new migration file
|
// NewMigration creates a new migration file
|
||||||
|
|
@ -231,13 +264,13 @@ func (db *DB) NewMigration(name string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer mustClose(file)
|
defer dbutil.MustClose(file)
|
||||||
_, err = file.WriteString(migrationTemplate)
|
_, err = file.WriteString(migrationTemplate)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
|
func doTransaction(sqlDB *sql.DB, txFunc func(dbutil.Transaction) error) error {
|
||||||
tx, err := db.Begin()
|
tx, err := sqlDB.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -253,27 +286,31 @@ func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) openDatabaseForMigration() (Driver, *sql.DB, error) {
|
func (db *DB) openDatabaseForMigration(drv Driver) (*sql.DB, error) {
|
||||||
drv, err := db.GetDriver()
|
sqlDB, err := drv.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
sqlDB, err := drv.Open(db.DatabaseURL)
|
if err := drv.CreateMigrationsTable(sqlDB); err != nil {
|
||||||
if err != nil {
|
dbutil.MustClose(sqlDB)
|
||||||
return nil, nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := drv.CreateMigrationsTable(db.DatabaseURL, sqlDB); err != nil {
|
return sqlDB, nil
|
||||||
mustClose(sqlDB)
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return drv, sqlDB, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate migrates database to the latest version
|
// Migrate migrates database to the latest version
|
||||||
func (db *DB) Migrate() error {
|
func (db *DB) Migrate() error {
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db.migrate(drv)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) migrate(drv Driver) error {
|
||||||
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
|
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -284,17 +321,17 @@ func (db *DB) Migrate() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if db.WaitBefore {
|
if db.WaitBefore {
|
||||||
err := db.Wait()
|
err := db.wait(drv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, sqlDB, err := db.openDatabaseForMigration()
|
sqlDB, err := db.openDatabaseForMigration(drv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(sqlDB)
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
applied, err := drv.SelectMigrations(sqlDB, -1)
|
applied, err := drv.SelectMigrations(sqlDB, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -315,7 +352,7 @@ func (db *DB) Migrate() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
execMigration := func(tx Transaction) error {
|
execMigration := func(tx dbutil.Transaction) error {
|
||||||
// run actual migration
|
// run actual migration
|
||||||
result, err := tx.Exec(up.Contents)
|
result, err := tx.Exec(up.Contents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -343,12 +380,23 @@ func (db *DB) Migrate() error {
|
||||||
|
|
||||||
// automatically update schema file, silence errors
|
// automatically update schema file, silence errors
|
||||||
if db.AutoDumpSchema {
|
if db.AutoDumpSchema {
|
||||||
_ = db.DumpSchema()
|
_ = db.dumpSchema(drv)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func printVerbose(result sql.Result) {
|
||||||
|
lastInsertID, err := result.LastInsertId()
|
||||||
|
if err == nil {
|
||||||
|
fmt.Printf("Last insert ID: %d\n", lastInsertID)
|
||||||
|
}
|
||||||
|
rowsAffected, err := result.RowsAffected()
|
||||||
|
if err == nil {
|
||||||
|
fmt.Printf("Rows affected: %d\n", rowsAffected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
|
func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
|
||||||
files, err := ioutil.ReadDir(dir)
|
files, err := ioutil.ReadDir(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -400,18 +448,23 @@ func migrationVersion(filename string) string {
|
||||||
|
|
||||||
// Rollback rolls back the most recent migration
|
// Rollback rolls back the most recent migration
|
||||||
func (db *DB) Rollback() error {
|
func (db *DB) Rollback() error {
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if db.WaitBefore {
|
if db.WaitBefore {
|
||||||
err := db.Wait()
|
err := db.wait(drv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, sqlDB, err := db.openDatabaseForMigration()
|
sqlDB, err := db.openDatabaseForMigration(drv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(sqlDB)
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
applied, err := drv.SelectMigrations(sqlDB, 1)
|
applied, err := drv.SelectMigrations(sqlDB, 1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -439,7 +492,7 @@ func (db *DB) Rollback() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
execMigration := func(tx Transaction) error {
|
execMigration := func(tx dbutil.Transaction) error {
|
||||||
// rollback migration
|
// rollback migration
|
||||||
result, err := tx.Exec(down.Contents)
|
result, err := tx.Exec(down.Contents)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -466,53 +519,20 @@ func (db *DB) Rollback() error {
|
||||||
|
|
||||||
// automatically update schema file, silence errors
|
// automatically update schema file, silence errors
|
||||||
if db.AutoDumpSchema {
|
if db.AutoDumpSchema {
|
||||||
_ = db.DumpSchema()
|
_ = db.dumpSchema(drv)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkMigrationsStatus(db *DB) ([]statusResult, error) {
|
|
||||||
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(files) == 0 {
|
|
||||||
return nil, fmt.Errorf("no migration files found")
|
|
||||||
}
|
|
||||||
|
|
||||||
drv, sqlDB, err := db.openDatabaseForMigration()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer mustClose(sqlDB)
|
|
||||||
|
|
||||||
applied, err := drv.SelectMigrations(sqlDB, -1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var results []statusResult
|
|
||||||
|
|
||||||
for _, filename := range files {
|
|
||||||
ver := migrationVersion(filename)
|
|
||||||
res := statusResult{filename: filename}
|
|
||||||
if ok := applied[ver]; ok {
|
|
||||||
res.applied = true
|
|
||||||
} else {
|
|
||||||
res.applied = false
|
|
||||||
}
|
|
||||||
|
|
||||||
results = append(results, res)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Status shows the status of all migrations
|
// Status shows the status of all migrations
|
||||||
func (db *DB) Status(quiet bool) (int, error) {
|
func (db *DB) Status(quiet bool) (int, error) {
|
||||||
results, err := checkMigrationsStatus(db)
|
drv, err := db.GetDriver()
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
results, err := db.CheckMigrationsStatus(drv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
|
|
@ -521,11 +541,11 @@ func (db *DB) Status(quiet bool) (int, error) {
|
||||||
var line string
|
var line string
|
||||||
|
|
||||||
for _, res := range results {
|
for _, res := range results {
|
||||||
if res.applied {
|
if res.Applied {
|
||||||
line = fmt.Sprintf("[X] %s", res.filename)
|
line = fmt.Sprintf("[X] %s", res.Filename)
|
||||||
totalApplied++
|
totalApplied++
|
||||||
} else {
|
} else {
|
||||||
line = fmt.Sprintf("[ ] %s", res.filename)
|
line = fmt.Sprintf("[ ] %s", res.Filename)
|
||||||
}
|
}
|
||||||
if !quiet {
|
if !quiet {
|
||||||
fmt.Println(line)
|
fmt.Println(line)
|
||||||
|
|
@ -541,3 +561,42 @@ func (db *DB) Status(quiet bool) (int, error) {
|
||||||
|
|
||||||
return totalPending, nil
|
return totalPending, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckMigrationsStatus returns the status of all available mgirations
|
||||||
|
func (db *DB) CheckMigrationsStatus(drv Driver) ([]StatusResult, error) {
|
||||||
|
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(files) == 0 {
|
||||||
|
return nil, fmt.Errorf("no migration files found")
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlDB, err := db.openDatabaseForMigration(drv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
|
applied, err := drv.SelectMigrations(sqlDB, -1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []StatusResult
|
||||||
|
|
||||||
|
for _, filename := range files {
|
||||||
|
ver := migrationVersion(filename)
|
||||||
|
res := StatusResult{Filename: filename}
|
||||||
|
if ok := applied[ver]; ok {
|
||||||
|
res.Applied = true
|
||||||
|
} else {
|
||||||
|
res.Applied = false
|
||||||
|
}
|
||||||
|
|
||||||
|
results = append(results, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package dbmate
|
package dbmate_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
|
@ -8,13 +8,19 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
_ "github.com/amacneil/dbmate/pkg/driver/mysql"
|
||||||
|
_ "github.com/amacneil/dbmate/pkg/driver/postgres"
|
||||||
|
_ "github.com/amacneil/dbmate/pkg/driver/sqlite"
|
||||||
|
|
||||||
"github.com/kami-zh/go-capturer"
|
"github.com/kami-zh/go-capturer"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testdataDir string
|
var testdataDir string
|
||||||
|
|
||||||
func newTestDB(t *testing.T, u *url.URL) *DB {
|
func newTestDB(t *testing.T, u *url.URL) *dbmate.DB {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// only chdir once, because testdata is relative to current directory
|
// only chdir once, because testdata is relative to current directory
|
||||||
|
|
@ -26,17 +32,16 @@ func newTestDB(t *testing.T, u *url.URL) *DB {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
db := New(u)
|
db := dbmate.New(u)
|
||||||
db.AutoDumpSchema = false
|
db.AutoDumpSchema = false
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
db := dbmate.New(dbutil.MustParseURL("foo:test"))
|
||||||
db := New(u)
|
|
||||||
require.True(t, db.AutoDumpSchema)
|
require.True(t, db.AutoDumpSchema)
|
||||||
require.Equal(t, u.String(), db.DatabaseURL.String())
|
require.Equal(t, "foo:test", db.DatabaseURL.String())
|
||||||
require.Equal(t, "./db/migrations", db.MigrationsDir)
|
require.Equal(t, "./db/migrations", db.MigrationsDir)
|
||||||
require.Equal(t, "schema_migrations", db.MigrationsTableName)
|
require.Equal(t, "schema_migrations", db.MigrationsTableName)
|
||||||
require.Equal(t, "./db/schema.sql", db.SchemaFile)
|
require.Equal(t, "./db/schema.sql", db.SchemaFile)
|
||||||
|
|
@ -46,20 +51,30 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetDriver(t *testing.T) {
|
func TestGetDriver(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
t.Run("missing URL", func(t *testing.T) {
|
||||||
db := New(u)
|
db := dbmate.New(nil)
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
require.Nil(t, drv)
|
||||||
|
require.EqualError(t, err, "invalid url")
|
||||||
|
})
|
||||||
|
|
||||||
drv, err := db.GetDriver()
|
t.Run("missing schema", func(t *testing.T) {
|
||||||
require.NoError(t, err)
|
db := dbmate.New(dbutil.MustParseURL("//hi"))
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
require.Nil(t, drv)
|
||||||
|
require.EqualError(t, err, "invalid url")
|
||||||
|
})
|
||||||
|
|
||||||
// driver should have default migrations table set
|
t.Run("invalid driver", func(t *testing.T) {
|
||||||
pgDrv, ok := drv.(*PostgresDriver)
|
db := dbmate.New(dbutil.MustParseURL("foo://bar"))
|
||||||
require.True(t, ok)
|
drv, err := db.GetDriver()
|
||||||
require.Equal(t, "schema_migrations", pgDrv.migrationsTableName)
|
require.EqualError(t, err, "unsupported driver: foo")
|
||||||
|
require.Nil(t, drv)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWait(t *testing.T) {
|
func TestWait(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
|
||||||
db := newTestDB(t, u)
|
db := newTestDB(t, u)
|
||||||
|
|
||||||
// speed up our retry loop for testing
|
// speed up our retry loop for testing
|
||||||
|
|
@ -83,7 +98,7 @@ func TestWait(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDumpSchema(t *testing.T) {
|
func TestDumpSchema(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
|
||||||
db := newTestDB(t, u)
|
db := newTestDB(t, u)
|
||||||
|
|
||||||
// create custom schema file directory
|
// create custom schema file directory
|
||||||
|
|
@ -120,7 +135,7 @@ func TestDumpSchema(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoDumpSchema(t *testing.T) {
|
func TestAutoDumpSchema(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
|
||||||
db := newTestDB(t, u)
|
db := newTestDB(t, u)
|
||||||
db.AutoDumpSchema = true
|
db.AutoDumpSchema = true
|
||||||
|
|
||||||
|
|
@ -177,7 +192,7 @@ func checkWaitCalled(t *testing.T, u *url.URL, command func() error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testWaitBefore(t *testing.T, verbose bool) {
|
func testWaitBefore(t *testing.T, verbose bool) {
|
||||||
u := postgresTestURL(t)
|
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
|
||||||
db := newTestDB(t, u)
|
db := newTestDB(t, u)
|
||||||
db.Verbose = verbose
|
db.Verbose = verbose
|
||||||
db.WaitBefore = true
|
db.WaitBefore = true
|
||||||
|
|
@ -234,173 +249,173 @@ Rows affected: 0`)
|
||||||
Rows affected: 0`)
|
Rows affected: 0`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testURLs(t *testing.T) []*url.URL {
|
func testURLs() []*url.URL {
|
||||||
return []*url.URL{
|
return []*url.URL{
|
||||||
postgresTestURL(t),
|
dbutil.MustParseURL(os.Getenv("MYSQL_TEST_URL")),
|
||||||
mySQLTestURL(t),
|
dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL")),
|
||||||
sqliteTestURL(t),
|
dbutil.MustParseURL(os.Getenv("SQLITE_TEST_URL")),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testMigrateURL(t *testing.T, u *url.URL) {
|
|
||||||
db := newTestDB(t, u)
|
|
||||||
|
|
||||||
// drop and recreate database
|
|
||||||
err := db.Drop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = db.Create()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// migrate
|
|
||||||
err = db.Migrate()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// verify results
|
|
||||||
sqlDB, err := getDriverOpen(u)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer mustClose(sqlDB)
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
|
||||||
where version = '20151129054053'`).Scan(&count)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
|
|
||||||
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMigrate(t *testing.T) {
|
func TestMigrate(t *testing.T) {
|
||||||
for _, u := range testURLs(t) {
|
for _, u := range testURLs() {
|
||||||
testMigrateURL(t, u)
|
t.Run(u.Scheme, func(t *testing.T) {
|
||||||
|
db := newTestDB(t, u)
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// drop and recreate database
|
||||||
|
err = db.Drop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = db.Create()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// migrate
|
||||||
|
err = db.Migrate()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// verify results
|
||||||
|
sqlDB, err := drv.Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
||||||
|
where version = '20151129054053'`).Scan(&count)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testUpURL(t *testing.T, u *url.URL) {
|
|
||||||
db := newTestDB(t, u)
|
|
||||||
|
|
||||||
// drop database
|
|
||||||
err := db.Drop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// create and migrate
|
|
||||||
err = db.CreateAndMigrate()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// verify results
|
|
||||||
sqlDB, err := getDriverOpen(u)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer mustClose(sqlDB)
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
|
||||||
where version = '20151129054053'`).Scan(&count)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
|
|
||||||
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUp(t *testing.T) {
|
func TestUp(t *testing.T) {
|
||||||
for _, u := range testURLs(t) {
|
for _, u := range testURLs() {
|
||||||
testUpURL(t, u)
|
t.Run(u.Scheme, func(t *testing.T) {
|
||||||
|
db := newTestDB(t, u)
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// drop database
|
||||||
|
err = db.Drop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// create and migrate
|
||||||
|
err = db.CreateAndMigrate()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// verify results
|
||||||
|
sqlDB, err := drv.Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
||||||
|
where version = '20151129054053'`).Scan(&count)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testRollbackURL(t *testing.T, u *url.URL) {
|
|
||||||
db := newTestDB(t, u)
|
|
||||||
|
|
||||||
// drop, recreate, and migrate database
|
|
||||||
err := db.Drop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = db.Create()
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = db.Migrate()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// verify migration
|
|
||||||
sqlDB, err := getDriverOpen(u)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer mustClose(sqlDB)
|
|
||||||
|
|
||||||
count := 0
|
|
||||||
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
|
||||||
where version = '20151129054053'`).Scan(&count)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
|
|
||||||
err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// rollback
|
|
||||||
err = db.Rollback()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// verify rollback
|
|
||||||
err = sqlDB.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 1, count)
|
|
||||||
|
|
||||||
err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
|
|
||||||
require.NotNil(t, err)
|
|
||||||
require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRollback(t *testing.T) {
|
func TestRollback(t *testing.T) {
|
||||||
for _, u := range testURLs(t) {
|
for _, u := range testURLs() {
|
||||||
testRollbackURL(t, u)
|
t.Run(u.Scheme, func(t *testing.T) {
|
||||||
|
db := newTestDB(t, u)
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// drop, recreate, and migrate database
|
||||||
|
err = db.Drop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = db.Create()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = db.Migrate()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// verify migration
|
||||||
|
sqlDB, err := drv.Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
||||||
|
where version = '20151129054053'`).Scan(&count)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// rollback
|
||||||
|
err = db.Rollback()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// verify rollback
|
||||||
|
err = sqlDB.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 1, count)
|
||||||
|
|
||||||
|
err = sqlDB.QueryRow("select count(*) from posts").Scan(&count)
|
||||||
|
require.NotNil(t, err)
|
||||||
|
require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func testStatusURL(t *testing.T, u *url.URL) {
|
|
||||||
db := newTestDB(t, u)
|
|
||||||
|
|
||||||
// drop, recreate, and migrate database
|
|
||||||
err := db.Drop()
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = db.Create()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// verify migration
|
|
||||||
sqlDB, err := getDriverOpen(u)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer mustClose(sqlDB)
|
|
||||||
|
|
||||||
// two pending
|
|
||||||
results, err := checkMigrationsStatus(db)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, results, 2)
|
|
||||||
require.False(t, results[0].applied)
|
|
||||||
require.False(t, results[1].applied)
|
|
||||||
|
|
||||||
// run migrations
|
|
||||||
err = db.Migrate()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// two applied
|
|
||||||
results, err = checkMigrationsStatus(db)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, results, 2)
|
|
||||||
require.True(t, results[0].applied)
|
|
||||||
require.True(t, results[1].applied)
|
|
||||||
|
|
||||||
// rollback last migration
|
|
||||||
err = db.Rollback()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// one applied, one pending
|
|
||||||
results, err = checkMigrationsStatus(db)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, results, 2)
|
|
||||||
require.True(t, results[0].applied)
|
|
||||||
require.False(t, results[1].applied)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatus(t *testing.T) {
|
func TestStatus(t *testing.T) {
|
||||||
for _, u := range testURLs(t) {
|
for _, u := range testURLs() {
|
||||||
testStatusURL(t, u)
|
t.Run(u.Scheme, func(t *testing.T) {
|
||||||
|
db := newTestDB(t, u)
|
||||||
|
drv, err := db.GetDriver()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// drop, recreate, and migrate database
|
||||||
|
err = db.Drop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = db.Create()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// verify migration
|
||||||
|
sqlDB, err := drv.Open()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer dbutil.MustClose(sqlDB)
|
||||||
|
|
||||||
|
// two pending
|
||||||
|
results, err := db.CheckMigrationsStatus(drv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
require.False(t, results[0].Applied)
|
||||||
|
require.False(t, results[1].Applied)
|
||||||
|
|
||||||
|
// run migrations
|
||||||
|
err = db.Migrate()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// two applied
|
||||||
|
results, err = db.CheckMigrationsStatus(drv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
require.True(t, results[0].Applied)
|
||||||
|
require.True(t, results[1].Applied)
|
||||||
|
|
||||||
|
// rollback last migration
|
||||||
|
err = db.Rollback()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// one applied, one pending
|
||||||
|
results, err = db.CheckMigrationsStatus(drv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, results, 2)
|
||||||
|
require.True(t, results[0].Applied)
|
||||||
|
require.False(t, results[1].Applied)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,56 +2,37 @@ package dbmate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Driver provides top level database functions
|
// Driver provides top level database functions
|
||||||
type Driver interface {
|
type Driver interface {
|
||||||
Open(*url.URL) (*sql.DB, error)
|
Open() (*sql.DB, error)
|
||||||
DatabaseExists(*url.URL) (bool, error)
|
DatabaseExists() (bool, error)
|
||||||
CreateDatabase(*url.URL) error
|
CreateDatabase() error
|
||||||
DropDatabase(*url.URL) error
|
DropDatabase() error
|
||||||
DumpSchema(*url.URL, *sql.DB) ([]byte, error)
|
DumpSchema(*sql.DB) ([]byte, error)
|
||||||
SetMigrationsTableName(string)
|
CreateMigrationsTable(*sql.DB) error
|
||||||
CreateMigrationsTable(*url.URL, *sql.DB) error
|
|
||||||
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
||||||
InsertMigration(Transaction, string) error
|
InsertMigration(dbutil.Transaction, string) error
|
||||||
DeleteMigration(Transaction, string) error
|
DeleteMigration(dbutil.Transaction, string) error
|
||||||
Ping(*url.URL) error
|
Ping() error
|
||||||
}
|
}
|
||||||
|
|
||||||
var drivers = map[string]Driver{}
|
// DriverConfig holds configuration passed to driver constructors
|
||||||
|
type DriverConfig struct {
|
||||||
// RegisterDriver registers a driver for a URL scheme
|
DatabaseURL *url.URL
|
||||||
func RegisterDriver(drv Driver, scheme string) {
|
MigrationsTableName string
|
||||||
drivers[scheme] = drv
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transaction can represent a database or open transaction
|
// DriverFunc represents a driver constructor
|
||||||
type Transaction interface {
|
type DriverFunc func(DriverConfig) Driver
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
var drivers = map[string]DriverFunc{}
|
||||||
QueryRow(query string, args ...interface{}) *sql.Row
|
|
||||||
}
|
// RegisterDriver registers a driver constructor for a given URL scheme
|
||||||
|
func RegisterDriver(f DriverFunc, scheme string) {
|
||||||
// getDriver loads a database driver by name
|
drivers[scheme] = f
|
||||||
func getDriver(name string) (Driver, error) {
|
|
||||||
if drv, ok := drivers[name]; ok {
|
|
||||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
|
||||||
|
|
||||||
return drv, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported driver: %s", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
|
|
||||||
func getDriverOpen(u *url.URL) (*sql.DB, error) {
|
|
||||||
drv, err := getDriver(u.Scheme)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return drv.Open(u)
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
package dbmate
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetDriver_Postgres(t *testing.T) {
|
|
||||||
drv, err := getDriver("postgres")
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, ok := drv.(*PostgresDriver)
|
|
||||||
require.Equal(t, true, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetDriver_MySQL(t *testing.T) {
|
|
||||||
drv, err := getDriver("mysql")
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, ok := drv.(*MySQLDriver)
|
|
||||||
require.Equal(t, true, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetDriver_Error(t *testing.T) {
|
|
||||||
drv, err := getDriver("foo")
|
|
||||||
require.EqualError(t, err, "unsupported driver: foo")
|
|
||||||
require.Nil(t, drv)
|
|
||||||
}
|
|
||||||
|
|
@ -1,58 +0,0 @@
|
||||||
package dbmate
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDatabaseName(t *testing.T) {
|
|
||||||
u, err := url.Parse("ignore://localhost/foo?query")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
name := databaseName(u)
|
|
||||||
require.Equal(t, "foo", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDatabaseName_Empty(t *testing.T) {
|
|
||||||
u, err := url.Parse("ignore://localhost")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
name := databaseName(u)
|
|
||||||
require.Equal(t, "", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTrimLeadingSQLComments(t *testing.T) {
|
|
||||||
in := "--\n" +
|
|
||||||
"-- foo\n\n" +
|
|
||||||
"-- bar\n\n" +
|
|
||||||
"real stuff\n" +
|
|
||||||
"-- end\n"
|
|
||||||
out, err := trimLeadingSQLComments([]byte(in))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "real stuff\n-- end\n", string(out))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryColumn(t *testing.T) {
|
|
||||||
u := postgresTestURL(t)
|
|
||||||
db, err := sql.Open("postgres", u.String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
val, err := queryColumn(db, "select concat('foo_', unnest($1::text[]))",
|
|
||||||
pq.Array([]string{"hi", "there"}))
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryValue(t *testing.T) {
|
|
||||||
u := postgresTestURL(t)
|
|
||||||
db, err := sql.Open("postgres", u.String())
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
val, err := queryValue(db, "select $1::int + $2::int", "5", 2)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "7", val)
|
|
||||||
}
|
|
||||||
|
|
@ -1,21 +1,26 @@
|
||||||
package dbmate
|
package dbutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
)
|
)
|
||||||
|
|
||||||
// databaseName returns the database name from a URL
|
// Transaction can represent a database or open transaction
|
||||||
func databaseName(u *url.URL) string {
|
type Transaction interface {
|
||||||
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseName returns the database name from a URL
|
||||||
|
func DatabaseName(u *url.URL) string {
|
||||||
name := u.Path
|
name := u.Path
|
||||||
if len(name) > 0 && name[:1] == "/" {
|
if len(name) > 0 && name[:1] == "/" {
|
||||||
name = name[1:]
|
name = name[1:]
|
||||||
|
|
@ -24,24 +29,15 @@ func databaseName(u *url.URL) string {
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|
||||||
// mustClose ensures a stream is closed
|
// MustClose ensures a stream is closed
|
||||||
func mustClose(c io.Closer) {
|
func MustClose(c io.Closer) {
|
||||||
if err := c.Close(); err != nil {
|
if err := c.Close(); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureDir creates a directory if it does not already exist
|
// RunCommand runs a command and returns the stdout if successful
|
||||||
func ensureDir(dir string) error {
|
func RunCommand(name string, args ...string) ([]byte, error) {
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
|
||||||
return fmt.Errorf("unable to create directory `%s`", dir)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// runCommand runs a command and returns the stdout if successful
|
|
||||||
func runCommand(name string, args ...string) ([]byte, error) {
|
|
||||||
var stdout, stderr bytes.Buffer
|
var stdout, stderr bytes.Buffer
|
||||||
cmd := exec.Command(name, args...)
|
cmd := exec.Command(name, args...)
|
||||||
cmd.Stdout = &stdout
|
cmd.Stdout = &stdout
|
||||||
|
|
@ -61,10 +57,10 @@ func runCommand(name string, args ...string) ([]byte, error) {
|
||||||
return stdout.Bytes(), nil
|
return stdout.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// trimLeadingSQLComments removes sql comments and blank lines from the beginning of text
|
// TrimLeadingSQLComments removes sql comments and blank lines from the beginning of text
|
||||||
// generally when performing sql dumps these contain host-specific information such as
|
// generally when performing sql dumps these contain host-specific information such as
|
||||||
// client/server version numbers
|
// client/server version numbers
|
||||||
func trimLeadingSQLComments(data []byte) ([]byte, error) {
|
func TrimLeadingSQLComments(data []byte) ([]byte, error) {
|
||||||
// create decent size buffer
|
// create decent size buffer
|
||||||
out := bytes.NewBuffer(make([]byte, 0, len(data)))
|
out := bytes.NewBuffer(make([]byte, 0, len(data)))
|
||||||
|
|
||||||
|
|
@ -101,15 +97,15 @@ func trimLeadingSQLComments(data []byte) ([]byte, error) {
|
||||||
return out.Bytes(), nil
|
return out.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryColumn runs a SQL statement and returns a slice of strings
|
// QueryColumn runs a SQL statement and returns a slice of strings
|
||||||
// it is assumed that the statement returns only one column
|
// it is assumed that the statement returns only one column
|
||||||
// e.g. schema_migrations table
|
// e.g. schema_migrations table
|
||||||
func queryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
|
func QueryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
|
||||||
rows, err := db.Query(query, args...)
|
rows, err := db.Query(query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer mustClose(rows)
|
defer MustClose(rows)
|
||||||
|
|
||||||
// read into slice
|
// read into slice
|
||||||
var result []string
|
var result []string
|
||||||
|
|
@ -128,10 +124,10 @@ func queryColumn(db Transaction, query string, args ...interface{}) ([]string, e
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryValue runs a SQL statement and returns a single string
|
// QueryValue runs a SQL statement and returns a single string
|
||||||
// it is assumed that the statement returns only one row and one column
|
// it is assumed that the statement returns only one row and one column
|
||||||
// sql NULL is returned as empty string
|
// sql NULL is returned as empty string
|
||||||
func queryValue(db Transaction, query string, args ...interface{}) (string, error) {
|
func QueryValue(db Transaction, query string, args ...interface{}) (string, error) {
|
||||||
var result sql.NullString
|
var result sql.NullString
|
||||||
err := db.QueryRow(query, args...).Scan(&result)
|
err := db.QueryRow(query, args...).Scan(&result)
|
||||||
if err != nil || !result.Valid {
|
if err != nil || !result.Valid {
|
||||||
|
|
@ -141,13 +137,17 @@ func queryValue(db Transaction, query string, args ...interface{}) (string, erro
|
||||||
return result.String, nil
|
return result.String, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func printVerbose(result sql.Result) {
|
// MustParseURL parses a URL from string, and panics if it fails.
|
||||||
lastInsertID, err := result.LastInsertId()
|
// It is used during testing and in cases where we are parsing a generated URL.
|
||||||
if err == nil {
|
func MustParseURL(s string) *url.URL {
|
||||||
fmt.Printf("Last insert ID: %d\n", lastInsertID)
|
if s == "" {
|
||||||
|
panic("missing url")
|
||||||
}
|
}
|
||||||
rowsAffected, err := result.RowsAffected()
|
|
||||||
if err == nil {
|
u, err := url.Parse(s)
|
||||||
fmt.Printf("Rows affected: %d\n", rowsAffected)
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return u
|
||||||
}
|
}
|
||||||
58
pkg/dbutil/dbutil_test.go
Normal file
58
pkg/dbutil/dbutil_test.go
Normal file
|
|
@ -0,0 +1,58 @@
|
||||||
|
package dbutil_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3" // database/sql driver
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDatabaseName(t *testing.T) {
|
||||||
|
t.Run("valid", func(t *testing.T) {
|
||||||
|
u := dbutil.MustParseURL("foo://host/dbname?query")
|
||||||
|
name := dbutil.DatabaseName(u)
|
||||||
|
require.Equal(t, "dbname", name)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty", func(t *testing.T) {
|
||||||
|
u := dbutil.MustParseURL("foo://host")
|
||||||
|
name := dbutil.DatabaseName(u)
|
||||||
|
require.Equal(t, "", name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrimLeadingSQLComments(t *testing.T) {
|
||||||
|
in := "--\n" +
|
||||||
|
"-- foo\n\n" +
|
||||||
|
"-- bar\n\n" +
|
||||||
|
"real stuff\n" +
|
||||||
|
"-- end\n"
|
||||||
|
out, err := dbutil.TrimLeadingSQLComments([]byte(in))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "real stuff\n-- end\n", string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect to in-memory sqlite database for testing
|
||||||
|
const sqliteMemoryDB = "file:dbutil.sqlite3?mode=memory&cache=shared"
|
||||||
|
|
||||||
|
func TestQueryColumn(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", sqliteMemoryDB)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
val, err := dbutil.QueryColumn(db, "select 'foo_' || val from (select ? as val union select ?)",
|
||||||
|
"hi", "there")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryValue(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", sqliteMemoryDB)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
val, err := dbutil.QueryValue(db, "select $1 + $2", "5", 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "7", val)
|
||||||
|
}
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package dbmate
|
package clickhouse
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
@ -9,19 +9,31 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
"github.com/ClickHouse/clickhouse-go"
|
"github.com/ClickHouse/clickhouse-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RegisterDriver(&ClickHouseDriver{}, "clickhouse")
|
dbmate.RegisterDriver(NewDriver, "clickhouse")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClickHouseDriver provides top level database functions
|
// Driver provides top level database functions
|
||||||
type ClickHouseDriver struct {
|
type Driver struct {
|
||||||
migrationsTableName string
|
migrationsTableName string
|
||||||
|
databaseURL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
|
// NewDriver initializes the driver
|
||||||
|
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
|
||||||
|
return &Driver{
|
||||||
|
migrationsTableName: config.MigrationsTableName,
|
||||||
|
databaseURL: config.DatabaseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionString(initialURL *url.URL) string {
|
||||||
u := *initialURL
|
u := *initialURL
|
||||||
|
|
||||||
u.Scheme = "tcp"
|
u.Scheme = "tcp"
|
||||||
|
|
@ -50,31 +62,31 @@ func normalizeClickHouseURL(initialURL *url.URL) *url.URL {
|
||||||
}
|
}
|
||||||
u.RawQuery = query.Encode()
|
u.RawQuery = query.Encode()
|
||||||
|
|
||||||
return &u
|
return u.String()
|
||||||
}
|
|
||||||
|
|
||||||
// SetMigrationsTableName sets the schema migrations table name
|
|
||||||
func (drv *ClickHouseDriver) SetMigrationsTableName(name string) {
|
|
||||||
drv.migrationsTableName = name
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
func (drv *ClickHouseDriver) Open(u *url.URL) (*sql.DB, error) {
|
func (drv *Driver) Open() (*sql.DB, error) {
|
||||||
return sql.Open("clickhouse", normalizeClickHouseURL(u).String())
|
return sql.Open("clickhouse", connectionString(drv.databaseURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *ClickHouseDriver) openClickHouseDB(u *url.URL) (*sql.DB, error) {
|
func (drv *Driver) openClickHouseDB() (*sql.DB, error) {
|
||||||
|
// clone databaseURL
|
||||||
|
clickhouseURL, err := url.Parse(connectionString(drv.databaseURL))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// connect to clickhouse database
|
// connect to clickhouse database
|
||||||
clickhouseURL := normalizeClickHouseURL(u)
|
|
||||||
values := clickhouseURL.Query()
|
values := clickhouseURL.Query()
|
||||||
values.Set("database", "default")
|
values.Set("database", "default")
|
||||||
clickhouseURL.RawQuery = values.Encode()
|
clickhouseURL.RawQuery = values.Encode()
|
||||||
|
|
||||||
return drv.Open(clickhouseURL)
|
return sql.Open("clickhouse", clickhouseURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *ClickHouseDriver) databaseName(u *url.URL) string {
|
func (drv *Driver) databaseName() string {
|
||||||
name := normalizeClickHouseURL(u).Query().Get("database")
|
name := dbutil.MustParseURL(connectionString(drv.databaseURL)).Query().Get("database")
|
||||||
if name == "" {
|
if name == "" {
|
||||||
name = "default"
|
name = "default"
|
||||||
}
|
}
|
||||||
|
|
@ -83,7 +95,7 @@ func (drv *ClickHouseDriver) databaseName(u *url.URL) string {
|
||||||
|
|
||||||
var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`)
|
var clickhouseValidIdentifier = regexp.MustCompile(`^[a-zA-Z_][0-9a-zA-Z_]*$`)
|
||||||
|
|
||||||
func (drv *ClickHouseDriver) quoteIdentifier(str string) string {
|
func (drv *Driver) quoteIdentifier(str string) string {
|
||||||
if clickhouseValidIdentifier.MatchString(str) {
|
if clickhouseValidIdentifier.MatchString(str) {
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
@ -94,15 +106,15 @@ func (drv *ClickHouseDriver) quoteIdentifier(str string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDatabase creates the specified database
|
// CreateDatabase creates the specified database
|
||||||
func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error {
|
func (drv *Driver) CreateDatabase() error {
|
||||||
name := drv.databaseName(u)
|
name := drv.databaseName()
|
||||||
fmt.Printf("Creating: %s\n", name)
|
fmt.Printf("Creating: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openClickHouseDB(u)
|
db, err := drv.openClickHouseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec("create database " + drv.quoteIdentifier(name))
|
_, err = db.Exec("create database " + drv.quoteIdentifier(name))
|
||||||
|
|
||||||
|
|
@ -110,27 +122,27 @@ func (drv *ClickHouseDriver) CreateDatabase(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropDatabase drops the specified database (if it exists)
|
// DropDatabase drops the specified database (if it exists)
|
||||||
func (drv *ClickHouseDriver) DropDatabase(u *url.URL) error {
|
func (drv *Driver) DropDatabase() error {
|
||||||
name := drv.databaseName(u)
|
name := drv.databaseName()
|
||||||
fmt.Printf("Dropping: %s\n", name)
|
fmt.Printf("Dropping: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openClickHouseDB(u)
|
db, err := drv.openClickHouseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name))
|
_, err = db.Exec("drop database if exists " + drv.quoteIdentifier(name))
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
|
func (drv *Driver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseName string) error {
|
||||||
buf.WriteString("\n--\n-- Database schema\n--\n\n")
|
buf.WriteString("\n--\n-- Database schema\n--\n\n")
|
||||||
|
|
||||||
buf.WriteString("CREATE DATABASE " + drv.quoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
|
buf.WriteString("CREATE DATABASE " + drv.quoteIdentifier(databaseName) + " IF NOT EXISTS;\n\n")
|
||||||
|
|
||||||
tables, err := queryColumn(db, "show tables")
|
tables, err := dbutil.QueryColumn(db, "show tables")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -147,11 +159,11 @@ func (drv *ClickHouseDriver) schemaDump(db *sql.DB, buf *bytes.Buffer, databaseN
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
|
func (drv *Driver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer) error {
|
||||||
migrationsTable := drv.quotedMigrationsTableName()
|
migrationsTable := drv.quotedMigrationsTableName()
|
||||||
|
|
||||||
// load applied migrations
|
// load applied migrations
|
||||||
migrations, err := queryColumn(db,
|
migrations, err := dbutil.QueryColumn(db,
|
||||||
fmt.Sprintf("select version from %s final ", migrationsTable)+
|
fmt.Sprintf("select version from %s final ", migrationsTable)+
|
||||||
"where applied order by version asc",
|
"where applied order by version asc",
|
||||||
)
|
)
|
||||||
|
|
@ -178,11 +190,11 @@ func (drv *ClickHouseDriver) schemaMigrationsDump(db *sql.DB, buf *bytes.Buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DumpSchema returns the current database schema
|
// DumpSchema returns the current database schema
|
||||||
func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
err = drv.schemaDump(db, &buf, drv.databaseName(u))
|
err = drv.schemaDump(db, &buf, drv.databaseName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -196,14 +208,14 @@ func (drv *ClickHouseDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseExists determines whether the database exists
|
// DatabaseExists determines whether the database exists
|
||||||
func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
|
func (drv *Driver) DatabaseExists() (bool, error) {
|
||||||
name := drv.databaseName(u)
|
name := drv.databaseName()
|
||||||
|
|
||||||
db, err := drv.openClickHouseDB(u)
|
db, err := drv.openClickHouseDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
exists := false
|
exists := false
|
||||||
err = db.QueryRow("SELECT 1 FROM system.databases where name = ?", name).
|
err = db.QueryRow("SELECT 1 FROM system.databases where name = ?", name).
|
||||||
|
|
@ -216,7 +228,7 @@ func (drv *ClickHouseDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMigrationsTable creates the schema migrations table
|
// CreateMigrationsTable creates the schema migrations table
|
||||||
func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(fmt.Sprintf(`
|
_, err := db.Exec(fmt.Sprintf(`
|
||||||
create table if not exists %s (
|
create table if not exists %s (
|
||||||
version String,
|
version String,
|
||||||
|
|
@ -232,7 +244,7 @@ func (drv *ClickHouseDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error
|
||||||
|
|
||||||
// SelectMigrations returns a list of applied migrations
|
// SelectMigrations returns a list of applied migrations
|
||||||
// with an optional limit (in descending order)
|
// with an optional limit (in descending order)
|
||||||
func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||||
query := fmt.Sprintf("select version from %s final where applied order by version desc",
|
query := fmt.Sprintf("select version from %s final where applied order by version desc",
|
||||||
drv.quotedMigrationsTableName())
|
drv.quotedMigrationsTableName())
|
||||||
|
|
||||||
|
|
@ -244,7 +256,7 @@ func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer mustClose(rows)
|
defer dbutil.MustClose(rows)
|
||||||
|
|
||||||
migrations := map[string]bool{}
|
migrations := map[string]bool{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
@ -264,7 +276,7 @@ func (drv *ClickHouseDriver) SelectMigrations(db *sql.DB, limit int) (map[string
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertMigration adds a new migration record
|
// InsertMigration adds a new migration record
|
||||||
func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) error {
|
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
||||||
version)
|
version)
|
||||||
|
|
@ -273,7 +285,7 @@ func (drv *ClickHouseDriver) InsertMigration(db Transaction, version string) err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMigration removes a migration record
|
// DeleteMigration removes a migration record
|
||||||
func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) error {
|
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
fmt.Sprintf("insert into %s (version, applied) values (?, ?)",
|
fmt.Sprintf("insert into %s (version, applied) values (?, ?)",
|
||||||
drv.quotedMigrationsTableName()),
|
drv.quotedMigrationsTableName()),
|
||||||
|
|
@ -285,15 +297,15 @@ func (drv *ClickHouseDriver) DeleteMigration(db Transaction, version string) err
|
||||||
|
|
||||||
// Ping verifies a connection to the database server. It does not verify whether the
|
// Ping verifies a connection to the database server. It does not verify whether the
|
||||||
// specified database exists.
|
// specified database exists.
|
||||||
func (drv *ClickHouseDriver) Ping(u *url.URL) error {
|
func (drv *Driver) Ping() error {
|
||||||
// attempt connection to primary database, not "clickhouse" database
|
// attempt connection to primary database, not "clickhouse" database
|
||||||
// to support servers with no "clickhouse" database
|
// to support servers with no "clickhouse" database
|
||||||
// (see https://github.com/amacneil/dbmate/issues/78)
|
// (see https://github.com/amacneil/dbmate/issues/78)
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
@ -309,6 +321,6 @@ func (drv *ClickHouseDriver) Ping(u *url.URL) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *ClickHouseDriver) quotedMigrationsTableName() string {
|
func (drv *Driver) quotedMigrationsTableName() string {
|
||||||
return drv.quoteIdentifier(drv.migrationsTableName)
|
return drv.quoteIdentifier(drv.migrationsTableName)
|
||||||
}
|
}
|
||||||
|
|
@ -1,108 +1,117 @@
|
||||||
package dbmate
|
package clickhouse
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func clickhouseTestURL(t *testing.T) *url.URL {
|
func testClickHouseDriver(t *testing.T) *Driver {
|
||||||
u, err := url.Parse("clickhouse://clickhouse:9000?database=dbmate")
|
u := dbutil.MustParseURL(os.Getenv("CLICKHOUSE_TEST_URL"))
|
||||||
|
drv, err := dbmate.New(u).GetDriver()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return u
|
return drv.(*Driver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testClickHouseDriver() *ClickHouseDriver {
|
func prepTestClickHouseDB(t *testing.T) *sql.DB {
|
||||||
drv := &ClickHouseDriver{}
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
|
||||||
|
|
||||||
return drv
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepTestClickHouseDB(t *testing.T, u *url.URL) *sql.DB {
|
|
||||||
drv := testClickHouseDriver()
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// connect database
|
// connect database
|
||||||
db, err := sql.Open("clickhouse", u.String())
|
db, err := sql.Open("clickhouse", drv.databaseURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeClickHouseURLSimplified(t *testing.T) {
|
func TestGetDriver(t *testing.T) {
|
||||||
u, err := url.Parse("clickhouse://user:pass@host/db")
|
db := dbmate.New(dbutil.MustParseURL("clickhouse://"))
|
||||||
|
drvInterface, err := db.GetDriver()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := normalizeClickHouseURL(u).String()
|
// driver should have URL and default migrations table set
|
||||||
require.Equal(t, "tcp://host:9000?database=db&password=pass&username=user", s)
|
drv, ok := drvInterface.(*Driver)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
|
||||||
|
require.Equal(t, "schema_migrations", drv.migrationsTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeClickHouseURLCanonical(t *testing.T) {
|
func TestConnectionString(t *testing.T) {
|
||||||
u, err := url.Parse("clickhouse://host:9000?database=db&password=pass&username=user")
|
t.Run("simple", func(t *testing.T) {
|
||||||
require.NoError(t, err)
|
u, err := url.Parse("clickhouse://user:pass@host/db")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := normalizeClickHouseURL(u).String()
|
s := connectionString(u)
|
||||||
require.Equal(t, "tcp://host:9000?database=db&password=pass&username=user", s)
|
require.Equal(t, "tcp://host:9000?database=db&password=pass&username=user", s)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("canonical", func(t *testing.T) {
|
||||||
|
u, err := url.Parse("clickhouse://host:9000?database=db&password=pass&username=user")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s := connectionString(u)
|
||||||
|
require.Equal(t, "tcp://host:9000?database=db&password=pass&username=user", s)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseCreateDropDatabase(t *testing.T) {
|
func TestClickHouseCreateDropDatabase(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
u := clickhouseTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database exists and we can connect to it
|
// check that database exists and we can connect to it
|
||||||
func() {
|
func() {
|
||||||
db, err := sql.Open("clickhouse", u.String())
|
db, err := sql.Open("clickhouse", drv.databaseURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// drop the database
|
// drop the database
|
||||||
err = drv.DropDatabase(u)
|
err = drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database no longer exists
|
// check that database no longer exists
|
||||||
func() {
|
func() {
|
||||||
db, err := sql.Open("clickhouse", u.String())
|
db, err := sql.Open("clickhouse", drv.databaseURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
require.EqualError(t, err, "code: 81, message: Database dbmate doesn't exist")
|
require.EqualError(t, err, "code: 81, message: Database dbmate_test doesn't exist")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseDumpSchema(t *testing.T) {
|
func TestClickHouseDumpSchema(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := clickhouseTestURL(t)
|
|
||||||
|
|
||||||
// prepare database
|
// prepare database
|
||||||
db := prepTestClickHouseDB(t, u)
|
db := prepTestClickHouseDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// insert migration
|
// insert migration
|
||||||
|
|
@ -120,9 +129,9 @@ func TestClickHouseDumpSchema(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DumpSchema should return schema
|
// DumpSchema should return schema
|
||||||
schema, err := drv.DumpSchema(u, db)
|
schema, err := drv.DumpSchema(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName(u)+".test_migrations")
|
require.Contains(t, string(schema), "CREATE TABLE "+drv.databaseName()+".test_migrations")
|
||||||
require.Contains(t, string(schema), "--\n"+
|
require.Contains(t, string(schema), "--\n"+
|
||||||
"-- Dbmate schema migrations\n"+
|
"-- Dbmate schema migrations\n"+
|
||||||
"--\n\n"+
|
"--\n\n"+
|
||||||
|
|
@ -131,66 +140,63 @@ func TestClickHouseDumpSchema(t *testing.T) {
|
||||||
" ('abc2');\n")
|
" ('abc2');\n")
|
||||||
|
|
||||||
// DumpSchema should return error if command fails
|
// DumpSchema should return error if command fails
|
||||||
values := u.Query()
|
values := drv.databaseURL.Query()
|
||||||
values.Set("database", "fakedb")
|
values.Set("database", "fakedb")
|
||||||
u.RawQuery = values.Encode()
|
drv.databaseURL.RawQuery = values.Encode()
|
||||||
db, err = sql.Open("clickhouse", u.String())
|
db, err = sql.Open("clickhouse", drv.databaseURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
schema, err = drv.DumpSchema(u, db)
|
schema, err = drv.DumpSchema(db)
|
||||||
require.Nil(t, schema)
|
require.Nil(t, schema)
|
||||||
require.EqualError(t, err, "code: 81, message: Database fakedb doesn't exist")
|
require.EqualError(t, err, "code: 81, message: Database fakedb doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseDatabaseExists(t *testing.T) {
|
func TestClickHouseDatabaseExists(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
u := clickhouseTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return false
|
// DatabaseExists should return false
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return true
|
// DatabaseExists should return true
|
||||||
exists, err = drv.DatabaseExists(u)
|
exists, err = drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, true, exists)
|
require.Equal(t, true, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseDatabaseExists_Error(t *testing.T) {
|
func TestClickHouseDatabaseExists_Error(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
u := clickhouseTestURL(t)
|
values := drv.databaseURL.Query()
|
||||||
values := u.Query()
|
|
||||||
values.Set("username", "invalid")
|
values.Set("username", "invalid")
|
||||||
u.RawQuery = values.Encode()
|
drv.databaseURL.RawQuery = values.Encode()
|
||||||
|
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
require.EqualError(t, err, "code: 192, message: Unknown user invalid")
|
require.EqualError(t, err, "code: 192, message: Unknown user invalid")
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseCreateMigrationsTable(t *testing.T) {
|
func TestClickHouseCreateMigrationsTable(t *testing.T) {
|
||||||
t.Run("default table", func(t *testing.T) {
|
t.Run("default table", func(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
u := clickhouseTestURL(t)
|
db := prepTestClickHouseDB(t)
|
||||||
db := prepTestClickHouseDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
count := 0
|
count := 0
|
||||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||||
require.EqualError(t, err, "code: 60, message: Table dbmate.schema_migrations doesn't exist.")
|
require.EqualError(t, err, "code: 60, message: Table dbmate_test.schema_migrations doesn't exist.")
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// migrations table should exist
|
// migrations table should exist
|
||||||
|
|
@ -198,25 +204,24 @@ func TestClickHouseCreateMigrationsTable(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom table", func(t *testing.T) {
|
t.Run("custom table", func(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName("testMigrations")
|
drv.migrationsTableName = "testMigrations"
|
||||||
|
|
||||||
u := clickhouseTestURL(t)
|
db := prepTestClickHouseDB(t)
|
||||||
db := prepTestClickHouseDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
count := 0
|
count := 0
|
||||||
err := db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
|
err := db.QueryRow("select count(*) from \"testMigrations\"").Scan(&count)
|
||||||
require.EqualError(t, err, "code: 60, message: Table dbmate.testMigrations doesn't exist.")
|
require.EqualError(t, err, "code: 60, message: Table dbmate_test.testMigrations doesn't exist.")
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// migrations table should exist
|
// migrations table should exist
|
||||||
|
|
@ -224,20 +229,19 @@ func TestClickHouseCreateMigrationsTable(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseSelectMigrations(t *testing.T) {
|
func TestClickHouseSelectMigrations(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := clickhouseTestURL(t)
|
db := prepTestClickHouseDB(t)
|
||||||
db := prepTestClickHouseDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
|
|
@ -268,14 +272,13 @@ func TestClickHouseSelectMigrations(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseInsertMigration(t *testing.T) {
|
func TestClickHouseInsertMigration(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := clickhouseTestURL(t)
|
db := prepTestClickHouseDB(t)
|
||||||
db := prepTestClickHouseDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
|
|
@ -297,14 +300,13 @@ func TestClickHouseInsertMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseDeleteMigration(t *testing.T) {
|
func TestClickHouseDeleteMigration(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := clickhouseTestURL(t)
|
db := prepTestClickHouseDB(t)
|
||||||
db := prepTestClickHouseDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tx, err := db.Begin()
|
tx, err := db.Begin()
|
||||||
|
|
@ -332,42 +334,41 @@ func TestClickHouseDeleteMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHousePing(t *testing.T) {
|
func TestClickHousePing(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
u := clickhouseTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ping database
|
// ping database
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ping invalid host should return error
|
// ping invalid host should return error
|
||||||
u.Host = "clickhouse:404"
|
drv.databaseURL.Host = "clickhouse:404"
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "connect: connection refused")
|
require.Contains(t, err.Error(), "connect: connection refused")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestClickHouseQuotedMigrationsTableName(t *testing.T) {
|
func TestClickHouseQuotedMigrationsTableName(t *testing.T) {
|
||||||
t.Run("default name", func(t *testing.T) {
|
t.Run("default name", func(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
name := drv.quotedMigrationsTableName()
|
name := drv.quotedMigrationsTableName()
|
||||||
require.Equal(t, "schema_migrations", name)
|
require.Equal(t, "schema_migrations", name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom name", func(t *testing.T) {
|
t.Run("custom name", func(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName("fooMigrations")
|
drv.migrationsTableName = "fooMigrations"
|
||||||
|
|
||||||
name := drv.quotedMigrationsTableName()
|
name := drv.quotedMigrationsTableName()
|
||||||
require.Equal(t, "fooMigrations", name)
|
require.Equal(t, "fooMigrations", name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("quoted name", func(t *testing.T) {
|
t.Run("quoted name", func(t *testing.T) {
|
||||||
drv := testClickHouseDriver()
|
drv := testClickHouseDriver(t)
|
||||||
drv.SetMigrationsTableName("bizarre\"$name")
|
drv.migrationsTableName = "bizarre\"$name"
|
||||||
|
|
||||||
name := drv.quotedMigrationsTableName()
|
name := drv.quotedMigrationsTableName()
|
||||||
require.Equal(t, `"bizarre""$name"`, name)
|
require.Equal(t, `"bizarre""$name"`, name)
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package dbmate
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
@ -7,19 +7,31 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql" // mysql driver for database/sql
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
|
_ "github.com/go-sql-driver/mysql" // database/sql driver
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
RegisterDriver(&MySQLDriver{}, "mysql")
|
dbmate.RegisterDriver(NewDriver, "mysql")
|
||||||
}
|
}
|
||||||
|
|
||||||
// MySQLDriver provides top level database functions
|
// Driver provides top level database functions
|
||||||
type MySQLDriver struct {
|
type Driver struct {
|
||||||
migrationsTableName string
|
migrationsTableName string
|
||||||
|
databaseURL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeMySQLURL(u *url.URL) string {
|
// NewDriver initializes the driver
|
||||||
|
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
|
||||||
|
return &Driver{
|
||||||
|
migrationsTableName: config.MigrationsTableName,
|
||||||
|
databaseURL: config.DatabaseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionString(u *url.URL) string {
|
||||||
query := u.Query()
|
query := u.Query()
|
||||||
query.Set("multiStatements", "true")
|
query.Set("multiStatements", "true")
|
||||||
|
|
||||||
|
|
@ -53,40 +65,40 @@ func normalizeMySQLURL(u *url.URL) string {
|
||||||
return normalizedString
|
return normalizedString
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMigrationsTableName sets the schema migrations table name
|
|
||||||
func (drv *MySQLDriver) SetMigrationsTableName(name string) {
|
|
||||||
drv.migrationsTableName = name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
func (drv *MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
|
func (drv *Driver) Open() (*sql.DB, error) {
|
||||||
return sql.Open("mysql", normalizeMySQLURL(u))
|
return sql.Open("mysql", connectionString(drv.databaseURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
|
func (drv *Driver) openRootDB() (*sql.DB, error) {
|
||||||
|
// clone databaseURL
|
||||||
|
rootURL, err := url.Parse(drv.databaseURL.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// connect to no particular database
|
// connect to no particular database
|
||||||
rootURL := *u
|
|
||||||
rootURL.Path = "/"
|
rootURL.Path = "/"
|
||||||
|
|
||||||
return drv.Open(&rootURL)
|
return sql.Open("mysql", connectionString(rootURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *MySQLDriver) quoteIdentifier(str string) string {
|
func (drv *Driver) quoteIdentifier(str string) string {
|
||||||
str = strings.Replace(str, "`", "\\`", -1)
|
str = strings.Replace(str, "`", "\\`", -1)
|
||||||
|
|
||||||
return fmt.Sprintf("`%s`", str)
|
return fmt.Sprintf("`%s`", str)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDatabase creates the specified database
|
// CreateDatabase creates the specified database
|
||||||
func (drv *MySQLDriver) CreateDatabase(u *url.URL) error {
|
func (drv *Driver) CreateDatabase() error {
|
||||||
name := databaseName(u)
|
name := dbutil.DatabaseName(drv.databaseURL)
|
||||||
fmt.Printf("Creating: %s\n", name)
|
fmt.Printf("Creating: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openRootDB(u)
|
db, err := drv.openRootDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec(fmt.Sprintf("create database %s",
|
_, err = db.Exec(fmt.Sprintf("create database %s",
|
||||||
drv.quoteIdentifier(name)))
|
drv.quoteIdentifier(name)))
|
||||||
|
|
@ -95,15 +107,15 @@ func (drv *MySQLDriver) CreateDatabase(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropDatabase drops the specified database (if it exists)
|
// DropDatabase drops the specified database (if it exists)
|
||||||
func (drv *MySQLDriver) DropDatabase(u *url.URL) error {
|
func (drv *Driver) DropDatabase() error {
|
||||||
name := databaseName(u)
|
name := dbutil.DatabaseName(drv.databaseURL)
|
||||||
fmt.Printf("Dropping: %s\n", name)
|
fmt.Printf("Dropping: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openRootDB(u)
|
db, err := drv.openRootDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
|
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
|
||||||
drv.quoteIdentifier(name)))
|
drv.quoteIdentifier(name)))
|
||||||
|
|
@ -111,37 +123,37 @@ func (drv *MySQLDriver) DropDatabase(u *url.URL) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *MySQLDriver) mysqldumpArgs(u *url.URL) []string {
|
func (drv *Driver) mysqldumpArgs() []string {
|
||||||
// generate CLI arguments
|
// generate CLI arguments
|
||||||
args := []string{"--opt", "--routines", "--no-data",
|
args := []string{"--opt", "--routines", "--no-data",
|
||||||
"--skip-dump-date", "--skip-add-drop-table"}
|
"--skip-dump-date", "--skip-add-drop-table"}
|
||||||
|
|
||||||
if hostname := u.Hostname(); hostname != "" {
|
if hostname := drv.databaseURL.Hostname(); hostname != "" {
|
||||||
args = append(args, "--host="+hostname)
|
args = append(args, "--host="+hostname)
|
||||||
}
|
}
|
||||||
if port := u.Port(); port != "" {
|
if port := drv.databaseURL.Port(); port != "" {
|
||||||
args = append(args, "--port="+port)
|
args = append(args, "--port="+port)
|
||||||
}
|
}
|
||||||
if username := u.User.Username(); username != "" {
|
if username := drv.databaseURL.User.Username(); username != "" {
|
||||||
args = append(args, "--user="+username)
|
args = append(args, "--user="+username)
|
||||||
}
|
}
|
||||||
// mysql recommends against using environment variables to supply password
|
// mysql recommends against using environment variables to supply password
|
||||||
// https://dev.mysql.com/doc/refman/5.7/en/password-security-user.html
|
// https://dev.mysql.com/doc/refman/5.7/en/password-security-user.html
|
||||||
if password, set := u.User.Password(); set {
|
if password, set := drv.databaseURL.User.Password(); set {
|
||||||
args = append(args, "--password="+password)
|
args = append(args, "--password="+password)
|
||||||
}
|
}
|
||||||
|
|
||||||
// add database name
|
// add database name
|
||||||
args = append(args, strings.TrimLeft(u.Path, "/"))
|
args = append(args, dbutil.DatabaseName(drv.databaseURL))
|
||||||
|
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||||
migrationsTable := drv.quotedMigrationsTableName()
|
migrationsTable := drv.quotedMigrationsTableName()
|
||||||
|
|
||||||
// load applied migrations
|
// load applied migrations
|
||||||
migrations, err := queryColumn(db,
|
migrations, err := dbutil.QueryColumn(db,
|
||||||
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
|
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -165,8 +177,8 @@ func (drv *MySQLDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DumpSchema returns the current database schema
|
// DumpSchema returns the current database schema
|
||||||
func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
|
||||||
schema, err := runCommand("mysqldump", drv.mysqldumpArgs(u)...)
|
schema, err := dbutil.RunCommand("mysqldump", drv.mysqldumpArgs()...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -177,18 +189,18 @@ func (drv *MySQLDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
schema = append(schema, migrations...)
|
schema = append(schema, migrations...)
|
||||||
return trimLeadingSQLComments(schema)
|
return dbutil.TrimLeadingSQLComments(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseExists determines whether the database exists
|
// DatabaseExists determines whether the database exists
|
||||||
func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
func (drv *Driver) DatabaseExists() (bool, error) {
|
||||||
name := databaseName(u)
|
name := dbutil.DatabaseName(drv.databaseURL)
|
||||||
|
|
||||||
db, err := drv.openRootDB(u)
|
db, err := drv.openRootDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
exists := false
|
exists := false
|
||||||
err = db.QueryRow("select true from information_schema.schemata "+
|
err = db.QueryRow("select true from information_schema.schemata "+
|
||||||
|
|
@ -201,7 +213,7 @@ func (drv *MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMigrationsTable creates the schema_migrations table
|
// CreateMigrationsTable creates the schema_migrations table
|
||||||
func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(fmt.Sprintf("create table if not exists %s "+
|
_, err := db.Exec(fmt.Sprintf("create table if not exists %s "+
|
||||||
"(version varchar(255) primary key) character set latin1 collate latin1_bin",
|
"(version varchar(255) primary key) character set latin1 collate latin1_bin",
|
||||||
drv.quotedMigrationsTableName()))
|
drv.quotedMigrationsTableName()))
|
||||||
|
|
@ -211,7 +223,7 @@ func (drv *MySQLDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||||
|
|
||||||
// SelectMigrations returns a list of applied migrations
|
// SelectMigrations returns a list of applied migrations
|
||||||
// with an optional limit (in descending order)
|
// with an optional limit (in descending order)
|
||||||
func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||||
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
|
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
|
||||||
if limit >= 0 {
|
if limit >= 0 {
|
||||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||||
|
|
@ -221,7 +233,7 @@ func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer mustClose(rows)
|
defer dbutil.MustClose(rows)
|
||||||
|
|
||||||
migrations := map[string]bool{}
|
migrations := map[string]bool{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
@ -241,7 +253,7 @@ func (drv *MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertMigration adds a new migration record
|
// InsertMigration adds a new migration record
|
||||||
func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error {
|
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
||||||
version)
|
version)
|
||||||
|
|
@ -250,7 +262,7 @@ func (drv *MySQLDriver) InsertMigration(db Transaction, version string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMigration removes a migration record
|
// DeleteMigration removes a migration record
|
||||||
func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error {
|
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
|
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
|
||||||
version)
|
version)
|
||||||
|
|
@ -260,16 +272,16 @@ func (drv *MySQLDriver) DeleteMigration(db Transaction, version string) error {
|
||||||
|
|
||||||
// Ping verifies a connection to the database server. It does not verify whether the
|
// Ping verifies a connection to the database server. It does not verify whether the
|
||||||
// specified database exists.
|
// specified database exists.
|
||||||
func (drv *MySQLDriver) Ping(u *url.URL) error {
|
func (drv *Driver) Ping() error {
|
||||||
db, err := drv.openRootDB(u)
|
db, err := drv.openRootDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
return db.Ping()
|
return db.Ping()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *MySQLDriver) quotedMigrationsTableName() string {
|
func (drv *Driver) quotedMigrationsTableName() string {
|
||||||
return drv.quoteIdentifier(drv.migrationsTableName)
|
return drv.quoteIdentifier(drv.migrationsTableName)
|
||||||
}
|
}
|
||||||
|
|
@ -1,137 +1,146 @@
|
||||||
package dbmate
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func mySQLTestURL(t *testing.T) *url.URL {
|
func testMySQLDriver(t *testing.T) *Driver {
|
||||||
u, err := url.Parse("mysql://root:root@mysql/dbmate")
|
u := dbutil.MustParseURL(os.Getenv("MYSQL_TEST_URL"))
|
||||||
|
drv, err := dbmate.New(u).GetDriver()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return u
|
return drv.(*Driver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testMySQLDriver() *MySQLDriver {
|
func prepTestMySQLDB(t *testing.T) *sql.DB {
|
||||||
drv := &MySQLDriver{}
|
drv := testMySQLDriver(t)
|
||||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
|
||||||
|
|
||||||
return drv
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepTestMySQLDB(t *testing.T, u *url.URL) *sql.DB {
|
|
||||||
drv := testMySQLDriver()
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// connect database
|
// connect database
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeMySQLURLDefaults(t *testing.T) {
|
func TestGetDriver(t *testing.T) {
|
||||||
u, err := url.Parse("mysql://host/foo")
|
db := dbmate.New(dbutil.MustParseURL("mysql://"))
|
||||||
|
drvInterface, err := db.GetDriver()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "", u.Port())
|
|
||||||
|
|
||||||
s := normalizeMySQLURL(u)
|
// driver should have URL and default migrations table set
|
||||||
require.Equal(t, "tcp(host:3306)/foo?multiStatements=true", s)
|
drv, ok := drvInterface.(*Driver)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
|
||||||
|
require.Equal(t, "schema_migrations", drv.migrationsTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeMySQLURLCustom(t *testing.T) {
|
func TestConnectionString(t *testing.T) {
|
||||||
u, err := url.Parse("mysql://bob:secret@host:123/foo?flag=on")
|
t.Run("defaults", func(t *testing.T) {
|
||||||
require.NoError(t, err)
|
u, err := url.Parse("mysql://host/foo")
|
||||||
require.Equal(t, "123", u.Port())
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "", u.Port())
|
||||||
|
|
||||||
s := normalizeMySQLURL(u)
|
s := connectionString(u)
|
||||||
require.Equal(t, "bob:secret@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
require.Equal(t, "tcp(host:3306)/foo?multiStatements=true", s)
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestNormalizeMySQLURLCustomSpecialChars(t *testing.T) {
|
t.Run("custom", func(t *testing.T) {
|
||||||
u, err := url.Parse("mysql://duhfsd7s:123!@123!@@host:123/foo?flag=on")
|
u, err := url.Parse("mysql://bob:secret@host:123/foo?flag=on")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "123", u.Port())
|
require.Equal(t, "123", u.Port())
|
||||||
|
|
||||||
s := normalizeMySQLURL(u)
|
s := connectionString(u)
|
||||||
require.Equal(t, "duhfsd7s:123!@123!@@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
require.Equal(t, "bob:secret@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
||||||
}
|
})
|
||||||
|
|
||||||
func TestNormalizeMySQLURLSocket(t *testing.T) {
|
t.Run("special chars", func(t *testing.T) {
|
||||||
// test with no user/pass
|
u, err := url.Parse("mysql://duhfsd7s:123!@123!@@host:123/foo?flag=on")
|
||||||
u, err := url.Parse("mysql:///foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
require.Equal(t, "123", u.Port())
|
||||||
require.Equal(t, "", u.Host)
|
|
||||||
|
|
||||||
s := normalizeMySQLURL(u)
|
s := connectionString(u)
|
||||||
require.Equal(t, "unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
|
require.Equal(t, "duhfsd7s:123!@123!@@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
||||||
|
})
|
||||||
|
|
||||||
// test with user/pass
|
t.Run("socket", func(t *testing.T) {
|
||||||
u, err = url.Parse("mysql://bob:secret@fakehost/foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
|
// test with no user/pass
|
||||||
require.NoError(t, err)
|
u, err := url.Parse("mysql:///foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "", u.Host)
|
||||||
|
|
||||||
s = normalizeMySQLURL(u)
|
s := connectionString(u)
|
||||||
require.Equal(t, "bob:secret@unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
|
require.Equal(t, "unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
|
||||||
|
|
||||||
|
// test with user/pass
|
||||||
|
u, err = url.Parse("mysql://bob:secret@fakehost/foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s = connectionString(u)
|
||||||
|
require.Equal(t, "bob:secret@unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLCreateDropDatabase(t *testing.T) {
|
func TestMySQLCreateDropDatabase(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
u := mySQLTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database exists and we can connect to it
|
// check that database exists and we can connect to it
|
||||||
func() {
|
func() {
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// drop the database
|
// drop the database
|
||||||
err = drv.DropDatabase(u)
|
err = drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database no longer exists
|
// check that database no longer exists
|
||||||
func() {
|
func() {
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
require.NotNil(t, err)
|
require.Error(t, err)
|
||||||
require.Regexp(t, "Unknown database 'dbmate'", err.Error())
|
require.Regexp(t, "Unknown database 'dbmate_test'", err.Error())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLDumpSchema(t *testing.T) {
|
func TestMySQLDumpSchema(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := mySQLTestURL(t)
|
|
||||||
|
|
||||||
// prepare database
|
// prepare database
|
||||||
db := prepTestMySQLDB(t, u)
|
db := prepTestMySQLDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// insert migration
|
// insert migration
|
||||||
|
|
@ -141,7 +150,7 @@ func TestMySQLDumpSchema(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DumpSchema should return schema
|
// DumpSchema should return schema
|
||||||
schema, err := drv.DumpSchema(u, db)
|
schema, err := drv.DumpSchema(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, string(schema), "CREATE TABLE `test_migrations`")
|
require.Contains(t, string(schema), "CREATE TABLE `test_migrations`")
|
||||||
require.Contains(t, string(schema), "\n-- Dump completed\n\n"+
|
require.Contains(t, string(schema), "\n-- Dump completed\n\n"+
|
||||||
|
|
@ -155,8 +164,8 @@ func TestMySQLDumpSchema(t *testing.T) {
|
||||||
"UNLOCK TABLES;\n")
|
"UNLOCK TABLES;\n")
|
||||||
|
|
||||||
// DumpSchema should return error if command fails
|
// DumpSchema should return error if command fails
|
||||||
u.Path = "/fakedb"
|
drv.databaseURL.Path = "/fakedb"
|
||||||
schema, err = drv.DumpSchema(u, db)
|
schema, err = drv.DumpSchema(db)
|
||||||
require.Nil(t, schema)
|
require.Nil(t, schema)
|
||||||
require.EqualError(t, err, "mysqldump: [Warning] Using a password "+
|
require.EqualError(t, err, "mysqldump: [Warning] Using a password "+
|
||||||
"on the command line interface can be insecure.\n"+
|
"on the command line interface can be insecure.\n"+
|
||||||
|
|
@ -165,54 +174,52 @@ func TestMySQLDumpSchema(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLDatabaseExists(t *testing.T) {
|
func TestMySQLDatabaseExists(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
u := mySQLTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return false
|
// DatabaseExists should return false
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return true
|
// DatabaseExists should return true
|
||||||
exists, err = drv.DatabaseExists(u)
|
exists, err = drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, true, exists)
|
require.Equal(t, true, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLDatabaseExists_Error(t *testing.T) {
|
func TestMySQLDatabaseExists_Error(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
u := mySQLTestURL(t)
|
drv.databaseURL.User = url.User("invalid")
|
||||||
u.User = url.User("invalid")
|
|
||||||
|
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
|
require.Error(t, err)
|
||||||
require.Regexp(t, "Access denied for user 'invalid'@", err.Error())
|
require.Regexp(t, "Access denied for user 'invalid'@", err.Error())
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLCreateMigrationsTable(t *testing.T) {
|
func TestMySQLCreateMigrationsTable(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := mySQLTestURL(t)
|
db := prepTestMySQLDB(t)
|
||||||
db := prepTestMySQLDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
count := 0
|
count := 0
|
||||||
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Regexp(t, "Table 'dbmate.test_migrations' doesn't exist", err.Error())
|
require.Regexp(t, "Table 'dbmate_test.test_migrations' doesn't exist", err.Error())
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// migrations table should exist
|
// migrations table should exist
|
||||||
|
|
@ -220,19 +227,18 @@ func TestMySQLCreateMigrationsTable(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLSelectMigrations(t *testing.T) {
|
func TestMySQLSelectMigrations(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := mySQLTestURL(t)
|
db := prepTestMySQLDB(t)
|
||||||
db := prepTestMySQLDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Exec(`insert into test_migrations (version)
|
_, err = db.Exec(`insert into test_migrations (version)
|
||||||
|
|
@ -254,14 +260,13 @@ func TestMySQLSelectMigrations(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLInsertMigration(t *testing.T) {
|
func TestMySQLInsertMigration(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := mySQLTestURL(t)
|
db := prepTestMySQLDB(t)
|
||||||
db := prepTestMySQLDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
|
|
@ -280,14 +285,13 @@ func TestMySQLInsertMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLDeleteMigration(t *testing.T) {
|
func TestMySQLDeleteMigration(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := mySQLTestURL(t)
|
db := prepTestMySQLDB(t)
|
||||||
db := prepTestMySQLDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Exec(`insert into test_migrations (version)
|
_, err = db.Exec(`insert into test_migrations (version)
|
||||||
|
|
@ -304,34 +308,33 @@ func TestMySQLDeleteMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLPing(t *testing.T) {
|
func TestMySQLPing(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
u := mySQLTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ping database
|
// ping database
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ping invalid host should return error
|
// ping invalid host should return error
|
||||||
u.Host = "mysql:404"
|
drv.databaseURL.Host = "mysql:404"
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "connect: connection refused")
|
require.Contains(t, err.Error(), "connect: connection refused")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMySQLQuotedMigrationsTableName(t *testing.T) {
|
func TestMySQLQuotedMigrationsTableName(t *testing.T) {
|
||||||
t.Run("default name", func(t *testing.T) {
|
t.Run("default name", func(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
name := drv.quotedMigrationsTableName()
|
name := drv.quotedMigrationsTableName()
|
||||||
require.Equal(t, "`schema_migrations`", name)
|
require.Equal(t, "`schema_migrations`", name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom name", func(t *testing.T) {
|
t.Run("custom name", func(t *testing.T) {
|
||||||
drv := testMySQLDriver()
|
drv := testMySQLDriver(t)
|
||||||
drv.SetMigrationsTableName("fooMigrations")
|
drv.migrationsTableName = "fooMigrations"
|
||||||
|
|
||||||
name := drv.quotedMigrationsTableName()
|
name := drv.quotedMigrationsTableName()
|
||||||
require.Equal(t, "`fooMigrations`", name)
|
require.Equal(t, "`fooMigrations`", name)
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
package dbmate
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
@ -7,21 +7,32 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
drv := &PostgresDriver{}
|
dbmate.RegisterDriver(NewDriver, "postgres")
|
||||||
RegisterDriver(drv, "postgres")
|
dbmate.RegisterDriver(NewDriver, "postgresql")
|
||||||
RegisterDriver(drv, "postgresql")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PostgresDriver provides top level database functions
|
// Driver provides top level database functions
|
||||||
type PostgresDriver struct {
|
type Driver struct {
|
||||||
migrationsTableName string
|
migrationsTableName string
|
||||||
|
databaseURL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizePostgresURL(u *url.URL) *url.URL {
|
// NewDriver initializes the driver
|
||||||
|
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
|
||||||
|
return &Driver{
|
||||||
|
migrationsTableName: config.MigrationsTableName,
|
||||||
|
databaseURL: config.DatabaseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectionString(u *url.URL) string {
|
||||||
hostname := u.Hostname()
|
hostname := u.Hostname()
|
||||||
port := u.Port()
|
port := u.Port()
|
||||||
query := u.Query()
|
query := u.Query()
|
||||||
|
|
@ -56,11 +67,11 @@ func normalizePostgresURL(u *url.URL) *url.URL {
|
||||||
out.Host = fmt.Sprintf("%s:%s", hostname, port)
|
out.Host = fmt.Sprintf("%s:%s", hostname, port)
|
||||||
out.RawQuery = query.Encode()
|
out.RawQuery = query.Encode()
|
||||||
|
|
||||||
return out
|
return out.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizePostgresURLForDump(u *url.URL) []string {
|
func connectionArgsForDump(u *url.URL) []string {
|
||||||
u = normalizePostgresURL(u)
|
u = dbutil.MustParseURL(connectionString(u))
|
||||||
|
|
||||||
// find schemas from search_path
|
// find schemas from search_path
|
||||||
query := u.Query()
|
query := u.Query()
|
||||||
|
|
@ -80,34 +91,34 @@ func normalizePostgresURLForDump(u *url.URL) []string {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMigrationsTableName sets the schema migrations table name
|
|
||||||
func (drv *PostgresDriver) SetMigrationsTableName(name string) {
|
|
||||||
drv.migrationsTableName = name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
func (drv *PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
func (drv *Driver) Open() (*sql.DB, error) {
|
||||||
return sql.Open("postgres", normalizePostgresURL(u).String())
|
return sql.Open("postgres", connectionString(drv.databaseURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
func (drv *Driver) openPostgresDB() (*sql.DB, error) {
|
||||||
|
// clone databaseURL
|
||||||
|
postgresURL, err := url.Parse(connectionString(drv.databaseURL))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// connect to postgres database
|
// connect to postgres database
|
||||||
postgresURL := *u
|
|
||||||
postgresURL.Path = "postgres"
|
postgresURL.Path = "postgres"
|
||||||
|
|
||||||
return drv.Open(&postgresURL)
|
return sql.Open("postgres", postgresURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDatabase creates the specified database
|
// CreateDatabase creates the specified database
|
||||||
func (drv *PostgresDriver) CreateDatabase(u *url.URL) error {
|
func (drv *Driver) CreateDatabase() error {
|
||||||
name := databaseName(u)
|
name := dbutil.DatabaseName(drv.databaseURL)
|
||||||
fmt.Printf("Creating: %s\n", name)
|
fmt.Printf("Creating: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openPostgresDB(u)
|
db, err := drv.openPostgresDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec(fmt.Sprintf("create database %s",
|
_, err = db.Exec(fmt.Sprintf("create database %s",
|
||||||
pq.QuoteIdentifier(name)))
|
pq.QuoteIdentifier(name)))
|
||||||
|
|
@ -116,15 +127,15 @@ func (drv *PostgresDriver) CreateDatabase(u *url.URL) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropDatabase drops the specified database (if it exists)
|
// DropDatabase drops the specified database (if it exists)
|
||||||
func (drv *PostgresDriver) DropDatabase(u *url.URL) error {
|
func (drv *Driver) DropDatabase() error {
|
||||||
name := databaseName(u)
|
name := dbutil.DatabaseName(drv.databaseURL)
|
||||||
fmt.Printf("Dropping: %s\n", name)
|
fmt.Printf("Dropping: %s\n", name)
|
||||||
|
|
||||||
db, err := drv.openPostgresDB(u)
|
db, err := drv.openPostgresDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
|
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
|
||||||
pq.QuoteIdentifier(name)))
|
pq.QuoteIdentifier(name)))
|
||||||
|
|
@ -132,14 +143,14 @@ func (drv *PostgresDriver) DropDatabase(u *url.URL) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// load applied migrations
|
// load applied migrations
|
||||||
migrations, err := queryColumn(db,
|
migrations, err := dbutil.QueryColumn(db,
|
||||||
"select quote_literal(version) from "+migrationsTable+" order by version asc")
|
"select quote_literal(version) from "+migrationsTable+" order by version asc")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -159,11 +170,11 @@ func (drv *PostgresDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DumpSchema returns the current database schema
|
// DumpSchema returns the current database schema
|
||||||
func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
|
||||||
// load schema
|
// load schema
|
||||||
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
|
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
|
||||||
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
|
"--no-privileges", "--no-owner"}, connectionArgsForDump(drv.databaseURL)...)
|
||||||
schema, err := runCommand("pg_dump", args...)
|
schema, err := dbutil.RunCommand("pg_dump", args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -174,18 +185,18 @@ func (drv *PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
schema = append(schema, migrations...)
|
schema = append(schema, migrations...)
|
||||||
return trimLeadingSQLComments(schema)
|
return dbutil.TrimLeadingSQLComments(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseExists determines whether the database exists
|
// DatabaseExists determines whether the database exists
|
||||||
func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
func (drv *Driver) DatabaseExists() (bool, error) {
|
||||||
name := databaseName(u)
|
name := dbutil.DatabaseName(drv.databaseURL)
|
||||||
|
|
||||||
db, err := drv.openPostgresDB(u)
|
db, err := drv.openPostgresDB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
exists := false
|
exists := false
|
||||||
err = db.QueryRow("select true from pg_database where datname = $1", name).
|
err = db.QueryRow("select true from pg_database where datname = $1", name).
|
||||||
|
|
@ -198,8 +209,8 @@ func (drv *PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMigrationsTable creates the schema_migrations table
|
// CreateMigrationsTable creates the schema_migrations table
|
||||||
func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db, u)
|
schema, migrationsTable, err := drv.quotedMigrationsTableNameParts(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -235,7 +246,7 @@ func (drv *PostgresDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||||
|
|
||||||
// SelectMigrations returns a list of applied migrations
|
// SelectMigrations returns a list of applied migrations
|
||||||
// with an optional limit (in descending order)
|
// with an optional limit (in descending order)
|
||||||
func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -250,7 +261,7 @@ func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]b
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer mustClose(rows)
|
defer dbutil.MustClose(rows)
|
||||||
|
|
||||||
migrations := map[string]bool{}
|
migrations := map[string]bool{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
@ -270,7 +281,7 @@ func (drv *PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]b
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertMigration adds a new migration record
|
// InsertMigration adds a new migration record
|
||||||
func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error {
|
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
|
||||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -282,7 +293,7 @@ func (drv *PostgresDriver) InsertMigration(db Transaction, version string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMigration removes a migration record
|
// DeleteMigration removes a migration record
|
||||||
func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error {
|
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
|
||||||
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
migrationsTable, err := drv.quotedMigrationsTableName(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -295,15 +306,15 @@ func (drv *PostgresDriver) DeleteMigration(db Transaction, version string) error
|
||||||
|
|
||||||
// Ping verifies a connection to the database server. It does not verify whether the
|
// Ping verifies a connection to the database server. It does not verify whether the
|
||||||
// specified database exists.
|
// specified database exists.
|
||||||
func (drv *PostgresDriver) Ping(u *url.URL) error {
|
func (drv *Driver) Ping() error {
|
||||||
// attempt connection to primary database, not "postgres" database
|
// attempt connection to primary database, not "postgres" database
|
||||||
// to support servers with no "postgres" database
|
// to support servers with no "postgres" database
|
||||||
// (see https://github.com/amacneil/dbmate/issues/78)
|
// (see https://github.com/amacneil/dbmate/issues/78)
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|
@ -319,8 +330,8 @@ func (drv *PostgresDriver) Ping(u *url.URL) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, error) {
|
func (drv *Driver) quotedMigrationsTableName(db dbutil.Transaction) (string, error) {
|
||||||
schema, name, err := drv.quotedMigrationsTableNameParts(db, nil)
|
schema, name, err := drv.quotedMigrationsTableNameParts(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
@ -328,7 +339,7 @@ func (drv *PostgresDriver) quotedMigrationsTableName(db Transaction) (string, er
|
||||||
return schema + "." + name, nil
|
return schema + "." + name, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url.URL) (string, string, error) {
|
func (drv *Driver) quotedMigrationsTableNameParts(db dbutil.Transaction) (string, string, error) {
|
||||||
schema := ""
|
schema := ""
|
||||||
tableNameParts := strings.Split(drv.migrationsTableName, ".")
|
tableNameParts := strings.Split(drv.migrationsTableName, ".")
|
||||||
if len(tableNameParts) > 1 {
|
if len(tableNameParts) > 1 {
|
||||||
|
|
@ -336,9 +347,9 @@ func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url
|
||||||
schema, tableNameParts = tableNameParts[0], tableNameParts[1:]
|
schema, tableNameParts = tableNameParts[0], tableNameParts[1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
if schema == "" && u != nil {
|
if schema == "" {
|
||||||
// no schema specified with table name, try URL search path if available
|
// no schema specified with table name, try URL search path if available
|
||||||
searchPath := strings.Split(u.Query().Get("search_path"), ",")
|
searchPath := strings.Split(drv.databaseURL.Query().Get("search_path"), ",")
|
||||||
schema = strings.TrimSpace(searchPath[0])
|
schema = strings.TrimSpace(searchPath[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -346,7 +357,7 @@ func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url
|
||||||
if schema == "" {
|
if schema == "" {
|
||||||
// if no URL available, use current schema
|
// if no URL available, use current schema
|
||||||
// this is a hack because we don't always have the URL context available
|
// this is a hack because we don't always have the URL context available
|
||||||
schema, err = queryValue(db, "select current_schema()")
|
schema, err = dbutil.QueryValue(db, "select current_schema()")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
@ -361,7 +372,7 @@ func (drv *PostgresDriver) quotedMigrationsTableNameParts(db Transaction, u *url
|
||||||
// use server rather than client to do this to avoid unnecessary quotes
|
// use server rather than client to do this to avoid unnecessary quotes
|
||||||
// (which would change schema.sql diff)
|
// (which would change schema.sql diff)
|
||||||
tableNameParts = append([]string{schema}, tableNameParts...)
|
tableNameParts = append([]string{schema}, tableNameParts...)
|
||||||
quotedNameParts, err := queryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
|
quotedNameParts, err := dbutil.QueryColumn(db, "select quote_ident(unnest($1::text[]))", pq.Array(tableNameParts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", err
|
return "", "", err
|
||||||
}
|
}
|
||||||
|
|
@ -1,46 +1,56 @@
|
||||||
package dbmate
|
package postgres
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func postgresTestURL(t *testing.T) *url.URL {
|
func testPostgresDriver(t *testing.T) *Driver {
|
||||||
u, err := url.Parse("postgres://postgres:postgres@postgres/dbmate?sslmode=disable")
|
u := dbutil.MustParseURL(os.Getenv("POSTGRES_TEST_URL"))
|
||||||
|
drv, err := dbmate.New(u).GetDriver()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return u
|
return drv.(*Driver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testPostgresDriver() *PostgresDriver {
|
func prepTestPostgresDB(t *testing.T) *sql.DB {
|
||||||
drv := &PostgresDriver{}
|
drv := testPostgresDriver(t)
|
||||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
|
||||||
|
|
||||||
return drv
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepTestPostgresDB(t *testing.T, u *url.URL) *sql.DB {
|
|
||||||
drv := testPostgresDriver()
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// connect database
|
// connect database
|
||||||
db, err := sql.Open("postgres", u.String())
|
db, err := sql.Open("postgres", drv.databaseURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizePostgresURL(t *testing.T) {
|
func TestGetDriver(t *testing.T) {
|
||||||
|
db := dbmate.New(dbutil.MustParseURL("postgres://"))
|
||||||
|
drvInterface, err := db.GetDriver()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// driver should have URL and default migrations table set
|
||||||
|
drv, ok := drvInterface.(*Driver)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
|
||||||
|
require.Equal(t, "schema_migrations", drv.migrationsTableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionString(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
expected string
|
expected string
|
||||||
|
|
@ -63,13 +73,13 @@ func TestNormalizePostgresURL(t *testing.T) {
|
||||||
u, err := url.Parse(c.input)
|
u, err := url.Parse(c.input)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual := normalizePostgresURL(u).String()
|
actual := connectionString(u)
|
||||||
require.Equal(t, c.expected, actual)
|
require.Equal(t, c.expected, actual)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizePostgresURLForDump(t *testing.T) {
|
func TestConnectionArgsForDump(t *testing.T) {
|
||||||
cases := []struct {
|
cases := []struct {
|
||||||
input string
|
input string
|
||||||
expected []string
|
expected []string
|
||||||
|
|
@ -87,59 +97,57 @@ func TestNormalizePostgresURLForDump(t *testing.T) {
|
||||||
u, err := url.Parse(c.input)
|
u, err := url.Parse(c.input)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
actual := normalizePostgresURLForDump(u)
|
actual := connectionArgsForDump(u)
|
||||||
require.Equal(t, c.expected, actual)
|
require.Equal(t, c.expected, actual)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresCreateDropDatabase(t *testing.T) {
|
func TestPostgresCreateDropDatabase(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
u := postgresTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database exists and we can connect to it
|
// check that database exists and we can connect to it
|
||||||
func() {
|
func() {
|
||||||
db, err := sql.Open("postgres", u.String())
|
db, err := sql.Open("postgres", drv.databaseURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// drop the database
|
// drop the database
|
||||||
err = drv.DropDatabase(u)
|
err = drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database no longer exists
|
// check that database no longer exists
|
||||||
func() {
|
func() {
|
||||||
db, err := sql.Open("postgres", u.String())
|
db, err := sql.Open("postgres", drv.databaseURL.String())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
err = db.Ping()
|
err = db.Ping()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, "pq: database \"dbmate\" does not exist", err.Error())
|
require.Equal(t, "pq: database \"dbmate_test\" does not exist", err.Error())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresDumpSchema(t *testing.T) {
|
func TestPostgresDumpSchema(t *testing.T) {
|
||||||
t.Run("default migrations table", func(t *testing.T) {
|
t.Run("default migrations table", func(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
u := postgresTestURL(t)
|
|
||||||
|
|
||||||
// prepare database
|
// prepare database
|
||||||
db := prepTestPostgresDB(t, u)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// insert migration
|
// insert migration
|
||||||
|
|
@ -149,7 +157,7 @@ func TestPostgresDumpSchema(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DumpSchema should return schema
|
// DumpSchema should return schema
|
||||||
schema, err := drv.DumpSchema(u, db)
|
schema, err := drv.DumpSchema(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations")
|
require.Contains(t, string(schema), "CREATE TABLE public.schema_migrations")
|
||||||
require.Contains(t, string(schema), "\n--\n"+
|
require.Contains(t, string(schema), "\n--\n"+
|
||||||
|
|
@ -163,23 +171,21 @@ func TestPostgresDumpSchema(t *testing.T) {
|
||||||
" ('abc2');\n")
|
" ('abc2');\n")
|
||||||
|
|
||||||
// DumpSchema should return error if command fails
|
// DumpSchema should return error if command fails
|
||||||
u.Path = "/fakedb"
|
drv.databaseURL.Path = "/fakedb"
|
||||||
schema, err = drv.DumpSchema(u, db)
|
schema, err = drv.DumpSchema(db)
|
||||||
require.Nil(t, schema)
|
require.Nil(t, schema)
|
||||||
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
|
require.EqualError(t, err, "pg_dump: [archiver (db)] connection to database "+
|
||||||
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
|
"\"fakedb\" failed: FATAL: database \"fakedb\" does not exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom migrations table with schema", func(t *testing.T) {
|
t.Run("custom migrations table with schema", func(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
drv.SetMigrationsTableName("camelSchema.testMigrations")
|
drv.migrationsTableName = "camelSchema.testMigrations"
|
||||||
|
|
||||||
u := postgresTestURL(t)
|
|
||||||
|
|
||||||
// prepare database
|
// prepare database
|
||||||
db := prepTestPostgresDB(t, u)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// insert migration
|
// insert migration
|
||||||
|
|
@ -189,7 +195,7 @@ func TestPostgresDumpSchema(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DumpSchema should return schema
|
// DumpSchema should return schema
|
||||||
schema, err := drv.DumpSchema(u, db)
|
schema, err := drv.DumpSchema(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, string(schema), "CREATE TABLE \"camelSchema\".\"testMigrations\"")
|
require.Contains(t, string(schema), "CREATE TABLE \"camelSchema\".\"testMigrations\"")
|
||||||
require.Contains(t, string(schema), "\n--\n"+
|
require.Contains(t, string(schema), "\n--\n"+
|
||||||
|
|
@ -205,34 +211,32 @@ func TestPostgresDumpSchema(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresDatabaseExists(t *testing.T) {
|
func TestPostgresDatabaseExists(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
u := postgresTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return false
|
// DatabaseExists should return false
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return true
|
// DatabaseExists should return true
|
||||||
exists, err = drv.DatabaseExists(u)
|
exists, err = drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, true, exists)
|
require.Equal(t, true, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresDatabaseExists_Error(t *testing.T) {
|
func TestPostgresDatabaseExists_Error(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
u := postgresTestURL(t)
|
drv.databaseURL.User = url.User("invalid")
|
||||||
u.User = url.User("invalid")
|
|
||||||
|
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, "pq: password authentication failed for user \"invalid\"", err.Error())
|
require.Equal(t, "pq: password authentication failed for user \"invalid\"", err.Error())
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
|
|
@ -240,10 +244,9 @@ func TestPostgresDatabaseExists_Error(t *testing.T) {
|
||||||
|
|
||||||
func TestPostgresCreateMigrationsTable(t *testing.T) {
|
func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
t.Run("default schema", func(t *testing.T) {
|
t.Run("default schema", func(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
u := postgresTestURL(t)
|
db := prepTestPostgresDB(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
count := 0
|
count := 0
|
||||||
|
|
@ -252,7 +255,7 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())
|
require.Equal(t, "pq: relation \"public.schema_migrations\" does not exist", err.Error())
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// migrations table should exist
|
// migrations table should exist
|
||||||
|
|
@ -260,18 +263,20 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom search path", func(t *testing.T) {
|
t.Run("custom search path", func(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
drv.SetMigrationsTableName("testMigrations")
|
drv.migrationsTableName = "testMigrations"
|
||||||
|
|
||||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=camelFoo")
|
u, err := url.Parse(drv.databaseURL.String() + "&search_path=camelFoo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
db := prepTestPostgresDB(t, u)
|
drv.databaseURL = u
|
||||||
defer mustClose(db)
|
|
||||||
|
db := prepTestPostgresDB(t)
|
||||||
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
// delete schema
|
// delete schema
|
||||||
_, err = db.Exec("drop schema if exists \"camelFoo\"")
|
_, err = db.Exec("drop schema if exists \"camelFoo\"")
|
||||||
|
|
@ -291,7 +296,7 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
|
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// camelFoo schema should be created, and migrations table should exist only in camelFoo schema
|
// camelFoo schema should be created, and migrations table should exist only in camelFoo schema
|
||||||
|
|
@ -302,18 +307,20 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
|
require.Equal(t, "pq: relation \"public.testMigrations\" does not exist", err.Error())
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom schema", func(t *testing.T) {
|
t.Run("custom schema", func(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
drv.SetMigrationsTableName("camelSchema.testMigrations")
|
drv.migrationsTableName = "camelSchema.testMigrations"
|
||||||
|
|
||||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
|
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
db := prepTestPostgresDB(t, u)
|
drv.databaseURL = u
|
||||||
defer mustClose(db)
|
|
||||||
|
db := prepTestPostgresDB(t)
|
||||||
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
// delete schemas
|
// delete schemas
|
||||||
_, err = db.Exec("drop schema if exists foo")
|
_, err = db.Exec("drop schema if exists foo")
|
||||||
|
|
@ -328,7 +335,7 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
require.Equal(t, "pq: relation \"camelSchema.testMigrations\" does not exist", err.Error())
|
require.Equal(t, "pq: relation \"camelSchema.testMigrations\" does not exist", err.Error())
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// camelSchema should be created, and testMigrations table should exist
|
// camelSchema should be created, and testMigrations table should exist
|
||||||
|
|
@ -341,20 +348,19 @@ func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||||
require.Equal(t, "pq: relation \"foo.testMigrations\" does not exist", err.Error())
|
require.Equal(t, "pq: relation \"foo.testMigrations\" does not exist", err.Error())
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresSelectMigrations(t *testing.T) {
|
func TestPostgresSelectMigrations(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := postgresTestURL(t)
|
db := prepTestPostgresDB(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Exec(`insert into public.test_migrations (version)
|
_, err = db.Exec(`insert into public.test_migrations (version)
|
||||||
|
|
@ -376,14 +382,13 @@ func TestPostgresSelectMigrations(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresInsertMigration(t *testing.T) {
|
func TestPostgresInsertMigration(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := postgresTestURL(t)
|
db := prepTestPostgresDB(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
|
|
@ -402,14 +407,13 @@ func TestPostgresInsertMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresDeleteMigration(t *testing.T) {
|
func TestPostgresDeleteMigration(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := postgresTestURL(t)
|
db := prepTestPostgresDB(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Exec(`insert into public.test_migrations (version)
|
_, err = db.Exec(`insert into public.test_migrations (version)
|
||||||
|
|
@ -426,31 +430,28 @@ func TestPostgresDeleteMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresPing(t *testing.T) {
|
func TestPostgresPing(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
drv := testPostgresDriver(t)
|
||||||
u := postgresTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ping database
|
// ping database
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ping invalid host should return error
|
// ping invalid host should return error
|
||||||
u.Host = "postgres:404"
|
drv.databaseURL.Host = "postgres:404"
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Contains(t, err.Error(), "connect: connection refused")
|
require.Contains(t, err.Error(), "connect: connection refused")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresQuotedMigrationsTableName(t *testing.T) {
|
func TestPostgresQuotedMigrationsTableName(t *testing.T) {
|
||||||
drv := testPostgresDriver()
|
|
||||||
|
|
||||||
t.Run("default schema", func(t *testing.T) {
|
t.Run("default schema", func(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
drv := testPostgresDriver(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
name, err := drv.quotedMigrationsTableName(db)
|
name, err := drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
@ -458,32 +459,29 @@ func TestPostgresQuotedMigrationsTableName(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom schema", func(t *testing.T) {
|
t.Run("custom schema", func(t *testing.T) {
|
||||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo,bar,public")
|
drv := testPostgresDriver(t)
|
||||||
|
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo,bar,public")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
db := prepTestPostgresDB(t, u)
|
drv.databaseURL = u
|
||||||
defer mustClose(db)
|
|
||||||
|
db := prepTestPostgresDB(t)
|
||||||
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
// if "foo" schema does not exist, current schema should be "public"
|
|
||||||
_, err = db.Exec("drop schema if exists foo")
|
_, err = db.Exec("drop schema if exists foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_, err = db.Exec("drop schema if exists bar")
|
_, err = db.Exec("drop schema if exists bar")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
name, err := drv.quotedMigrationsTableName(db)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, "public.schema_migrations", name)
|
|
||||||
|
|
||||||
// if "foo" schema exists, it should be used
|
// should use first schema from search path
|
||||||
_, err = db.Exec("create schema foo")
|
name, err := drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
|
||||||
name, err = drv.quotedMigrationsTableName(db)
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "foo.schema_migrations", name)
|
require.Equal(t, "foo.schema_migrations", name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("no schema", func(t *testing.T) {
|
t.Run("no schema", func(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
drv := testPostgresDriver(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
// this is an unlikely edge case, but if for some reason there is
|
// this is an unlikely edge case, but if for some reason there is
|
||||||
// no current schema then we should default to "public"
|
// no current schema then we should default to "public"
|
||||||
|
|
@ -496,48 +494,54 @@ func TestPostgresQuotedMigrationsTableName(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom table name", func(t *testing.T) {
|
t.Run("custom table name", func(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
drv := testPostgresDriver(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
drv.SetMigrationsTableName("simple_name")
|
drv.migrationsTableName = "simple_name"
|
||||||
name, err := drv.quotedMigrationsTableName(db)
|
name, err := drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "public.simple_name", name)
|
require.Equal(t, "public.simple_name", name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom table name quoted", func(t *testing.T) {
|
t.Run("custom table name quoted", func(t *testing.T) {
|
||||||
u := postgresTestURL(t)
|
drv := testPostgresDriver(t)
|
||||||
db := prepTestPostgresDB(t, u)
|
db := prepTestPostgresDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
// this table name will need quoting
|
// this table name will need quoting
|
||||||
drv.SetMigrationsTableName("camelCase")
|
drv.migrationsTableName = "camelCase"
|
||||||
name, err := drv.quotedMigrationsTableName(db)
|
name, err := drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "public.\"camelCase\"", name)
|
require.Equal(t, "public.\"camelCase\"", name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom table name with custom schema", func(t *testing.T) {
|
t.Run("custom table name with custom schema", func(t *testing.T) {
|
||||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
|
drv := testPostgresDriver(t)
|
||||||
|
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
db := prepTestPostgresDB(t, u)
|
drv.databaseURL = u
|
||||||
defer mustClose(db)
|
|
||||||
|
db := prepTestPostgresDB(t)
|
||||||
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec("create schema if not exists foo")
|
_, err = db.Exec("create schema if not exists foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
drv.SetMigrationsTableName("simple_name")
|
drv.migrationsTableName = "simple_name"
|
||||||
name, err := drv.quotedMigrationsTableName(db)
|
name, err := drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "foo.simple_name", name)
|
require.Equal(t, "foo.simple_name", name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom table name overrides schema", func(t *testing.T) {
|
t.Run("custom table name overrides schema", func(t *testing.T) {
|
||||||
u, err := url.Parse(postgresTestURL(t).String() + "&search_path=foo")
|
drv := testPostgresDriver(t)
|
||||||
|
u, err := url.Parse(drv.databaseURL.String() + "&search_path=foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
db := prepTestPostgresDB(t, u)
|
drv.databaseURL = u
|
||||||
defer mustClose(db)
|
|
||||||
|
db := prepTestPostgresDB(t)
|
||||||
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
_, err = db.Exec("create schema if not exists foo")
|
_, err = db.Exec("create schema if not exists foo")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
@ -545,19 +549,19 @@ func TestPostgresQuotedMigrationsTableName(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// if schema is specified as part of table name, it should override search_path
|
// if schema is specified as part of table name, it should override search_path
|
||||||
drv.SetMigrationsTableName("bar.simple_name")
|
drv.migrationsTableName = "bar.simple_name"
|
||||||
name, err := drv.quotedMigrationsTableName(db)
|
name, err := drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "bar.simple_name", name)
|
require.Equal(t, "bar.simple_name", name)
|
||||||
|
|
||||||
// schema and table name should be quoted if necessary
|
// schema and table name should be quoted if necessary
|
||||||
drv.SetMigrationsTableName("barName.camelTable")
|
drv.migrationsTableName = "barName.camelTable"
|
||||||
name, err = drv.quotedMigrationsTableName(db)
|
name, err = drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "\"barName\".\"camelTable\"", name)
|
require.Equal(t, "\"barName\".\"camelTable\"", name)
|
||||||
|
|
||||||
// more than 2 components is unexpected but we will quote and pass it along anyway
|
// more than 2 components is unexpected but we will quote and pass it along anyway
|
||||||
drv.SetMigrationsTableName("whyWould.i.doThis")
|
drv.migrationsTableName = "whyWould.i.doThis"
|
||||||
name, err = drv.quotedMigrationsTableName(db)
|
name, err = drv.quotedMigrationsTableName(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, "\"whyWould\".i.\"doThis\"", name)
|
require.Equal(t, "\"whyWould\".i.\"doThis\"", name)
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
// +build cgo
|
// +build cgo
|
||||||
|
|
||||||
package dbmate
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
|
@ -11,58 +11,68 @@ import (
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
_ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql
|
_ "github.com/mattn/go-sqlite3" // database/sql driver
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
drv := &SQLiteDriver{}
|
dbmate.RegisterDriver(NewDriver, "sqlite")
|
||||||
RegisterDriver(drv, "sqlite")
|
dbmate.RegisterDriver(NewDriver, "sqlite3")
|
||||||
RegisterDriver(drv, "sqlite3")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SQLiteDriver provides top level database functions
|
// Driver provides top level database functions
|
||||||
type SQLiteDriver struct {
|
type Driver struct {
|
||||||
migrationsTableName string
|
migrationsTableName string
|
||||||
|
databaseURL *url.URL
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqlitePath(u *url.URL) string {
|
// NewDriver initializes the driver
|
||||||
// strip one leading slash
|
func NewDriver(config dbmate.DriverConfig) dbmate.Driver {
|
||||||
// absolute URLs can be specified as sqlite:////tmp/foo.sqlite3
|
return &Driver{
|
||||||
str := regexp.MustCompile("^/").ReplaceAllString(u.Path, "")
|
migrationsTableName: config.MigrationsTableName,
|
||||||
|
databaseURL: config.DatabaseURL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionString converts a URL into a valid connection string
|
||||||
|
func ConnectionString(u *url.URL) string {
|
||||||
|
// duplicate URL and remove scheme
|
||||||
|
newURL := *u
|
||||||
|
newURL.Scheme = ""
|
||||||
|
|
||||||
|
// trim duplicate leading slashes
|
||||||
|
str := regexp.MustCompile("^//+").ReplaceAllString(newURL.String(), "/")
|
||||||
|
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMigrationsTableName sets the schema migrations table name
|
|
||||||
func (drv *SQLiteDriver) SetMigrationsTableName(name string) {
|
|
||||||
drv.migrationsTableName = name
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
func (drv *SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
|
func (drv *Driver) Open() (*sql.DB, error) {
|
||||||
return sql.Open("sqlite3", sqlitePath(u))
|
return sql.Open("sqlite3", ConnectionString(drv.databaseURL))
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDatabase creates the specified database
|
// CreateDatabase creates the specified database
|
||||||
func (drv *SQLiteDriver) CreateDatabase(u *url.URL) error {
|
func (drv *Driver) CreateDatabase() error {
|
||||||
fmt.Printf("Creating: %s\n", sqlitePath(u))
|
fmt.Printf("Creating: %s\n", ConnectionString(drv.databaseURL))
|
||||||
|
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
return db.Ping()
|
return db.Ping()
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropDatabase drops the specified database (if it exists)
|
// DropDatabase drops the specified database (if it exists)
|
||||||
func (drv *SQLiteDriver) DropDatabase(u *url.URL) error {
|
func (drv *Driver) DropDatabase() error {
|
||||||
path := sqlitePath(u)
|
path := ConnectionString(drv.databaseURL)
|
||||||
fmt.Printf("Dropping: %s\n", path)
|
fmt.Printf("Dropping: %s\n", path)
|
||||||
|
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -73,11 +83,11 @@ func (drv *SQLiteDriver) DropDatabase(u *url.URL) error {
|
||||||
return os.Remove(path)
|
return os.Remove(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
func (drv *Driver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||||
migrationsTable := drv.quotedMigrationsTableName()
|
migrationsTable := drv.quotedMigrationsTableName()
|
||||||
|
|
||||||
// load applied migrations
|
// load applied migrations
|
||||||
migrations, err := queryColumn(db,
|
migrations, err := dbutil.QueryColumn(db,
|
||||||
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
|
fmt.Sprintf("select quote(version) from %s order by version asc", migrationsTable))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -98,9 +108,9 @@ func (drv *SQLiteDriver) schemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DumpSchema returns the current database schema
|
// DumpSchema returns the current database schema
|
||||||
func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
func (drv *Driver) DumpSchema(db *sql.DB) ([]byte, error) {
|
||||||
path := sqlitePath(u)
|
path := ConnectionString(drv.databaseURL)
|
||||||
schema, err := runCommand("sqlite3", path, ".schema")
|
schema, err := dbutil.RunCommand("sqlite3", path, ".schema")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -111,12 +121,12 @@ func (drv *SQLiteDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
schema = append(schema, migrations...)
|
schema = append(schema, migrations...)
|
||||||
return trimLeadingSQLComments(schema)
|
return dbutil.TrimLeadingSQLComments(schema)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseExists determines whether the database exists
|
// DatabaseExists determines whether the database exists
|
||||||
func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
|
func (drv *Driver) DatabaseExists() (bool, error) {
|
||||||
_, err := os.Stat(sqlitePath(u))
|
_, err := os.Stat(ConnectionString(drv.databaseURL))
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
@ -128,7 +138,7 @@ func (drv *SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateMigrationsTable creates the schema migrations table
|
// CreateMigrationsTable creates the schema migrations table
|
||||||
func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
func (drv *Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
fmt.Sprintf("create table if not exists %s ", drv.quotedMigrationsTableName()) +
|
fmt.Sprintf("create table if not exists %s ", drv.quotedMigrationsTableName()) +
|
||||||
"(version varchar(255) primary key)")
|
"(version varchar(255) primary key)")
|
||||||
|
|
@ -138,7 +148,7 @@ func (drv *SQLiteDriver) CreateMigrationsTable(u *url.URL, db *sql.DB) error {
|
||||||
|
|
||||||
// SelectMigrations returns a list of applied migrations
|
// SelectMigrations returns a list of applied migrations
|
||||||
// with an optional limit (in descending order)
|
// with an optional limit (in descending order)
|
||||||
func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
func (drv *Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||||
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
|
query := fmt.Sprintf("select version from %s order by version desc", drv.quotedMigrationsTableName())
|
||||||
if limit >= 0 {
|
if limit >= 0 {
|
||||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||||
|
|
@ -148,7 +158,7 @@ func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]boo
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer mustClose(rows)
|
defer dbutil.MustClose(rows)
|
||||||
|
|
||||||
migrations := map[string]bool{}
|
migrations := map[string]bool{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
@ -168,7 +178,7 @@ func (drv *SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]boo
|
||||||
}
|
}
|
||||||
|
|
||||||
// InsertMigration adds a new migration record
|
// InsertMigration adds a new migration record
|
||||||
func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error {
|
func (drv *Driver) InsertMigration(db dbutil.Transaction, version string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
fmt.Sprintf("insert into %s (version) values (?)", drv.quotedMigrationsTableName()),
|
||||||
version)
|
version)
|
||||||
|
|
@ -177,7 +187,7 @@ func (drv *SQLiteDriver) InsertMigration(db Transaction, version string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMigration removes a migration record
|
// DeleteMigration removes a migration record
|
||||||
func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error {
|
func (drv *Driver) DeleteMigration(db dbutil.Transaction, version string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
|
fmt.Sprintf("delete from %s where version = ?", drv.quotedMigrationsTableName()),
|
||||||
version)
|
version)
|
||||||
|
|
@ -188,23 +198,23 @@ func (drv *SQLiteDriver) DeleteMigration(db Transaction, version string) error {
|
||||||
// Ping verifies a connection to the database. Due to the way SQLite works, by
|
// Ping verifies a connection to the database. Due to the way SQLite works, by
|
||||||
// testing whether the database is valid, it will automatically create the database
|
// testing whether the database is valid, it will automatically create the database
|
||||||
// if it does not already exist.
|
// if it does not already exist.
|
||||||
func (drv *SQLiteDriver) Ping(u *url.URL) error {
|
func (drv *Driver) Ping() error {
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
|
|
||||||
return db.Ping()
|
return db.Ping()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv *SQLiteDriver) quotedMigrationsTableName() string {
|
func (drv *Driver) quotedMigrationsTableName() string {
|
||||||
return drv.quoteIdentifier(drv.migrationsTableName)
|
return drv.quoteIdentifier(drv.migrationsTableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// quoteIdentifier quotes a table or column name
|
// quoteIdentifier quotes a table or column name
|
||||||
// we fall back to lib/pq implementation since both use ansi standard (double quotes)
|
// we fall back to lib/pq implementation since both use ansi standard (double quotes)
|
||||||
// and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent
|
// and mattn/go-sqlite3 doesn't provide a sqlite-specific equivalent
|
||||||
func (drv *SQLiteDriver) quoteIdentifier(s string) string {
|
func (drv *Driver) quoteIdentifier(s string) string {
|
||||||
return pq.QuoteIdentifier(s)
|
return pq.QuoteIdentifier(s)
|
||||||
}
|
}
|
||||||
|
|
@ -1,59 +1,91 @@
|
||||||
// +build cgo
|
// +build cgo
|
||||||
|
|
||||||
package dbmate
|
package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbmate"
|
||||||
|
"github.com/amacneil/dbmate/pkg/dbutil"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func sqliteTestURL(t *testing.T) *url.URL {
|
func testSQLiteDriver(t *testing.T) *Driver {
|
||||||
u, err := url.Parse("sqlite3:////tmp/dbmate.sqlite3")
|
u := dbutil.MustParseURL(os.Getenv("SQLITE_TEST_URL"))
|
||||||
|
drv, err := dbmate.New(u).GetDriver()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return u
|
return drv.(*Driver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func testSQLiteDriver() *SQLiteDriver {
|
func prepTestSQLiteDB(t *testing.T) *sql.DB {
|
||||||
drv := &SQLiteDriver{}
|
drv := testSQLiteDriver(t)
|
||||||
drv.SetMigrationsTableName(DefaultMigrationsTableName)
|
|
||||||
|
|
||||||
return drv
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepTestSQLiteDB(t *testing.T, u *url.URL) *sql.DB {
|
|
||||||
drv := testSQLiteDriver()
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// connect database
|
// connect database
|
||||||
db, err := drv.Open(u)
|
db, err := drv.Open()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetDriver(t *testing.T) {
|
||||||
|
db := dbmate.New(dbutil.MustParseURL("sqlite://"))
|
||||||
|
drvInterface, err := db.GetDriver()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// driver should have URL and default migrations table set
|
||||||
|
drv, ok := drvInterface.(*Driver)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, db.DatabaseURL.String(), drv.databaseURL.String())
|
||||||
|
require.Equal(t, "schema_migrations", drv.migrationsTableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectionString(t *testing.T) {
|
||||||
|
t.Run("relative", func(t *testing.T) {
|
||||||
|
u := dbutil.MustParseURL("sqlite:foo/bar.sqlite3?mode=ro")
|
||||||
|
require.Equal(t, "foo/bar.sqlite3?mode=ro", ConnectionString(u))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("absolute", func(t *testing.T) {
|
||||||
|
u := dbutil.MustParseURL("sqlite:/tmp/foo.sqlite3?mode=ro")
|
||||||
|
require.Equal(t, "/tmp/foo.sqlite3?mode=ro", ConnectionString(u))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("three slashes", func(t *testing.T) {
|
||||||
|
// interpreted as absolute path
|
||||||
|
u := dbutil.MustParseURL("sqlite:///tmp/foo.sqlite3?mode=ro")
|
||||||
|
require.Equal(t, "/tmp/foo.sqlite3?mode=ro", ConnectionString(u))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("four slashes", func(t *testing.T) {
|
||||||
|
// interpreted as absolute path
|
||||||
|
// supported for backwards compatibility
|
||||||
|
u := dbutil.MustParseURL("sqlite:////tmp/foo.sqlite3?mode=ro")
|
||||||
|
require.Equal(t, "/tmp/foo.sqlite3?mode=ro", ConnectionString(u))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestSQLiteCreateDropDatabase(t *testing.T) {
|
func TestSQLiteCreateDropDatabase(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
u := sqliteTestURL(t)
|
path := ConnectionString(drv.databaseURL)
|
||||||
path := sqlitePath(u)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database exists
|
// check that database exists
|
||||||
|
|
@ -61,7 +93,7 @@ func TestSQLiteCreateDropDatabase(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// drop the database
|
// drop the database
|
||||||
err = drv.DropDatabase(u)
|
err = drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that database no longer exists
|
// check that database no longer exists
|
||||||
|
|
@ -71,15 +103,13 @@ func TestSQLiteCreateDropDatabase(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLiteDumpSchema(t *testing.T) {
|
func TestSQLiteDumpSchema(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := sqliteTestURL(t)
|
|
||||||
|
|
||||||
// prepare database
|
// prepare database
|
||||||
db := prepTestSQLiteDB(t, u)
|
db := prepTestSQLiteDB(t)
|
||||||
defer mustClose(db)
|
defer dbutil.MustClose(db)
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// insert migration
|
// insert migration
|
||||||
|
|
@ -89,7 +119,7 @@ func TestSQLiteDumpSchema(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DumpSchema should return schema
|
// DumpSchema should return schema
|
||||||
schema, err := drv.DumpSchema(u, db)
|
schema, err := drv.DumpSchema(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Contains(t, string(schema), "CREATE TABLE IF NOT EXISTS \"test_migrations\"")
|
require.Contains(t, string(schema), "CREATE TABLE IF NOT EXISTS \"test_migrations\"")
|
||||||
require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+
|
require.Contains(t, string(schema), ");\n-- Dbmate schema migrations\n"+
|
||||||
|
|
@ -98,50 +128,50 @@ func TestSQLiteDumpSchema(t *testing.T) {
|
||||||
" ('abc2');\n")
|
" ('abc2');\n")
|
||||||
|
|
||||||
// DumpSchema should return error if command fails
|
// DumpSchema should return error if command fails
|
||||||
u.Path = "/."
|
drv.databaseURL = dbutil.MustParseURL(".")
|
||||||
schema, err = drv.DumpSchema(u, db)
|
schema, err = drv.DumpSchema(db)
|
||||||
require.Nil(t, schema)
|
require.Nil(t, schema)
|
||||||
|
require.Error(t, err)
|
||||||
require.EqualError(t, err, "Error: unable to open database \".\": "+
|
require.EqualError(t, err, "Error: unable to open database \".\": "+
|
||||||
"unable to open database file")
|
"unable to open database file")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLiteDatabaseExists(t *testing.T) {
|
func TestSQLiteDatabaseExists(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
u := sqliteTestURL(t)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return false
|
// DatabaseExists should return false
|
||||||
exists, err := drv.DatabaseExists(u)
|
exists, err := drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, false, exists)
|
require.Equal(t, false, exists)
|
||||||
|
|
||||||
// create database
|
// create database
|
||||||
err = drv.CreateDatabase(u)
|
err = drv.CreateDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// DatabaseExists should return true
|
// DatabaseExists should return true
|
||||||
exists, err = drv.DatabaseExists(u)
|
exists, err = drv.DatabaseExists()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, true, exists)
|
require.Equal(t, true, exists)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLiteCreateMigrationsTable(t *testing.T) {
|
func TestSQLiteCreateMigrationsTable(t *testing.T) {
|
||||||
t.Run("default table", func(t *testing.T) {
|
t.Run("default table", func(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
u := sqliteTestURL(t)
|
db := prepTestSQLiteDB(t)
|
||||||
db := prepTestSQLiteDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
count := 0
|
count := 0
|
||||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||||
|
require.Error(t, err)
|
||||||
require.Regexp(t, "no such table: schema_migrations", err.Error())
|
require.Regexp(t, "no such table: schema_migrations", err.Error())
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// migrations table should exist
|
// migrations table should exist
|
||||||
|
|
@ -149,25 +179,25 @@ func TestSQLiteCreateMigrationsTable(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom table", func(t *testing.T) {
|
t.Run("custom table", func(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := sqliteTestURL(t)
|
db := prepTestSQLiteDB(t)
|
||||||
db := prepTestSQLiteDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
// migrations table should not exist
|
// migrations table should not exist
|
||||||
count := 0
|
count := 0
|
||||||
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
err := db.QueryRow("select count(*) from test_migrations").Scan(&count)
|
||||||
|
require.Error(t, err)
|
||||||
require.Regexp(t, "no such table: test_migrations", err.Error())
|
require.Regexp(t, "no such table: test_migrations", err.Error())
|
||||||
|
|
||||||
// create table
|
// create table
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// migrations table should exist
|
// migrations table should exist
|
||||||
|
|
@ -175,20 +205,19 @@ func TestSQLiteCreateMigrationsTable(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create table should be idempotent
|
// create table should be idempotent
|
||||||
err = drv.CreateMigrationsTable(u, db)
|
err = drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLiteSelectMigrations(t *testing.T) {
|
func TestSQLiteSelectMigrations(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := sqliteTestURL(t)
|
db := prepTestSQLiteDB(t)
|
||||||
db := prepTestSQLiteDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Exec(`insert into test_migrations (version)
|
_, err = db.Exec(`insert into test_migrations (version)
|
||||||
|
|
@ -210,14 +239,13 @@ func TestSQLiteSelectMigrations(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLiteInsertMigration(t *testing.T) {
|
func TestSQLiteInsertMigration(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := sqliteTestURL(t)
|
db := prepTestSQLiteDB(t)
|
||||||
db := prepTestSQLiteDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
|
|
@ -236,14 +264,13 @@ func TestSQLiteInsertMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLiteDeleteMigration(t *testing.T) {
|
func TestSQLiteDeleteMigration(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
drv.SetMigrationsTableName("test_migrations")
|
drv.migrationsTableName = "test_migrations"
|
||||||
|
|
||||||
u := sqliteTestURL(t)
|
db := prepTestSQLiteDB(t)
|
||||||
db := prepTestSQLiteDB(t, u)
|
defer dbutil.MustClose(db)
|
||||||
defer mustClose(db)
|
|
||||||
|
|
||||||
err := drv.CreateMigrationsTable(u, db)
|
err := drv.CreateMigrationsTable(db)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = db.Exec(`insert into test_migrations (version)
|
_, err = db.Exec(`insert into test_migrations (version)
|
||||||
|
|
@ -260,16 +287,15 @@ func TestSQLiteDeleteMigration(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLitePing(t *testing.T) {
|
func TestSQLitePing(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
u := sqliteTestURL(t)
|
path := ConnectionString(drv.databaseURL)
|
||||||
path := sqlitePath(u)
|
|
||||||
|
|
||||||
// drop any existing database
|
// drop any existing database
|
||||||
err := drv.DropDatabase(u)
|
err := drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// ping database
|
// ping database
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// check that the database was created (sqlite-only behavior)
|
// check that the database was created (sqlite-only behavior)
|
||||||
|
|
@ -277,7 +303,7 @@ func TestSQLitePing(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// drop the database
|
// drop the database
|
||||||
err = drv.DropDatabase(u)
|
err = drv.DropDatabase()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// create directory where database file is expected
|
// create directory where database file is expected
|
||||||
|
|
@ -289,20 +315,20 @@ func TestSQLitePing(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// ping database should fail
|
// ping database should fail
|
||||||
err = drv.Ping(u)
|
err = drv.Ping()
|
||||||
require.EqualError(t, err, "unable to open database file: is a directory")
|
require.EqualError(t, err, "unable to open database file: is a directory")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSQLiteQuotedMigrationsTableName(t *testing.T) {
|
func TestSQLiteQuotedMigrationsTableName(t *testing.T) {
|
||||||
t.Run("default name", func(t *testing.T) {
|
t.Run("default name", func(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
name := drv.quotedMigrationsTableName()
|
name := drv.quotedMigrationsTableName()
|
||||||
require.Equal(t, `"schema_migrations"`, name)
|
require.Equal(t, `"schema_migrations"`, name)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("custom name", func(t *testing.T) {
|
t.Run("custom name", func(t *testing.T) {
|
||||||
drv := testSQLiteDriver()
|
drv := testSQLiteDriver(t)
|
||||||
drv.SetMigrationsTableName("fooMigrations")
|
drv.migrationsTableName = "fooMigrations"
|
||||||
|
|
||||||
name := drv.quotedMigrationsTableName()
|
name := drv.quotedMigrationsTableName()
|
||||||
require.Equal(t, `"fooMigrations"`, name)
|
require.Equal(t, `"fooMigrations"`, name)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue