diff --git a/README.md b/README.md index 46d9647..f6c4c07 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,12 @@ protocol://username:password@host:port/database_name?options DATABASE_URL="mysql://username:password@127.0.0.1:3306/database_name" ``` +A socket parameter can be specified to connect through a unix socket file: + +```sh +DATABASE_URL="mysql://username:password@/database_name?socket=/var/run/mysqld/mysqld.sock" +``` + **PostgreSQL** When connecting to Postgres, you may need to add the `sslmode=disable` option to your connection string, as dbmate by default requires a TLS connection (some other frameworks/languages allow unencrypted connections by default). diff --git a/pkg/dbmate/mysql.go b/pkg/dbmate/mysql.go index 55c405a..636b63a 100644 --- a/pkg/dbmate/mysql.go +++ b/pkg/dbmate/mysql.go @@ -19,20 +19,20 @@ type MySQLDriver struct { } func normalizeMySQLURL(u *url.URL) string { - // set default port - host := u.Host - - if u.Port() == "" { - host = fmt.Sprintf("%s:3306", host) - } - - // host format required by go-sql-driver/mysql - host = fmt.Sprintf("tcp(%s)", host) - query := u.Query() query.Set("multiStatements", "true") - queryString := query.Encode() + host := u.Host + protocol := "tcp" + + if query.Get("socket") != "" { + protocol = "unix" + host = query.Get("socket") + query.Del("socket") + } else if u.Port() == "" { + // set default port + host = fmt.Sprintf("%s:3306", host) + } // Get decoded user:pass userPassEncoded := u.User.String() @@ -45,8 +45,9 @@ func normalizeMySQLURL(u *url.URL) string { normalizedString = userPass + "@" } - normalizedString = fmt.Sprintf("%s%s%s?%s", normalizedString, - host, u.Path, queryString) + // connection string format required by go-sql-driver/mysql + normalizedString = fmt.Sprintf("%s%s(%s)%s?%s", normalizedString, + protocol, host, u.Path, query.Encode()) return normalizedString } diff --git a/pkg/dbmate/mysql_test.go b/pkg/dbmate/mysql_test.go index 4aaf3e6..b1d6ebb 100644 --- a/pkg/dbmate/mysql_test.go +++ b/pkg/dbmate/mysql_test.go @@ -61,6 +61,15 @@ func TestNormalizeMySQLURLCustomSpecialChars(t *testing.T) { require.Equal(t, "duhfsd7s:123!@123!@@tcp(host:123)/foo?flag=on&multiStatements=true", s) } +func TestNormalizeMySQLURLSocket(t *testing.T) { + u, err := url.Parse("mysql:///foo?socket=/var/run/mysqld/mysqld.sock&flag=on") + require.NoError(t, err) + require.Equal(t, "", u.Host) + + s := normalizeMySQLURL(u) + require.Equal(t, "unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s) +} + func TestMySQLCreateDropDatabase(t *testing.T) { drv := MySQLDriver{} u := mySQLTestURL(t)