From a2700af1cfd092b0b2c07584b1f0664b7a579cc3 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 8 Oct 2024 13:26:09 +0200 Subject: [PATCH] update migration --- channeldb/migration/lnwire21/msat.go | 38 ++ channeldb/migration/lnwire21/true_boolean.go | 37 ++ channeldb/migration32/migration_test.go | 329 +++++----- .../migration32/mission_control_store.go | 566 ++++++++++++++---- channeldb/migration32/route.go | 22 + routing/result_interpretation.go | 2 - 6 files changed, 738 insertions(+), 256 deletions(-) create mode 100644 channeldb/migration/lnwire21/true_boolean.go diff --git a/channeldb/migration/lnwire21/msat.go b/channeldb/migration/lnwire21/msat.go index 7473d72c826..47c6762850e 100644 --- a/channeldb/migration/lnwire21/msat.go +++ b/channeldb/migration/lnwire21/msat.go @@ -2,8 +2,10 @@ package lnwire import ( "fmt" + "io" "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -49,3 +51,39 @@ func (m MilliSatoshi) String() string { } // TODO(roasbeef): extend with arithmetic operations? + +// Record returns a TLV record that can be used to encode/decode a MilliSatoshi +// to/from a TLV stream. +func (m *MilliSatoshi) Record() tlv.Record { + return tlv.MakeDynamicRecord( + 0, m, tlv.SizeBigSize(m), encodeMilliSatoshis, + decodeMilliSatoshis, + ) +} +func encodeMilliSatoshis(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*MilliSatoshi); ok { + bigSize := uint64(*v) + + return tlv.EBigSize(w, &bigSize, buf) + } + + return tlv.NewTypeForEncodingErr(val, "lnwire.MilliSatoshi") +} + +func decodeMilliSatoshis(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + if v, ok := val.(*MilliSatoshi); ok { + var bigSize uint64 + err := tlv.DBigSize(r, &bigSize, buf, l) + if err != nil { + return err + } + + *v = MilliSatoshi(bigSize) + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "lnwire.MilliSatoshi", l, l) +} diff --git a/channeldb/migration/lnwire21/true_boolean.go b/channeldb/migration/lnwire21/true_boolean.go new file mode 100644 index 00000000000..ae282039c3c --- /dev/null +++ b/channeldb/migration/lnwire21/true_boolean.go @@ -0,0 +1,37 @@ +package lnwire + +import ( + "io" + + "github.com/lightningnetwork/lnd/tlv" +) + +// TrueBoolean is a record that indicates true or false using the presence of +// the record. If the record is absent, it indicates false. If it is presence, +// it indicates true. +type TrueBoolean struct{} + +// Record returns the tlv record for the boolean entry. +func (b *TrueBoolean) Record() tlv.Record { + return tlv.MakeStaticRecord( + 0, b, 0, booleanEncoder, booleanDecoder, + ) +} + +func booleanEncoder(_ io.Writer, val interface{}, _ *[8]byte) error { + if _, ok := val.(*TrueBoolean); ok { + return nil + } + + return tlv.NewTypeForEncodingErr(val, "TrueBoolean") +} + +func booleanDecoder(_ io.Reader, val interface{}, _ *[8]byte, + l uint64) error { + + if _, ok := val.(*TrueBoolean); ok && (l == 0 || l == 1) { + return nil + } + + return tlv.NewTypeForEncodingErr(val, "TrueBoolean") +} diff --git a/channeldb/migration32/migration_test.go b/channeldb/migration32/migration_test.go index 1ce6016ed43..74709bd3d04 100644 --- a/channeldb/migration32/migration_test.go +++ b/channeldb/migration32/migration_test.go @@ -9,6 +9,7 @@ import ( lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" "github.com/lightningnetwork/lnd/channeldb/migtest" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" ) var ( @@ -24,174 +25,206 @@ var ( _ = pubKeyY.SetByteSlice(pubkeyBytes) pubkey = btcec.NewPublicKey(new(btcec.FieldVal).SetInt(4), pubKeyY) - paymentResultCommon1 = paymentResultCommon{ + customRecord = map[uint64][]byte{ + 65536: {4, 2, 2}, + } + + resultOld1 = paymentResultOld{ id: 0, timeFwd: time.Unix(0, 1), timeReply: time.Unix(0, 2), success: false, failureSourceIdx: &failureIndex, failure: &lnwire.FailFeeInsufficient{}, + route: &Route{ + TotalTimeLock: 100, + TotalAmount: 400, + SourcePubKey: testPub, + Hops: []*Hop{ + // A hop with MPP, AMP and custom + // records. + { + PubKeyBytes: testPub, + ChannelID: 100, + OutgoingTimeLock: 300, + AmtToForward: 500, + MPP: &MPP{ + paymentAddr: [32]byte{4, 5}, + totalMsat: 900, + }, + AMP: &{ + rootShare: [32]byte{0, 0}, + setID: [32]byte{5, 5, 5}, + childIndex: 90, + }, + CustomRecords: customRecord, + Metadata: []byte{6, 7, 7}, + }, + // A legacy hop. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + LegacyPayload: true, + }, + // A hop with a blinding key. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + BlindingPoint: pubkey, + EncryptedData: []byte{1, 2, 3}, + TotalAmtMsat: 600, + }, + // A hop with a blinding key and custom + // records. + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + CustomRecords: customRecord, + BlindingPoint: pubkey, + EncryptedData: []byte{1, 2, 3}, + TotalAmtMsat: 600, + }, + }, + }, } - paymentResultCommon2 = paymentResultCommon{ + resultOld2 = paymentResultOld{ id: 2, timeFwd: time.Unix(0, 4), timeReply: time.Unix(0, 7), success: true, - } -) - -// TestMigrateMCRouteSerialisation tests that the MigrateMCRouteSerialisation -// migration function correctly migrates the MC store from using the old route -// encoding to using the newer, more minimal route encoding. -func TestMigrateMCRouteSerialisation(t *testing.T) { - customRecord := map[uint64][]byte{ - 65536: {4, 2, 2}, - } - - resultsOld := []*paymentResultOld{ - { - paymentResultCommon: paymentResultCommon1, - route: &Route{ - TotalTimeLock: 100, - TotalAmount: 400, - SourcePubKey: testPub, - Hops: []*Hop{ - // A hop with MPP, AMP and custom - // records. - { - PubKeyBytes: testPub, - ChannelID: 100, - OutgoingTimeLock: 300, - AmtToForward: 500, - MPP: &MPP{ - paymentAddr: [32]byte{ - 4, 5, - }, - totalMsat: 900, - }, - AMP: &{ - rootShare: [32]byte{ - 0, 0, - }, - setID: [32]byte{ - 5, 5, 5, - }, - childIndex: 90, - }, - CustomRecords: customRecord, - Metadata: []byte{6, 7, 7}, - }, - // A legacy hop. - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - LegacyPayload: true, - }, - // A hop with a blinding key. - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - BlindingPoint: pubkey, - EncryptedData: []byte{ - 1, 2, 3, - }, - TotalAmtMsat: 600, - }, - // A hop with a blinding key and custom - // records. - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - CustomRecords: customRecord, - BlindingPoint: pubkey, - EncryptedData: []byte{ - 1, 2, 3, - }, - TotalAmtMsat: 600, - }, - }, - }, - }, - { - paymentResultCommon: paymentResultCommon2, - route: &Route{ - TotalTimeLock: 101, - TotalAmount: 401, - SourcePubKey: testPub2, - Hops: []*Hop{ - { - PubKeyBytes: testPub, - ChannelID: 800, - OutgoingTimeLock: 4, - AmtToForward: 4, - BlindingPoint: pubkey, - EncryptedData: []byte{ - 1, 2, 3, - }, - TotalAmtMsat: 600, - }, + route: &Route{ + TotalTimeLock: 101, + TotalAmount: 401, + SourcePubKey: testPub2, + Hops: []*Hop{ + { + PubKeyBytes: testPub, + ChannelID: 800, + OutgoingTimeLock: 4, + AmtToForward: 4, + BlindingPoint: pubkey, + EncryptedData: []byte{1, 2, 3}, + CustomRecords: customRecord, + TotalAmtMsat: 600, }, }, }, } - expectedResultsNew := []*paymentResultNew{ - { - paymentResultCommon: paymentResultCommon1, - route: &mcRoute{ - sourcePubKey: testPub, - totalAmount: 400, - hops: []*mcHop{ - { - channelID: 100, - pubKeyBytes: testPub, - amtToFwd: 500, - hasCustomRecords: true, - }, - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - }, - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - hasBlindingPoint: true, - }, - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - hasBlindingPoint: true, - hasCustomRecords: true, - }, + resultNew1 = paymentResultNew{ + id: 0, + timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType1, uint64]( + uint64(time.Unix(0, 1).UnixNano()), + ), + timeReply: tlv.NewPrimitiveRecord[tlv.TlvType2, uint64]( + uint64(time.Unix(0, 2).UnixNano()), + ), + failureSourceIdx: tlv.SomeRecordT[tlv.TlvType5, uint8]( + tlv.NewPrimitiveRecord[tlv.TlvType5, uint8]( + uint8(failureIndex), + ), + ), + failure: tlv.SomeRecordT[tlv.TlvType6, failureMessage]( + tlv.NewRecordT[tlv.TlvType6, failureMessage]( + failureMessage{&lnwire.FailFeeInsufficient{}}, + ), + ), + route: tlv.NewRecordT[tlv.TlvType3, mcRoute](mcRoute{ + sourcePubKey: tlv.NewRecordT[tlv.TlvType0, Vertex]( + testPub, + ), + totalAmount: tlv.NewRecordT[tlv.TlvType1, lnwire.MilliSatoshi](400), + hops: tlv.NewRecordT[tlv.TlvType2, mcHops](mcHops{ + { + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](100), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](500), + hasCustomRecords: tlv.SomeRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](), + ), }, - }, - }, - { - paymentResultCommon: paymentResultCommon2, - route: &mcRoute{ - sourcePubKey: testPub2, - totalAmount: 401, - hops: []*mcHop{ - { - channelID: 800, - pubKeyBytes: testPub, - amtToFwd: 4, - hasBlindingPoint: true, - }, + { + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), }, - }, - }, + { + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), + hasBlindingPoint: tlv.SomeRecordT[tlv.TlvType3, lnwire.TrueBoolean]( + tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](), + ), + }, + { + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), + hasCustomRecords: tlv.SomeRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](), + ), + hasBlindingPoint: tlv.SomeRecordT[tlv.TlvType3, lnwire.TrueBoolean]( + tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](), + ), + }, + }), + }), + } + + resultNew2 = paymentResultNew{ + id: 2, + timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType1, uint64]( + uint64(time.Unix(0, 4).UnixNano()), + ), + timeReply: tlv.NewPrimitiveRecord[tlv.TlvType2, uint64]( + uint64(time.Unix(0, 7).UnixNano()), + ), + success: tlv.SomeRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + tlv.NewRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + lnwire.TrueBoolean{}, + ), + ), + route: tlv.NewRecordT[tlv.TlvType3, mcRoute](mcRoute{ + sourcePubKey: tlv.NewRecordT[tlv.TlvType0, Vertex]( + testPub2, + ), + totalAmount: tlv.NewRecordT[tlv.TlvType1, lnwire.MilliSatoshi](401), + hops: tlv.NewRecordT[tlv.TlvType2, mcHops](mcHops{ + { + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](800), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex](testPub), + amtToFwd: tlv.NewPrimitiveRecord[tlv.TlvType2, lnwire.MilliSatoshi](4), + hasCustomRecords: tlv.SomeRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean](), + ), + hasBlindingPoint: tlv.SomeRecordT[tlv.TlvType3, lnwire.TrueBoolean]( + tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean](), + ), + }, + }), + }), } +) + +// TestMigrateMCRouteSerialisation tests that the MigrateMCRouteSerialisation +// migration function correctly migrates the MC store from using the old route +// encoding to using the newer, more minimal route encoding. +func TestMigrateMCRouteSerialisation(t *testing.T) { + var ( + resultsOld = []*paymentResultOld{ + &resultOld1, &resultOld2, + } + expectedResultsNew = []*paymentResultNew{ + &resultNew1, &resultNew2, + } + ) // Prime the database with some mission control data that uses the // old route encoding. diff --git a/channeldb/migration32/mission_control_store.go b/channeldb/migration32/mission_control_store.go index 3953cd1f25b..90de172430b 100644 --- a/channeldb/migration32/mission_control_store.go +++ b/channeldb/migration32/mission_control_store.go @@ -8,6 +8,7 @@ import ( "github.com/btcsuite/btcd/wire" lnwire "github.com/lightningnetwork/lnd/channeldb/migration/lnwire21" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -22,30 +23,22 @@ var ( resultsKey = []byte("missioncontrol-results") ) -// paymentResultCommon holds the fields that are shared by the old and new -// payment result encoding. -type paymentResultCommon struct { +// paymentResultOld is the information that becomes available when a payment +// attempt completes. +type paymentResultOld struct { id uint64 timeFwd, timeReply time.Time + route *Route success bool failureSourceIdx *int failure lnwire.FailureMessage } -// paymentResultOld is the information that becomes available when a payment -// attempt completes. -type paymentResultOld struct { - paymentResultCommon - route *Route -} - // deserializeOldResult deserializes a payment result using the old encoding. func deserializeOldResult(k, v []byte) (*paymentResultOld, error) { // Parse payment id. result := paymentResultOld{ - paymentResultCommon: paymentResultCommon{ - id: byteOrder.Uint64(k[8:]), - }, + id: byteOrder.Uint64(k[8:]), } r := bytes.NewReader(v) @@ -99,67 +92,425 @@ func deserializeOldResult(k, v []byte) (*paymentResultOld, error) { // convertPaymentResult converts a paymentResultOld to a paymentResultNew. func convertPaymentResult(old *paymentResultOld) *paymentResultNew { - return &paymentResultNew{ - paymentResultCommon: old.paymentResultCommon, - route: extractMCRoute(old.route), + return newPaymentResult( + old.id, extractMCRoute(old.route), old.timeFwd, old.timeReply, + old.success, old.failureSourceIdx, old.failure, + ) +} + +func newPaymentResult(id uint64, rt *mcRoute, timeFwd, timeReply time.Time, + success bool, failureSourceIdx *int, + failure lnwire.FailureMessage) *paymentResultNew { + + result := &paymentResultNew{ + id: id, + timeFwd: tlv.NewPrimitiveRecord[tlv.TlvType1, uint64]( + uint64(timeFwd.UnixNano()), + ), + timeReply: tlv.NewPrimitiveRecord[tlv.TlvType2, uint64]( + uint64(timeReply.UnixNano()), + ), + route: tlv.NewRecordT[tlv.TlvType3, mcRoute](*rt), + } + + if success { + result.success = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + lnwire.TrueBoolean{}, + ), + ) + } + + if failureSourceIdx != nil { + result.failureSourceIdx = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType5, uint8]( + uint8(*failureSourceIdx), + ), + ) + } + + if failure != nil { + result.failure = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType6, failureMessage]( + failureMessage{failure}, + ), + ) } + + return result } // paymentResultNew is the information that becomes available when a payment // attempt completes. type paymentResultNew struct { - paymentResultCommon - route *mcRoute + id uint64 + timeFwd tlv.RecordT[tlv.TlvType1, uint64] + timeReply tlv.RecordT[tlv.TlvType2, uint64] + route tlv.RecordT[tlv.TlvType3, mcRoute] + success tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean] + failureSourceIdx tlv.OptionalRecordT[tlv.TlvType5, uint8] + failure tlv.OptionalRecordT[tlv.TlvType6, failureMessage] +} + +type failureMessage struct { + lnwire.FailureMessage +} + +// Record returns a TLV record that can be used to encode/decode a list of +// mcRoute to/from a TLV stream. +func (r *failureMessage) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeFailureMessage(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodeFailureMessage, decodeFailureMessage, + ) +} + +func encodeFailureMessage(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*failureMessage); ok { + var b bytes.Buffer + err := lnwire.EncodeFailureMessage(&b, v.FailureMessage, 0) + if err != nil { + return err + } + + _, err = w.Write(b.Bytes()) + + return err + } + + return tlv.NewTypeForEncodingErr(val, "routing.failureMessage") +} + +func decodeFailureMessage(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if v, ok := val.(*failureMessage); ok { + msg, err := lnwire.DecodeFailureMessage(r, 0) + if err != nil { + return err + } + + *v = failureMessage{ + FailureMessage: msg, + } + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.failureMessage", l, l) } // extractMCRoute extracts the fields required by MC from the Route struct to // create the more minimal mcRoute struct. -func extractMCRoute(route *Route) *mcRoute { +// extractMCRoute extracts the fields required by MC from the Route struct to +// create the more minimal mcRoute struct. +func extractMCRoute(r *Route) *mcRoute { + //nolint:lll return &mcRoute{ - sourcePubKey: route.SourcePubKey, - totalAmount: route.TotalAmount, - hops: extractMCHops(route.Hops), + sourcePubKey: tlv.NewRecordT[tlv.TlvType0, Vertex](r.SourcePubKey), + totalAmount: tlv.NewRecordT[tlv.TlvType1, lnwire.MilliSatoshi](r.TotalAmount), + hops: tlv.NewRecordT[tlv.TlvType2, mcHops](extractMCHops(r.Hops)), } } // extractMCHops extracts the Hop fields that MC actually uses from a slice of // Hops. -func extractMCHops(hops []*Hop) []*mcHop { - mcHops := make([]*mcHop, len(hops)) +func extractMCHops(hops []*Hop) mcHops { + hopList := make(mcHops, len(hops)) for i, hop := range hops { - mcHops[i] = extractMCHop(hop) + hopList[i] = extractMCHop(hop) } - return mcHops + return hopList } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. +// +//nolint:lll func extractMCHop(hop *Hop) *mcHop { - return &mcHop{ - channelID: hop.ChannelID, - pubKeyBytes: hop.PubKeyBytes, - amtToFwd: hop.AmtToForward, - hasBlindingPoint: hop.BlindingPoint != nil, - hasCustomRecords: len(hop.CustomRecords) > 0, + h := mcHop{ + channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64]( + hop.ChannelID, + ), + pubKeyBytes: tlv.NewRecordT[tlv.TlvType1, Vertex]( + hop.PubKeyBytes, + ), + amtToFwd: tlv.NewRecordT[tlv.TlvType2, lnwire.MilliSatoshi]( + hop.AmtToForward, + ), } + + if hop.BlindingPoint != nil { + h.hasBlindingPoint = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType3, lnwire.TrueBoolean]( + lnwire.TrueBoolean{}, + ), + ) + } + + if len(hop.CustomRecords) != 0 { + h.hasCustomRecords = tlv.SomeRecordT( + tlv.NewRecordT[tlv.TlvType4, lnwire.TrueBoolean]( + lnwire.TrueBoolean{}, + ), + ) + } + + return &h } // mcRoute holds the bare minimum info about a payment attempt route that MC // requires. type mcRoute struct { - sourcePubKey Vertex - totalAmount lnwire.MilliSatoshi - hops []*mcHop + sourcePubKey tlv.RecordT[tlv.TlvType0, Vertex] + totalAmount tlv.RecordT[tlv.TlvType1, lnwire.MilliSatoshi] + hops tlv.RecordT[tlv.TlvType2, mcHops] +} + +// Record returns a TLV record that can be used to encode/decode a list of +// mcRoute to/from a TLV stream. +func (r *mcRoute) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeMCRoute(&b, r, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, r, recordSize, encodeMCRoute, decodeMCRoute, + ) +} + +func encodeMCRoute(w io.Writer, val interface{}, _ *[8]byte) error { + if v, ok := val.(*mcRoute); ok { + return serializeRoute(w, v) + } + + return tlv.NewTypeForEncodingErr(val, "routing.mcRoute") +} + +func decodeMCRoute(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if v, ok := val.(*mcRoute); ok { + route, err := deserializeRoute(io.LimitReader(r, int64(l))) + if err != nil { + return err + } + + *v = *route + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.mcRoute", l, l) +} + +// mcHops is a list of mcHop records. +type mcHops []*mcHop + +// Record returns a TLV record that can be used to encode/decode a list of +// mcHop to/from a TLV stream. +func (h *mcHops) Record() tlv.Record { + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := encodeMCHops(&b, h, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, h, recordSize, encodeMCHops, decodeMCHops, + ) +} + +func encodeMCHops(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*mcHops); ok { + // Encode the number of hops as a var int. + if err := tlv.WriteVarInt(w, uint64(len(*v)), buf); err != nil { + return err + } + + // With that written out, we'll now encode the entries + // themselves as a sub-TLV record, which includes its _own_ + // inner length prefix. + for _, hop := range *v { + var hopBytes bytes.Buffer + if err := serializeNewHop(&hopBytes, hop); err != nil { + return err + } + + // We encode the record with a varint length followed by + // the _raw_ TLV bytes. + tlvLen := uint64(len(hopBytes.Bytes())) + if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil { + return err + } + + if _, err := w.Write(hopBytes.Bytes()); err != nil { + return err + } + } + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "routing.mcHops") +} + +func decodeMCHops(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*mcHops); ok { + // First, we'll decode the varint that encodes how many hops + // are encoded in the stream. + numHops, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Now that we know how many records we'll need to read, we can + // iterate and read them all out in series. + for i := uint64(0); i < numHops; i++ { + // Read out the varint that encodes the size of this + // inner TLV record. + hopSize, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Using this information, we'll create a new limited + // reader that'll return an EOF once the end has been + // reached so the stream stops consuming bytes. + innerTlvReader := &io.LimitedReader{ + R: r, + N: int64(hopSize), + } + + hop, err := deserializeNewHop(innerTlvReader) + if err != nil { + return err + } + + *v = append(*v, hop) + } + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "routing.mcHops", l, l) +} + +// serializeRoute serializes a mcRoute and writes the resulting bytes to the +// given io.Writer. +func serializeRoute(w io.Writer, r *mcRoute) error { + records := lnwire.ProduceRecordsSorted( + &r.sourcePubKey, + &r.totalAmount, + &r.hops, + ) + + return lnwire.EncodeRecordsTo(w, records) +} + +// deserializeRoute deserializes the mcRoute from the given io.Reader. +func deserializeRoute(r io.Reader) (*mcRoute, error) { + var rt mcRoute + records := lnwire.ProduceRecordsSorted( + &rt.sourcePubKey, + &rt.totalAmount, + &rt.hops, + ) + + _, err := lnwire.DecodeRecords(r, records...) + if err != nil { + return nil, err + } + + return &rt, nil +} + +// deserializeNewHop deserializes the mcHop from the given io.Reader. +func deserializeNewHop(r io.Reader) (*mcHop, error) { + var ( + h mcHop + blinding = tlv.ZeroRecordT[tlv.TlvType3, lnwire.TrueBoolean]() + custom = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]() + ) + records := lnwire.ProduceRecordsSorted( + &h.channelID, + &h.pubKeyBytes, + &h.amtToFwd, + &blinding, + &custom, + ) + + typeMap, err := lnwire.DecodeRecords(r, records...) + if err != nil { + return nil, err + } + + if _, ok := typeMap[h.hasBlindingPoint.TlvType()]; ok { + h.hasBlindingPoint = tlv.SomeRecordT(blinding) + } + + if _, ok := typeMap[h.hasCustomRecords.TlvType()]; ok { + h.hasCustomRecords = tlv.SomeRecordT(custom) + } + + return &h, nil +} + +// serializeNewHop serializes a mcHop and writes the resulting bytes to the +// given io.Writer. +func serializeNewHop(w io.Writer, h *mcHop) error { + recordProducers := []tlv.RecordProducer{ + &h.channelID, + &h.pubKeyBytes, + &h.amtToFwd, + } + + h.hasBlindingPoint.WhenSome(func( + hasBlinding tlv.RecordT[tlv.TlvType3, lnwire.TrueBoolean]) { + + recordProducers = append(recordProducers, &hasBlinding) + }) + + h.hasCustomRecords.WhenSome(func( + hasCustom tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) { + + recordProducers = append(recordProducers, &hasCustom) + }) + + return lnwire.EncodeRecordsTo( + w, lnwire.ProduceRecordsSorted(recordProducers...), + ) } // mcHop holds the bare minimum info about a payment attempt route hop that MC // requires. type mcHop struct { - channelID uint64 - pubKeyBytes Vertex - amtToFwd lnwire.MilliSatoshi - hasBlindingPoint bool - hasCustomRecords bool + channelID tlv.RecordT[tlv.TlvType0, uint64] + pubKeyBytes tlv.RecordT[tlv.TlvType1, Vertex] + amtToFwd tlv.RecordT[tlv.TlvType2, lnwire.MilliSatoshi] + hasBlindingPoint tlv.OptionalRecordT[tlv.TlvType3, lnwire.TrueBoolean] + hasCustomRecords tlv.OptionalRecordT[tlv.TlvType4, lnwire.TrueBoolean] } // serializeOldResult serializes a payment result and returns a key and value @@ -222,51 +573,88 @@ func getResultKeyOld(rp *paymentResultOld) []byte { return keyBytes[:] } -// serializeNewResult serializes a payment result and returns a key and value -// byte slice to insert into the bucket. -func serializeNewResult(rp *paymentResultNew) ([]byte, []byte, error) { - // Write timestamps, success status, failure source index and route. - var b bytes.Buffer +func deserializeNewResult(k, v []byte) (*paymentResultNew, error) { + // Parse payment id. + result := paymentResultNew{ + id: byteOrder.Uint64(k[8:]), + } - var dbFailureSourceIdx int32 - if rp.failureSourceIdx == nil { - dbFailureSourceIdx = unknownFailureSourceIdx - } else { - dbFailureSourceIdx = int32(*rp.failureSourceIdx) + var ( + success = tlv.ZeroRecordT[tlv.TlvType4, lnwire.TrueBoolean]() + failIndex = tlv.ZeroRecordT[tlv.TlvType5, uint8]() + failMsg = tlv.ZeroRecordT[tlv.TlvType6, failureMessage]() + ) + recordProducers := []tlv.RecordProducer{ + &result.timeFwd, + &result.timeReply, + &result.route, + &success, + &failIndex, + &failMsg, } - err := WriteElements( - &b, - uint64(rp.timeFwd.UnixNano()), - uint64(rp.timeReply.UnixNano()), - rp.success, dbFailureSourceIdx, + r := bytes.NewReader(v) + typeMap, err := lnwire.DecodeRecords( + r, lnwire.ProduceRecordsSorted(recordProducers...)..., ) if err != nil { - return nil, nil, err + return nil, err } - if err := serializeMCRoute(&b, rp.route); err != nil { - return nil, nil, err + if _, ok := typeMap[result.success.TlvType()]; ok { + result.success = tlv.SomeRecordT(success) } - // Write failure. If there is no failure message, write an empty - // byte slice. - var failureBytes bytes.Buffer - if rp.failure != nil { - err := lnwire.EncodeFailureMessage(&failureBytes, rp.failure, 0) - if err != nil { - return nil, nil, err - } + if _, ok := typeMap[result.failureSourceIdx.TlvType()]; ok { + result.failureSourceIdx = tlv.SomeRecordT(failIndex) } - err = wire.WriteVarBytes(&b, 0, failureBytes.Bytes()) - if err != nil { - return nil, nil, err + + if _, ok := typeMap[result.failure.TlvType()]; ok { + result.failure = tlv.SomeRecordT(failMsg) + } + + return &result, nil +} + +// serializeNewResult serializes a payment result and returns a key and value +// byte slice to insert into the bucket. +func serializeNewResult(rp *paymentResultNew) ([]byte, []byte, error) { + recordProducers := []tlv.RecordProducer{ + &rp.timeFwd, + &rp.timeReply, + &rp.route, } + rp.success.WhenSome( + func(success tlv.RecordT[tlv.TlvType4, lnwire.TrueBoolean]) { + recordProducers = append(recordProducers, &success) + }, + ) + + rp.failureSourceIdx.WhenSome( + func(idx tlv.RecordT[tlv.TlvType5, uint8]) { + recordProducers = append(recordProducers, &idx) + }, + ) + + rp.failure.WhenSome( + func(failMsg tlv.RecordT[tlv.TlvType6, failureMessage]) { + recordProducers = append(recordProducers, &failMsg) + }, + ) + // Compose key that identifies this result. key := getResultKeyNew(rp) - return key, b.Bytes(), nil + var buff bytes.Buffer + err := lnwire.EncodeRecordsTo( + &buff, lnwire.ProduceRecordsSorted(recordProducers...), + ) + if err != nil { + return nil, nil, err + } + + return key, buff.Bytes(), err } // getResultKeyNew returns a byte slice representing a unique key for this @@ -278,43 +666,9 @@ func getResultKeyNew(rp *paymentResultNew) []byte { // key. This allows importing mission control data from an external // source without key collisions and keeps the records sorted // chronologically. - byteOrder.PutUint64(keyBytes[:], uint64(rp.timeReply.UnixNano())) + byteOrder.PutUint64(keyBytes[:], rp.timeReply.Val) byteOrder.PutUint64(keyBytes[8:], rp.id) - copy(keyBytes[16:], rp.route.sourcePubKey[:]) + copy(keyBytes[16:], rp.route.Val.sourcePubKey.Val[:]) return keyBytes[:] } - -// serializeMCRoute serializes an mcRoute and writes the bytes to the given -// io.Writer. -func serializeMCRoute(w io.Writer, r *mcRoute) error { - if err := WriteElements( - w, r.totalAmount, r.sourcePubKey[:], - ); err != nil { - return err - } - - if err := WriteElements(w, uint32(len(r.hops))); err != nil { - return err - } - - for _, h := range r.hops { - if err := serializeNewHop(w, h); err != nil { - return err - } - } - - return nil -} - -// serializeMCRoute serializes an mcHop and writes the bytes to the given -// io.Writer. -func serializeNewHop(w io.Writer, h *mcHop) error { - return WriteElements(w, - h.pubKeyBytes[:], - h.channelID, - h.amtToFwd, - h.hasBlindingPoint, - h.hasCustomRecords, - ) -} diff --git a/channeldb/migration32/route.go b/channeldb/migration32/route.go index a4d40a45cb3..1a36513217f 100644 --- a/channeldb/migration32/route.go +++ b/channeldb/migration32/route.go @@ -29,6 +29,28 @@ const VertexSize = 33 // public key. type Vertex [VertexSize]byte +// Record returns a TLV record that can be used to encode/decode a Vertex +// to/from a TLV stream. +func (v *Vertex) Record() tlv.Record { + return tlv.MakeStaticRecord(0, v, 33, encodeVertex, decodeVertex) +} + +func encodeVertex(w io.Writer, val interface{}, _ *[8]byte) error { + if b, ok := val.(*Vertex); ok { + _, err := w.Write(b[:]) + return err + } + return tlv.NewTypeForEncodingErr(val, "Vertex") +} + +func decodeVertex(r io.Reader, val interface{}, _ *[8]byte, l uint64) error { + if b, ok := val.(*Vertex); ok { + _, err := io.ReadFull(r, b[:]) + return err + } + return tlv.NewTypeForDecodingErr(val, "Vertex", l, 33) +} + // Route represents a path through the channel graph which runs over one or // more channels in succession. This struct carries all the information // required to craft the Sphinx onion packet, and send the payment along the diff --git a/routing/result_interpretation.go b/routing/result_interpretation.go index 35a62571766..fb2f8084db6 100644 --- a/routing/result_interpretation.go +++ b/routing/result_interpretation.go @@ -583,8 +583,6 @@ func extractMCHops(hops []*route.Hop) mcHops { } // extractMCHop extracts the Hop fields that MC actually uses from a Hop. -// -//nolint:lll func extractMCHop(hop *route.Hop) *mcHop { h := mcHop{ channelID: tlv.NewPrimitiveRecord[tlv.TlvType0, uint64](