Skip to content

Commit

Permalink
Refactored Store to use BaseRunners instead of proxies
Browse files Browse the repository at this point in the history
Query logger is now inherited for transactions (fixes #254)
Paves the way for #256
  • Loading branch information
nadiamoe committed Jan 29, 2018
1 parent 3162cdd commit 1dec030
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 58 deletions.
4 changes: 2 additions & 2 deletions batcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type batchQueryRunner struct {
q Query
oneToOneRels []Relationship
oneToManyRels []Relationship
db squirrel.DBProxy
db squirrel.BaseRunner
builder squirrel.SelectBuilder
total int
eof bool
Expand All @@ -24,7 +24,7 @@ type batchQueryRunner struct {

var errNoMoreRows = errors.New("kallax: there are no more rows in the result set")

func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQueryRunner {
func newBatchQueryRunner(schema Schema, db squirrel.BaseRunner, q Query) *batchQueryRunner {
cols, builder := q.compile()
var (
oneToOneRels []Relationship
Expand Down
6 changes: 3 additions & 3 deletions batcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestBatcherLimit(t *testing.T) {
q.BatchSize(2)
q.Limit(5)
r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1")))
runner := newBatchQueryRunner(ModelSchema, store.proxy, q)
runner := newBatchQueryRunner(ModelSchema, store.runner, q)
rs := NewBatchingResultSet(runner)

var count int
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestBatcherNoExtraQueryIfLessThanLimit(t *testing.T) {
var queries int
proxy := store.DebugWith(func(_ string, _ ...interface{}) {
queries++
}).proxy
}).runner
runner := newBatchQueryRunner(ModelSchema, proxy, q)
rs := NewBatchingResultSet(runner)

Expand Down Expand Up @@ -130,7 +130,7 @@ func TestBatcherNoExtraQueryIfLessThanBatchSize(t *testing.T) {
var queries int
proxy := store.DebugWith(func(_ string, _ ...interface{}) {
queries++
}).proxy
}).runner
runner := newBatchQueryRunner(ModelSchema, proxy, q)
rs := NewBatchingResultSet(runner)

Expand Down
151 changes: 98 additions & 53 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,62 +60,87 @@ func StoreFrom(to, from GenericStorer) {
// logs it.
type LoggerFunc func(string, ...interface{})

// debugProxy is a database proxy that logs all SQL statements executed.
type debugProxy struct {
func defaultLogger(message string, args ...interface{}) {
log.Printf("%s, args: %v", message, args)
}

// basicLogger is a database runner that logs all SQL statements executed.
type basicLogger struct {
logger LoggerFunc
proxy squirrel.DBProxy
runner squirrel.BaseRunner
}

func defaultLogger(message string, args ...interface{}) {
log.Printf("%s, args: %v", message, args)
// basicLogger is a database runner that logs all SQL statements executed.
type proxyLogger struct {
basicLogger
}

func (p *debugProxy) Exec(query string, args ...interface{}) (sql.Result, error) {
func (p *basicLogger) Exec(query string, args ...interface{}) (sql.Result, error) {
p.logger(fmt.Sprintf("kallax: Exec: %s", query), args...)
return p.proxy.Exec(query, args...)
return p.runner.Exec(query, args...)
}

func (p *debugProxy) Query(query string, args ...interface{}) (*sql.Rows, error) {
func (p *basicLogger) Query(query string, args ...interface{}) (*sql.Rows, error) {
p.logger(fmt.Sprintf("kallax: Query: %s", query), args...)
return p.proxy.Query(query, args...)
return p.runner.Query(query, args...)
}

func (p *debugProxy) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
p.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
return p.proxy.QueryRow(query, args...)
func (p *proxyLogger) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
p.basicLogger.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
if queryRower, ok := p.basicLogger.runner.(squirrel.QueryRower); ok {
return queryRower.QueryRow(query, args...)
} else {
panic("Called proxyLogger with a runner which doesn't implement QueryRower")
}
}

func (p *debugProxy) Prepare(query string) (*sql.Stmt, error) {
p.logger(fmt.Sprintf("kallax: Prepare: %s", query))
return p.proxy.Prepare(query)
func (p *proxyLogger) Prepare(query string) (*sql.Stmt, error) {
// If chained runner is a proxy, run Prepare(). Otherwise, noop.
if preparer, ok := p.basicLogger.runner.(squirrel.Preparer); ok {
p.basicLogger.logger(fmt.Sprintf("kallax: Prepare: %s", query))
return preparer.Prepare(query)
} else {
panic("Called proxyLogger with a runner which doesn't implement QueryRower")
}
}

// Store is a structure capable of retrieving records from a concrete table in
// the database.
type Store struct {
builder squirrel.StatementBuilderType
db *sql.DB
proxy squirrel.DBProxy
db interface {
squirrel.BaseRunner
squirrel.PreparerContext
}
runner squirrel.BaseRunner
useCacher bool
logger LoggerFunc
}

// NewStore returns a new Store instance.
func NewStore(db *sql.DB) *Store {
proxy := squirrel.NewStmtCacher(db)
builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar).RunWith(proxy)
return &Store{
db: db,
proxy: proxy,
builder: builder,
}
return (&Store{
db: db,
useCacher: true,
}).init()
}

func newStoreWithTransaction(tx *sql.Tx) *Store {
proxy := squirrel.NewStmtCacher(tx)
builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar).RunWith(proxy)
return &Store{
proxy: proxy,
builder: builder,
// init initializes the store runner with debugging or caching, and returns itself for chainability
func (s *Store) init() *Store {
s.runner = s.db

if s.useCacher {
s.runner = squirrel.NewStmtCacher(s.db)
}

if s.logger != nil && !s.useCacher {
// Use BasicLogger as wrapper
s.runner = &basicLogger{s.logger, s.runner}
} else if s.logger != nil && s.useCacher {
// We're using a proxy (cacher), so use proxyLogger instead
s.runner = &proxyLogger{basicLogger{s.logger, s.runner}}
}

return s
}

// Debug returns a new store that will print all SQL statements to stdout using
Expand All @@ -127,12 +152,11 @@ func (s *Store) Debug() *Store {
// DebugWith returns a new store that will print all SQL statements using the
// given logger function.
func (s *Store) DebugWith(logger LoggerFunc) *Store {
proxy := &debugProxy{logger, s.proxy}
return &Store{
builder: s.builder.RunWith(proxy),
db: s.db,
proxy: proxy,
}
return (&Store{
db: s.db,
useCacher: s.useCacher,
logger: logger,
}).init()
}

// Insert insert the given record in the table, returns error if no-new
Expand Down Expand Up @@ -192,9 +216,20 @@ func (s *Store) Insert(schema Schema, record Record) error {
}

query.WriteString(fmt.Sprintf(" RETURNING %s", schema.ID().String()))
err = s.proxy.QueryRow(query.String(), values...).Scan(pk)
//err = s.runner.QueryRow(query.String(), values...).Scan(pk)
rows, err := s.runner.Query(query.String(), values...)
if err != nil {
return err
}
if rows.Next() {
err = rows.Scan(pk)
rows.Close()
if err != nil {
return err
}
}
} else {
_, err = s.proxy.Exec(query.String(), values...)
_, err = s.runner.Exec(query.String(), values...)
}

if err != nil {
Expand Down Expand Up @@ -255,7 +290,7 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64
query.WriteRune('=')
query.WriteString(fmt.Sprintf("$%d", len(columnNames)+1))

result, err := s.proxy.Exec(query.String(), append(values, record.GetID())...)
result, err := s.runner.Exec(query.String(), append(values, record.GetID())...)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -300,7 +335,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
query.WriteString(schema.ID().String())
query.WriteString("=$1")

_, err := s.proxy.Exec(query.String(), record.GetID())
_, err := s.runner.Exec(query.String(), record.GetID())
return err
}

Expand All @@ -309,7 +344,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
// WARNING: A result set created from a raw query can only be scanned using the
// RawScan method of ResultSet, instead of Scan.
func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
rows, err := s.proxy.Query(sql, params...)
rows, err := s.runner.Query(sql, params...)
if err != nil {
return nil, err
}
Expand All @@ -320,7 +355,7 @@ func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
// RawExec executes a raw SQL query with the given parameters and returns
// the number of affected rows.
func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
result, err := s.proxy.Exec(sql, params...)
result, err := s.runner.Exec(sql, params...)
if err != nil {
return 0, err
}
Expand All @@ -332,7 +367,7 @@ func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
func (s *Store) Find(q Query) (ResultSet, error) {
rels := q.getRelationships()
if containsRelationshipOfType(rels, OneToMany) {
return NewBatchingResultSet(newBatchQueryRunner(q.Schema(), s.proxy, q)), nil
return NewBatchingResultSet(newBatchQueryRunner(q.Schema(), s.runner, q)), nil
}

columns, builder := q.compile()
Expand All @@ -344,7 +379,7 @@ func (s *Store) Find(q Query) (ResultSet, error) {
builder = builder.Limit(limit)
}

rows, err := builder.RunWith(s.proxy).Query()
rows, err := builder.RunWith(s.runner).Query()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -379,7 +414,7 @@ func (s *Store) Reload(schema Schema, record Record) error {
q.Limit(1)
columns, builder := q.compile()

rows, err := builder.RunWith(s.proxy).Query()
rows, err := builder.RunWith(s.runner).Query()
if err != nil {
return err
}
Expand All @@ -399,7 +434,7 @@ func (s *Store) Count(q Query) (count int64, err error) {
_, queryBuilder := q.compile()
builder := builder.Set(queryBuilder, "Columns", nil).(squirrel.SelectBuilder)
err = builder.Column(fmt.Sprintf("COUNT(%s)", all.QualifiedName(q.Schema()))).
RunWith(s.proxy).
RunWith(s.runner).
QueryRow().
Scan(&count)
return
Expand All @@ -423,16 +458,26 @@ func (s *Store) MustCount(q Query) int64 {
// If a transaction is already opened in this store, instead of opening a new
// one, the other will be reused.
func (s *Store) Transaction(callback func(*Store) error) error {
if s.db == nil {
var tx *sql.Tx
var err error
if db, ok := s.db.(*sql.DB); ok {
// db is *sql.DB, not *sql.Tx
tx, err = db.Begin()
if err != nil {
return fmt.Errorf("kallax: can't open transaction: %s", err)
}
} else {
// store is already holding a transaction
return callback(s)
}

tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("kallax: can't open transaction: %s", err)
}
txStore := (&Store{
db: tx,
logger: s.logger,
useCacher: true,
}).init()

if err := callback(newStoreWithTransaction(tx)); err != nil {
if err := callback(txStore); err != nil {
if err := tx.Rollback(); err != nil {
return fmt.Errorf("kallax: unable to rollback transaction: %s", err)
}
Expand Down

0 comments on commit 1dec030

Please sign in to comment.