Skip to content

Commit

Permalink
sqlite: replace query placeholders with prepared statements
Browse files Browse the repository at this point in the history
  • Loading branch information
n8maninger committed Jul 16, 2024
1 parent 4dc4705 commit 3bb27e9
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 49 deletions.
27 changes: 18 additions & 9 deletions persist/sqlite/peers.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,25 @@ func (s *Store) Banned(peer string) (banned bool, _ error) {
}

err = s.transaction(func(tx *txn) error {
query := `SELECT net_cidr, expiration FROM syncer_bans WHERE net_cidr IN (` + queryPlaceHolders(len(checkSubnets)) + `) ORDER BY expiration DESC LIMIT 1`

var subnet string
var expiration time.Time
err := tx.QueryRow(query, queryArgs(checkSubnets)...).Scan(&subnet, decode(&expiration))
banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows
if err == nil && banned {
s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration))
checkSubnetStmt, err := tx.Prepare(`SELECT expiration FROM syncer_bans WHERE net_cidr = $1 ORDER BY expiration DESC LIMIT 1`)
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
return err
defer checkSubnetStmt.Close()

for _, subnet := range checkSubnets {
var expiration time.Time

err := checkSubnetStmt.QueryRow(subnet).Scan(decode(&expiration))
banned = time.Now().Before(expiration) // will return false for any sql errors, including ErrNoRows
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("failed to check ban status: %w", err)
} else if banned {
s.log.Debug("found ban", zap.String("subnet", subnet), zap.Time("expiration", expiration))
return nil
}
}
return nil
})
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return false, fmt.Errorf("failed to check ban status: %w", err)
Expand Down
27 changes: 0 additions & 27 deletions persist/sqlite/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"math/rand"
"strings"
"time"

_ "github.com/mattn/go-sqlite3" // import sqlite3 driver
Expand Down Expand Up @@ -171,32 +170,6 @@ func (tx *txn) QueryRow(query string, args ...any) *row {
return &row{r, tx.log.Named("row")}
}

func queryPlaceHolders(n int) string {
if n == 0 {
return ""
} else if n == 1 {
return "?"
}
var b strings.Builder
b.Grow(((n - 1) * 2) + 1) // ?,?
for i := 0; i < n-1; i++ {
b.WriteString("?,")
}
b.WriteString("?")
return b.String()
}

func queryArgs[T any](args []T) []any {
if len(args) == 0 {
return nil
}
out := make([]any, len(args))
for i, arg := range args {
out[i] = arg
}
return out
}

// getDBVersion returns the current version of the database.
func getDBVersion(db *sql.DB) (version int64) {
// error is ignored -- the database may not have been initialized yet.
Expand Down
41 changes: 28 additions & 13 deletions persist/sqlite/wallet.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,42 @@ import (
)

func (s *Store) getWalletEventRelevantAddresses(tx *txn, id wallet.ID, eventIDs []int64) (map[int64][]types.Address, error) {
query := `SELECT ea.event_id, sa.sia_address
stmt, err := tx.Prepare(`SELECT sa.sia_address
FROM event_addresses ea
INNER JOIN sia_addresses sa ON (ea.address_id = sa.id)
WHERE event_id IN (` + queryPlaceHolders(len(eventIDs)) + `) AND address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=?)`

rows, err := tx.Query(query, append(queryArgs(eventIDs), id)...)
INNER JOIN wallet_addresses wa ON (ea.address_id = wa.address_id)
WHERE wa.wallet_id=? AND ea.event_id=?`)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to prepare statement: %w", err)
}
defer stmt.Close()

relevant := func(walletID wallet.ID, eventID int64) (addresses []types.Address, err error) {
rows, err := stmt.Query(walletID, eventID)
if err != nil {
return nil, fmt.Errorf("failed to query relevant addresses: %w", err)
}
defer rows.Close()

for rows.Next() {
var address types.Address
if err := rows.Scan(decode(&address)); err != nil {
return nil, fmt.Errorf("failed to scan relevant address: %w", err)
}
addresses = append(addresses, address)
}
return addresses, rows.Err()
}
defer rows.Close()

relevantAddresses := make(map[int64][]types.Address)
for rows.Next() {
var eventID int64
var address types.Address
if err := rows.Scan(&eventID, decode(&address)); err != nil {
return nil, fmt.Errorf("failed to scan relevant address: %w", err)
for _, eventID := range eventIDs {
addresses, err := relevant(id, eventID)
if err != nil {
return nil, err
}
relevantAddresses[eventID] = append(relevantAddresses[eventID], address)
relevantAddresses[eventID] = addresses
}
return relevantAddresses, rows.Err()
return relevantAddresses, nil
}

// WalletEvents returns the events relevant to a wallet, sorted by height descending.
Expand Down

0 comments on commit 3bb27e9

Please sign in to comment.