mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2026-02-02 01:15:09 +01:00
postgres: Limit pg_dump to schemas in search_path (#166)
This commit is contained in:
parent
af41fbfb4e
commit
d4ecd0b259
2 changed files with 74 additions and 42 deletions
|
|
@ -19,7 +19,7 @@ func init() {
|
|||
type PostgresDriver struct {
|
||||
}
|
||||
|
||||
func normalizePostgresURL(u *url.URL) string {
|
||||
func normalizePostgresURL(u *url.URL) *url.URL {
|
||||
hostname := u.Hostname()
|
||||
port := u.Port()
|
||||
query := u.Query()
|
||||
|
|
@ -54,12 +54,33 @@ func normalizePostgresURL(u *url.URL) string {
|
|||
out.Host = fmt.Sprintf("%s:%s", hostname, port)
|
||||
out.RawQuery = query.Encode()
|
||||
|
||||
return out.String()
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizePostgresURLForDump(u *url.URL) []string {
|
||||
u = normalizePostgresURL(u)
|
||||
|
||||
// find schemas from search_path
|
||||
query := u.Query()
|
||||
schemas := strings.Split(query.Get("search_path"), ",")
|
||||
query.Del("search_path")
|
||||
u.RawQuery = query.Encode()
|
||||
|
||||
out := []string{}
|
||||
for _, schema := range schemas {
|
||||
schema = strings.TrimSpace(schema)
|
||||
if schema != "" {
|
||||
out = append(out, "--schema", schema)
|
||||
}
|
||||
}
|
||||
out = append(out, u.String())
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Open creates a new database connection
|
||||
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||
return sql.Open("postgres", normalizePostgresURL(u))
|
||||
return sql.Open("postgres", normalizePostgresURL(u).String())
|
||||
}
|
||||
|
||||
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||
|
|
@ -128,8 +149,9 @@ func postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
|||
// DumpSchema returns the current database schema
|
||||
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||
// load schema
|
||||
schema, err := runCommand("pg_dump", "--format=plain", "--encoding=UTF8",
|
||||
"--schema-only", "--no-privileges", "--no-owner", normalizePostgresURL(u))
|
||||
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
|
||||
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
|
||||
schema, err := runCommand("pg_dump", args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,47 +34,57 @@ func prepTestPostgresDB(t *testing.T) *sql.DB {
|
|||
return db
|
||||
}
|
||||
|
||||
func TestNormalizePostgresURLDefaults(t *testing.T) {
|
||||
u, err := url.Parse("postgres:///foo")
|
||||
require.NoError(t, err)
|
||||
s := normalizePostgresURL(u)
|
||||
require.Equal(t, "postgres://localhost:5432/foo", s)
|
||||
func TestNormalizePostgresURL(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
// defaults
|
||||
{"postgres:///foo", "postgres://localhost:5432/foo"},
|
||||
// support custom url params
|
||||
{"postgres://bob:secret@myhost:1234/foo?bar=baz", "postgres://bob:secret@myhost:1234/foo?bar=baz"},
|
||||
// support `host` and `port` via url params
|
||||
{"postgres://bob:secret@myhost:1234/foo?host=new&port=9999", "postgres://bob:secret@:9999/foo?host=new"},
|
||||
{"postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz", "postgres://bob:secret@myhost:9999/foo?bar=baz"},
|
||||
// support unix sockets via `host` or `socket` param
|
||||
{"postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql", "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql"},
|
||||
{"postgres://bob:secret@localhost/foo?socket=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
|
||||
{"postgres:///foo?socket=/var/run/postgresql", "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
|
||||
}
|
||||
|
||||
func TestNormalizePostgresURLCustom(t *testing.T) {
|
||||
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?bar=baz")
|
||||
for _, c := range cases {
|
||||
t.Run(c.input, func(t *testing.T) {
|
||||
u, err := url.Parse(c.input)
|
||||
require.NoError(t, err)
|
||||
s := normalizePostgresURL(u)
|
||||
require.Equal(t, "postgres://bob:secret@myhost:1234/foo?bar=baz", s)
|
||||
|
||||
actual := normalizePostgresURL(u).String()
|
||||
require.Equal(t, c.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizePostgresURLHostPortParams(t *testing.T) {
|
||||
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz")
|
||||
require.NoError(t, err)
|
||||
s := normalizePostgresURL(u)
|
||||
require.Equal(t, "postgres://bob:secret@myhost:9999/foo?bar=baz", s)
|
||||
|
||||
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=new&port=9999")
|
||||
require.NoError(t, err)
|
||||
s = normalizePostgresURL(u)
|
||||
require.Equal(t, "postgres://bob:secret@:9999/foo?host=new", s)
|
||||
|
||||
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql")
|
||||
require.NoError(t, err)
|
||||
s = normalizePostgresURL(u)
|
||||
require.Equal(t, "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
|
||||
func TestNormalizePostgresURLForDump(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
// defaults
|
||||
{"postgres:///foo", []string{"postgres://localhost:5432/foo"}},
|
||||
// support single schema
|
||||
{"postgres:///foo?search_path=foo", []string{"--schema", "foo", "postgres://localhost:5432/foo"}},
|
||||
// support multiple schemas
|
||||
{"postgres:///foo?search_path=foo,public", []string{"--schema", "foo", "--schema", "public", "postgres://localhost:5432/foo"}},
|
||||
}
|
||||
|
||||
func TestNormalizePostgresURLSocketParam(t *testing.T) {
|
||||
u, err := url.Parse("postgres://bob:secret@localhost/foo?socket=/var/run/postgresql")
|
||||
for _, c := range cases {
|
||||
t.Run(c.input, func(t *testing.T) {
|
||||
u, err := url.Parse(c.input)
|
||||
require.NoError(t, err)
|
||||
s := normalizePostgresURL(u)
|
||||
require.Equal(t, "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
|
||||
|
||||
u, err = url.Parse("postgres:///foo?socket=/var/run/postgresql")
|
||||
require.NoError(t, err)
|
||||
s = normalizePostgresURL(u)
|
||||
require.Equal(t, "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
|
||||
actual := normalizePostgresURLForDump(u)
|
||||
require.Equal(t, c.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresCreateDropDatabase(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue