add ability to specify limits and the target version to migrate to

This commit is contained in:
Technofab 2022-01-07 21:12:30 +01:00
parent 5b60f68107
commit a55233c50b
2 changed files with 83 additions and 36 deletions

12
main.go
View file

@ -109,6 +109,7 @@ func NewApp() *cli.App {
}, },
}, },
Action: action(func(db *dbmate.DB, c *cli.Context) error { Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.TargetVersion = c.Args().First()
db.Verbose = c.Bool("verbose") db.Verbose = c.Bool("verbose")
return db.CreateAndMigrate() return db.CreateAndMigrate()
}), }),
@ -129,7 +130,7 @@ func NewApp() *cli.App {
}, },
{ {
Name: "migrate", Name: "migrate",
Usage: "Migrate to the latest version", Usage: "Migrate to the specified or latest version",
Flags: []cli.Flag{ Flags: []cli.Flag{
&cli.BoolFlag{ &cli.BoolFlag{
Name: "verbose", Name: "verbose",
@ -139,6 +140,7 @@ func NewApp() *cli.App {
}, },
}, },
Action: action(func(db *dbmate.DB, c *cli.Context) error { Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.TargetVersion = c.Args().First()
db.Verbose = c.Bool("verbose") db.Verbose = c.Bool("verbose")
return db.Migrate() return db.Migrate()
}), }),
@ -154,8 +156,16 @@ func NewApp() *cli.App {
EnvVars: []string{"DBMATE_VERBOSE"}, EnvVars: []string{"DBMATE_VERBOSE"},
Usage: "print the result of each statement execution", Usage: "print the result of each statement execution",
}, },
&cli.IntFlag{
Name: "limit",
Aliases: []string{"l"},
Usage: "Limits the amount of rollbacks (defaults to 1 if no target version is specified)",
Value: -1,
},
}, },
Action: action(func(db *dbmate.DB, c *cli.Context) error { Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.TargetVersion = c.Args().First()
db.Limit = c.Int("limit")
db.Verbose = c.Bool("verbose") db.Verbose = c.Bool("verbose")
return db.Rollback() return db.Rollback()
}), }),

View file

@ -41,6 +41,8 @@ type DB struct {
WaitBefore bool WaitBefore bool
WaitInterval time.Duration WaitInterval time.Duration
WaitTimeout time.Duration WaitTimeout time.Duration
Limit int
TargetVersion string
Log io.Writer Log io.Writer
} }
@ -64,6 +66,8 @@ func New(databaseURL *url.URL) *DB {
WaitBefore: false, WaitBefore: false,
WaitInterval: DefaultWaitInterval, WaitInterval: DefaultWaitInterval,
WaitTimeout: DefaultWaitTimeout, WaitTimeout: DefaultWaitTimeout,
Limit: -1,
TargetVersion: "",
Log: os.Stdout, Log: os.Stdout,
} }
} }
@ -336,14 +340,14 @@ func (db *DB) migrate(drv Driver) error {
} }
defer dbutil.MustClose(sqlDB) defer dbutil.MustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, -1) applied, err := drv.SelectMigrations(sqlDB, db.Limit)
if err != nil { if err != nil {
return err return err
} }
for _, filename := range files { for _, filename := range files {
ver := migrationVersion(filename) ver := migrationVersion(filename)
if ok := applied[ver]; ok { if ok := applied[ver]; ok && ver != db.TargetVersion {
// migration already applied // migration already applied
continue continue
} }
@ -379,6 +383,11 @@ func (db *DB) migrate(drv Driver) error {
if err != nil { if err != nil {
return err return err
} }
if ver == db.TargetVersion {
fmt.Fprintf(db.Log, "Reached target version %s\n", ver)
break
}
} }
// automatically update schema file, silence errors // automatically update schema file, silence errors
@ -469,27 +478,54 @@ func (db *DB) Rollback() error {
} }
defer dbutil.MustClose(sqlDB) defer dbutil.MustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, 1) limit := db.Limit
// default limit is -1, if we don't specify a version it should only rollback one version, not all
if limit <= 0 && db.TargetVersion == "" {
limit = 1
}
applied, err := drv.SelectMigrations(sqlDB, limit)
if err != nil { if err != nil {
return err return err
} }
// grab most recent applied migration (applied has len=1) if len(applied) == 0 {
latest := "" return fmt.Errorf("can't rollback, no migrations found")
for ver := range applied {
latest = ver
}
if latest == "" {
return fmt.Errorf("can't rollback: no migrations have been applied")
} }
filename, err := findMigrationFile(db.MigrationsDir, latest) var versions []string
for v := range applied {
versions = append(versions, v)
}
// new → old
sort.Sort(sort.Reverse(sort.StringSlice(versions)))
if db.TargetVersion != "" {
cache := map[string]bool{}
found := false
// latest version comes first, so take every version until the version matches
for _, ver := range versions {
if ver == db.TargetVersion {
found = true
break
}
cache[ver] = true
}
if !found {
return fmt.Errorf("target version not found")
}
applied = cache
}
for version := range applied {
filename, err := findMigrationFile(db.MigrationsDir, version)
if err != nil { if err != nil {
return err return err
} }
fmt.Fprintf(db.Log, "Rolling back: %s\n", filename) fmt.Fprintf(db.Log, "Rolling back: %s\n", filename)
_, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename)) _, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
if err != nil { if err != nil {
return err return err
@ -505,7 +541,7 @@ func (db *DB) Rollback() error {
} }
// remove migration record // remove migration record
return drv.DeleteMigration(tx, latest) return drv.DeleteMigration(tx, version)
} }
if down.Options.Transaction() { if down.Options.Transaction() {
@ -519,6 +555,7 @@ func (db *DB) Rollback() error {
if err != nil { if err != nil {
return err return err
} }
}
// automatically update schema file, silence errors // automatically update schema file, silence errors
if db.AutoDumpSchema { if db.AutoDumpSchema {
@ -582,7 +619,7 @@ func (db *DB) CheckMigrationsStatus(drv Driver) ([]StatusResult, error) {
} }
defer dbutil.MustClose(sqlDB) defer dbutil.MustClose(sqlDB)
applied, err := drv.SelectMigrations(sqlDB, -1) applied, err := drv.SelectMigrations(sqlDB, db.Limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }