Skip to content

Commit

Permalink
spv: Request cfilters only for best chain
Browse files Browse the repository at this point in the history
This changes the initial cfilter fetch to request cfilters only for the
best sidechain.  This avoids having to perform work for sidechains that
are not (and will never) become the main chain.

Additionally, this refactors the initial cfilter fetching logic to a
separate function.  In the future, this will help in refactoring the
cfilter fetching to perform it in a batched fashion.
  • Loading branch information
matheusd committed Nov 20, 2023
1 parent 284762f commit 38621ff
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 36 deletions.
48 changes: 48 additions & 0 deletions spv/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package spv

import (
"context"
"fmt"
"runtime"
"sync"

Expand All @@ -18,6 +19,7 @@ import (
"github.com/decred/dcrd/gcs/v4"
"github.com/decred/dcrd/txscript/v4/stdaddr"
"github.com/decred/dcrd/wire"
"golang.org/x/sync/errgroup"
)

var _ wallet.NetworkBackend = (*Syncer)(nil)
Expand Down Expand Up @@ -89,6 +91,52 @@ func (s *Syncer) CFiltersV2(ctx context.Context, blockHashes []*chainhash.Hash)
}
}

// cfiltersV2FromNodes fetches cfilters for all the specified nodes from a
// remote peer.
func (s *Syncer) cfiltersV2FromNodes(ctx context.Context, cnet wire.CurrencyNet, rp *p2p.RemotePeer, nodes []*wallet.BlockNode) ([]*gcs.FilterV2, error) {
if len(nodes) == 0 {
return nil, nil
}

g, ctx := errgroup.WithContext(ctx)
res := make([]*gcs.FilterV2, len(nodes))
for i := range nodes {
i := i
g.Go(func() error {
node := nodes[i]
filter, proofIndex, proof, err := rp.CFilterV2(ctx, node.Hash)
if err != nil {
log.Tracef("Unable to fetch cfilter for "+
"block %v (height %d) from %v: %v",
node.Hash, node.Header.Height,
rp, err)
return err
}

err = validate.CFilterV2HeaderCommitment(cnet, node.Header,
filter, proofIndex, proof)
if err != nil {
errMsg := fmt.Sprintf("CFilter for block %v (height %d) "+
"received from %v failed validation: %v",
node.Hash, node.Header.Height,
rp, err)
log.Warnf(errMsg)
err := errors.E(errors.Protocol, errMsg)
rp.Disconnect(err)
return err
}

res[i] = filter
return nil
})
}
err := g.Wait()
if err != nil {
return nil, err
}
return res, nil
}

func (s *Syncer) String() string {
// This method is part of the wallet.Peer interface and will typically
// specify the remote address of the peer. Since the syncer can encompass
Expand Down
74 changes: 38 additions & 36 deletions spv/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,7 @@ nextbatch:
tipHash, tipHeight, time.Since(startTime).Round(time.Second))
return nil
}
log.Tracef("Attempting next batch of headers from %v", rp)

// Request headers from the selected peer.
locators, err := s.wallet.BlockLocators(ctx, nil)
Expand All @@ -1350,6 +1351,7 @@ nextbatch:
}
headers, err := rp.Headers(ctx, locators, &hashStop)
if err != nil {
log.Debugf("Unable to fetch headers from %v: %v", rp, err)
continue nextbatch
}

Expand All @@ -1372,6 +1374,8 @@ nextbatch:
// Try to pick a different peer with a higher advertised
// height or check there are no such peers (thus we're
// done with fetching headers for initial sync).
log.Tracef("Skipping to next batch due to "+
"len(headers) == 0 from %v", rp)
continue nextbatch
}

Expand Down Expand Up @@ -1412,36 +1416,9 @@ nextbatch:
}
continue nextbatch
}
s.sidechainMu.Unlock()

g, ctx := errgroup.WithContext(ctx)
for i := range headers {
i := i
g.Go(func() error {
node := nodes[i]
filter, proofIndex, proof, err := rp.CFilterV2(ctx, node.Hash)
if err != nil {
return err
}

err = validate.CFilterV2HeaderCommitment(cnet, node.Header,
filter, proofIndex, proof)
if err != nil {
return err
}

node.FilterV2 = filter
return nil
})
}
err = g.Wait()
if err != nil {
rp.Disconnect(err)
continue nextbatch
}

// Add new headers to the sidechain forest.
var added int
s.sidechainMu.Lock()
for _, n := range nodes {
haveBlock, _, _ := s.wallet.BlockInMainChain(ctx, n.Hash)
if haveBlock {
Expand All @@ -1451,15 +1428,8 @@ nextbatch:
added++
}
}
if added == 0 {
s.sidechainMu.Unlock()

continue nextbatch
}
s.fetchHeadersProgress(headers[len(headers)-1])
log.Debugf("Fetched %d new header(s) ending at height %d from %v",
added, nodes[len(nodes)-1].Header.Height, rp)

// Determine if this extends the best known chain.
bestChain, err := s.wallet.EvaluateBestChain(ctx, &s.sidechains)
if err != nil {
s.sidechainMu.Unlock()
Expand All @@ -1471,6 +1441,38 @@ nextbatch:
continue nextbatch
}

s.fetchHeadersProgress(headers[len(headers)-1])
log.Debugf("Fetched %d new header(s) ending at height %d from %v",
added, headers[len(headers)-1].Height, rp)

// Fetch cfilters for nodes which don't yet have them.
var missingCFNodes []*wallet.BlockNode
for i := range bestChain {
if bestChain[i].FilterV2 == nil {
missingCFNodes = bestChain[i:]
break
}
}
s.sidechainMu.Unlock()
filters, err := s.cfiltersV2FromNodes(ctx, cnet, rp, missingCFNodes)
if err != nil {
log.Debugf("Unable to fetch missing cfilters from %v: %v",
rp, err)
continue nextbatch
}
if len(missingCFNodes) > 0 {
log.Debugf("Fetched %d new cfilters(s) ending at height %d from %v",
len(missingCFNodes),
missingCFNodes[len(missingCFNodes)-1].Header.Height,
rp)
}

// Switch the best chain, now that all cfilters have been
// fetched for it.
s.sidechainMu.Lock()
for i := range missingCFNodes {
missingCFNodes[i].FilterV2 = filters[i]
}
prevChain, err := s.wallet.ChainSwitch(ctx, &s.sidechains, bestChain, nil)
if err != nil {
s.sidechainMu.Unlock()
Expand Down

0 comments on commit 38621ff

Please sign in to comment.