Refactor drivers into separate packages (#179)

`dbmate` package was starting to get a bit polluted. This PR migrates each driver into a separate package, with clean separation between each.

In addition:

* Drivers are now initialized with a URL, avoiding the need to pass `*url.URL` to every method
* Sqlite supports a cleaner syntax for relative paths
* Driver tests now load their test URL from environment variables

Public API of `dbmate` package has not changed (no changes to `main` package).
This commit is contained in:
Adrian Macneil 2020-11-19 15:04:42 +13:00 committed by GitHub
parent c907c3f5c6
commit 61771e386d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 1195 additions and 1078 deletions

153
pkg/dbutil/dbutil.go Normal file
View file

@ -0,0 +1,153 @@
package dbutil
import (
"bufio"
"bytes"
"database/sql"
"errors"
"io"
"net/url"
"os/exec"
"strings"
"unicode"
)
// Transaction can represent a database or open transaction
type Transaction interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// DatabaseName returns the database name from a URL
func DatabaseName(u *url.URL) string {
name := u.Path
if len(name) > 0 && name[:1] == "/" {
name = name[1:]
}
return name
}
// MustClose ensures a stream is closed
func MustClose(c io.Closer) {
if err := c.Close(); err != nil {
panic(err)
}
}
// RunCommand runs a command and returns the stdout if successful
func RunCommand(name string, args ...string) ([]byte, error) {
var stdout, stderr bytes.Buffer
cmd := exec.Command(name, args...)
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
// return stderr if available
if s := strings.TrimSpace(stderr.String()); s != "" {
return nil, errors.New(s)
}
// otherwise return error
return nil, err
}
// return stdout
return stdout.Bytes(), nil
}
// TrimLeadingSQLComments removes sql comments and blank lines from the beginning of text
// generally when performing sql dumps these contain host-specific information such as
// client/server version numbers
func TrimLeadingSQLComments(data []byte) ([]byte, error) {
// create decent size buffer
out := bytes.NewBuffer(make([]byte, 0, len(data)))
// iterate over sql lines
preamble := true
scanner := bufio.NewScanner(bytes.NewReader(data))
for scanner.Scan() {
// we read bytes directly for premature performance optimization
line := scanner.Bytes()
if preamble && (len(line) == 0 || bytes.Equal(line[0:2], []byte("--"))) {
// header section, skip this line in output buffer
continue
}
// header section is over
preamble = false
// trim trailing whitespace
line = bytes.TrimRightFunc(line, unicode.IsSpace)
// copy bytes to output buffer
if _, err := out.Write(line); err != nil {
return nil, err
}
if _, err := out.WriteString("\n"); err != nil {
return nil, err
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
return out.Bytes(), nil
}
// QueryColumn runs a SQL statement and returns a slice of strings
// it is assumed that the statement returns only one column
// e.g. schema_migrations table
func QueryColumn(db Transaction, query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer MustClose(rows)
// read into slice
var result []string
for rows.Next() {
var v string
if err := rows.Scan(&v); err != nil {
return nil, err
}
result = append(result, v)
}
if err = rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// QueryValue runs a SQL statement and returns a single string
// it is assumed that the statement returns only one row and one column
// sql NULL is returned as empty string
func QueryValue(db Transaction, query string, args ...interface{}) (string, error) {
var result sql.NullString
err := db.QueryRow(query, args...).Scan(&result)
if err != nil || !result.Valid {
return "", err
}
return result.String, nil
}
// MustParseURL parses a URL from string, and panics if it fails.
// It is used during testing and in cases where we are parsing a generated URL.
func MustParseURL(s string) *url.URL {
if s == "" {
panic("missing url")
}
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}

58
pkg/dbutil/dbutil_test.go Normal file
View file

@ -0,0 +1,58 @@
package dbutil_test
import (
"database/sql"
"testing"
"github.com/amacneil/dbmate/pkg/dbutil"
_ "github.com/mattn/go-sqlite3" // database/sql driver
"github.com/stretchr/testify/require"
)
func TestDatabaseName(t *testing.T) {
t.Run("valid", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host/dbname?query")
name := dbutil.DatabaseName(u)
require.Equal(t, "dbname", name)
})
t.Run("empty", func(t *testing.T) {
u := dbutil.MustParseURL("foo://host")
name := dbutil.DatabaseName(u)
require.Equal(t, "", name)
})
}
func TestTrimLeadingSQLComments(t *testing.T) {
in := "--\n" +
"-- foo\n\n" +
"-- bar\n\n" +
"real stuff\n" +
"-- end\n"
out, err := dbutil.TrimLeadingSQLComments([]byte(in))
require.NoError(t, err)
require.Equal(t, "real stuff\n-- end\n", string(out))
}
// connect to in-memory sqlite database for testing
const sqliteMemoryDB = "file:dbutil.sqlite3?mode=memory&cache=shared"
func TestQueryColumn(t *testing.T) {
db, err := sql.Open("sqlite3", sqliteMemoryDB)
require.NoError(t, err)
val, err := dbutil.QueryColumn(db, "select 'foo_' || val from (select ? as val union select ?)",
"hi", "there")
require.NoError(t, err)
require.Equal(t, []string{"foo_hi", "foo_there"}, val)
}
func TestQueryValue(t *testing.T) {
db, err := sql.Open("sqlite3", sqliteMemoryDB)
require.NoError(t, err)
val, err := dbutil.QueryValue(db, "select $1 + $2", "5", 2)
require.NoError(t, err)
require.Equal(t, "7", val)
}