mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2026-02-02 17:35:08 +01:00
Fix special chars passwords in MySQL Driver (#76)
When MySQL password included special chars such as exclamation point or @ (at), MySQL backend errored out (invalid password). url.userinfo.String() (which gets called inside url.String()) returns %-encoded strings and MySQL interprets it as an actual password. Now the function percent-decodes it first before returning. Closes #57
This commit is contained in:
parent
c2c05ffb91
commit
3250277c26
2 changed files with 32 additions and 10 deletions
|
|
@ -19,23 +19,36 @@ type MySQLDriver struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeMySQLURL(u *url.URL) string {
|
func normalizeMySQLURL(u *url.URL) string {
|
||||||
normalizedURL := *u
|
|
||||||
normalizedURL.Scheme = ""
|
|
||||||
|
|
||||||
// set default port
|
// set default port
|
||||||
if normalizedURL.Port() == "" {
|
host := u.Host
|
||||||
normalizedURL.Host = fmt.Sprintf("%s:3306", normalizedURL.Host)
|
|
||||||
|
if u.Port() == "" {
|
||||||
|
host = fmt.Sprintf("%s:3306", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// host format required by go-sql-driver/mysql
|
// host format required by go-sql-driver/mysql
|
||||||
normalizedURL.Host = fmt.Sprintf("tcp(%s)", normalizedURL.Host)
|
host = fmt.Sprintf("tcp(%s)", host)
|
||||||
|
|
||||||
query := normalizedURL.Query()
|
query := u.Query()
|
||||||
query.Set("multiStatements", "true")
|
query.Set("multiStatements", "true")
|
||||||
normalizedURL.RawQuery = query.Encode()
|
|
||||||
|
|
||||||
str := normalizedURL.String()
|
queryString := query.Encode()
|
||||||
return strings.TrimLeft(str, "/")
|
|
||||||
|
// Get decoded user:pass
|
||||||
|
userPassEncoded := u.User.String()
|
||||||
|
userPass, _ := url.QueryUnescape(userPassEncoded)
|
||||||
|
|
||||||
|
// Build DSN w/ user:pass percent-decoded
|
||||||
|
normalizedString := ""
|
||||||
|
|
||||||
|
if userPass != "" { // user:pass can be empty
|
||||||
|
normalizedString = userPass + "@"
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedString = fmt.Sprintf("%s%s%s?%s", normalizedString,
|
||||||
|
host, u.Path, queryString)
|
||||||
|
|
||||||
|
return normalizedString
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,15 @@ func TestNormalizeMySQLURLCustom(t *testing.T) {
|
||||||
require.Equal(t, "bob:secret@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
require.Equal(t, "bob:secret@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeMySQLURLCustomSpecialChars(t *testing.T) {
|
||||||
|
u, err := url.Parse("mysql://duhfsd7s:123!@123!@@host:123/foo?flag=on")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "123", u.Port())
|
||||||
|
|
||||||
|
s := normalizeMySQLURL(u)
|
||||||
|
require.Equal(t, "duhfsd7s:123!@123!@@tcp(host:123)/foo?flag=on&multiStatements=true", s)
|
||||||
|
}
|
||||||
|
|
||||||
func TestMySQLCreateDropDatabase(t *testing.T) {
|
func TestMySQLCreateDropDatabase(t *testing.T) {
|
||||||
drv := MySQLDriver{}
|
drv := MySQLDriver{}
|
||||||
u := mySQLTestURL(t)
|
u := mySQLTestURL(t)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue