mirror of
https://github.com/TECHNOFAB11/dbmate.git
synced 2026-02-02 09:25:07 +01:00
postgres: Limit pg_dump to schemas in search_path (#166)
This commit is contained in:
parent
af41fbfb4e
commit
d4ecd0b259
2 changed files with 74 additions and 42 deletions
|
|
@ -19,7 +19,7 @@ func init() {
|
||||||
type PostgresDriver struct {
|
type PostgresDriver struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizePostgresURL(u *url.URL) string {
|
func normalizePostgresURL(u *url.URL) *url.URL {
|
||||||
hostname := u.Hostname()
|
hostname := u.Hostname()
|
||||||
port := u.Port()
|
port := u.Port()
|
||||||
query := u.Query()
|
query := u.Query()
|
||||||
|
|
@ -54,12 +54,33 @@ func normalizePostgresURL(u *url.URL) string {
|
||||||
out.Host = fmt.Sprintf("%s:%s", hostname, port)
|
out.Host = fmt.Sprintf("%s:%s", hostname, port)
|
||||||
out.RawQuery = query.Encode()
|
out.RawQuery = query.Encode()
|
||||||
|
|
||||||
return out.String()
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePostgresURLForDump(u *url.URL) []string {
|
||||||
|
u = normalizePostgresURL(u)
|
||||||
|
|
||||||
|
// find schemas from search_path
|
||||||
|
query := u.Query()
|
||||||
|
schemas := strings.Split(query.Get("search_path"), ",")
|
||||||
|
query.Del("search_path")
|
||||||
|
u.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
out := []string{}
|
||||||
|
for _, schema := range schemas {
|
||||||
|
schema = strings.TrimSpace(schema)
|
||||||
|
if schema != "" {
|
||||||
|
out = append(out, "--schema", schema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out = append(out, u.String())
|
||||||
|
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// Open creates a new database connection
|
// Open creates a new database connection
|
||||||
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
func (drv PostgresDriver) Open(u *url.URL) (*sql.DB, error) {
|
||||||
return sql.Open("postgres", normalizePostgresURL(u))
|
return sql.Open("postgres", normalizePostgresURL(u).String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
func (drv PostgresDriver) openPostgresDB(u *url.URL) (*sql.DB, error) {
|
||||||
|
|
@ -128,8 +149,9 @@ func postgresSchemaMigrationsDump(db *sql.DB) ([]byte, error) {
|
||||||
// DumpSchema returns the current database schema
|
// DumpSchema returns the current database schema
|
||||||
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
func (drv PostgresDriver) DumpSchema(u *url.URL, db *sql.DB) ([]byte, error) {
|
||||||
// load schema
|
// load schema
|
||||||
schema, err := runCommand("pg_dump", "--format=plain", "--encoding=UTF8",
|
args := append([]string{"--format=plain", "--encoding=UTF8", "--schema-only",
|
||||||
"--schema-only", "--no-privileges", "--no-owner", normalizePostgresURL(u))
|
"--no-privileges", "--no-owner"}, normalizePostgresURLForDump(u)...)
|
||||||
|
schema, err := runCommand("pg_dump", args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -34,47 +34,57 @@ func prepTestPostgresDB(t *testing.T) *sql.DB {
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizePostgresURLDefaults(t *testing.T) {
|
func TestNormalizePostgresURL(t *testing.T) {
|
||||||
u, err := url.Parse("postgres:///foo")
|
cases := []struct {
|
||||||
require.NoError(t, err)
|
input string
|
||||||
s := normalizePostgresURL(u)
|
expected string
|
||||||
require.Equal(t, "postgres://localhost:5432/foo", s)
|
}{
|
||||||
|
// defaults
|
||||||
|
{"postgres:///foo", "postgres://localhost:5432/foo"},
|
||||||
|
// support custom url params
|
||||||
|
{"postgres://bob:secret@myhost:1234/foo?bar=baz", "postgres://bob:secret@myhost:1234/foo?bar=baz"},
|
||||||
|
// support `host` and `port` via url params
|
||||||
|
{"postgres://bob:secret@myhost:1234/foo?host=new&port=9999", "postgres://bob:secret@:9999/foo?host=new"},
|
||||||
|
{"postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz", "postgres://bob:secret@myhost:9999/foo?bar=baz"},
|
||||||
|
// support unix sockets via `host` or `socket` param
|
||||||
|
{"postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql", "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql"},
|
||||||
|
{"postgres://bob:secret@localhost/foo?socket=/var/run/postgresql", "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
|
||||||
|
{"postgres:///foo?socket=/var/run/postgresql", "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql"},
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizePostgresURLCustom(t *testing.T) {
|
for _, c := range cases {
|
||||||
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?bar=baz")
|
t.Run(c.input, func(t *testing.T) {
|
||||||
|
u, err := url.Parse(c.input)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
s := normalizePostgresURL(u)
|
|
||||||
require.Equal(t, "postgres://bob:secret@myhost:1234/foo?bar=baz", s)
|
actual := normalizePostgresURL(u).String()
|
||||||
|
require.Equal(t, c.expected, actual)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizePostgresURLHostPortParams(t *testing.T) {
|
func TestNormalizePostgresURLForDump(t *testing.T) {
|
||||||
u, err := url.Parse("postgres://bob:secret@myhost:1234/foo?port=9999&bar=baz")
|
cases := []struct {
|
||||||
require.NoError(t, err)
|
input string
|
||||||
s := normalizePostgresURL(u)
|
expected []string
|
||||||
require.Equal(t, "postgres://bob:secret@myhost:9999/foo?bar=baz", s)
|
}{
|
||||||
|
// defaults
|
||||||
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=new&port=9999")
|
{"postgres:///foo", []string{"postgres://localhost:5432/foo"}},
|
||||||
require.NoError(t, err)
|
// support single schema
|
||||||
s = normalizePostgresURL(u)
|
{"postgres:///foo?search_path=foo", []string{"--schema", "foo", "postgres://localhost:5432/foo"}},
|
||||||
require.Equal(t, "postgres://bob:secret@:9999/foo?host=new", s)
|
// support multiple schemas
|
||||||
|
{"postgres:///foo?search_path=foo,public", []string{"--schema", "foo", "--schema", "public", "postgres://localhost:5432/foo"}},
|
||||||
u, err = url.Parse("postgres://bob:secret@myhost:1234/foo?host=/var/run/postgresql")
|
|
||||||
require.NoError(t, err)
|
|
||||||
s = normalizePostgresURL(u)
|
|
||||||
require.Equal(t, "postgres://bob:secret@:1234/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizePostgresURLSocketParam(t *testing.T) {
|
for _, c := range cases {
|
||||||
u, err := url.Parse("postgres://bob:secret@localhost/foo?socket=/var/run/postgresql")
|
t.Run(c.input, func(t *testing.T) {
|
||||||
|
u, err := url.Parse(c.input)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
s := normalizePostgresURL(u)
|
|
||||||
require.Equal(t, "postgres://bob:secret@:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
|
|
||||||
|
|
||||||
u, err = url.Parse("postgres:///foo?socket=/var/run/postgresql")
|
actual := normalizePostgresURLForDump(u)
|
||||||
require.NoError(t, err)
|
require.Equal(t, c.expected, actual)
|
||||||
s = normalizePostgresURL(u)
|
})
|
||||||
require.Equal(t, "postgres://:5432/foo?host=%2Fvar%2Frun%2Fpostgresql", s)
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPostgresCreateDropDatabase(t *testing.T) {
|
func TestPostgresCreateDropDatabase(t *testing.T) {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue