mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2026-02-02 17:35:08 +01:00
Implement rollback command
This commit is contained in:
parent
ece5d3cf0e
commit
1c4cf2c122
4 changed files with 138 additions and 25 deletions
145
commands.go
145
commands.go
|
|
@ -107,6 +107,29 @@ func doTransaction(db *sql.DB, txFunc func(shared.Transaction) error) error {
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func openDatabaseForMigration(ctx *cli.Context) (driver.Driver, *sql.DB, error) {
|
||||||
|
u, err := GetDatabaseURL(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
drv, err := driver.Get(u.Scheme)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := drv.Open(u)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := drv.CreateMigrationsTable(db); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return drv, db, nil
|
||||||
|
}
|
||||||
|
|
||||||
// MigrateCommand migrates database to the latest version
|
// MigrateCommand migrates database to the latest version
|
||||||
func MigrateCommand(ctx *cli.Context) error {
|
func MigrateCommand(ctx *cli.Context) error {
|
||||||
migrationsDir := ctx.GlobalString("migrations-dir")
|
migrationsDir := ctx.GlobalString("migrations-dir")
|
||||||
|
|
@ -119,27 +142,15 @@ func MigrateCommand(ctx *cli.Context) error {
|
||||||
return fmt.Errorf("No migration files found.")
|
return fmt.Errorf("No migration files found.")
|
||||||
}
|
}
|
||||||
|
|
||||||
u, err := GetDatabaseURL(ctx)
|
drv, db, err := openDatabaseForMigration(ctx)
|
||||||
|
if db != nil {
|
||||||
|
defer db.Close()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
drv, err := driver.Get(u.Scheme)
|
applied, err := drv.SelectMigrations(db, -1)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := drv.Open(u)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
if err := drv.CreateMigrationsTable(db); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
applied, err := drv.SelectMigrations(db)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -178,31 +189,66 @@ func MigrateCommand(ctx *cli.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func findAvailableMigrations(dir string) (map[string]struct{}, error) {
|
func findMigrationFiles(dir string, re *regexp.Regexp) ([]string, error) {
|
||||||
files, err := ioutil.ReadDir(dir)
|
files, err := ioutil.ReadDir(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Could not find migrations directory `%s`.", dir)
|
return nil, fmt.Errorf("Could not find migrations directory `%s`.", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
nameRegexp := regexp.MustCompile(`^\d.*\.sql$`)
|
matches := []string{}
|
||||||
migrations := map[string]struct{}{}
|
|
||||||
|
|
||||||
for _, file := range files {
|
for _, file := range files {
|
||||||
if file.IsDir() {
|
if file.IsDir() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
name := file.Name()
|
name := file.Name()
|
||||||
if !nameRegexp.MatchString(name) {
|
if !re.MatchString(name) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
matches = append(matches, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func findAvailableMigrations(dir string) (map[string]struct{}, error) {
|
||||||
|
re := regexp.MustCompile(`^\d.*\.sql$`)
|
||||||
|
files, err := findMigrationFiles(dir, re)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// why does go not have Set?
|
||||||
|
// convert into map for easier lookups
|
||||||
|
migrations := map[string]struct{}{}
|
||||||
|
for _, name := range files {
|
||||||
migrations[name] = struct{}{}
|
migrations[name] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return migrations, nil
|
return migrations, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func findMigrationFile(dir string, ver string) (string, error) {
|
||||||
|
if ver == "" {
|
||||||
|
panic("migration version is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
ver = regexp.QuoteMeta(ver)
|
||||||
|
re := regexp.MustCompile(fmt.Sprintf(`^%s.*\.sql$`, ver))
|
||||||
|
|
||||||
|
files, err := findMigrationFiles(dir, re)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(files) == 0 {
|
||||||
|
return "", fmt.Errorf("Can't find migration file: %s*.sql", ver)
|
||||||
|
}
|
||||||
|
|
||||||
|
return files[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
func migrationVersion(filename string) string {
|
func migrationVersion(filename string) string {
|
||||||
return regexp.MustCompile(`^\d+`).FindString(filename)
|
return regexp.MustCompile(`^\d+`).FindString(filename)
|
||||||
}
|
}
|
||||||
|
|
@ -243,3 +289,58 @@ func parseMigration(path string) (map[string]string, error) {
|
||||||
|
|
||||||
return migrations, nil
|
return migrations, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RollbackCommand rolls back the most recent migration
|
||||||
|
func RollbackCommand(ctx *cli.Context) error {
|
||||||
|
drv, db, err := openDatabaseForMigration(ctx)
|
||||||
|
if db != nil {
|
||||||
|
defer db.Close()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
applied, err := drv.SelectMigrations(db, 1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// grab most recent applied migration (applied has len=1)
|
||||||
|
latest := ""
|
||||||
|
for ver := range applied {
|
||||||
|
latest = ver
|
||||||
|
}
|
||||||
|
if latest == "" {
|
||||||
|
return fmt.Errorf("Can't rollback: no migrations have been applied.")
|
||||||
|
}
|
||||||
|
|
||||||
|
migrationsDir := ctx.GlobalString("migrations-dir")
|
||||||
|
filename, err := findMigrationFile(migrationsDir, latest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Rolling back: %s\n", filename)
|
||||||
|
|
||||||
|
migration, err := parseMigration(filepath.Join(migrationsDir, filename))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// begin transaction
|
||||||
|
doTransaction(db, func(tx shared.Transaction) error {
|
||||||
|
// rollback migration
|
||||||
|
if _, err := tx.Exec(migration["down"]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove migration record
|
||||||
|
if err := drv.DeleteMigration(tx, latest); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ type Driver interface {
|
||||||
CreateDatabase(*url.URL) error
|
CreateDatabase(*url.URL) error
|
||||||
DropDatabase(*url.URL) error
|
DropDatabase(*url.URL) error
|
||||||
CreateMigrationsTable(*sql.DB) error
|
CreateMigrationsTable(*sql.DB) error
|
||||||
SelectMigrations(*sql.DB) (map[string]struct{}, error)
|
SelectMigrations(*sql.DB, int) (map[string]struct{}, error)
|
||||||
InsertMigration(shared.Transaction, string) error
|
InsertMigration(shared.Transaction, string) error
|
||||||
DeleteMigration(shared.Transaction, string) error
|
DeleteMigration(shared.Transaction, string) error
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -67,8 +67,13 @@ func (postgres Driver) CreateMigrationsTable(db *sql.DB) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectMigrations returns a list of applied migrations
|
// SelectMigrations returns a list of applied migrations
|
||||||
func (postgres Driver) SelectMigrations(db *sql.DB) (map[string]struct{}, error) {
|
// with an optional limit (in descending order)
|
||||||
rows, err := db.Query("SELECT version FROM schema_migrations")
|
func (postgres Driver) SelectMigrations(db *sql.DB, limit int) (map[string]struct{}, error) {
|
||||||
|
query := "SELECT version FROM schema_migrations ORDER BY version DESC"
|
||||||
|
if limit >= 0 {
|
||||||
|
query = fmt.Sprintf("%s LIMIT %d", query, limit)
|
||||||
|
}
|
||||||
|
rows, err := db.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
7
main.go
7
main.go
|
|
@ -42,6 +42,13 @@ func NewApp() *cli.App {
|
||||||
runCommand(MigrateCommand, ctx)
|
runCommand(MigrateCommand, ctx)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Name: "rollback",
|
||||||
|
Usage: "Rollback the most recent migration",
|
||||||
|
Action: func(ctx *cli.Context) {
|
||||||
|
runCommand(RollbackCommand, ctx)
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Name: "new",
|
Name: "new",
|
||||||
Usage: "Generate a new migration file",
|
Usage: "Generate a new migration file",
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue