postgres: Limit pg_dump to schemas in search_path (#166)

This commit is contained in:
Adrian Macneil 2020-10-31 17:13:49 +13:00 committed by GitHub
parent af41fbfb4e
commit d4ecd0b259
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 42 deletions

View file

@ -19,7 +19,7 @@ func init() {
type PostgresDriver struct {
}
func normalizePostgresURL(u *url.URL) string {
func normalizePostgresURL(u *url.URL) *url.URL {
hostname := u.Hostname()
port := u.Port()
query := u.Query()
@ -54,12 +54,33 @@ func normalizePostgresURL(u *url.URL) string {
out.Host = fmt.Sprintf("%s:%s", hostname, port)
out.RawQuery = query.Encode()
return out.String()
return out
}
func normalizePostgresURLForDump(u *url.URL) []string {
u = normalizePostgresURL(u)
// find schemas from search_path
query := u.Query()
schemas := strings.Split(query.Get("search_path"), ",")
query.Del("search_path")
u.RawQuery = query.Encode()
out := []string{}
for _, schema := range schemas {
schema = strings.TrimSpace(schema)
if schema != "" {
out = append(out, "--schema", schema)
}
}
out = append(out, u.String())
return out
}
// Open creates a new database connection
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
return sql.Open("postgres", normalizePostgresURL(u))
return sql.Open("postgres", normalizePostgresURL(u).String())
}
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
@ -128,8 +149,9 @@ func postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
// DumpSchema returns the current database schema
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
// load schema
schema, err := runCommand("pg_dump", "--format=plain", "--encoding=UTF8",
"--schema-only", "--no-privileges", "--no-owner", normalizePostgresURL(u))
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
schema, err := runCommand("pg_dump", args...)
if err != nil {
return nil, err
}

View file

@ -34,47 +34,57 @@ func prepTestPostgresDB(t *testing.T) *sql.DB {
return db
}
func TestNormalizePostgresURLDefaults(t *testing.T) {
u, err := url.Parse("postgres:///foo")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://localhost:5432/foo", s)
func TestNormalizePostgresURL(t *testing.T) {
cases := []struct {
input string
expected string
}{
// defaults
{"postgres:///foo", "postgres://localhost:5432/foo"},
// support custom url params
{"postgres://bob:secret@myhost:1234/foo?bar=baz", "postgres://bob:secret@myhost:1234/foo?bar=baz"},
// support `host` and `port` via url params
{"postgres://bob:secret@myhost:1234/foo?host=new&port=9999", "postgres://bob:secret@:9999/foo?host=new"},
{"postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz", "postgres://bob:secret@myhost:9999/foo?bar=baz"},
// support unix sockets via `host` or `socket` param
{"postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql", "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql"},
{"postgres://bob:secret@localhost/foo?socket=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
{"postgres:///foo?socket=/var/run/postgresql", "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
}
for _, c := range cases {
t.Run(c.input, func(t *testing.T) {
u, err := url.Parse(c.input)
require.NoError(t, err)
actual := normalizePostgresURL(u).String()
require.Equal(t, c.expected, actual)
})
}
}
func TestNormalizePostgresURLCustom(t *testing.T) {
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?bar=baz")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@myhost:1234/foo?bar=baz", s)
}
func TestNormalizePostgresURLForDump(t *testing.T) {
cases := []struct {
input string
expected []string
}{
// defaults
{"postgres:///foo", []string{"postgres://localhost:5432/foo"}},
// support single schema
{"postgres:///foo?search_path=foo", []string{"--schema", "foo", "postgres://localhost:5432/foo"}},
// support multiple schemas
{"postgres:///foo?search_path=foo,public", []string{"--schema", "foo", "--schema", "public", "postgres://localhost:5432/foo"}},
}
func TestNormalizePostgresURLHostPortParams(t *testing.T) {
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@myhost:9999/foo?bar=baz", s)
for _, c := range cases {
t.Run(c.input, func(t *testing.T) {
u, err := url.Parse(c.input)
require.NoError(t, err)
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=new&port=9999")
require.NoError(t, err)
s = normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@:9999/foo?host=new", s)
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql")
require.NoError(t, err)
s = normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
}
func TestNormalizePostgresURLSocketParam(t *testing.T) {
u, err := url.Parse("postgres://bob:secret@localhost/foo?socket=/var/run/postgresql")
require.NoError(t, err)
s := normalizePostgresURL(u)
require.Equal(t, "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
u, err = url.Parse("postgres:///foo?socket=/var/run/postgresql")
require.NoError(t, err)
s = normalizePostgresURL(u)
require.Equal(t, "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
actual := normalizePostgresURLForDump(u)
require.Equal(t, c.expected, actual)
})
}
}
func TestPostgresCreateDropDatabase(t *testing.T) {