Add postgres support for socket parameter (#136)

Simplifies connecting to postgres via sockets (https://github.com/amacneil/dbmate/issues/107), and standardize the `socket` parameter across both mysql and postgresql.

Closes #107
This commit is contained in:
Adrian Macneil 2020-05-24 17:12:41 -07:00 committed by GitHub
parent 45a122eb86
commit ed9e57a4ad
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 3 deletions

View file

@ -62,12 +62,20 @@ func TestNormalizeMySQLURLCustomSpecialChars(t *testing.T) {
}
func TestNormalizeMySQLURLSocket(t *testing.T) {
// test with no user/pass
u, err := url.Parse("mysql:///foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
require.NoError(t, err)
require.Equal(t, "", u.Host)
s := normalizeMySQLURL(u)
require.Equal(t, "unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
// test with user/pass
u, err = url.Parse("mysql://bob:secret@fakehost/foo?socket=/var/run/mysqld/mysqld.sock&flag=on")
require.NoError(t, err)
s = normalizeMySQLURL(u)
require.Equal(t, "bob:secret@unix(/var/run/mysqld/mysqld.sock)/foo?flag=on&multiStatements=true", s)
}
func TestMySQLCreateDropDatabase(t *testing.T) {

View file

@ -19,9 +19,47 @@ func init() {
type PostgresDriver struct {
}
func normalizePostgresURL(u *url.URL) string {
hostname := u.Hostname()
port := u.Port()
query := u.Query()
// support socket parameter for consistency with mysql
if query.Get("socket") != "" {
query.Set("host", query.Get("socket"))
query.Del("socket")
}
// default hostname
if hostname == "" {
hostname = "localhost"
}
// host param overrides url hostname
if query.Get("host") != "" {
hostname = ""
}
// always specify a port
if query.Get("port") != "" {
port = query.Get("port")
query.Del("port")
}
if port == "" {
port = "5432"
}
// generate output URL
out, _ := url.Parse(u.String())
out.Host = fmt.Sprintf("%s:%s", hostname, port)
out.RawQuery = query.Encode()
return out.String()
}
// Open creates a new database connection
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("postgres", u.String())
return sql.Open("postgres", normalizePostgresURL(u))
}
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
@ -91,7 +129,7 @@ func postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
// load schema
schema, err := runCommand("pg_dump", "--format=plain", "--encoding=UTF8",
"--schema-only", "--no-privileges", "--no-owner", u.String())
"--schema-only", "--no-privileges", "--no-owner", normalizePostgresURL(u))
if err != nil {
return nil, err
}

View file

@ -34,6 +34,49 @@ func prepTestPostgresDB(t *testing.T) *sql.DB {
return db
}
func TestNormalizePostgresURLDefaults(t *testing.T) {
u, err := url.Parse("postgres:///foo")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://localhost:5432/foo", s)
}
func TestNormalizePostgresURLCustom(t *testing.T) {
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?bar=baz")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@myhost:1234/foo?bar=baz", s)
}
func TestNormalizePostgresURLHostPortParams(t *testing.T) {
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@myhost:9999/foo?bar=baz", s)
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=new&port=9999")
require.NoError(t, err)
s = normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@:9999/foo?host=new", s)
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql")
require.NoError(t, err)
s = normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
}
func TestNormalizePostgresURLSocketParam(t *testing.T) {
u, err := url.Parse("postgres://bob:secret@localhost/foo?socket=/var/run/postgresql")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
u, err = url.Parse("postgres:///foo?socket=/var/run/postgresql")
require.NoError(t, err)
s = normalizePostgresURL(u)
require.Equal(t, "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
}
func TestPostgresCreateDropDatabase(t *testing.T) {
drv := PostgresDriver{}
u := postgresTestURL(t)