Skip to content

Commit

Permalink
Basic query testing framework
Browse files Browse the repository at this point in the history
  • Loading branch information
nassibnassar committed Sep 17, 2019
1 parent 4a2a3a5 commit 79537be
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 4 deletions.
5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module github.com/folio-org/ldp-analytics

go 1.13

require (
github.com/lib/pq v1.2.0
github.com/nassibnassar/goconfig v0.1.0
)
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/nassibnassar/goconfig v0.1.0 h1:T+TRnqhi2jKyM6EClc/fcxF0SZ88yhN9cQVuf/SEaLs=
github.com/nassibnassar/goconfig v0.1.0/go.mod h1:opcdYVSH3KaZ7VOAzLo2VD9dDsa4i5YCDRmoOuVRQew=
27 changes: 27 additions & 0 deletions gotest/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package gotest

import (
"fmt"
"os"
"path/filepath"

"github.com/nassibnassar/goconfig/ini"
)

// readConfig reads and returns the configuration file ".ldptestsql", which it
// expects to find in the user's home directory.
func readConfig() (*ini.Config, error) {
homedir, err := os.UserHomeDir()
if err != nil {
return nil,
fmt.Errorf("Error reading configuration file: " +
"Unable to determine home directory")
}
filename := filepath.Join(homedir, ".ldptestsql")
config, err := ini.NewConfigFile(filename)
if err != nil {
return nil,
fmt.Errorf("Error reading configuration file: %v", err)
}
return config, nil
}
30 changes: 30 additions & 0 deletions gotest/database.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package gotest

import (
"database/sql"
"fmt"

_ "github.com/lib/pq"
)

// openDatabase creates and returns a client connection to a specified
// database.
func openDatabase(host, port, user, password, dbname, sslmode string) (*sql.DB,
error) {

connStr := fmt.Sprintf(
"host=%s port=%s user=%s password=%s dbname=%s sslmode=%s",
host, port, user, password, dbname, sslmode)
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, err
}

// Ping the database to test for connection errors.
err = db.Ping()
if err != nil {
return nil, err
}

return db, nil
}
160 changes: 157 additions & 3 deletions gotest/gotest.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,165 @@
package gotest

import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)

func RunTest(t *testing.T, queryFile string, resultFile string) {
// dir, _ := os.Getwd()
// query := filepath.Join(dir, queryFile)
// result := filepath.Join(dir, resultFile)
dir, _ := os.Getwd()
queryPath := filepath.Join(dir, queryFile)
resultPath := filepath.Join(dir, resultFile)
err := run(queryPath, resultPath)
if err != nil {
t.Errorf("%v", err)
}
}

func run(queryPath string, resultPath string) error {
config, err := readConfig()
if err != nil {
return err
}
databases := strings.Split(config.Get("", "databases"), ",")
for _, d := range databases {
// TODO Move inside loop to function to allow proper db.Close()
section := strings.TrimSpace(d)
db, err := openDatabase(
config.Get(section, "host"),
config.Get(section, "port"),
config.Get(section, "user"),
config.Get(section, "password"),
config.Get(section, "dbname"),
"require")
if err != nil {
return err
}
defer db.Close()
err = runDB(queryPath, resultPath, section, db)
if err != nil {
return err
}
}
return nil
}

func runDB(queryPath string, resultPath string, section string,
db *sql.DB) error {

buf, err := ioutil.ReadFile(queryPath)
if err != nil {
return fmt.Errorf("Unable to read query file: %v", err)
}
query := string(buf)

buf, err = ioutil.ReadFile(resultPath)
if err != nil {
return fmt.Errorf("Unable to read result file: %v", err)
}
expectedResult := string(buf)

rows, err := db.Query(query)
if err != nil {
return fmt.Errorf("Error running query: %v", err)
}
defer rows.Close()

var b strings.Builder

columns, err := rows.Columns()
if err != nil {
return fmt.Errorf("Error running query: %v", err)
}
for x, c := range columns {
if x > 0 {
fmt.Fprintf(&b, ",")
}
fmt.Fprintf(&b, "%s", c)
}
fmt.Fprintf(&b, "\n")

data := make([]interface{}, len(columns))
rdata := make([][]byte, len(columns))
for x := range rdata {
data[x] = &rdata[x]
}
for rows.Next() {
err = rows.Scan(data...)
if err != nil {
return fmt.Errorf("Error running query (scan): %v",
err)
}
for x, r := range rdata {
if x > 0 {
fmt.Fprintf(&b, ",")
}
if r != nil {
fmt.Fprintf(&b, "%s", string(r))
}
}
fmt.Fprintf(&b, "\n")
}

err = rows.Err()
if err != nil {
return fmt.Errorf("Error running query: %v", err)
}

result := b.String()

match := strings.TrimSpace(result) == strings.TrimSpace(expectedResult)

if !match {

fn, err := writeResult(filepath.Dir(queryPath), result)
if err != nil {
fn = "(Unable to write file)"
}

return fmt.Errorf(
"\n\nQuery:\n"+
"%s\n\n"+
"Expected result:\n"+
"%s\n\n"+
"Testing in database: %s\n\n"+
"Got:\n"+
"%s\n\n"+
"Want:\n"+
"%s\n\n"+
"Unexpected result written to:\n"+
"%s\n\n",
queryPath,
resultPath,
section,
strings.TrimSpace(result),
strings.TrimSpace(expectedResult),
fn)

}

return nil
}

// writeResult writes the unexpected result to a file and returns the file
// name.
func writeResult(dir string, result string) (string, error) {
f, err := ioutil.TempFile(dir, "test-result-*.csv")
if err != nil {
return "", err
}
_, err = f.Write([]byte(result))
if err != nil {
f.Close()
return "", err
}
err = f.Close()
if err != nil {
return "", err
}
return f.Name(), nil
}
4 changes: 4 additions & 0 deletions sql/circ_detail/circ_detail_result.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
service_point_name,group_name,ct
,faculty,2
,graduate,2
,staff,1
2 changes: 1 addition & 1 deletion sql/circ_detail/circ_detail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ import (

func TestQuery(t *testing.T) {

gotest.RunTest(t, "circ-detail.sql", "circ-detail.out")
gotest.RunTest(t, "circ-detail.sql", "circ_detail_result.csv")

}

0 comments on commit 79537be

Please sign in to comment.