diff --git a/Makefile b/Makefile index bd076f1..63b3cbb 100644 --- a/Makefile +++ b/Makefile @@ -11,4 +11,4 @@ lint: $(DOCKER) errcheck ./... test: - $(DOCKER) go test -p=1 -v ./... + $(DOCKER) go test -p 1 -v ./... diff --git a/README.md b/README.md index e257d7d..ff8be3a 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Dbmate is a database migration tool, to keep your database schema in sync across ## Features -* Currently supports PostgreSQL only. +* Supports PostgreSQL and MySQL. * Powerful, [purpose-built DSL](https://en.wikipedia.org/wiki/SQL#Data_definition) for writing schema migrations. * Migrations are timestamp-versioned, to avoid version number conflicts with multiple developers. * Supports creating and dropping databases (handy in development/test). diff --git a/commands_test.go b/commands_test.go index 4d90ef3..9043ac4 100644 --- a/commands_test.go +++ b/commands_test.go @@ -52,11 +52,25 @@ func postgresTestURL(t *testing.T) *url.URL { return u } +func mysqlTestURL(t *testing.T) *url.URL { + str := os.Getenv("MYSQL_PORT") + require.NotEmpty(t, str, "missing MYSQL_PORT environment variable") + + u, err := url.Parse(str) + require.Nil(t, err) + + u.Scheme = "mysql" + u.User = url.UserPassword("root", "root") + u.Path = "/dbmate" + + return u +} + func testURLs(t *testing.T) []*url.URL { return []*url.URL{ postgresTestURL(t), + mysqlTestURL(t), } - } func mustClose(c io.Closer) { @@ -178,7 +192,8 @@ func testRollbackCommandURL(t *testing.T, u *url.URL) { require.Equal(t, 0, count) err = db.QueryRow("select count(*) from users").Scan(&count) - require.Equal(t, "pq: relation \"users\" does not exist", err.Error()) + require.NotNil(t, err) + require.Regexp(t, "(does not exist|doesn't exist)", err.Error()) } func TestRollbackCommand(t *testing.T) { diff --git a/docker-compose.yml b/docker-compose.yml index c5e065b..2a578f6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,8 +1,13 @@ -postgres: - image: postgres:9.4 dbmate: build: . volumes: - .:/go/src/github.com/adrianmacneil/dbmate links: + - mysql - postgres +mysql: + image: mysql:5.7 + environment: + MYSQL_ROOT_PASSWORD: root +postgres: + image: postgres:9.4 diff --git a/driver/driver.go b/driver/driver.go index c783f88..ef69446 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -3,6 +3,7 @@ package driver import ( "database/sql" "fmt" + "github.com/adrianmacneil/dbmate/driver/mysql" "github.com/adrianmacneil/dbmate/driver/postgres" "github.com/adrianmacneil/dbmate/driver/shared" "net/url" @@ -23,6 +24,8 @@ type Driver interface { // Get loads a database driver by name func Get(name string) (Driver, error) { switch name { + case "mysql": + return mysql.Driver{}, nil case "postgres": return postgres.Driver{}, nil default: diff --git a/driver/mysql/mysql.go b/driver/mysql/mysql.go new file mode 100644 index 0000000..adfce26 --- /dev/null +++ b/driver/mysql/mysql.go @@ -0,0 +1,156 @@ +package mysql + +import ( + "database/sql" + "fmt" + "github.com/adrianmacneil/dbmate/driver/shared" + _ "github.com/adrianmacneil/go-mysql" // mysql driver + "io" + "net/url" + "strings" +) + +// Driver provides top level database functions +type Driver struct { +} + +func normalizeURL(u *url.URL) string { + normalizedURL := *u + normalizedURL.Scheme = "" + normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host) + + query := normalizedURL.Query() + query.Set("multiStatements", "true") + normalizedURL.RawQuery = query.Encode() + + str := normalizedURL.String() + return strings.TrimLeft(str, "/") +} + +// Open creates a new database connection +func (drv Driver) Open(u *url.URL) (*sql.DB, error) { + return sql.Open("mysql", normalizeURL(u)) +} + +func (drv Driver) openRootDB(u *url.URL) (*sql.DB, error) { + // connect to no particular database + rootURL := *u + rootURL.Path = "/" + + return drv.Open(&rootURL) +} + +func mustClose(c io.Closer) { + if err := c.Close(); err != nil { + panic(err) + } +} + +func quoteIdentifier(str string) string { + str = strings.Replace(str, "`", "\\`", -1) + + return fmt.Sprintf("`%s`", str) +} + +// CreateDatabase creates the specified database +func (drv Driver) CreateDatabase(u *url.URL) error { + name := shared.DatabaseName(u) + fmt.Printf("Creating: %s\n", name) + + db, err := drv.openRootDB(u) + if err != nil { + return err + } + defer mustClose(db) + + _, err = db.Exec(fmt.Sprintf("create database %s", + quoteIdentifier(name))) + + return err +} + +// DropDatabase drops the specified database (if it exists) +func (drv Driver) DropDatabase(u *url.URL) error { + name := shared.DatabaseName(u) + fmt.Printf("Dropping: %s\n", name) + + db, err := drv.openRootDB(u) + if err != nil { + return err + } + defer mustClose(db) + + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", + quoteIdentifier(name))) + + return err +} + +// DatabaseExists determines whether the database exists +func (drv Driver) DatabaseExists(u *url.URL) (bool, error) { + name := shared.DatabaseName(u) + + db, err := drv.openRootDB(u) + if err != nil { + return false, err + } + defer mustClose(db) + + exists := false + err = db.QueryRow(`select true from information_schema.schemata + where schema_name = ?`, name).Scan(&exists) + if err == sql.ErrNoRows { + return false, nil + } + + return exists, err +} + +// CreateMigrationsTable creates the schema_migrations table +func (drv Driver) CreateMigrationsTable(db *sql.DB) error { + _, err := db.Exec(`create table if not exists schema_migrations ( + version varchar(255) primary key)`) + + return err +} + +// SelectMigrations returns a list of applied migrations +// with an optional limit (in descending order) +func (drv Driver) SelectMigrations(db *sql.DB, limit int) (map[string]bool, error) { + query := "select version from schema_migrations order by version desc" + if limit >= 0 { + query = fmt.Sprintf("%s limit %d", query, limit) + } + rows, err := db.Query(query) + if err != nil { + return nil, err + } + + defer mustClose(rows) + + migrations := map[string]bool{} + for rows.Next() { + var version string + if err := rows.Scan(&version); err != nil { + return nil, err + } + + migrations[version] = true + } + + return migrations, nil +} + +// InsertMigration adds a new migration record +func (drv Driver) InsertMigration(db shared.Transaction, version string) error { + _, err := db.Exec("insert into schema_migrations (version) values (?)", version) + + return err +} + +// DeleteMigration removes a migration record +func (drv Driver) DeleteMigration(db shared.Transaction, version string) error { + _, err := db.Exec("delete from schema_migrations where version = ?", version) + + return err +} diff --git a/driver/mysql/mysql_test.go b/driver/mysql/mysql_test.go new file mode 100644 index 0000000..6f741ac --- /dev/null +++ b/driver/mysql/mysql_test.go @@ -0,0 +1,214 @@ +package mysql_test + +import ( + "database/sql" + "github.com/adrianmacneil/dbmate/driver/mysql" + "github.com/stretchr/testify/require" + "io" + "net/url" + "os" + "testing" +) + +func testURL(t *testing.T) *url.URL { + str := os.Getenv("MYSQL_PORT") + require.NotEmpty(t, str, "missing MYSQL_PORT environment variable") + + u, err := url.Parse(str) + require.Nil(t, err) + + u.Scheme = "mysql" + u.User = url.UserPassword("root", "root") + u.Path = "/dbmate" + + return u +} + +func mustClose(c io.Closer) { + if err := c.Close(); err != nil { + panic(err) + } +} + +func prepTestDB(t *testing.T) *sql.DB { + drv := mysql.Driver{} + u := testURL(t) + + // drop any existing database + err := drv.DropDatabase(u) + require.Nil(t, err) + + // create database + err = drv.CreateDatabase(u) + require.Nil(t, err) + + // connect database + db, err := drv.Open(u) + require.Nil(t, err) + + return db +} + +func TestCreateDropDatabase(t *testing.T) { + drv := mysql.Driver{} + u := testURL(t) + + // drop any existing database + err := drv.DropDatabase(u) + require.Nil(t, err) + + // create database + err = drv.CreateDatabase(u) + require.Nil(t, err) + + // check that database exists and we can connect to it + func() { + db, err := drv.Open(u) + require.Nil(t, err) + defer mustClose(db) + + err = db.Ping() + require.Nil(t, err) + }() + + // drop the database + err = drv.DropDatabase(u) + require.Nil(t, err) + + // check that database no longer exists + func() { + db, err := drv.Open(u) + require.Nil(t, err) + defer mustClose(db) + + err = db.Ping() + require.NotNil(t, err) + require.Regexp(t, "Unknown database 'dbmate'", err.Error()) + }() +} + +func TestDatabaseExists(t *testing.T) { + drv := mysql.Driver{} + u := testURL(t) + + // drop any existing database + err := drv.DropDatabase(u) + require.Nil(t, err) + + // DatabaseExists should return false + exists, err := drv.DatabaseExists(u) + require.Nil(t, err) + require.Equal(t, false, exists) + + // create database + err = drv.CreateDatabase(u) + require.Nil(t, err) + + // DatabaseExists should return true + exists, err = drv.DatabaseExists(u) + require.Nil(t, err) + require.Equal(t, true, exists) +} + +func TestDatabaseExists_Error(t *testing.T) { + drv := mysql.Driver{} + u := testURL(t) + u.User = url.User("invalid") + + exists, err := drv.DatabaseExists(u) + require.Regexp(t, "Access denied for user 'invalid'@", err.Error()) + require.Equal(t, false, exists) +} + +func TestCreateMigrationsTable(t *testing.T) { + drv := mysql.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + // migrations table should not exist + count := 0 + err := db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Regexp(t, "Table 'dbmate.schema_migrations' doesn't exist", err.Error()) + + // create table + err = drv.CreateMigrationsTable(db) + require.Nil(t, err) + + // migrations table should exist + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Nil(t, err) + + // create table should be idempotent + err = drv.CreateMigrationsTable(db) + require.Nil(t, err) +} + +func TestSelectMigrations(t *testing.T) { + drv := mysql.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + err := drv.CreateMigrationsTable(db) + require.Nil(t, err) + + _, err = db.Exec(`insert into schema_migrations (version) + values ('abc2'), ('abc1'), ('abc3')`) + require.Nil(t, err) + + migrations, err := drv.SelectMigrations(db, -1) + require.Nil(t, err) + require.Equal(t, true, migrations["abc1"]) + require.Equal(t, true, migrations["abc2"]) + require.Equal(t, true, migrations["abc2"]) + + // test limit param + migrations, err = drv.SelectMigrations(db, 1) + require.Nil(t, err) + require.Equal(t, true, migrations["abc3"]) + require.Equal(t, false, migrations["abc1"]) + require.Equal(t, false, migrations["abc2"]) +} + +func TestInsertMigration(t *testing.T) { + drv := mysql.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + err := drv.CreateMigrationsTable(db) + require.Nil(t, err) + + count := 0 + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Nil(t, err) + require.Equal(t, 0, count) + + // insert migration + err = drv.InsertMigration(db, "abc1") + require.Nil(t, err) + + err = db.QueryRow("select count(*) from schema_migrations where version = 'abc1'"). + Scan(&count) + require.Nil(t, err) + require.Equal(t, 1, count) +} + +func TestDeleteMigration(t *testing.T) { + drv := mysql.Driver{} + db := prepTestDB(t) + defer mustClose(db) + + err := drv.CreateMigrationsTable(db) + require.Nil(t, err) + + _, err = db.Exec(`insert into schema_migrations (version) + values ('abc1'), ('abc2')`) + require.Nil(t, err) + + err = drv.DeleteMigration(db, "abc2") + require.Nil(t, err) + + count := 0 + err = db.QueryRow("select count(*) from schema_migrations").Scan(&count) + require.Nil(t, err) + require.Equal(t, 1, count) +}