diff --git a/pkg/dbmate/postgres.go b/pkg/dbmate/postgres.go index 084cff5..5d751cd 100644 --- a/pkg/dbmate/postgres.go +++ b/pkg/dbmate/postgres.go @@ -19,7 +19,7 @@ func init() { type PostgresDriver struct { } -func normalizePostgresURL(u *url.URL) string { +func normalizePostgresURL(u *url.URL) *url.URL { hostname := u.Hostname() port := u.Port() query := u.Query() @@ -54,12 +54,33 @@ func normalizePostgresURL(u *url.URL) string { out.Host = fmt.Sprintf("%s:%s", hostname, port) out.RawQuery = query.Encode() - return out.String() + return out +} + +func normalizePostgresURLForDump(u *url.URL) []string { + u = normalizePostgresURL(u) + + // find schemas from search_path + query := u.Query() + schemas := strings.Split(query.Get("search_path"), ",") + query.Del("search_path") + u.RawQuery = query.Encode() + + out := []string{} + for _, schema := range schemas { + schema = strings.TrimSpace(schema) + if schema != "" { + out = append(out, "--schema", schema) + } + } + out = append(out, u.String()) + + return out } // Open creates a new database connection func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) { - return sql.Open("postgres", normalizePostgresURL(u)) + return sql.Open("postgres", normalizePostgresURL(u).String()) } func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) { @@ -128,8 +149,9 @@ func postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) { // DumpSchema returns the current database schema func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) { // load schema - schema, err := runCommand("pg_dump", "--format=plain", "--encoding=UTF8", - "--schema-only", "--no-privileges", "--no-owner", normalizePostgresURL(u)) + args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only", + "--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...) + schema, err := runCommand("pg_dump", args...) if err != nil { return nil, err } diff --git a/pkg/dbmate/postgres_test.go b/pkg/dbmate/postgres_test.go index 1294a3e..4e5739a 100644 --- a/pkg/dbmate/postgres_test.go +++ b/pkg/dbmate/postgres_test.go @@ -34,47 +34,57 @@ func prepTestPostgresDB(t *testing.T) *sql.DB { return db } -func TestNormalizePostgresURLDefaults(t *testing.T) { - u, err := url.Parse("postgres:///foo") - require.NoError(t, err) - s := normalizePostgresURL(u) - require.Equal(t, "postgres://localhost:5432/foo", s) +func TestNormalizePostgresURL(t *testing.T) { + cases := []struct { + input string + expected string + }{ + // defaults + {"postgres:///foo", "postgres://localhost:5432/foo"}, + // support custom url params + {"postgres://bob:secret@myhost:1234/foo?bar=baz", "postgres://bob:secret@myhost:1234/foo?bar=baz"}, + // support `host` and `port` via url params + {"postgres://bob:secret@myhost:1234/foo?host=new&port=9999", "postgres://bob:secret@:9999/foo?host=new"}, + {"postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz", "postgres://bob:secret@myhost:9999/foo?bar=baz"}, + // support unix sockets via `host` or `socket` param + {"postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql", "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql"}, + {"postgres://bob:secret@localhost/foo?socket=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"}, + {"postgres:///foo?socket=/var/run/postgresql", "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"}, + } + + for _, c := range cases { + t.Run(c.input, func(t *testing.T) { + u, err := url.Parse(c.input) + require.NoError(t, err) + + actual := normalizePostgresURL(u).String() + require.Equal(t, c.expected, actual) + }) + } } -func TestNormalizePostgresURLCustom(t *testing.T) { - u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?bar=baz") - require.NoError(t, err) - s := normalizePostgresURL(u) - require.Equal(t, "postgres://bob:secret@myhost:1234/foo?bar=baz", s) -} +func TestNormalizePostgresURLForDump(t *testing.T) { + cases := []struct { + input string + expected []string + }{ + // defaults + {"postgres:///foo", []string{"postgres://localhost:5432/foo"}}, + // support single schema + {"postgres:///foo?search_path=foo", []string{"--schema", "foo", "postgres://localhost:5432/foo"}}, + // support multiple schemas + {"postgres:///foo?search_path=foo,public", []string{"--schema", "foo", "--schema", "public", "postgres://localhost:5432/foo"}}, + } -func TestNormalizePostgresURLHostPortParams(t *testing.T) { - u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz") - require.NoError(t, err) - s := normalizePostgresURL(u) - require.Equal(t, "postgres://bob:secret@myhost:9999/foo?bar=baz", s) + for _, c := range cases { + t.Run(c.input, func(t *testing.T) { + u, err := url.Parse(c.input) + require.NoError(t, err) - u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=new&port=9999") - require.NoError(t, err) - s = normalizePostgresURL(u) - require.Equal(t, "postgres://bob:secret@:9999/foo?host=new", s) - - u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql") - require.NoError(t, err) - s = normalizePostgresURL(u) - require.Equal(t, "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql", s) -} - -func TestNormalizePostgresURLSocketParam(t *testing.T) { - u, err := url.Parse("postgres://bob:secret@localhost/foo?socket=/var/run/postgresql") - require.NoError(t, err) - s := normalizePostgresURL(u) - require.Equal(t, "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s) - - u, err = url.Parse("postgres:///foo?socket=/var/run/postgresql") - require.NoError(t, err) - s = normalizePostgresURL(u) - require.Equal(t, "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s) + actual := normalizePostgresURLForDump(u) + require.Equal(t, c.expected, actual) + }) + } } func TestPostgresCreateDropDatabase(t *testing.T) {