forked from fergusstrange/embedded-postgres
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_database.go
129 lines (102 loc) · 3.29 KB
/
prepare_database.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package embeddedpostgres
import (
"context"
"database/sql"
"errors"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"github.com/lib/pq"
)
type initDatabase func(binaryExtractLocation, username, password, locale string) error
type createDatabase func(port uint32, username, password, database string) error
func defaultInitDatabase(binaryExtractLocation, username, password, locale string) error {
passwordFile, err := createPasswordFile(binaryExtractLocation, password)
if err != nil {
return err
}
args := []string{
"-A", "password",
"-U", username,
"-D", filepath.Join(binaryExtractLocation, "data"),
fmt.Sprintf("--pwfile=%s", passwordFile),
}
if locale != "" {
args = append(args, fmt.Sprintf("--locale=%s", locale))
}
postgresInitDbBinary := filepath.Join(binaryExtractLocation, "bin/initdb")
postgresInitDbProcess := exec.Command(postgresInitDbBinary, args...)
postgresInitDbProcess.Stderr = os.Stderr
postgresInitDbProcess.Stdout = os.Stdout
if err := postgresInitDbProcess.Run(); err != nil {
return fmt.Errorf("unable to init database using: %s", postgresInitDbProcess.String())
}
return nil
}
func createPasswordFile(binaryExtractLocation, password string) (string, error) {
passwordFileLocation := filepath.Join(binaryExtractLocation, "pwfile")
if err := ioutil.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil {
return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation)
}
return passwordFileLocation, nil
}
func defaultCreateDatabase(port uint32, username, password, database string) error {
if database == "postgres" {
return nil
}
conn, err := openDatabaseConnection(port, username, password, "postgres")
if err != nil {
return errorCustomDatabase(database, err)
}
if _, err := sql.OpenDB(conn).Exec(fmt.Sprintf("CREATE DATABASE %s", database)); err != nil {
return errorCustomDatabase(database, err)
}
return nil
}
func healthCheckDatabaseOrTimeout(config Config) error {
healthCheckSignal := make(chan bool)
defer close(healthCheckSignal)
timeout, cancelFunc := context.WithTimeout(context.Background(), config.startTimeout)
defer cancelFunc()
go func() {
for timeout.Err() == nil {
if err := healthCheckDatabase(config.port, config.database, config.username, config.password); err != nil {
continue
}
healthCheckSignal <- true
break
}
}()
select {
case <-healthCheckSignal:
return nil
case <-timeout.Done():
return errors.New("timed out waiting for database to become available")
}
}
func healthCheckDatabase(port uint32, database, username, password string) error {
conn, err := openDatabaseConnection(port, username, password, database)
if err != nil {
return err
}
if _, err := sql.OpenDB(conn).Query("SELECT 1"); err != nil {
return err
}
return nil
}
func openDatabaseConnection(port uint32, username string, password string, database string) (*pq.Connector, error) {
conn, err := pq.NewConnector(fmt.Sprintf("host=localhost port=%d user=%s password=%s dbname=%s sslmode=disable",
port,
username,
password,
database))
if err != nil {
return nil, err
}
return conn, nil
}
func errorCustomDatabase(database string, err error) error {
return fmt.Errorf("unable to connect to create database with custom name %s with the following error: %s", database, err)
}