Skip to content

Commit

Permalink
multi: remove context.TODOs
Browse files Browse the repository at this point in the history
  • Loading branch information
ellemouton committed Jan 21, 2025
1 parent 85f4bcd commit fa642a5
Show file tree
Hide file tree
Showing 17 changed files with 213 additions and 171 deletions.
4 changes: 3 additions & 1 deletion lnrpc/invoicesrpc/addinvoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ type AddInvoiceConfig struct {

// QueryBlindedRoutes can be used to generate a few routes to this node
// that can then be used in the construction of a blinded payment path.
QueryBlindedRoutes func(lnwire.MilliSatoshi) ([]*route.Route, error)
QueryBlindedRoutes func(context.Context, lnwire.MilliSatoshi) (
[]*route.Route, error)
}

// AddInvoiceData contains the required data to create a new invoice.
Expand Down Expand Up @@ -521,6 +522,7 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig,

//nolint:ll
paths, err := blindedpath.BuildBlindedPaymentPaths(
ctx,
&blindedpath.BuildBlindedPathCfg{
FindRoutes: cfg.QueryBlindedRoutes,
FetchChannelEdgesByID: cfg.Graph.FetchChannelEdgesByID,
Expand Down
5 changes: 3 additions & 2 deletions lnrpc/routerrpc/router_backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ type RouterBackend struct {

// FindRoute is a closure that abstracts away how we locate/query for
// routes.
FindRoute func(*routing.RouteRequest) (*route.Route, float64, error)
FindRoute func(context.Context, *routing.RouteRequest) (*route.Route,
float64, error)

MissionControl MissionControl

Expand Down Expand Up @@ -169,7 +170,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
// Query the channel router for a possible path to the destination that
// can carry `in.Amt` satoshis _including_ the total fee required on
// the route
route, successProb, err := r.FindRoute(routeReq)
route, successProb, err := r.FindRoute(ctx, routeReq)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions lnrpc/routerrpc/router_backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool,
}
}

findRoute := func(req *routing.RouteRequest) (*route.Route, float64,
error) {
findRoute := func(_ context.Context, req *routing.RouteRequest) (
*route.Route, float64, error) {

if int64(req.Amount) != amtSat*1000 {
t.Fatal("unexpected amount")
Expand Down
12 changes: 6 additions & 6 deletions lnrpc/routerrpc/router_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ func (s *Server) EstimateRouteFee(ctx context.Context,
return nil, errors.New("amount must be greater than 0")

default:
return s.probeDestination(req.Dest, req.AmtSat)
return s.probeDestination(ctx, req.Dest, req.AmtSat)
}

case isProbeInvoice:
Expand All @@ -449,8 +449,8 @@ func (s *Server) EstimateRouteFee(ctx context.Context,

// probeDestination estimates fees along a route to a destination based on the
// contents of the local graph.
func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse,
error) {
func (s *Server) probeDestination(ctx context.Context, dest []byte,
amtSat int64) (*RouteFeeResponse, error) {

destNode, err := route.NewVertexFromBytes(dest)
if err != nil {
Expand Down Expand Up @@ -478,7 +478,7 @@ func (s *Server) probeDestination(dest []byte, amtSat int64) (*RouteFeeResponse,
return nil, err
}

route, _, err := s.cfg.Router.FindRoute(routeReq)
route, _, err := s.cfg.Router.FindRoute(ctx, routeReq)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1438,7 +1438,7 @@ func (s *Server) trackPaymentStream(context context.Context,
}

// BuildRoute builds a route from a list of hop addresses.
func (s *Server) BuildRoute(_ context.Context,
func (s *Server) BuildRoute(ctx context.Context,
req *BuildRouteRequest) (*BuildRouteResponse, error) {

if len(req.HopPubkeys) == 0 {
Expand Down Expand Up @@ -1499,7 +1499,7 @@ func (s *Server) BuildRoute(_ context.Context,

// Build the route and return it to the caller.
route, err := s.cfg.Router.BuildRoute(
amt, hops, outgoingChan, req.FinalCltvDelta, payAddr,
ctx, amt, hops, outgoingChan, req.FinalCltvDelta, payAddr,
firstHopBlob,
)
if err != nil {
Expand Down
8 changes: 5 additions & 3 deletions routing/blindedpath/blinded_path.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package blindedpath

import (
"bytes"
"context"
"errors"
"fmt"
"math"
Expand Down Expand Up @@ -38,7 +39,8 @@ type BuildBlindedPathCfg struct {
// various lengths and may even contain only a single hop. Any route
// shorter than MinNumHops will be padded with dummy hops during route
// construction.
FindRoutes func(value lnwire.MilliSatoshi) ([]*route.Route, error)
FindRoutes func(ctx context.Context, value lnwire.MilliSatoshi) (
[]*route.Route, error)

// FetchChannelEdgesByID attempts to look up the two directed edges for
// the channel identified by the channel ID.
Expand Down Expand Up @@ -111,12 +113,12 @@ type BuildBlindedPathCfg struct {

// BuildBlindedPaymentPaths uses the passed config to construct a set of blinded
// payment paths that can be added to the invoice.
func BuildBlindedPaymentPaths(cfg *BuildBlindedPathCfg) (
func BuildBlindedPaymentPaths(ctx context.Context, cfg *BuildBlindedPathCfg) (
[]*zpay32.BlindedPaymentPath, error) {

// Find some appropriate routes for the value to be routed. This will
// return a set of routes made up of real nodes.
routes, err := cfg.FindRoutes(cfg.ValueMsat)
routes, err := cfg.FindRoutes(ctx, cfg.ValueMsat)
if err != nil {
return nil, err
}
Expand Down
34 changes: 22 additions & 12 deletions routing/blindedpath/blinded_path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package blindedpath

import (
"bytes"
"context"
"encoding/hex"
"fmt"
"math/rand"
Expand Down Expand Up @@ -548,6 +549,9 @@ func genBlindedRouteData(rand *rand.Rand) *record.BlindedRouteData {
// https://github.com/lightning/bolts/blob/master/proposals/route-blinding.md
// This example does not use any dummy hops.
func TestBuildBlindedPath(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Alice chooses the following path to herself for blinded path
// construction:
// Carol -> Bob -> Alice.
Expand Down Expand Up @@ -591,9 +595,9 @@ func TestBuildBlindedPath(t *testing.T) {
},
}

paths, err := BuildBlindedPaymentPaths(&BuildBlindedPathCfg{
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
error) {
paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
[]*route.Route, error) {

return []*route.Route{realRoute}, nil
},
Expand Down Expand Up @@ -716,6 +720,9 @@ func TestBuildBlindedPath(t *testing.T) {
// TestBuildBlindedPathWithDummyHops tests the construction of a blinded path
// which includes dummy hops.
func TestBuildBlindedPathWithDummyHops(t *testing.T) {
t.Parallel()
ctx := context.Background()

// Alice chooses the following path to herself for blinded path
// construction:
// Carol -> Bob -> Alice.
Expand Down Expand Up @@ -759,9 +766,9 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) {
},
}

paths, err := BuildBlindedPaymentPaths(&BuildBlindedPathCfg{
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
error) {
paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
[]*route.Route, error) {

return []*route.Route{realRoute}, nil
},
Expand Down Expand Up @@ -929,9 +936,9 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) {
// the first 2 calls. FindRoutes returns 3 routes and so by the end, we
// still get 1 valid path.
var errCount int
paths, err = BuildBlindedPaymentPaths(&BuildBlindedPathCfg{
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
error) {
paths, err = BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
[]*route.Route, error) {

return []*route.Route{realRoute, realRoute, realRoute},
nil
Expand Down Expand Up @@ -997,7 +1004,10 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) {
// correctly for the case where the destination node is also the introduction
// node.
func TestSingleHopBlindedPath(t *testing.T) {
t.Parallel()

var (
ctx = context.Background()
_, pkC = btcec.PrivKeyFromBytes([]byte{1})
carol = route.NewVertex(pkC)
)
Expand All @@ -1009,9 +1019,9 @@ func TestSingleHopBlindedPath(t *testing.T) {
Hops: []*route.Hop{},
}

paths, err := BuildBlindedPaymentPaths(&BuildBlindedPathCfg{
FindRoutes: func(_ lnwire.MilliSatoshi) ([]*route.Route,
error) {
paths, err := BuildBlindedPaymentPaths(ctx, &BuildBlindedPathCfg{
FindRoutes: func(_ context.Context, _ lnwire.MilliSatoshi) (
[]*route.Route, error) {

return []*route.Route{realRoute}, nil
},
Expand Down
8 changes: 5 additions & 3 deletions routing/integrated_routing_context_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"fmt"
"math"
"os"
Expand Down Expand Up @@ -173,8 +174,8 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,
)
require.NoError(c.t, err)

getBandwidthHints := func(_ graphdb.RoutingGraph) (bandwidthHints,
error) {
getBandwidthHints := func(_ context.Context, _ graphdb.RoutingGraph) (
bandwidthHints, error) {

// Create bandwidth hints based on local channel balances.
bandwidthHints := map[uint64]lnwire.MilliSatoshi{}
Expand Down Expand Up @@ -237,7 +238,8 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32,

// Find a route.
route, err := session.RequestRoute(
amtRemaining, lnwire.MaxMilliSatoshi, inFlightHtlcs, 0,
context.Background(), amtRemaining,
lnwire.MaxMilliSatoshi, inFlightHtlcs, 0,
lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
},
Expand Down
14 changes: 8 additions & 6 deletions routing/mock_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package routing

import (
"context"
"fmt"
"sync"

Expand Down Expand Up @@ -168,9 +169,9 @@ type mockPaymentSessionOld struct {

var _ PaymentSession = (*mockPaymentSessionOld)(nil)

func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
_, height uint32, _ lnwire.CustomRecords) (*route.Route,
error) {
func (m *mockPaymentSessionOld) RequestRoute(_ context.Context, _,
_ lnwire.MilliSatoshi, _, height uint32, _ lnwire.CustomRecords) (
*route.Route, error) {

if m.release != nil {
m.release <- struct{}{}
Expand Down Expand Up @@ -695,12 +696,13 @@ type mockPaymentSession struct {

var _ PaymentSession = (*mockPaymentSession)(nil)

func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
activeShards, height uint32,
func (m *mockPaymentSession) RequestRoute(ctx context.Context, maxAmt,
feeLimit lnwire.MilliSatoshi, activeShards, height uint32,
firstHopCustomRecords lnwire.CustomRecords) (*route.Route, error) {

args := m.Called(
maxAmt, feeLimit, activeShards, height, firstHopCustomRecords,
ctx, maxAmt, feeLimit, activeShards, height,
firstHopCustomRecords,
)

// Type assertion on nil will fail, so we check and return here.
Expand Down
12 changes: 5 additions & 7 deletions routing/pathfind.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ const (
)

// pathFinder defines the interface of a path finding algorithm.
type pathFinder = func(g *graphParams, r *RestrictParams,
type pathFinder = func(ctx context.Context, g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, self, source, target route.Vertex,
amt lnwire.MilliSatoshi, timePref float64, finalHtlcExpiry int32) (
[]*unifiedEdge, float64, error)
Expand Down Expand Up @@ -577,12 +577,10 @@ func getOutgoingBalance(ctx context.Context, node route.Vertex,
// source. This is to properly accumulate fees that need to be paid along the
// path and accurately check the amount to forward at every node against the
// available bandwidth.
func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
self, source, target route.Vertex, amt lnwire.MilliSatoshi,
timePref float64, finalHtlcExpiry int32) ([]*unifiedEdge, float64,
error) {

ctx := context.TODO()
func findPath(ctx context.Context, g *graphParams, r *RestrictParams,
cfg *PathFindingConfig, self, source, target route.Vertex,
amt lnwire.MilliSatoshi, timePref float64,
finalHtlcExpiry int32) ([]*unifiedEdge, float64, error) {

// Pathfinding can be a significant portion of the total payment
// latency, especially on low-powered devices. Log several metrics to
Expand Down
17 changes: 9 additions & 8 deletions routing/pathfind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2305,28 +2305,29 @@ func TestPathInsufficientCapacityWithFee(t *testing.T) {

func TestPathFindSpecExample(t *testing.T) {
t.Parallel()
ctx := context.Background()

// All our path finding tests will assume a starting height of 100, so
// we'll pass that in to ensure that the router uses 100 as the current
// height.
const startingHeight = 100
ctx := createTestCtxFromFile(t, startingHeight, specExampleFilePath)
tCtx := createTestCtxFromFile(t, startingHeight, specExampleFilePath)

// We'll first exercise the scenario of a direct payment from Bob to
// Carol, so we set "B" as the source node so path finding starts from
// Bob.
bob := ctx.aliases["B"]
bob := tCtx.aliases["B"]

// Query for a route of 4,999,999 mSAT to carol.
carol := ctx.aliases["C"]
carol := tCtx.aliases["C"]
const amt lnwire.MilliSatoshi = 4999999
req, err := NewRouteRequest(
bob, &carol, amt, 0, noRestrictions, nil, nil,
nil, MinCLTVDelta,
)
require.NoError(t, err, "invalid route request")

route, _, err := ctx.router.FindRoute(req)
route, _, err := tCtx.router.FindRoute(ctx, req)
require.NoError(t, err, "unable to find route")

// Now we'll examine the route returned for correctness.
Expand All @@ -2343,8 +2344,8 @@ func TestPathFindSpecExample(t *testing.T) {

// Next, we'll set A as the source node so we can assert that we create
// the proper route for any queries starting with Alice.
alice := ctx.aliases["A"]
ctx.router.cfg.SelfNode = alice
alice := tCtx.aliases["A"]
tCtx.router.cfg.SelfNode = alice

// We'll now request a route from A -> B -> C.
req, err = NewRouteRequest(
Expand All @@ -2353,7 +2354,7 @@ func TestPathFindSpecExample(t *testing.T) {
)
require.NoError(t, err, "invalid route request")

route, _, err = ctx.router.FindRoute(req)
route, _, err = tCtx.router.FindRoute(ctx, req)
require.NoError(t, err, "unable to find routes")

// The route should be two hops.
Expand Down Expand Up @@ -3234,7 +3235,7 @@ func dbFindPath(graph *graphdb.ChannelGraph,
}()

route, _, err := findPath(
&graphParams{
context.Background(), &graphParams{
additionalEdges: additionalEdges,
bandwidthHints: bandwidthHints,
graph: graphSess,
Expand Down
Loading

0 comments on commit fa642a5

Please sign in to comment.