mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2025-12-11 23:50:04 +01:00
Implement rollback command
This commit is contained in:
parent
ece5d3cf0e
commit
1c4cf2c122
4 changed files with 138 additions and 25 deletions
145
commands.go
145
commands.go
|
|
@ -107,6 +107,29 @@ func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
|
|||
return tx.Commit()
|
||||
}
|
||||
|
||||
func openDatabaseForMigration(ctx *cli.Context) (driver.Driver, *sql.DB, error) {
|
||||
u, err := GetDatabaseURL(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
drv, err := driver.Get(u.Scheme)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
db, err := drv.Open(u)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := drv.CreateMigrationsTable(db); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return drv, db, nil
|
||||
}
|
||||
|
||||
// MigrateCommand migrates database to the latest version
|
||||
func MigrateCommand(ctx *cli.Context) error {
|
||||
migrationsDir := ctx.GlobalString("migrations-dir")
|
||||
|
|
@ -119,27 +142,15 @@ func MigrateCommand(ctx *cli.Context) error {
|
|||
return fmt.Errorf("No migration files found.")
|
||||
}
|
||||
|
||||
u, err := GetDatabaseURL(ctx)
|
||||
drv, db, err := openDatabaseForMigration(ctx)
|
||||
if db != nil {
|
||||
defer db.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
drv, err := driver.Get(u.Scheme)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db, err := drv.Open(u)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err := drv.CreateMigrationsTable(db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applied, err := drv.SelectMigrations(db)
|
||||
applied, err := drv.SelectMigrations(db, -1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -178,31 +189,66 @@ func MigrateCommand(ctx *cli.Context) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func findAvailableMigrations(dir string) (map[string]struct{}, error) {
|
||||
func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Could not find migrations directory `%s`.", dir)
|
||||
}
|
||||
|
||||
nameRegexp := regexp.MustCompile(`^\d.*\.sql$`)
|
||||
migrations := map[string]struct{}{}
|
||||
|
||||
matches := []string{}
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
name := file.Name()
|
||||
if !nameRegexp.MatchString(name) {
|
||||
if !re.MatchString(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
matches = append(matches, name)
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
}
|
||||
|
||||
func findAvailableMigrations(dir string) (map[string]struct{}, error) {
|
||||
re := regexp.MustCompile(`^\d.*\.sql$`)
|
||||
files, err := findMigrationFiles(dir, re)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// why does go not have Set?
|
||||
// convert into map for easier lookups
|
||||
migrations := map[string]struct{}{}
|
||||
for _, name := range files {
|
||||
migrations[name] = struct{}{}
|
||||
}
|
||||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
func findMigrationFile(dir string, ver string) (string, error) {
|
||||
if ver == "" {
|
||||
panic("migration version is required")
|
||||
}
|
||||
|
||||
ver = regexp.QuoteMeta(ver)
|
||||
re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver))
|
||||
|
||||
files, err := findMigrationFiles(dir, re)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return "", fmt.Errorf("Can't find migration file: %s*.sql", ver)
|
||||
}
|
||||
|
||||
return files[0], nil
|
||||
}
|
||||
|
||||
func migrationVersion(filename string) string {
|
||||
return regexp.MustCompile(`^\d+`).FindString(filename)
|
||||
}
|
||||
|
|
@ -243,3 +289,58 @@ func parseMigration(path string) (map[string]string, error) {
|
|||
|
||||
return migrations, nil
|
||||
}
|
||||
|
||||
// RollbackCommand rolls back the most recent migration
|
||||
func RollbackCommand(ctx *cli.Context) error {
|
||||
drv, db, err := openDatabaseForMigration(ctx)
|
||||
if db != nil {
|
||||
defer db.Close()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applied, err := drv.SelectMigrations(db, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// grab most recent applied migration (applied has len=1)
|
||||
latest := ""
|
||||
for ver := range applied {
|
||||
latest = ver
|
||||
}
|
||||
if latest == "" {
|
||||
return fmt.Errorf("Can't rollback: no migrations have been applied.")
|
||||
}
|
||||
|
||||
migrationsDir := ctx.GlobalString("migrations-dir")
|
||||
filename, err := findMigrationFile(migrationsDir, latest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Rolling back: %s\n", filename)
|
||||
|
||||
migration, err := parseMigration(filepath.Join(migrationsDir, filename))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// begin transaction
|
||||
doTransaction(db, func(tx shared.Transaction) error {
|
||||
// rollback migration
|
||||
if _, err := tx.Exec(migration["down"]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove migration record
|
||||
if err := drv.DeleteMigration(tx, latest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ type Driver interface {
|
|||
CreateDatabase(*url.URL) error
|
||||
DropDatabase(*url.URL) error
|
||||
CreateMigrationsTable(*sql.DB) error
|
||||
SelectMigrations(*sql.DB) (map[string]struct{}, error)
|
||||
SelectMigrations(*sql.DB, int) (map[string]struct{}, error)
|
||||
InsertMigration(shared.Transaction, string) error
|
||||
DeleteMigration(shared.Transaction, string) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,8 +67,13 @@ func (postgres Driver) CreateMigrationsTable(db *sql.DB) error {
|
|||
}
|
||||
|
||||
// SelectMigrations returns a list of applied migrations
|
||||
func (postgres Driver) SelectMigrations(db *sql.DB) (map[string]struct{}, error) {
|
||||
rows, err := db.Query("SELECT version FROM schema_migrations")
|
||||
// with an optional limit (in descending order)
|
||||
func (postgres Driver) SelectMigrations(db *sql.DB, limit int) (map[string]struct{}, error) {
|
||||
query := "SELECT version FROM schema_migrations ORDER BY version DESC"
|
||||
if limit >= 0 {
|
||||
query = fmt.Sprintf("%s LIMIT %d", query, limit)
|
||||
}
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
7
main.go
7
main.go
|
|
@ -42,6 +42,13 @@ func NewApp() *cli.App {
|
|||
runCommand(MigrateCommand, ctx)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "rollback",
|
||||
Usage: "Rollback the most recent migration",
|
||||
Action: func(ctx *cli.Context) {
|
||||
runCommand(RollbackCommand, ctx)
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "new",
|
||||
Usage: "Generate a new migration file",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue