dbmate/pkg/dbmate/db.go

643 lines
13 KiB
Go
Raw Normal View History

package dbmate
2015-11-25 10:57:58 -08:00
import (
"database/sql"
"errors"
2015-11-25 10:57:58 -08:00
"fmt"
"io"
2015-11-25 10:57:58 -08:00
"net/url"
"os"
"path/filepath"
"regexp"
"sort"
2015-11-25 10:57:58 -08:00
"time"
"github.com/amacneil/dbmate/pkg/dbutil"
2015-11-25 10:57:58 -08:00
)
// DefaultMigrationsDir specifies default directory to find migration files
2018-04-15 18:37:57 -07:00
const DefaultMigrationsDir = "./db/migrations"
// DefaultMigrationsTableName specifies default database tables to record migraitons in
const DefaultMigrationsTableName = "schema_migrations"
// DefaultSchemaFile specifies default location for schema.sql
2018-04-15 18:37:57 -07:00
const DefaultSchemaFile = "./db/schema.sql"
// DefaultWaitInterval specifies length of time between connection attempts
const DefaultWaitInterval = time.Second
// DefaultWaitTimeout specifies maximum time for connection attempts
const DefaultWaitTimeout = 60 * time.Second
// DB allows dbmate actions to be performed on a specified database
type DB struct {
AutoDumpSchema bool
DatabaseURL *url.URL
MigrationsDir string
MigrationsTableName string
SchemaFile string
Verbose bool
WaitBefore bool
WaitInterval time.Duration
WaitTimeout time.Duration
Limit int
TargetVersion string
Log io.Writer
}
// migrationFileRegexp pattern for valid migration files
var migrationFileRegexp = regexp.MustCompile(`^\d.*\.sql$`)
// StatusResult represents an available migration status
type StatusResult struct {
Filename string
Applied bool
}
// New initializes a new dbmate database
func New(databaseURL *url.URL) *DB {
return &DB{
AutoDumpSchema: true,
DatabaseURL: databaseURL,
MigrationsDir: DefaultMigrationsDir,
MigrationsTableName: DefaultMigrationsTableName,
SchemaFile: DefaultSchemaFile,
WaitBefore: false,
WaitInterval: DefaultWaitInterval,
WaitTimeout: DefaultWaitTimeout,
Limit: -1,
TargetVersion: "",
Log: os.Stdout,
}
}
// GetDriver initializes the appropriate database driver
func (db *DB) GetDriver() (Driver, error) {
if db.DatabaseURL == nil || db.DatabaseURL.Scheme == "" {
return nil, errors.New("invalid url, have you set your --url flag or DATABASE_URL environment variable?")
}
driverFunc := drivers[db.DatabaseURL.Scheme]
if driverFunc == nil {
return nil, fmt.Errorf("unsupported driver: %s", db.DatabaseURL.Scheme)
}
config := DriverConfig{
DatabaseURL: db.DatabaseURL,
MigrationsTableName: db.MigrationsTableName,
Log: db.Log,
}
return driverFunc(config), nil
}
2018-04-15 18:37:57 -07:00
// Wait blocks until the database server is available. It does not verify that
// the specified database exists, only that the host is ready to accept connections.
func (db *DB) Wait() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
return db.wait(drv)
}
func (db *DB) wait(drv Driver) error {
2018-04-15 18:37:57 -07:00
// attempt connection to database server
err := drv.Ping()
2018-04-15 18:37:57 -07:00
if err == nil {
// connection successful
return nil
}
fmt.Fprint(db.Log, "Waiting for database")
2018-04-15 18:37:57 -07:00
for i := 0 * time.Second; i < db.WaitTimeout; i += db.WaitInterval {
fmt.Fprint(db.Log, ".")
2018-04-15 18:37:57 -07:00
time.Sleep(db.WaitInterval)
// attempt connection to database server
err = drv.Ping()
2018-04-15 18:37:57 -07:00
if err == nil {
// connection successful
fmt.Fprint(db.Log, "\n")
2018-04-15 18:37:57 -07:00
return nil
}
}
// if we find outselves here, we could not connect within the timeout
fmt.Fprint(db.Log, "\n")
2018-04-15 18:37:57 -07:00
return fmt.Errorf("unable to connect to database: %s", err)
}
// CreateAndMigrate creates the database (if necessary) and runs migrations
func (db *DB) CreateAndMigrate() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
// create database if it does not already exist
// skip this step if we cannot determine status
// (e.g. user does not have list database permission)
exists, err := drv.DatabaseExists()
if err == nil && !exists {
if err := drv.CreateDatabase(); err != nil {
return err
}
}
// migrate
return db.migrate(drv)
}
// Create creates the current database
func (db *DB) Create() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
return drv.CreateDatabase()
}
// Drop drops the current database (if it exists)
func (db *DB) Drop() error {
drv, err := db.GetDriver()
2015-11-25 10:57:58 -08:00
if err != nil {
return err
}
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
return drv.DropDatabase()
}
// DumpSchema writes the current database schema to a file
func (db *DB) DumpSchema() error {
drv, err := db.GetDriver()
2015-11-25 10:57:58 -08:00
if err != nil {
return err
}
return db.dumpSchema(drv)
2015-11-25 10:57:58 -08:00
}
func (db *DB) dumpSchema(drv Driver) error {
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return err
}
defer dbutil.MustClose(sqlDB)
schema, err := drv.DumpSchema(sqlDB)
if err != nil {
return err
}
fmt.Fprintf(db.Log, "Writing: %s\n", db.SchemaFile)
// ensure schema directory exists
if err = ensureDir(filepath.Dir(db.SchemaFile)); err != nil {
return err
}
// write schema to file
return os.WriteFile(db.SchemaFile, schema, 0644)
}
// ensureDir creates a directory if it does not already exist
func ensureDir(dir string) error {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("unable to create directory `%s`", dir)
}
return nil
}
2015-11-25 10:57:58 -08:00
const migrationTemplate = "-- migrate:up\n\n\n-- migrate:down\n\n"
// NewMigration creates a new migration file
func (db *DB) NewMigration(name string) error {
2015-11-25 10:57:58 -08:00
// new migration name
timestamp := time.Now().UTC().Format("20060102150405")
if name == "" {
return fmt.Errorf("please specify a name for the new migration")
2015-11-25 10:57:58 -08:00
}
name = fmt.Sprintf("%s_%s.sql", timestamp, name)
// create migrations dir if missing
if err := ensureDir(db.MigrationsDir); err != nil {
return err
2015-11-25 10:57:58 -08:00
}
// check file does not already exist
path := filepath.Join(db.MigrationsDir, name)
fmt.Fprintf(db.Log, "Creating migration: %s\n", path)
2015-11-25 10:57:58 -08:00
if _, err := os.Stat(path); !os.IsNotExist(err) {
return fmt.Errorf("file already exists")
2015-11-25 10:57:58 -08:00
}
// write new migration
file, err := os.Create(path)
if err != nil {
return err
}
defer dbutil.MustClose(file)
2015-11-25 10:57:58 -08:00
_, err = file.WriteString(migrationTemplate)
2018-04-15 19:59:56 -07:00
return err
2015-11-25 10:57:58 -08:00
}
func doTransaction(sqlDB *sql.DB, txFunc func(dbutil.Transaction) error) error {
tx, err := sqlDB.Begin()
2015-11-25 10:57:58 -08:00
if err != nil {
return err
}
if err := txFunc(tx); err != nil {
2015-11-27 14:27:44 -08:00
if err1 := tx.Rollback(); err1 != nil {
return err1
}
2015-11-25 10:57:58 -08:00
return err
}
return tx.Commit()
}
func (db *DB) openDatabaseForMigration(drv Driver) (*sql.DB, error) {
sqlDB, err := drv.Open()
2015-11-25 10:57:58 -08:00
if err != nil {
return nil, err
2015-11-25 10:57:58 -08:00
}
if err := drv.CreateMigrationsTable(sqlDB); err != nil {
dbutil.MustClose(sqlDB)
return nil, err
2015-11-25 10:57:58 -08:00
}
return sqlDB, nil
2015-11-25 12:26:57 -08:00
}
// Migrate migrates database to the latest version
func (db *DB) Migrate() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
return db.migrate(drv)
}
func (db *DB) migrate(drv Driver) error {
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
2015-11-25 10:57:58 -08:00
if err != nil {
return err
}
if len(files) == 0 {
return fmt.Errorf("no migration files found")
2015-11-25 12:26:57 -08:00
}
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
sqlDB, err := db.openDatabaseForMigration(drv)
2015-11-25 12:26:57 -08:00
if err != nil {
2015-11-25 10:57:58 -08:00
return err
}
defer dbutil.MustClose(sqlDB)
2015-11-25 10:57:58 -08:00
applied, err := drv.SelectMigrations(sqlDB, db.Limit)
2015-11-25 10:57:58 -08:00
if err != nil {
return err
}
for _, filename := range files {
2015-11-25 10:57:58 -08:00
ver := migrationVersion(filename)
if ok := applied[ver]; ok && ver != db.TargetVersion {
2015-11-25 10:57:58 -08:00
// migration already applied
continue
}
fmt.Fprintf(db.Log, "Applying: %s\n", filename)
2015-11-25 10:57:58 -08:00
up, _, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
2015-11-25 10:57:58 -08:00
if err != nil {
return err
}
execMigration := func(tx dbutil.Transaction) error {
2015-11-25 10:57:58 -08:00
// run actual migration
result, err := tx.Exec(up.Contents)
if err != nil {
2015-11-25 10:57:58 -08:00
return err
} else if db.Verbose {
db.printVerbose(result)
2015-11-25 10:57:58 -08:00
}
// record migration
2018-01-07 12:52:10 -08:00
return drv.InsertMigration(tx, ver)
}
if up.Options.Transaction() {
// begin transaction
err = doTransaction(sqlDB, execMigration)
} else {
// run outside of transaction
err = execMigration(sqlDB)
}
if err != nil {
return err
}
if ver == db.TargetVersion {
fmt.Fprintf(db.Log, "Reached target version %s\n", ver)
break
}
2015-11-25 10:57:58 -08:00
}
// automatically update schema file, silence errors
if db.AutoDumpSchema {
_ = db.dumpSchema(drv)
}
2015-11-25 10:57:58 -08:00
return nil
}
func (db *DB) printVerbose(result sql.Result) {
lastInsertID, err := result.LastInsertId()
if err == nil {
fmt.Fprintf(db.Log, "Last insert ID: %d\n", lastInsertID)
}
rowsAffected, err := result.RowsAffected()
if err == nil {
fmt.Fprintf(db.Log, "Rows affected: %d\n", rowsAffected)
}
}
2015-11-25 12:26:57 -08:00
func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
files, err := os.ReadDir(dir)
2015-11-25 10:57:58 -08:00
if err != nil {
return nil, fmt.Errorf("could not find migrations directory `%s`", dir)
2015-11-25 10:57:58 -08:00
}
2015-11-25 12:26:57 -08:00
matches := []string{}
2015-11-25 10:57:58 -08:00
for _, file := range files {
if file.IsDir() {
continue
}
name := file.Name()
2015-11-25 12:26:57 -08:00
if !re.MatchString(name) {
2015-11-25 10:57:58 -08:00
continue
}
2015-11-25 12:26:57 -08:00
matches = append(matches, name)
}
sort.Strings(matches)
2015-11-25 10:57:58 -08:00
return matches, nil
2015-11-25 10:57:58 -08:00
}
2015-11-25 12:26:57 -08:00
func findMigrationFile(dir string, ver string) (string, error) {
if ver == "" {
panic("migration version is required")
}
ver = regexp.QuoteMeta(ver)
re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver))
files, err := findMigrationFiles(dir, re)
if err != nil {
return "", err
}
if len(files) == 0 {
return "", fmt.Errorf("can't find migration file: %s*.sql", ver)
2015-11-25 12:26:57 -08:00
}
return files[0], nil
}
2015-11-25 10:57:58 -08:00
func migrationVersion(filename string) string {
return regexp.MustCompile(`^\d+`).FindString(filename)
}
// Rollback rolls back the most recent migration
func (db *DB) Rollback() error {
drv, err := db.GetDriver()
if err != nil {
return err
}
if db.WaitBefore {
err := db.wait(drv)
if err != nil {
return err
}
}
sqlDB, err := db.openDatabaseForMigration(drv)
2015-11-25 12:26:57 -08:00
if err != nil {
return err
}
defer dbutil.MustClose(sqlDB)
2015-11-25 12:26:57 -08:00
limit := db.Limit
// default limit is -1, if we don't specify a version it should only rollback one version, not all
if limit <= 0 && db.TargetVersion == "" {
limit = 1
}
applied, err := drv.SelectMigrations(sqlDB, limit)
2015-11-25 12:26:57 -08:00
if err != nil {
return err
}
if len(applied) == 0 {
return fmt.Errorf("can't rollback, no migrations found")
2015-11-25 12:26:57 -08:00
}
var versions []string
for v := range applied {
versions = append(versions, v)
2015-11-25 12:26:57 -08:00
}
// new → old
sort.Sort(sort.Reverse(sort.StringSlice(versions)))
2015-11-25 12:26:57 -08:00
if db.TargetVersion != "" {
cache := map[string]bool{}
found := false
// latest version comes first, so take every version until the version matches
for _, ver := range versions {
if ver == db.TargetVersion {
found = true
break
}
cache[ver] = true
}
if !found {
return fmt.Errorf("target version not found")
}
applied = cache
2015-11-25 12:26:57 -08:00
}
for version := range applied {
filename, err := findMigrationFile(db.MigrationsDir, version)
if err != nil {
2015-11-25 12:26:57 -08:00
return err
}
fmt.Fprintf(db.Log, "Rolling back: %s\n", filename)
_, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
if err != nil {
return err
}
execMigration := func(tx dbutil.Transaction) error {
// rollback migration
result, err := tx.Exec(down.Contents)
if err != nil {
return err
} else if db.Verbose {
db.printVerbose(result)
}
// remove migration record
return drv.DeleteMigration(tx, version)
}
if down.Options.Transaction() {
// begin transaction
err = doTransaction(sqlDB, execMigration)
} else {
// run outside of transaction
err = execMigration(sqlDB)
}
if err != nil {
return err
}
}
2015-11-25 12:26:57 -08:00
// automatically update schema file, silence errors
if db.AutoDumpSchema {
_ = db.dumpSchema(drv)
}
2015-11-25 12:26:57 -08:00
return nil
}
// Status shows the status of all migrations
func (db *DB) Status(quiet bool) (int, error) {
drv, err := db.GetDriver()
if err != nil {
return -1, err
}
results, err := db.CheckMigrationsStatus(drv)
if err != nil {
return -1, err
}
var totalApplied int
var line string
for _, res := range results {
if res.Applied {
line = fmt.Sprintf("[X] %s", res.Filename)
totalApplied++
} else {
line = fmt.Sprintf("[ ] %s", res.Filename)
}
if !quiet {
fmt.Fprintln(db.Log, line)
}
}
totalPending := len(results) - totalApplied
if !quiet {
fmt.Fprintln(db.Log)
fmt.Fprintf(db.Log, "Applied: %d\n", totalApplied)
fmt.Fprintf(db.Log, "Pending: %d\n", totalPending)
}
return totalPending, nil
}
// CheckMigrationsStatus returns the status of all available mgirations
func (db *DB) CheckMigrationsStatus(drv Driver) ([]StatusResult, error) {
files, err := findMigrationFiles(db.MigrationsDir, migrationFileRegexp)
if err != nil {
return nil, err
}
if len(files) == 0 {
return nil, fmt.Errorf("no migration files found")
}
sqlDB, err := db.openDatabaseForMigration(drv)
if err != nil {
return nil, err
}
defer dbutil.MustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, db.Limit)
if err != nil {
return nil, err
}
var results []StatusResult
for _, filename := range files {
ver := migrationVersion(filename)
res := StatusResult{Filename: filename}
if ok := applied[ver]; ok {
res.Applied = true
} else {
res.Applied = false
}
results = append(results, res)
}
return results, nil
}