mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2026-02-02 17:35:08 +01:00
Move code into pkg directory (#22)
This commit is contained in:
parent
4e01c75eca
commit
54a9fbc859
15 changed files with 3 additions and 3 deletions
344
pkg/dbmate/db.go
Normal file
344
pkg/dbmate/db.go
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultMigrationsDir specifies default directory to find migration files
|
||||
var DefaultMigrationsDir = "./db/migrations"
|
||||
|
||||
// DB allows dbmate actions to be performed on a specified database
|
||||
type DB struct {
|
||||
DatabaseURL *url.URL
|
||||
MigrationsDir string
|
||||
}
|
||||
|
||||
// NewDB initializes a new dbmate database
|
||||
func NewDB(databaseURL *url.URL) *DB {
|
||||
return &DB{
|
||||
DatabaseURL: databaseURL,
|
||||
MigrationsDir: DefaultMigrationsDir,
|
||||
}
|
||||
}
|
||||
|
||||
// GetDriver loads the required database driver
|
||||
func (db *DB) GetDriver() (Driver, error) {
|
||||
return GetDriver(db.DatabaseURL.Scheme)
|
||||
}
|
||||
|
||||
// Up creates the database (if necessary) and runs migrations
|
||||
func (db *DB) Up() error {
|
||||
drv, err := db.GetDriver()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create database if it does not already exist
|
||||
// skip this step if we cannot determine status
|
||||
// (e.g. user does not have list database permission)
|
||||
exists, err := drv.DatabaseExists(db.DatabaseURL)
|
||||
if err == nil && !exists {
|
||||
if err := drv.CreateDatabase(db.DatabaseURL); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// migrate
|
||||
return db.Migrate()
|
||||
}
|
||||
|
||||
// Create creates the current database
|
||||
func (db *DB) Create() error {
|
||||
drv, err := db.GetDriver()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return drv.CreateDatabase(db.DatabaseURL)
|
||||
}
|
||||
|
||||
// Drop drops the current database (if it exists)
|
||||
func (db *DB) Drop() error {
|
||||
drv, err := db.GetDriver()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return drv.DropDatabase(db.DatabaseURL)
|
||||
}
|
||||
|
||||
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
|
||||
|
||||
// New creates a new migration file
|
||||
func (db *DB) New(name string) error {
|
||||
// new migration name
|
||||
timestamp := time.Now().UTC().Format("20060102150405")
|
||||
if name == "" {
|
||||
return fmt.Errorf("please specify a name for the new migration")
|
||||
}
|
||||
name = fmt.Sprintf("%s_%s.sql", timestamp, name)
|
||||
|
||||
// create migrations dir if missing
|
||||
if err := os.MkdirAll(db.MigrationsDir, 0755); err != nil {
|
||||
return fmt.Errorf("unable to create directory `%s`", db.MigrationsDir)
|
||||
}
|
||||
|
||||
// check file does not already exist
|
||||
path := filepath.Join(db.MigrationsDir, name)
|
||||
fmt.Printf("Creating migration: %s\n", path)
|
||||
|
||||
if _, err := os.Stat(path); !os.IsNotExist(err) {
|
||||
return fmt.Errorf("file already exists")
|
||||
}
|
||||
|
||||
// write new migration
|
||||
file, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer mustClose(file)
|
||||
_, err = file.WriteString(migrationTemplate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func doTransaction(db *sql.DB, txFunc func(Transaction) error) error {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := txFunc(tx); err != nil {
|
||||
if err1 := tx.Rollback(); err1 != nil {
|
||||
return err1
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (db *DB) openDatabaseForMigration() (Driver, *sql.DB, error) {
|
||||
drv, err := db.GetDriver()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sqlDB, err := drv.Open(db.DatabaseURL)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := drv.CreateMigrationsTable(sqlDB); err != nil {
|
||||
mustClose(sqlDB)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return drv, sqlDB, nil
|
||||
}
|
||||
|
||||
// Migrate migrates database to the latest version
|
||||
func (db *DB) Migrate() error {
|
||||
re := regexp.MustCompile(`^\d.*\.sql$`)
|
||||
files, err := findMigrationFiles(db.MigrationsDir, re)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return fmt.Errorf("no migration files found")
|
||||
}
|
||||
|
||||
drv, sqlDB, err := db.openDatabaseForMigration()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
applied, err := drv.SelectMigrations(sqlDB, -1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, filename := range files {
|
||||
ver := migrationVersion(filename)
|
||||
if ok := applied[ver]; ok {
|
||||
// migration already applied
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("Applying: %s\n", filename)
|
||||
|
||||
migration, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// begin transaction
|
||||
err = doTransaction(sqlDB, func(tx Transaction) error {
|
||||
// run actual migration
|
||||
if _, err := tx.Exec(migration["up"]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// record migration
|
||||
return drv.InsertMigration(tx, ver)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not find migrations directory `%s`", dir)
|
||||
}
|
||||
|
||||
matches := []string{}
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := file.Name()
|
||||
if !re.MatchString(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
matches = append(matches, name)
|
||||
}
|
||||
|
||||
sort.Strings(matches)
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func findMigrationFile(dir string, ver string) (string, error) {
|
||||
if ver == "" {
|
||||
panic("migration version is required")
|
||||
}
|
||||
|
||||
ver = regexp.QuoteMeta(ver)
|
||||
re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver))
|
||||
|
||||
files, err := findMigrationFiles(dir, re)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return "", fmt.Errorf("can't find migration file: %s*.sql", ver)
|
||||
}
|
||||
|
||||
return files[0], nil
|
||||
}
|
||||
|
||||
func migrationVersion(filename string) string {
|
||||
return regexp.MustCompile(`^\d+`).FindString(filename)
|
||||
}
|
||||
|
||||
// parseMigration reads a migration file into a map with up/down keys
|
||||
// implementation is similar to regexp.Split()
|
||||
func parseMigration(path string) (map[string]string, error) {
|
||||
// read migration file into string
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contents := string(data)
|
||||
|
||||
// split string on our trigger comment
|
||||
separatorRegexp := regexp.MustCompile(`(?m)^-- migrate:(.*)$`)
|
||||
matches := separatorRegexp.FindAllStringSubmatchIndex(contents, -1)
|
||||
|
||||
migrations := map[string]string{}
|
||||
direction := ""
|
||||
beg := 0
|
||||
end := 0
|
||||
|
||||
for _, match := range matches {
|
||||
end = match[0]
|
||||
if direction != "" {
|
||||
// write previous direction to output map
|
||||
migrations[direction] = contents[beg:end]
|
||||
}
|
||||
|
||||
// each match records the start of a new direction
|
||||
direction = contents[match[2]:match[3]]
|
||||
beg = match[1]
|
||||
}
|
||||
|
||||
// write final direction to output map
|
||||
migrations[direction] = contents[beg:]
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// Rollback rolls back the most recent migration
|
||||
func (db *DB) Rollback() error {
|
||||
drv, sqlDB, err := db.openDatabaseForMigration()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
applied, err := drv.SelectMigrations(sqlDB, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// grab most recent applied migration (applied has len=1)
|
||||
latest := ""
|
||||
for ver := range applied {
|
||||
latest = ver
|
||||
}
|
||||
if latest == "" {
|
||||
return fmt.Errorf("can't rollback: no migrations have been applied")
|
||||
}
|
||||
|
||||
filename, err := findMigrationFile(db.MigrationsDir, latest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Rolling back: %s\n", filename)
|
||||
|
||||
migration, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// begin transaction
|
||||
err = doTransaction(sqlDB, func(tx Transaction) error {
|
||||
// rollback migration
|
||||
if _, err := tx.Exec(migration["down"]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove migration record
|
||||
return drv.DeleteMigration(tx, latest)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
145
pkg/dbmate/db_test.go
Normal file
145
pkg/dbmate/db_test.go
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testdataDir string
|
||||
|
||||
func newTestDB(t *testing.T, u *url.URL) *DB {
|
||||
var err error
|
||||
|
||||
// only chdir once, because testdata is relative to current directory
|
||||
if testdataDir == "" {
|
||||
testdataDir, err = filepath.Abs("../../testdata")
|
||||
require.Nil(t, err)
|
||||
|
||||
err = os.Chdir(testdataDir)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
return NewDB(u)
|
||||
}
|
||||
|
||||
func testURLs(t *testing.T) []*url.URL {
|
||||
return []*url.URL{
|
||||
postgresTestURL(t),
|
||||
mySQLTestURL(t),
|
||||
sqliteTestURL(t),
|
||||
}
|
||||
}
|
||||
|
||||
func testMigrateURL(t *testing.T, u *url.URL) {
|
||||
db := newTestDB(t, u)
|
||||
|
||||
// drop and recreate database
|
||||
err := db.Drop()
|
||||
require.Nil(t, err)
|
||||
err = db.Create()
|
||||
require.Nil(t, err)
|
||||
|
||||
// migrate
|
||||
err = db.Migrate()
|
||||
require.Nil(t, err)
|
||||
|
||||
// verify results
|
||||
sqlDB, err := GetDriverOpen(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
count := 0
|
||||
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
||||
where version = '20151129054053'`).Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
for _, u := range testURLs(t) {
|
||||
testMigrateURL(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func testUpURL(t *testing.T, u *url.URL) {
|
||||
db := newTestDB(t, u)
|
||||
|
||||
// drop database
|
||||
err := db.Drop()
|
||||
require.Nil(t, err)
|
||||
|
||||
// create and migrate
|
||||
err = db.Up()
|
||||
require.Nil(t, err)
|
||||
|
||||
// verify results
|
||||
sqlDB, err := GetDriverOpen(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
count := 0
|
||||
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
||||
where version = '20151129054053'`).Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestUp(t *testing.T) {
|
||||
for _, u := range testURLs(t) {
|
||||
testUpURL(t, u)
|
||||
}
|
||||
}
|
||||
|
||||
func testRollbackURL(t *testing.T, u *url.URL) {
|
||||
db := newTestDB(t, u)
|
||||
|
||||
// drop, recreate, and migrate database
|
||||
err := db.Drop()
|
||||
require.Nil(t, err)
|
||||
err = db.Create()
|
||||
require.Nil(t, err)
|
||||
err = db.Migrate()
|
||||
require.Nil(t, err)
|
||||
|
||||
// verify migration
|
||||
sqlDB, err := GetDriverOpen(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(sqlDB)
|
||||
|
||||
count := 0
|
||||
err = sqlDB.QueryRow(`select count(*) from schema_migrations
|
||||
where version = '20151129054053'`).Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
// rollback
|
||||
err = db.Rollback()
|
||||
require.Nil(t, err)
|
||||
|
||||
// verify rollback
|
||||
err = sqlDB.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
err = sqlDB.QueryRow("select count(*) from users").Scan(&count)
|
||||
require.NotNil(t, err)
|
||||
require.Regexp(t, "(does not exist|doesn't exist|no such table)", err.Error())
|
||||
}
|
||||
|
||||
func TestRollback(t *testing.T) {
|
||||
for _, u := range testURLs(t) {
|
||||
testRollbackURL(t, u)
|
||||
}
|
||||
}
|
||||
48
pkg/dbmate/driver.go
Normal file
48
pkg/dbmate/driver.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Driver provides top level database functions
|
||||
type Driver interface {
|
||||
Open(*url.URL) (*sql.DB, error)
|
||||
DatabaseExists(*url.URL) (bool, error)
|
||||
CreateDatabase(*url.URL) error
|
||||
DropDatabase(*url.URL) error
|
||||
CreateMigrationsTable(*sql.DB) error
|
||||
SelectMigrations(*sql.DB, int) (map[string]bool, error)
|
||||
InsertMigration(Transaction, string) error
|
||||
DeleteMigration(Transaction, string) error
|
||||
}
|
||||
|
||||
// Transaction can represent a database or open transaction
|
||||
type Transaction interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
// GetDriver loads a database driver by name
|
||||
func GetDriver(name string) (Driver, error) {
|
||||
switch name {
|
||||
case "mysql":
|
||||
return MySQLDriver{}, nil
|
||||
case "postgres", "postgresql":
|
||||
return PostgresDriver{}, nil
|
||||
case "sqlite", "sqlite3":
|
||||
return SQLiteDriver{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown driver: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// GetDriverOpen is a shortcut for GetDriver(u.Scheme).Open(u)
|
||||
func GetDriverOpen(u *url.URL) (*sql.DB, error) {
|
||||
drv, err := GetDriver(u.Scheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return drv.Open(u)
|
||||
}
|
||||
27
pkg/dbmate/driver_test.go
Normal file
27
pkg/dbmate/driver_test.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetDriver_Postgres(t *testing.T) {
|
||||
drv, err := GetDriver("postgres")
|
||||
require.Nil(t, err)
|
||||
_, ok := drv.(PostgresDriver)
|
||||
require.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestGetDriver_MySQL(t *testing.T) {
|
||||
drv, err := GetDriver("mysql")
|
||||
require.Nil(t, err)
|
||||
_, ok := drv.(MySQLDriver)
|
||||
require.Equal(t, true, ok)
|
||||
}
|
||||
|
||||
func TestGetDriver_Error(t *testing.T) {
|
||||
drv, err := GetDriver("foo")
|
||||
require.Equal(t, "unknown driver: foo", err.Error())
|
||||
require.Nil(t, drv)
|
||||
}
|
||||
156
pkg/dbmate/mysql.go
Normal file
156
pkg/dbmate/mysql.go
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql" // mysql driver for database/sql
|
||||
)
|
||||
|
||||
// MySQLDriver provides top level database functions
|
||||
type MySQLDriver struct {
|
||||
}
|
||||
|
||||
func normalizeMySQLURL(u *url.URL) string {
|
||||
normalizedURL := *u
|
||||
normalizedURL.Scheme = ""
|
||||
|
||||
// set default port
|
||||
if normalizedURL.Port() == "" {
|
||||
normalizedURL.Host = fmt.Sprintf("%s:3306", normalizedURL.Host)
|
||||
}
|
||||
|
||||
// host format required by go-sql-driver/mysql
|
||||
normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host)
|
||||
|
||||
query := normalizedURL.Query()
|
||||
query.Set("multiStatements", "true")
|
||||
normalizedURL.RawQuery = query.Encode()
|
||||
|
||||
str := normalizedURL.String()
|
||||
return strings.TrimLeft(str, "/")
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv MySQLDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("mysql", normalizeMySQLURL(u))
|
||||
}
|
||||
|
||||
func (drv MySQLDriver) openRootDB(u *url.URL) (*sql.DB, error) {
|
||||
// connect to no particular database
|
||||
rootURL := *u
|
||||
rootURL.Path = "/"
|
||||
|
||||
return drv.Open(&rootURL)
|
||||
}
|
||||
|
||||
func quoteIdentifier(str string) string {
|
||||
str = strings.Replace(str, "`", "\\`", -1)
|
||||
|
||||
return fmt.Sprintf("`%s`", str)
|
||||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv MySQLDriver) CreateDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Creating: %s\n", name)
|
||||
|
||||
db, err := drv.openRootDB(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec(fmt.Sprintf("create database %s",
|
||||
quoteIdentifier(name)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv MySQLDriver) DropDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Dropping: %s\n", name)
|
||||
|
||||
db, err := drv.openRootDB(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
|
||||
quoteIdentifier(name)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv MySQLDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := databaseName(u)
|
||||
|
||||
db, err := drv.openRootDB(u)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer mustClose(db)
|
||||
|
||||
exists := false
|
||||
err = db.QueryRow(`select true from information_schema.schemata
|
||||
where schema_name = ?`, name).Scan(&exists)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv MySQLDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||
_, err := db.Exec(`create table if not exists schema_migrations (
|
||||
version varchar(255) primary key)`)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv MySQLDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations order by version desc"
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
}
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer mustClose(rows)
|
||||
|
||||
migrations := map[string]bool{}
|
||||
for rows.Next() {
|
||||
var version string
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrations[version] = true
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv MySQLDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv MySQLDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from schema_migrations where version = ?", version)
|
||||
|
||||
return err
|
||||
}
|
||||
217
pkg/dbmate/mysql_test.go
Normal file
217
pkg/dbmate/mysql_test.go
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mySQLTestURL(t *testing.T) *url.URL {
|
||||
u, err := url.Parse("mysql://root:root@mysql/dbmate")
|
||||
require.Nil(t, err)
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func prepTestMySQLDB(t *testing.T) *sql.DB {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// connect database
|
||||
db, err := drv.Open(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestNormalizeMySQLURLDefaults(t *testing.T) {
|
||||
u, err := url.Parse("mysql://host/foo")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "", u.Port())
|
||||
|
||||
s := normalizeMySQLURL(u)
|
||||
require.Equal(t, "tcp(host:3306)/foo?multiStatements=true", s)
|
||||
}
|
||||
|
||||
func TestNormalizeMySQLURLCustom(t *testing.T) {
|
||||
u, err := url.Parse("mysql://bob:secret@host:123/foo?flag=on")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "123", u.Port())
|
||||
|
||||
s := normalizeMySQLURL(u)
|
||||
require.Equal(t, "bob:secret@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
||||
}
|
||||
|
||||
func TestMySQLCreateDropDatabase(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check that database exists and we can connect to it
|
||||
func() {
|
||||
db, err := drv.Open(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(db)
|
||||
|
||||
err = db.Ping()
|
||||
require.Nil(t, err)
|
||||
}()
|
||||
|
||||
// drop the database
|
||||
err = drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check that database no longer exists
|
||||
func() {
|
||||
db, err := drv.Open(u)
|
||||
require.Nil(t, err)
|
||||
defer mustClose(db)
|
||||
|
||||
err = db.Ping()
|
||||
require.NotNil(t, err)
|
||||
require.Regexp(t, "Unknown database 'dbmate'", err.Error())
|
||||
}()
|
||||
}
|
||||
|
||||
func TestMySQLDatabaseExists(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// DatabaseExists should return false
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, false, exists)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// DatabaseExists should return true
|
||||
exists, err = drv.DatabaseExists(u)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, exists)
|
||||
}
|
||||
|
||||
func TestMySQLDatabaseExists_Error(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
u := mySQLTestURL(t)
|
||||
u.User = url.User("invalid")
|
||||
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
require.Regexp(t, "Access denied for user 'invalid'@", err.Error())
|
||||
require.Equal(t, false, exists)
|
||||
}
|
||||
|
||||
func TestMySQLCreateMigrationsTable(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Regexp(t, "Table 'dbmate.schema_migrations' doesn't exist", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestMySQLSelectMigrations(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
values ('abc2'), ('abc1'), ('abc3')`)
|
||||
require.Nil(t, err)
|
||||
|
||||
migrations, err := drv.SelectMigrations(db, -1)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, migrations["abc1"])
|
||||
require.Equal(t, true, migrations["abc2"])
|
||||
require.Equal(t, true, migrations["abc2"])
|
||||
|
||||
// test limit param
|
||||
migrations, err = drv.SelectMigrations(db, 1)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, migrations["abc3"])
|
||||
require.Equal(t, false, migrations["abc1"])
|
||||
require.Equal(t, false, migrations["abc2"])
|
||||
}
|
||||
|
||||
func TestMySQLInsertMigration(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
// insert migration
|
||||
err = drv.InsertMigration(db, "abc1")
|
||||
require.Nil(t, err)
|
||||
|
||||
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").
|
||||
Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestMySQLDeleteMigration(t *testing.T) {
|
||||
drv := MySQLDriver{}
|
||||
db := prepTestMySQLDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
values ('abc1'), ('abc2')`)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = drv.DeleteMigration(db, "abc2")
|
||||
require.Nil(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
129
pkg/dbmate/postgres.go
Normal file
129
pkg/dbmate/postgres.go
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// PostgresDriver provides top level database functions
|
||||
type PostgresDriver struct {
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("postgres", u.String())
|
||||
}
|
||||
|
||||
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||
// connect to postgres database
|
||||
postgresURL := *u
|
||||
postgresURL.Path = "postgres"
|
||||
|
||||
return drv.Open(&postgresURL)
|
||||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv PostgresDriver) CreateDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Creating: %s\n", name)
|
||||
|
||||
db, err := drv.openPostgresDB(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec(fmt.Sprintf("create database %s",
|
||||
pq.QuoteIdentifier(name)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv PostgresDriver) DropDatabase(u *url.URL) error {
|
||||
name := databaseName(u)
|
||||
fmt.Printf("Dropping: %s\n", name)
|
||||
|
||||
db, err := drv.openPostgresDB(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mustClose(db)
|
||||
|
||||
_, err = db.Exec(fmt.Sprintf("drop database if exists %s",
|
||||
pq.QuoteIdentifier(name)))
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv PostgresDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
name := databaseName(u)
|
||||
|
||||
db, err := drv.openPostgresDB(u)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer mustClose(db)
|
||||
|
||||
exists := false
|
||||
err = db.QueryRow("select true from pg_database where datname = $1", name).
|
||||
Scan(&exists)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv PostgresDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||
_, err := db.Exec(`create table if not exists schema_migrations (
|
||||
version varchar(255) primary key)`)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv PostgresDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations order by version desc"
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
}
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer mustClose(rows)
|
||||
|
||||
migrations := map[string]bool{}
|
||||
for rows.Next() {
|
||||
var version string
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrations[version] = true
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv PostgresDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values ($1)", version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv PostgresDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from schema_migrations where version = $1", version)
|
||||
|
||||
return err
|
||||
}
|
||||
199
pkg/dbmate/postgres_test.go
Normal file
199
pkg/dbmate/postgres_test.go
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func postgresTestURL(t *testing.T) *url.URL {
|
||||
u, err := url.Parse("postgres://postgres:postgres@postgres/dbmate?sslmode=disable")
|
||||
require.Nil(t, err)
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func prepTestPostgresDB(t *testing.T) *sql.DB {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// connect database
|
||||
db, err := sql.Open("postgres", u.String())
|
||||
require.Nil(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestPostgresCreateDropDatabase(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check that database exists and we can connect to it
|
||||
func() {
|
||||
db, err := sql.Open("postgres", u.String())
|
||||
require.Nil(t, err)
|
||||
defer mustClose(db)
|
||||
|
||||
err = db.Ping()
|
||||
require.Nil(t, err)
|
||||
}()
|
||||
|
||||
// drop the database
|
||||
err = drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check that database no longer exists
|
||||
func() {
|
||||
db, err := sql.Open("postgres", u.String())
|
||||
require.Nil(t, err)
|
||||
defer mustClose(db)
|
||||
|
||||
err = db.Ping()
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, "pq: database \"dbmate\" does not exist", err.Error())
|
||||
}()
|
||||
}
|
||||
|
||||
func TestPostgresDatabaseExists(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// DatabaseExists should return false
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, false, exists)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// DatabaseExists should return true
|
||||
exists, err = drv.DatabaseExists(u)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, exists)
|
||||
}
|
||||
|
||||
func TestPostgresDatabaseExists_Error(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
u := postgresTestURL(t)
|
||||
u.User = url.User("invalid")
|
||||
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
require.Equal(t, "pq: role \"invalid\" does not exist", err.Error())
|
||||
require.Equal(t, false, exists)
|
||||
}
|
||||
|
||||
func TestPostgresCreateMigrationsTable(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Equal(t, "pq: relation \"schema_migrations\" does not exist", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestPostgresSelectMigrations(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
values ('abc2'), ('abc1'), ('abc3')`)
|
||||
require.Nil(t, err)
|
||||
|
||||
migrations, err := drv.SelectMigrations(db, -1)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, migrations["abc1"])
|
||||
require.Equal(t, true, migrations["abc2"])
|
||||
require.Equal(t, true, migrations["abc2"])
|
||||
|
||||
// test limit param
|
||||
migrations, err = drv.SelectMigrations(db, 1)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, migrations["abc3"])
|
||||
require.Equal(t, false, migrations["abc1"])
|
||||
require.Equal(t, false, migrations["abc2"])
|
||||
}
|
||||
|
||||
func TestPostgresInsertMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
// insert migration
|
||||
err = drv.InsertMigration(db, "abc1")
|
||||
require.Nil(t, err)
|
||||
|
||||
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").
|
||||
Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestPostgresDeleteMigration(t *testing.T) {
|
||||
drv := PostgresDriver{}
|
||||
db := prepTestPostgresDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
values ('abc1'), ('abc2')`)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = drv.DeleteMigration(db, "abc2")
|
||||
require.Nil(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
119
pkg/dbmate/sqlite.go
Normal file
119
pkg/dbmate/sqlite.go
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
"regexp"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // sqlite driver for database/sql
|
||||
)
|
||||
|
||||
// SQLiteDriver provides top level database functions
|
||||
type SQLiteDriver struct {
|
||||
}
|
||||
|
||||
func sqlitePath(u *url.URL) string {
|
||||
// strip one leading slash
|
||||
// absolute URLs can be specified as sqlite:////tmp/foo.sqlite3
|
||||
str := regexp.MustCompile("^/").ReplaceAllString(u.Path, "")
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv SQLiteDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("sqlite3", sqlitePath(u))
|
||||
}
|
||||
|
||||
// CreateDatabase creates the specified database
|
||||
func (drv SQLiteDriver) CreateDatabase(u *url.URL) error {
|
||||
fmt.Printf("Creating: %s\n", sqlitePath(u))
|
||||
|
||||
db, err := drv.Open(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer mustClose(db)
|
||||
|
||||
return db.Ping()
|
||||
}
|
||||
|
||||
// DropDatabase drops the specified database (if it exists)
|
||||
func (drv SQLiteDriver) DropDatabase(u *url.URL) error {
|
||||
path := sqlitePath(u)
|
||||
fmt.Printf("Dropping: %s\n", path)
|
||||
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
return os.Remove(path)
|
||||
}
|
||||
|
||||
// DatabaseExists determines whether the database exists
|
||||
func (drv SQLiteDriver) DatabaseExists(u *url.URL) (bool, error) {
|
||||
_, err := os.Stat(sqlitePath(u))
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CreateMigrationsTable creates the schema_migrations table
|
||||
func (drv SQLiteDriver) CreateMigrationsTable(db *sql.DB) error {
|
||||
_, err := db.Exec(`create table if not exists schema_migrations (
|
||||
version varchar(255) primary key)`)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
// with an optional limit (in descending order)
|
||||
func (drv SQLiteDriver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) {
|
||||
query := "select version from schema_migrations order by version desc"
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s limit %d", query, limit)
|
||||
}
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer mustClose(rows)
|
||||
|
||||
migrations := map[string]bool{}
|
||||
for rows.Next() {
|
||||
var version string
|
||||
if err := rows.Scan(&version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrations[version] = true
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// InsertMigration adds a new migration record
|
||||
func (drv SQLiteDriver) InsertMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("insert into schema_migrations (version) values (?)", version)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMigration removes a migration record
|
||||
func (drv SQLiteDriver) DeleteMigration(db Transaction, version string) error {
|
||||
_, err := db.Exec("delete from schema_migrations where version = ?", version)
|
||||
|
||||
return err
|
||||
}
|
||||
178
pkg/dbmate/sqlite_test.go
Normal file
178
pkg/dbmate/sqlite_test.go
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func sqliteTestURL(t *testing.T) *url.URL {
|
||||
u, err := url.Parse("sqlite3:////tmp/dbmate.sqlite3")
|
||||
require.Nil(t, err)
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func prepTestSQLiteDB(t *testing.T) *sql.DB {
|
||||
drv := SQLiteDriver{}
|
||||
u := sqliteTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// connect database
|
||||
db, err := drv.Open(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestSQLiteCreateDropDatabase(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
u := sqliteTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check that database exists
|
||||
_, err = os.Stat(sqlitePath(u))
|
||||
require.Nil(t, err)
|
||||
|
||||
// drop the database
|
||||
err = drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// check that database no longer exists
|
||||
_, err = os.Stat(sqlitePath(u))
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, true, os.IsNotExist(err))
|
||||
}
|
||||
|
||||
func TestSQLiteDatabaseExists(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
u := sqliteTestURL(t)
|
||||
|
||||
// drop any existing database
|
||||
err := drv.DropDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// DatabaseExists should return false
|
||||
exists, err := drv.DatabaseExists(u)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, false, exists)
|
||||
|
||||
// create database
|
||||
err = drv.CreateDatabase(u)
|
||||
require.Nil(t, err)
|
||||
|
||||
// DatabaseExists should return true
|
||||
exists, err = drv.DatabaseExists(u)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, exists)
|
||||
}
|
||||
|
||||
func TestSQLiteCreateMigrationsTable(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
db := prepTestSQLiteDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
// migrations table should not exist
|
||||
count := 0
|
||||
err := db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Regexp(t, "no such table: schema_migrations", err.Error())
|
||||
|
||||
// create table
|
||||
err = drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
// migrations table should exist
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
|
||||
// create table should be idempotent
|
||||
err = drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestSQLiteSelectMigrations(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
db := prepTestSQLiteDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
values ('abc2'), ('abc1'), ('abc3')`)
|
||||
require.Nil(t, err)
|
||||
|
||||
migrations, err := drv.SelectMigrations(db, -1)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, migrations["abc1"])
|
||||
require.Equal(t, true, migrations["abc2"])
|
||||
require.Equal(t, true, migrations["abc2"])
|
||||
|
||||
// test limit param
|
||||
migrations, err = drv.SelectMigrations(db, 1)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, true, migrations["abc3"])
|
||||
require.Equal(t, false, migrations["abc1"])
|
||||
require.Equal(t, false, migrations["abc2"])
|
||||
}
|
||||
|
||||
func TestSQLiteInsertMigration(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
db := prepTestSQLiteDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
|
||||
// insert migration
|
||||
err = drv.InsertMigration(db, "abc1")
|
||||
require.Nil(t, err)
|
||||
|
||||
err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'").
|
||||
Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestSQLiteDeleteMigration(t *testing.T) {
|
||||
drv := SQLiteDriver{}
|
||||
db := prepTestSQLiteDB(t)
|
||||
defer mustClose(db)
|
||||
|
||||
err := drv.CreateMigrationsTable(db)
|
||||
require.Nil(t, err)
|
||||
|
||||
_, err = db.Exec(`insert into schema_migrations (version)
|
||||
values ('abc1'), ('abc2')`)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = drv.DeleteMigration(db, "abc2")
|
||||
require.Nil(t, err)
|
||||
|
||||
count := 0
|
||||
err = db.QueryRow("select count(*) from schema_migrations").Scan(&count)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
22
pkg/dbmate/utils.go
Normal file
22
pkg/dbmate/utils.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// databaseName returns the database name from a URL
|
||||
func databaseName(u *url.URL) string {
|
||||
name := u.Path
|
||||
if len(name) > 0 && name[:1] == "/" {
|
||||
name = name[1:]
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
func mustClose(c io.Closer) {
|
||||
if err := c.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
24
pkg/dbmate/utils_test.go
Normal file
24
pkg/dbmate/utils_test.go
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
package dbmate
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDatabaseName(t *testing.T) {
|
||||
u, err := url.Parse("ignore://localhost/foo?query")
|
||||
require.Nil(t, err)
|
||||
|
||||
name := databaseName(u)
|
||||
require.Equal(t, "foo", name)
|
||||
}
|
||||
|
||||
func TestDatabaseName_Empty(t *testing.T) {
|
||||
u, err := url.Parse("ignore://localhost")
|
||||
require.Nil(t, err)
|
||||
|
||||
name := databaseName(u)
|
||||
require.Equal(t, "", name)
|
||||
}
|
||||
4
pkg/dbmate/version.go
Normal file
4
pkg/dbmate/version.go
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
package dbmate
|
||||
|
||||
// Version of dbmate
|
||||
const Version = "1.2.1"
|
||||
Loading…
Add table
Add a link
Reference in a new issue