Implement rollback command

This commit is contained in:
Adrian Macneil 2015-11-25 12:26:57 -08:00
parent ece5d3cf0e
commit 1c4cf2c122
4 changed files with 138 additions and 25 deletions

View file

@ -107,6 +107,29 @@ func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
return tx.Commit() return tx.Commit()
} }
func openDatabaseForMigration(ctx *cli.Context) (driver.Driver, *sql.DB, error) {
u, err := GetDatabaseURL(ctx)
if err != nil {
return nil, nil, err
}
drv, err := driver.Get(u.Scheme)
if err != nil {
return nil, nil, err
}
db, err := drv.Open(u)
if err != nil {
return nil, nil, err
}
if err := drv.CreateMigrationsTable(db); err != nil {
return nil, nil, err
}
return drv, db, nil
}
// MigrateCommand migrates database to the latest version // MigrateCommand migrates database to the latest version
func MigrateCommand(ctx *cli.Context) error { func MigrateCommand(ctx *cli.Context) error {
migrationsDir := ctx.GlobalString("migrations-dir") migrationsDir := ctx.GlobalString("migrations-dir")
@ -119,27 +142,15 @@ func MigrateCommand(ctx *cli.Context) error {
return fmt.Errorf("No migration files found.") return fmt.Errorf("No migration files found.")
} }
u, err := GetDatabaseURL(ctx) drv, db, err := openDatabaseForMigration(ctx)
if err != nil { if db != nil {
return err
}
drv, err := driver.Get(u.Scheme)
if err != nil {
return err
}
db, err := drv.Open(u)
if err != nil {
return err
}
defer db.Close() defer db.Close()
}
if err := drv.CreateMigrationsTable(db); err != nil { if err != nil {
return err return err
} }
applied, err := drv.SelectMigrations(db) applied, err := drv.SelectMigrations(db, -1)
if err != nil { if err != nil {
return err return err
} }
@ -178,31 +189,66 @@ func MigrateCommand(ctx *cli.Context) error {
return nil return nil
} }
func findAvailableMigrations(dir string) (map[string]struct{}, error) { func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
files, err := ioutil.ReadDir(dir) files, err := ioutil.ReadDir(dir)
if err != nil { if err != nil {
return nil, fmt.Errorf("Could not find migrations directory `%s`.", dir) return nil, fmt.Errorf("Could not find migrations directory `%s`.", dir)
} }
nameRegexp := regexp.MustCompile(`^\d.*\.sql$`) matches := []string{}
migrations := map[string]struct{}{}
for _, file := range files { for _, file := range files {
if file.IsDir() { if file.IsDir() {
continue continue
} }
name := file.Name() name := file.Name()
if !nameRegexp.MatchString(name) { if !re.MatchString(name) {
continue continue
} }
matches = append(matches, name)
}
return matches, nil
}
func findAvailableMigrations(dir string) (map[string]struct{}, error) {
re := regexp.MustCompile(`^\d.*\.sql$`)
files, err := findMigrationFiles(dir, re)
if err != nil {
return nil, err
}
// why does go not have Set?
// convert into map for easier lookups
migrations := map[string]struct{}{}
for _, name := range files {
migrations[name] = struct{}{} migrations[name] = struct{}{}
} }
return migrations, nil return migrations, nil
} }
func findMigrationFile(dir string, ver string) (string, error) {
if ver == "" {
panic("migration version is required")
}
ver = regexp.QuoteMeta(ver)
re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver))
files, err := findMigrationFiles(dir, re)
if err != nil {
return "", err
}
if len(files) == 0 {
return "", fmt.Errorf("Can't find migration file: %s*.sql", ver)
}
return files[0], nil
}
func migrationVersion(filename string) string { func migrationVersion(filename string) string {
return regexp.MustCompile(`^\d+`).FindString(filename) return regexp.MustCompile(`^\d+`).FindString(filename)
} }
@ -243,3 +289,58 @@ func parseMigration(path string) (map[string]string, error) {
return migrations, nil return migrations, nil
} }
// RollbackCommand rolls back the most recent migration
func RollbackCommand(ctx *cli.Context) error {
drv, db, err := openDatabaseForMigration(ctx)
if db != nil {
defer db.Close()
}
if err != nil {
return err
}
applied, err := drv.SelectMigrations(db, 1)
if err != nil {
return err
}
// grab most recent applied migration (applied has len=1)
latest := ""
for ver := range applied {
latest = ver
}
if latest == "" {
return fmt.Errorf("Can't rollback: no migrations have been applied.")
}
migrationsDir := ctx.GlobalString("migrations-dir")
filename, err := findMigrationFile(migrationsDir, latest)
if err != nil {
return err
}
fmt.Printf("Rolling back: %s\n", filename)
migration, err := parseMigration(filepath.Join(migrationsDir, filename))
if err != nil {
return err
}
// begin transaction
doTransaction(db, func(tx shared.Transaction) error {
// rollback migration
if _, err := tx.Exec(migration["down"]); err != nil {
return err
}
// remove migration record
if err := drv.DeleteMigration(tx, latest); err != nil {
return err
}
return nil
})
return nil
}

View file

@ -14,7 +14,7 @@ type Driver interface {
CreateDatabase(*url.URL) error CreateDatabase(*url.URL) error
DropDatabase(*url.URL) error DropDatabase(*url.URL) error
CreateMigrationsTable(*sql.DB) error CreateMigrationsTable(*sql.DB) error
SelectMigrations(*sql.DB) (map[string]struct{}, error) SelectMigrations(*sql.DB, int) (map[string]struct{}, error)
InsertMigration(shared.Transaction, string) error InsertMigration(shared.Transaction, string) error
DeleteMigration(shared.Transaction, string) error DeleteMigration(shared.Transaction, string) error
} }

View file

@ -67,8 +67,13 @@ func (postgres Driver) CreateMigrationsTable(db *sql.DB) error {
} }
// SelectMigrations returns a list of applied migrations // SelectMigrations returns a list of applied migrations
func (postgres Driver) SelectMigrations(db *sql.DB) (map[string]struct{}, error) { // with an optional limit (in descending order)
rows, err := db.Query("SELECT version FROM schema_migrations") func (postgres Driver) SelectMigrations(db *sql.DB, limit int) (map[string]struct{}, error) {
query := "SELECT version FROM schema_migrations ORDER BY version DESC"
if limit >= 0 {
query = fmt.Sprintf("%s LIMIT %d", query, limit)
}
rows, err := db.Query(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -42,6 +42,13 @@ func NewApp() *cli.App {
runCommand(MigrateCommand, ctx) runCommand(MigrateCommand, ctx)
}, },
}, },
{
Name: "rollback",
Usage: "Rollback the most recent migration",
Action: func(ctx *cli.Context) {
runCommand(RollbackCommand, ctx)
},
},
{ {
Name: "new", Name: "new",
Usage: "Generate a new migration file", Usage: "Generate a new migration file",