diff --git a/main.go b/main.go index 2ed5875..3899cbe 100644 --- a/main.go +++ b/main.go @@ -109,6 +109,7 @@ func NewApp() *cli.App { }, }, Action: action(func(db *dbmate.DB, c *cli.Context) error { + db.TargetVersion = c.Args().First() db.Verbose = c.Bool("verbose") return db.CreateAndMigrate() }), @@ -129,7 +130,7 @@ func NewApp() *cli.App { }, { Name: "migrate", - Usage: "Migrate to the latest version", + Usage: "Migrate to the specified or latest version", Flags: []cli.Flag{ &cli.BoolFlag{ Name: "verbose", @@ -139,6 +140,7 @@ func NewApp() *cli.App { }, }, Action: action(func(db *dbmate.DB, c *cli.Context) error { + db.TargetVersion = c.Args().First() db.Verbose = c.Bool("verbose") return db.Migrate() }), @@ -154,8 +156,16 @@ func NewApp() *cli.App { EnvVars: []string{"DBMATE_VERBOSE"}, Usage: "print the result of each statement execution", }, + &cli.IntFlag{ + Name: "limit", + Aliases: []string{"l"}, + Usage: "Limits the amount of rollbacks (defaults to 1 if no target version is specified)", + Value: -1, + }, }, Action: action(func(db *dbmate.DB, c *cli.Context) error { + db.TargetVersion = c.Args().First() + db.Limit = c.Int("limit") db.Verbose = c.Bool("verbose") return db.Rollback() }), diff --git a/pkg/dbmate/db.go b/pkg/dbmate/db.go index f64dd88..8240207 100644 --- a/pkg/dbmate/db.go +++ b/pkg/dbmate/db.go @@ -41,6 +41,8 @@ type DB struct { WaitBefore bool WaitInterval time.Duration WaitTimeout time.Duration + Limit int + TargetVersion string Log io.Writer } @@ -64,6 +66,8 @@ func New(databaseURL *url.URL) *DB { WaitBefore: false, WaitInterval: DefaultWaitInterval, WaitTimeout: DefaultWaitTimeout, + Limit: -1, + TargetVersion: "", Log: os.Stdout, } } @@ -336,14 +340,14 @@ func (db *DB) migrate(drv Driver) error { } defer dbutil.MustClose(sqlDB) - applied, err := drv.SelectMigrations(sqlDB, -1) + applied, err := drv.SelectMigrations(sqlDB, db.Limit) if err != nil { return err } for _, filename := range files { ver := migrationVersion(filename) - if ok := applied[ver]; ok { + if ok := applied[ver]; ok && ver != db.TargetVersion { // migration already applied continue } @@ -379,6 +383,11 @@ func (db *DB) migrate(drv Driver) error { if err != nil { return err } + + if ver == db.TargetVersion { + fmt.Fprintf(db.Log, "Reached target version %s\n", ver) + break + } } // automatically update schema file, silence errors @@ -469,55 +478,83 @@ func (db *DB) Rollback() error { } defer dbutil.MustClose(sqlDB) - applied, err := drv.SelectMigrations(sqlDB, 1) + limit := db.Limit + // default limit is -1, if we don't specify a version it should only rollback one version, not all + if limit <= 0 && db.TargetVersion == "" { + limit = 1 + } + + applied, err := drv.SelectMigrations(sqlDB, limit) if err != nil { return err } - // grab most recent applied migration (applied has len=1) - latest := "" - for ver := range applied { - latest = ver - } - if latest == "" { - return fmt.Errorf("can't rollback: no migrations have been applied") + if len(applied) == 0 { + return fmt.Errorf("can't rollback, no migrations found") } - filename, err := findMigrationFile(db.MigrationsDir, latest) - if err != nil { - return err + var versions []string + for v := range applied { + versions = append(versions, v) } - fmt.Fprintf(db.Log, "Rolling back: %s\n", filename) + // new → old + sort.Sort(sort.Reverse(sort.StringSlice(versions))) - _, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename)) - if err != nil { - return err + if db.TargetVersion != "" { + cache := map[string]bool{} + found := false + + // latest version comes first, so take every version until the version matches + for _, ver := range versions { + if ver == db.TargetVersion { + found = true + break + } + cache[ver] = true + } + if !found { + return fmt.Errorf("target version not found") + } + applied = cache } - execMigration := func(tx dbutil.Transaction) error { - // rollback migration - result, err := tx.Exec(down.Contents) + for version := range applied { + filename, err := findMigrationFile(db.MigrationsDir, version) if err != nil { return err - } else if db.Verbose { - db.printVerbose(result) } - // remove migration record - return drv.DeleteMigration(tx, latest) - } + fmt.Fprintf(db.Log, "Rolling back: %s\n", filename) + _, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename)) + if err != nil { + return err + } - if down.Options.Transaction() { - // begin transaction - err = doTransaction(sqlDB, execMigration) - } else { - // run outside of transaction - err = execMigration(sqlDB) - } + execMigration := func(tx dbutil.Transaction) error { + // rollback migration + result, err := tx.Exec(down.Contents) + if err != nil { + return err + } else if db.Verbose { + db.printVerbose(result) + } - if err != nil { - return err + // remove migration record + return drv.DeleteMigration(tx, version) + } + + if down.Options.Transaction() { + // begin transaction + err = doTransaction(sqlDB, execMigration) + } else { + // run outside of transaction + err = execMigration(sqlDB) + } + + if err != nil { + return err + } } // automatically update schema file, silence errors @@ -582,7 +619,7 @@ func (db *DB) CheckMigrationsStatus(drv Driver) ([]StatusResult, error) { } defer dbutil.MustClose(sqlDB) - applied, err := drv.SelectMigrations(sqlDB, -1) + applied, err := drv.SelectMigrations(sqlDB, db.Limit) if err != nil { return nil, err }