From c62c1fee93527502f0f3cd5a4995a44bcf409886 Mon Sep 17 00:00:00 2001 From: Omer <100387053+omerlavanet@users.noreply.github.com> Date: Sun, 5 Feb 2023 22:02:27 +0200 Subject: [PATCH] PRT-300 consumer state tracker (#285) * finished consumer state tracker * lint * fix some bugs * rpcconsumer fixes * fix missing VrfSk * fixed review changes --------- Co-authored-by: Aleksao998 --- cmd/lavad/main.go | 8 +- protocol/chainlib/chain_fetcher.go | 31 +++++ protocol/chaintracker/chain_tracker.go | 14 +- protocol/chaintracker/chain_tracker_test.go | 16 +-- protocol/rpcconsumer/rpcconsumer.go | 14 +- protocol/rpcconsumer/rpcconsumer_server.go | 17 ++- protocol/rpcprovider/rpcprovider.go | 8 +- .../statetracker/consumer_state_tracker.go | 95 ++++---------- .../finalization_consensus_updater.go | 16 ++- protocol/statetracker/pairing_updater.go | 121 +++++++++++++++--- protocol/statetracker/state_query.go | 80 ++++++++++-- protocol/statetracker/state_tracker.go | 69 ++++++++++ protocol/statetracker/statetracker.go | 5 - protocol/statetracker/tx_sender.go | 95 +++++++++++++- relayer/server.go | 2 +- 15 files changed, 459 insertions(+), 132 deletions(-) create mode 100644 protocol/statetracker/state_tracker.go delete mode 100644 protocol/statetracker/statetracker.go diff --git a/cmd/lavad/main.go b/cmd/lavad/main.go index d1f88ed9dc..a93aef2826 100644 --- a/cmd/lavad/main.go +++ b/cmd/lavad/main.go @@ -306,8 +306,8 @@ func main() { utils.LavaFormatInfo("cache service connected", &map[string]string{"address": cacheAddr}) } } - rpcConsumer.Start(ctx, txFactory, clientCtx, rpcEndpoints, requiredResponses, vrf_sk, cache) - return nil + err = rpcConsumer.Start(ctx, txFactory, clientCtx, rpcEndpoints, requiredResponses, vrf_sk, cache) + return err }, } @@ -425,8 +425,8 @@ func main() { if err != nil { utils.LavaFormatFatal("error fetching chainproxy.ParallelConnectionsFlag", err, nil) } - rpcProvider.Start(ctx, txFactory, clientCtx, rpcProviderEndpoints, cache, numberOfNodeParallelConnections) - return nil + err = rpcProvider.Start(ctx, txFactory, clientCtx, rpcProviderEndpoints, cache, numberOfNodeParallelConnections) + return err }, } diff --git a/protocol/chainlib/chain_fetcher.go b/protocol/chainlib/chain_fetcher.go index 9c884c207c..84a4207bdb 100644 --- a/protocol/chainlib/chain_fetcher.go +++ b/protocol/chainlib/chain_fetcher.go @@ -3,6 +3,12 @@ package chainlib import ( "context" "fmt" + + "github.com/cosmos/cosmos-sdk/client" +) + +const ( + TendermintStatusQuery = "status" ) type ChainFetcher struct { @@ -22,3 +28,28 @@ func NewChainFetcher(ctx context.Context, chainProxy ChainProxy /*here needs som cf := &ChainFetcher{chainProxy: chainProxy} return cf } + +type LavaChainFetcher struct { + clientCtx client.Context +} + +func (lcf *LavaChainFetcher) FetchLatestBlockNum(ctx context.Context) (int64, error) { + resultStatus, err := lcf.clientCtx.Client.Status(ctx) + if err != nil { + return 0, err + } + return resultStatus.SyncInfo.LatestBlockHeight, nil +} + +func (lcf *LavaChainFetcher) FetchBlockHashByNum(ctx context.Context, blockNum int64) (string, error) { + resultStatus, err := lcf.clientCtx.Client.Status(ctx) + if err != nil { + return "", err + } + return resultStatus.SyncInfo.LatestBlockHash.String(), nil +} + +func NewLavaChainFetcher(ctx context.Context, clientCtx client.Context) *LavaChainFetcher { + lcf := &LavaChainFetcher{clientCtx: clientCtx} + return lcf +} diff --git a/protocol/chaintracker/chain_tracker.go b/protocol/chaintracker/chain_tracker.go index 7fe9dd735b..3d51169e10 100644 --- a/protocol/chaintracker/chain_tracker.go +++ b/protocol/chaintracker/chain_tracker.go @@ -335,10 +335,16 @@ func (ct *ChainTracker) serve(ctx context.Context, listenAddr string) error { return nil } -func New(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker) { - config.validate() +func New(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker, err error) { + err = config.validate() + if err != nil { + return nil, err + } chainTracker = &ChainTracker{forkCallback: config.ForkCallback, newLatestCallback: config.NewLatestCallback, blocksToSave: config.BlocksToSave, chainFetcher: chainFetcher, latestBlockNum: 0, serverBlockMemory: config.ServerBlockMemory} - chainTracker.start(ctx, config.AverageBlockTime) - chainTracker.serve(ctx, config.ServerAddress) + err = chainTracker.start(ctx, config.AverageBlockTime) + if err != nil { + return nil, err + } + err = chainTracker.serve(ctx, config.ServerAddress) return } diff --git a/protocol/chaintracker/chain_tracker_test.go b/protocol/chaintracker/chain_tracker_test.go index 26afde322a..dd22cf7e0d 100644 --- a/protocol/chaintracker/chain_tracker_test.go +++ b/protocol/chaintracker/chain_tracker_test.go @@ -119,8 +119,8 @@ func TestChainTracker(t *testing.T) { currentLatestBlockInMock := mockChainFetcher.AdvanceBlock() chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)} - chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) - + chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) + require.NoError(t, err) for _, advancement := range tt.advancements { for i := 0; i < int(advancement); i++ { currentLatestBlockInMock = mockChainFetcher.AdvanceBlock() @@ -176,8 +176,8 @@ func TestChainTrackerRangeOnly(t *testing.T) { currentLatestBlockInMock := mockChainFetcher.AdvanceBlock() chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)} - chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) - + chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) + require.NoError(t, err) for _, advancement := range tt.advancements { for i := 0; i < int(advancement); i++ { currentLatestBlockInMock = mockChainFetcher.AdvanceBlock() @@ -256,8 +256,8 @@ func TestChainTrackerCallbacks(t *testing.T) { callbackCalledNewLatest = true } chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback, NewLatestCallback: newBlockCallback} - chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) - + chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) + require.NoError(t, err) t.Run("one long test", func(t *testing.T) { for _, tt := range tests { utils.LavaFormatInfo("started test "+tt.name, nil) @@ -338,8 +338,8 @@ func TestChainTrackerMaintainMemory(t *testing.T) { callbackCalledFork = true } chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback} - chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) - + chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig) + require.NoError(t, err) t.Run("one long test", func(t *testing.T) { for _, tt := range tests { utils.LavaFormatInfo("started test "+tt.name, nil) diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 06db4951c9..90ad72450e 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -33,9 +33,9 @@ var ( type ConsumerStateTrackerInf interface { RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) - RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser) + RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser, chainID string) error RegisterFinalizationConsensusForUpdates(context.Context, *lavaprotocol.FinalizationConsensus) - TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) + TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) error } type RPCConsumer struct { @@ -46,11 +46,12 @@ type RPCConsumer struct { // spawns a new RPCConsumer server with all it's processes and internals ready for communications func (rpcc *RPCConsumer) Start(ctx context.Context, txFactory tx.Factory, clientCtx client.Context, rpcEndpoints []*lavasession.RPCEndpoint, requiredResponses int, vrf_sk vrf.PrivateKey, cache *performance.Cache) (err error) { // spawn up ConsumerStateTracker - consumerStateTracker := statetracker.ConsumerStateTracker{} - rpcc.consumerStateTracker, err = consumerStateTracker.New(ctx, txFactory, clientCtx) + lavaChainFetcher := chainlib.NewLavaChainFetcher(ctx, clientCtx) + consumerStateTracker, err := statetracker.NewConsumerStateTracker(ctx, txFactory, clientCtx, lavaChainFetcher) if err != nil { return err } + rpcc.consumerStateTracker = consumerStateTracker rpcc.rpcConsumerServers = make(map[string]*RPCConsumerServer, len(rpcEndpoints)) keyName, err := sigs.GetKeyName(clientCtx) @@ -78,7 +79,10 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, txFactory tx.Factory, client if err != nil { return err } - consumerStateTracker.RegisterChainParserForSpecUpdates(ctx, chainParser) + err = consumerStateTracker.RegisterChainParserForSpecUpdates(ctx, chainParser, rpcEndpoint.ChainID) + if err != nil { + return err + } finalizationConsensus := &lavaprotocol.FinalizationConsensus{} consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus) rpcc.rpcConsumerServers[key] = &RPCConsumerServer{} diff --git a/protocol/rpcconsumer/rpcconsumer_server.go b/protocol/rpcconsumer/rpcconsumer_server.go index dcc37c958c..fb8fa4f369 100644 --- a/protocol/rpcconsumer/rpcconsumer_server.go +++ b/protocol/rpcconsumer/rpcconsumer_server.go @@ -40,7 +40,7 @@ type RPCConsumerServer struct { } type ConsumerTxSender interface { - TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) + TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) error } func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndpoint *lavasession.RPCEndpoint, @@ -58,6 +58,7 @@ func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndp rpccs.cache = cache rpccs.consumerTxSender = consumerStateTracker rpccs.requiredResponses = requiredResponses + rpccs.VrfSk = vrfSk pLogs, err := common.NewRPCConsumerLogs() if err != nil { utils.LavaFormatFatal("failed creating RPCConsumer logs", err, nil) @@ -217,11 +218,16 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider( } // get here only if performed a regular relay successfully expectedBH, numOfProviders := rpccs.finalizationConsensus.ExpectedBlockHeight(rpccs.chainParser) - err = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, epoch, reply.LatestBlock, chainMessage.GetServiceApi().ComputeUnits, relayLatency, expectedBH, numOfProviders, rpccs.consumerSessionManager.GetAtomicPairingAddressesLength()) // session done successfully + pairingAddressesLen := rpccs.consumerSessionManager.GetAtomicPairingAddressesLength() + latestBlock := relayResult.Reply.LatestBlock + err = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, epoch, latestBlock, chainMessage.GetServiceApi().ComputeUnits, relayLatency, expectedBH, numOfProviders, pairingAddressesLen) // session done successfully // set cache in a non blocking call go func() { - err2 := rpccs.cache.SetEntry(ctx, relayRequest, chainMessage.GetInterface().Interface, nil, chainID, dappID, reply, relayResult.Finalized) // caching in the portal doesn't care about hashes + new_ctx := context.Background() + new_ctx, cancel := context.WithTimeout(new_ctx, lavaprotocol.DataReliabilityTimeoutIncrease) + defer cancel() + err2 := rpccs.cache.SetEntry(new_ctx, relayRequest, chainMessage.GetInterface().Interface, nil, chainID, dappID, relayResult.Reply, relayResult.Finalized) // caching in the portal doesn't care about hashes if err2 != nil && !performance.NotInitialisedError.Is(err2) { utils.LavaFormatWarning("error updating cache with new entry", err2, nil) } @@ -399,7 +405,10 @@ func (rpccs *RPCConsumerServer) sendDataReliabilityRelayIfApplicable(ctx context report, conflicts := lavaprotocol.VerifyReliabilityResults(relayResult, dataReliabilityVerifications, numberOfReliabilitySessions) if report { for _, conflict := range conflicts { - rpccs.consumerTxSender.TxConflictDetection(ctx, nil, conflict, nil) + err := rpccs.consumerTxSender.TxConflictDetection(ctx, nil, conflict, nil) + if err != nil { + utils.LavaFormatError("could not send detection Transaction", err, &map[string]string{"conflict": fmt.Sprintf("%+v", conflict)}) + } } } // detectionMessage = conflicttypes.NewMsgDetection(consumerAddress, nil, &responseConflict, nil) diff --git a/protocol/rpcprovider/rpcprovider.go b/protocol/rpcprovider/rpcprovider.go index 12a65c8a4f..706a3a005b 100644 --- a/protocol/rpcprovider/rpcprovider.go +++ b/protocol/rpcprovider/rpcprovider.go @@ -92,13 +92,15 @@ func (rpcp *RPCProvider) Start(ctx context.Context, txFactory tx.Factory, client _, avergaeBlockTime, blocksToFinalization, blocksInFinalizationData := chainParser.ChainBlockStats() blocksToSaveChainTracker := uint64(blocksToFinalization + blocksInFinalizationData) chainTrackerConfig := chaintracker.ChainTrackerConfig{ - ServerAddress: rpcProviderEndpoint.NodeUrl, BlocksToSave: blocksToSaveChainTracker, - AverageBlockTime: avergaeBlockTime, // divide here to make the querying more often so we don't miss block changes by that much + AverageBlockTime: avergaeBlockTime, ServerBlockMemory: ChainTrackerDefaultMemory + blocksToSaveChainTracker, } chainFetcher := chainlib.NewChainFetcher(ctx, chainProxy) - chainTracker := chaintracker.New(ctx, chainFetcher, chainTrackerConfig) + chainTracker, err := chaintracker.New(ctx, chainFetcher, chainTrackerConfig) + if err != nil { + utils.LavaFormatFatal("failed creating chain tracker", err, &map[string]string{"chainTrackerConfig": fmt.Sprintf("%+v", chainTrackerConfig)}) + } reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker) providerStateTracker.RegisterReliabilityManagerForVoteUpdates(ctx, reliabilityManager) diff --git a/protocol/statetracker/consumer_state_tracker.go b/protocol/statetracker/consumer_state_tracker.go index 936bcb8a31..ee304b68a0 100644 --- a/protocol/statetracker/consumer_state_tracker.go +++ b/protocol/statetracker/consumer_state_tracker.go @@ -3,7 +3,6 @@ package statetracker import ( "context" "fmt" - "sync" "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/tx" @@ -14,100 +13,62 @@ import ( "github.com/lavanet/lava/protocol/lavasession" "github.com/lavanet/lava/utils" conflicttypes "github.com/lavanet/lava/x/conflict/types" - spectypes "github.com/lavanet/lava/x/spec/types" ) // ConsumerStateTracker CSTis a class for tracking consumer data from the lava blockchain, such as epoch changes. // it allows also to query specific data form the blockchain and acts as a single place to send transactions type ConsumerStateTracker struct { - consumerAddress sdk.AccAddress - chainTracker *chaintracker.ChainTracker - stateQuery *StateQuery - txSender *TxSender - registrationLock sync.RWMutex - newLavaBlockUpdaters map[string]Updater + consumerAddress sdk.AccAddress + stateQuery *ConsumerStateQuery + txSender *ConsumerTxSender + *StateTracker } -type Updater interface { - Update(int64) - UpdaterKey() string -} - -func (cst *ConsumerStateTracker) New(ctx context.Context, txFactory tx.Factory, clientCtx client.Context) (ret *ConsumerStateTracker, err error) { - // set up StateQuery - // Spin up chain tracker on the lava node, its address is in the --node flag (or its default), on new block call to newLavaBlock - // use StateQuery to get the lava spec and spin up the chain tracker with the right params - // set up txSender the same way - - stateQuery := StateQuery{} - cst.stateQuery, err = stateQuery.New(ctx, clientCtx) +func NewConsumerStateTracker(ctx context.Context, txFactory tx.Factory, clientCtx client.Context, chainFetcher chaintracker.ChainFetcher) (ret *ConsumerStateTracker, err error) { + stateTrackerBase, err := NewStateTracker(ctx, txFactory, clientCtx, chainFetcher) if err != nil { return nil, err } - - txSender := TxSender{} - cst.txSender, err = txSender.New(ctx, txFactory, clientCtx) + txSender, err := NewConsumerTxSender(ctx, clientCtx, txFactory) if err != nil { return nil, err } - cst.consumerAddress = clientCtx.FromAddress + cst := &ConsumerStateTracker{StateTracker: stateTrackerBase, stateQuery: NewConsumerStateQuery(ctx, clientCtx), txSender: txSender} return cst, nil } -func (cst *ConsumerStateTracker) newLavaBlock(latestBlock int64) { - // go over the registered updaters and trigger update - cst.registrationLock.RLock() - defer cst.registrationLock.RUnlock() - for _, updater := range cst.newLavaBlockUpdaters { - updater.Update(latestBlock) - } -} - func (cst *ConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) { // register this CSM to get the updated pairing list when a new epoch starts - cst.registrationLock.Lock() - defer cst.registrationLock.Unlock() - // make sure new lava block exists as a callback in stateTracker - // add updatePairingForRegistered as a callback on a new block - - var pairingUpdater *PairingUpdater = nil // UpdaterKey is nil safe - pairingUpdater_raw, ok := cst.newLavaBlockUpdaters[pairingUpdater.UpdaterKey()] + pairingUpdater := NewPairingUpdater(cst.consumerAddress, cst.stateQuery) + pairingUpdaterRaw := cst.StateTracker.RegisterForUpdates(ctx, pairingUpdater) + pairingUpdater, ok := pairingUpdaterRaw.(*PairingUpdater) if !ok { - pairingUpdater = NewPairingUpdater(cst.consumerAddress, cst.stateQuery) - cst.newLavaBlockUpdaters[pairingUpdater.UpdaterKey()] = pairingUpdater + utils.LavaFormatFatal("invalid updater type returned from RegisterForUpdates", nil, &map[string]string{"updater": fmt.Sprintf("%+v", pairingUpdaterRaw)}) } - pairingUpdater, ok = pairingUpdater_raw.(*PairingUpdater) - if !ok { - utils.LavaFormatFatal("invalid_updater_key in RegisterConsumerSessionManagerForPairingUpdates", nil, &map[string]string{"updaters_map": fmt.Sprintf("%+v", cst.newLavaBlockUpdaters)}) - } - pairingUpdater.RegisterPairing(consumerSessionManager) + pairingUpdater.RegisterPairing(ctx, consumerSessionManager) } func (cst *ConsumerStateTracker) RegisterFinalizationConsensusForUpdates(ctx context.Context, finalizationConsensus *lavaprotocol.FinalizationConsensus) { - cst.registrationLock.Lock() - defer cst.registrationLock.Unlock() - - var finalizationConsensusUpdater *FinalizationConsensusUpdater = nil // UpdaterKey is nil safe - finalizationConsensusUpdater_raw, ok := cst.newLavaBlockUpdaters[finalizationConsensusUpdater.UpdaterKey()] + finalizationConsensusUpdater := NewFinalizationConsensusUpdater(cst.consumerAddress, cst.stateQuery) + finalizationConsensusUpdaterRaw := cst.StateTracker.RegisterForUpdates(ctx, finalizationConsensusUpdater) + finalizationConsensusUpdater, ok := finalizationConsensusUpdaterRaw.(*FinalizationConsensusUpdater) if !ok { - finalizationConsensusUpdater = NewFinalizationConsensusUpdater(cst.consumerAddress, cst.stateQuery) - cst.newLavaBlockUpdaters[finalizationConsensusUpdater.UpdaterKey()] = finalizationConsensusUpdater - } - finalizationConsensusUpdater, ok = finalizationConsensusUpdater_raw.(*FinalizationConsensusUpdater) - if !ok { - utils.LavaFormatFatal("invalid_updater_key in RegisterFinalizationConsensusForUpdates", nil, &map[string]string{"updaters_map": fmt.Sprintf("%+v", cst.newLavaBlockUpdaters)}) + utils.LavaFormatFatal("invalid updater type returned from RegisterForUpdates", nil, &map[string]string{"updater": fmt.Sprintf("%+v", finalizationConsensusUpdaterRaw)}) } finalizationConsensusUpdater.RegisterFinalizationConsensus(finalizationConsensus) } -func (cst *ConsumerStateTracker) RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser) { - // register this chainParser for spec updates - // currently just set the first one, and have a TODO to handle spec changes - // get the spec and set it into the chainParser - spec := spectypes.Spec{} - chainParser.SetSpec(spec) +func (cst *ConsumerStateTracker) RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser, chainID string) error { + // TODO: handle spec changes + spec, err := cst.stateQuery.GetSpec(ctx, chainID) + if err != nil { + return err + } + chainParser.SetSpec(*spec) + return nil } -func (cst *ConsumerStateTracker) TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) { - cst.txSender.TxConflictDetection(ctx, finalizationConflict, responseConflict, sameProviderConflict) +func (cst *ConsumerStateTracker) TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) error { + err := cst.txSender.TxConflictDetection(ctx, finalizationConflict, responseConflict, sameProviderConflict) + return err } diff --git a/protocol/statetracker/finalization_consensus_updater.go b/protocol/statetracker/finalization_consensus_updater.go index 050049f64b..ca5c7a8dca 100644 --- a/protocol/statetracker/finalization_consensus_updater.go +++ b/protocol/statetracker/finalization_consensus_updater.go @@ -1,8 +1,12 @@ package statetracker import ( + "context" + "strconv" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lavanet/lava/protocol/lavaprotocol" + "github.com/lavanet/lava/utils" ) const ( @@ -12,10 +16,10 @@ const ( type FinalizationConsensusUpdater struct { registeredFinalizationConsensuses []*lavaprotocol.FinalizationConsensus nextBlockForUpdate uint64 - stateQuery *StateQuery + stateQuery *ConsumerStateQuery } -func NewFinalizationConsensusUpdater(consumerAddress sdk.AccAddress, stateQuery *StateQuery) *FinalizationConsensusUpdater { +func NewFinalizationConsensusUpdater(consumerAddress sdk.AccAddress, stateQuery *ConsumerStateQuery) *FinalizationConsensusUpdater { return &FinalizationConsensusUpdater{registeredFinalizationConsensuses: []*lavaprotocol.FinalizationConsensus{}, stateQuery: stateQuery} } @@ -29,10 +33,16 @@ func (fcu *FinalizationConsensusUpdater) UpdaterKey() string { } func (fcu *FinalizationConsensusUpdater) Update(latestBlock int64) { + ctx := context.Background() if int64(fcu.nextBlockForUpdate) > latestBlock { return } - _, epoch, nextBlockForUpdate := fcu.stateQuery.GetPairing(latestBlock) + _, epoch, nextBlockForUpdate, err := fcu.stateQuery.GetPairing(ctx, "", latestBlock) + if err != nil { + utils.LavaFormatError("could not get block stats for finzalizationConsensus, trying again later", err, &map[string]string{"latestBlock": strconv.FormatInt(latestBlock, 10)}) + fcu.nextBlockForUpdate += 1 + return + } fcu.nextBlockForUpdate = nextBlockForUpdate for _, finalizationConsensus := range fcu.registeredFinalizationConsensuses { finalizationConsensus.NewEpoch(epoch) diff --git a/protocol/statetracker/pairing_updater.go b/protocol/statetracker/pairing_updater.go index df1b77f1ee..5c46fc9de7 100644 --- a/protocol/statetracker/pairing_updater.go +++ b/protocol/statetracker/pairing_updater.go @@ -1,9 +1,14 @@ package statetracker import ( + "fmt" + "strconv" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/lavanet/lava/protocol/lavasession" + "github.com/lavanet/lava/utils" epochstoragetypes "github.com/lavanet/lava/x/epochstorage/types" + "golang.org/x/net/context" ) const ( @@ -11,18 +16,29 @@ const ( ) type PairingUpdater struct { - consumerSessionManagers []*lavasession.ConsumerSessionManager - nextBlockForUpdate uint64 - stateQuery *StateQuery + consumerSessionManagersMap map[string][]*lavasession.ConsumerSessionManager // key is chainID so we don;t run getPairing more than once per chain + nextBlockForUpdate uint64 + stateQuery *ConsumerStateQuery } -func NewPairingUpdater(consumerAddress sdk.AccAddress, stateQuery *StateQuery) *PairingUpdater { - return &PairingUpdater{consumerSessionManagers: []*lavasession.ConsumerSessionManager{}, stateQuery: stateQuery} +func NewPairingUpdater(consumerAddress sdk.AccAddress, stateQuery *ConsumerStateQuery) *PairingUpdater { + return &PairingUpdater{consumerSessionManagersMap: map[string][]*lavasession.ConsumerSessionManager{}, stateQuery: stateQuery} } -func (pu *PairingUpdater) RegisterPairing(consumerSessionManager *lavasession.ConsumerSessionManager) { - // TODO: also update here for the first time - pu.consumerSessionManagers = append(pu.consumerSessionManagers, consumerSessionManager) +func (pu *PairingUpdater) RegisterPairing(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) error { + chainID := consumerSessionManager.RPCEndpoint().ChainID + pairingList, epoch, _, err := pu.stateQuery.GetPairing(context.Background(), chainID, -1) + if err != nil { + return err + } + pu.updateConsummerSessionManager(ctx, pairingList, consumerSessionManager, epoch) + consumerSessionsManagersList, ok := pu.consumerSessionManagersMap[chainID] + if !ok { + pu.consumerSessionManagersMap[chainID] = []*lavasession.ConsumerSessionManager{consumerSessionManager} + return nil + } + pu.consumerSessionManagersMap[chainID] = append(consumerSessionsManagersList, consumerSessionManager) + return nil } func (pu *PairingUpdater) UpdaterKey() string { @@ -30,19 +46,94 @@ func (pu *PairingUpdater) UpdaterKey() string { } func (pu *PairingUpdater) Update(latestBlock int64) { + ctx := context.Background() if int64(pu.nextBlockForUpdate) > latestBlock { return } - pairingList, epoch, nextBlockForUpdate := pu.stateQuery.GetPairing(latestBlock) + nextBlockForUpdateList := []uint64{} + for chainID, consumerSessionManagerList := range pu.consumerSessionManagersMap { + pairingList, epoch, nextBlockForUpdate, err := pu.stateQuery.GetPairing(ctx, chainID, latestBlock) + if err != nil { + utils.LavaFormatError("could not update pairing for chain, trying again next block", err, &map[string]string{"chain": chainID}) + nextBlockForUpdateList = append(nextBlockForUpdateList, pu.nextBlockForUpdate+1) + continue + } else { + nextBlockForUpdateList = append(nextBlockForUpdateList, nextBlockForUpdate) + } + for _, consumerSessionManager := range consumerSessionManagerList { + // same pairing for all apiInterfaces, they pick the right endpoints from inside using our filter function + err := pu.updateConsummerSessionManager(ctx, pairingList, consumerSessionManager, epoch) + if err != nil { + utils.LavaFormatError("failed updating consumer session manager", err, &map[string]string{"chainID": chainID, "apiInterface": consumerSessionManager.RPCEndpoint().ApiInterface, "pairingListLen": strconv.Itoa(len(pairingList))}) + continue + } + } + } + nextBlockForUpdateMin := uint64(0) + for idx, blockToUpdate := range nextBlockForUpdateList { + if idx == 0 || blockToUpdate < nextBlockForUpdateMin { + nextBlockForUpdateMin = blockToUpdate + } + } + pu.nextBlockForUpdate = nextBlockForUpdateMin +} - for _, consumerSessionManager := range pu.consumerSessionManagers { - pairingListForThisCSM := filterPairingListByEndpoint(pairingList, consumerSessionManager.RPCEndpoint()) - consumerSessionManager.UpdateAllProviders(epoch, pairingListForThisCSM) +func (pu *PairingUpdater) updateConsummerSessionManager(ctx context.Context, pairingList []epochstoragetypes.StakeEntry, consumerSessionManager *lavasession.ConsumerSessionManager, epoch uint64) (err error) { + pairingListForThisCSM, err := pu.filterPairingListByEndpoint(ctx, pairingList, consumerSessionManager.RPCEndpoint(), epoch) + if err != nil { + return err } - pu.nextBlockForUpdate = nextBlockForUpdate + err = consumerSessionManager.UpdateAllProviders(epoch, pairingListForThisCSM) + return } -func filterPairingListByEndpoint(pairingList []epochstoragetypes.StakeEntry, rpcEndpoint lavasession.RPCEndpoint) (filteredList []*lavasession.ConsumerSessionsWithProvider) { +func (pu *PairingUpdater) filterPairingListByEndpoint(ctx context.Context, pairingList []epochstoragetypes.StakeEntry, rpcEndpoint lavasession.RPCEndpoint, epoch uint64) (filteredList []*lavasession.ConsumerSessionsWithProvider, err error) { // go over stake entries, and filter endpoints that match geolocation and api interface - return + pairing := []*lavasession.ConsumerSessionsWithProvider{} + for _, provider := range pairingList { + // + // Sanity + providerEndpoints := provider.GetEndpoints() + if len(providerEndpoints) == 0 { + utils.LavaFormatError("skipping provider with no endoints", nil, &map[string]string{"Address": provider.Address, "ChainID": provider.Chain}) + continue + } + + relevantEndpoints := []epochstoragetypes.Endpoint{} + for _, endpoint := range providerEndpoints { + // only take into account endpoints that use the same api interface and the same geolocation + if endpoint.UseType == rpcEndpoint.ApiInterface && endpoint.Geolocation == rpcEndpoint.Geolocation { + relevantEndpoints = append(relevantEndpoints, endpoint) + } + } + if len(relevantEndpoints) == 0 { + utils.LavaFormatError("skipping provider, No relevant endpoints for apiInterface", nil, &map[string]string{"Address": provider.Address, "ChainID": provider.Chain, "apiInterface": rpcEndpoint.ApiInterface, "Endpoints": fmt.Sprintf("%v", providerEndpoints)}) + continue + } + + maxcu, err := pu.stateQuery.GetMaxCUForUser(ctx, provider.Chain, epoch) + if err != nil { + return nil, err + } + // + pairingEndpoints := make([]*lavasession.Endpoint, len(relevantEndpoints)) + for idx, relevantEndpoint := range relevantEndpoints { + endp := &lavasession.Endpoint{NetworkAddress: relevantEndpoint.IPPORT, Enabled: true, Client: nil, ConnectionRefusals: 0} + pairingEndpoints[idx] = endp + } + + pairing = append(pairing, &lavasession.ConsumerSessionsWithProvider{ + PublicLavaAddress: provider.Address, + Endpoints: pairingEndpoints, + Sessions: map[int64]*lavasession.SingleConsumerSession{}, + MaxComputeUnits: maxcu, + ReliabilitySent: false, + PairingEpoch: epoch, + }) + } + if len(pairing) == 0 { + return nil, utils.LavaFormatError("Failed getting pairing for consumer, pairing is empty", err, &map[string]string{"apiInterface": rpcEndpoint.ApiInterface, "ChainID": rpcEndpoint.ChainID, "geolocation": strconv.FormatUint(rpcEndpoint.Geolocation, 10)}) + } + // replace previous pairing with new providers + return pairing, nil } diff --git a/protocol/statetracker/state_query.go b/protocol/statetracker/state_query.go index 852becce2c..08d497d714 100644 --- a/protocol/statetracker/state_query.go +++ b/protocol/statetracker/state_query.go @@ -2,20 +2,84 @@ package statetracker import ( "context" + "strconv" "github.com/cosmos/cosmos-sdk/client" + "github.com/lavanet/lava/utils" epochstoragetypes "github.com/lavanet/lava/x/epochstorage/types" + pairingtypes "github.com/lavanet/lava/x/pairing/types" + spectypes "github.com/lavanet/lava/x/spec/types" ) -type StateQuery struct{} +type StateQuery struct { + SpecQueryClient spectypes.QueryClient + PairingQueryClient pairingtypes.QueryClient + EpochStorageQueryClient epochstoragetypes.QueryClient +} + +func NewStateQuery(ctx context.Context, clientCtx client.Context) *StateQuery { + sq := &StateQuery{} + sq.SpecQueryClient = spectypes.NewQueryClient(clientCtx) + sq.PairingQueryClient = pairingtypes.NewQueryClient(clientCtx) + sq.EpochStorageQueryClient = epochstoragetypes.NewQueryClient(clientCtx) + return sq +} + +type ConsumerStateQuery struct { + StateQuery + clientCtx client.Context + cachedPairings map[string]*pairingtypes.QueryGetPairingResponse +} + +func NewConsumerStateQuery(ctx context.Context, clientCtx client.Context) *ConsumerStateQuery { + csq := &ConsumerStateQuery{StateQuery: *NewStateQuery(ctx, clientCtx), clientCtx: clientCtx, cachedPairings: map[string]*pairingtypes.QueryGetPairingResponse{}} + return csq +} + +func (csq *ConsumerStateQuery) GetPairing(ctx context.Context, chainID string, latestBlock int64) (pairingList []epochstoragetypes.StakeEntry, epoch uint64, nextBlockForUpdate uint64, errRet error) { + if chainID == "" { + // the caller doesn;t care which so just return the first + for key := range csq.cachedPairings { + chainID = key + } + if chainID == "" { + chainID = "LAV1" + utils.LavaFormatWarning("failed to run get pairing as there is no cached entry for empty chainID call, using default chainID", nil, &map[string]string{"chainID": chainID}) + } + } + + if cachedResp, ok := csq.cachedPairings[chainID]; ok { + if cachedResp.BlockOfNextPairing > uint64(latestBlock) { + return cachedResp.Providers, cachedResp.CurrentEpoch, cachedResp.BlockOfNextPairing, nil + } + } + + pairingResp, err := csq.PairingQueryClient.GetPairing(ctx, &pairingtypes.QueryGetPairingRequest{ + ChainID: chainID, + Client: csq.clientCtx.FromAddress.String(), + }) + if err != nil { + return nil, 0, 0, utils.LavaFormatError("Failed in get pairing query", err, &map[string]string{}) + } + csq.cachedPairings[chainID] = pairingResp + return pairingResp.Providers, pairingResp.CurrentEpoch, pairingResp.BlockOfNextPairing, nil +} -func (sq *StateQuery) New(ctx context.Context, clientCtx client.Context) (ret *StateQuery, err error) { - // set up the rpcClient necessary to make queries - return sq, nil +func (csq *ConsumerStateQuery) GetMaxCUForUser(ctx context.Context, chainID string, epoch uint64) (maxCu uint64, err error) { + address := csq.clientCtx.FromAddress.String() + UserEntryRes, err := csq.PairingQueryClient.UserEntry(ctx, &pairingtypes.QueryUserEntryRequest{ChainID: chainID, Address: address, Block: epoch}) + if err != nil { + return 0, utils.LavaFormatError("failed querying StakeEntry for consumer", err, &map[string]string{"chainID": chainID, "address": address, "block": strconv.FormatUint(epoch, 10)}) + } + return UserEntryRes.GetMaxCU(), nil } -func (sq *StateQuery) GetPairing(latestBlock int64) (pairingList []epochstoragetypes.StakeEntry, epoch uint64, nextBlockForUpdate uint64) { - // query the node via our clientCtx and run the get pairing query with the client address (in the clientCtx from) - // latestBlock arg can be used for caching the result - return +func (csq *ConsumerStateQuery) GetSpec(ctx context.Context, chainID string) (*spectypes.Spec, error) { + spec, err := csq.SpecQueryClient.Spec(ctx, &spectypes.QueryGetSpecRequest{ + ChainID: chainID, + }) + if err != nil { + return nil, utils.LavaFormatError("Failed Querying spec for chain", err, &map[string]string{"ChainID": chainID}) + } + return &spec.Spec, nil } diff --git a/protocol/statetracker/state_tracker.go b/protocol/statetracker/state_tracker.go new file mode 100644 index 0000000000..7b631e781d --- /dev/null +++ b/protocol/statetracker/state_tracker.go @@ -0,0 +1,69 @@ +package statetracker + +import ( + "context" + "sync" + "time" + + "github.com/cosmos/cosmos-sdk/client" + "github.com/cosmos/cosmos-sdk/client/tx" + "github.com/lavanet/lava/protocol/chaintracker" +) + +const ( + BlocksToSaveLavaChainTracker = 1 // we only need the latest block + TendermintConsensusParamsQuery = "consensus_params" +) + +// ConsumerStateTracker CSTis a class for tracking consumer data from the lava blockchain, such as epoch changes. +// it allows also to query specific data form the blockchain and acts as a single place to send transactions +type StateTracker struct { + chainTracker *chaintracker.ChainTracker + registrationLock sync.RWMutex + newLavaBlockUpdaters map[string]Updater +} + +type Updater interface { + Update(int64) + UpdaterKey() string +} + +func NewStateTracker(ctx context.Context, txFactory tx.Factory, clientCtx client.Context, chainFetcher chaintracker.ChainFetcher) (ret *StateTracker, err error) { + cst := &StateTracker{newLavaBlockUpdaters: map[string]Updater{}} + resultConsensusParams, err := clientCtx.Client.ConsensusParams(ctx, nil) // nil returns latest + if err != nil { + return nil, err + } + chainTrackerConfig := chaintracker.ChainTrackerConfig{ + NewLatestCallback: cst.newLavaBlock, + BlocksToSave: BlocksToSaveLavaChainTracker, + AverageBlockTime: time.Duration(resultConsensusParams.ConsensusParams.Block.TimeIotaMs) * time.Millisecond, + ServerBlockMemory: BlocksToSaveLavaChainTracker, + } + cst.chainTracker, err = chaintracker.New(ctx, chainFetcher, chainTrackerConfig) + return cst, err +} + +func (cst *StateTracker) newLavaBlock(latestBlock int64) { + // go over the registered updaters and trigger update + cst.registrationLock.RLock() + defer cst.registrationLock.RUnlock() + for _, updater := range cst.newLavaBlockUpdaters { + updater.Update(latestBlock) + } +} + +func (cst *StateTracker) RegisterForUpdates(ctx context.Context, updater Updater) Updater { + cst.registrationLock.Lock() + defer cst.registrationLock.Unlock() + existingUpdater, ok := cst.newLavaBlockUpdaters[updater.UpdaterKey()] + if !ok { + cst.newLavaBlockUpdaters[updater.UpdaterKey()] = updater + existingUpdater = updater + } + return existingUpdater +} + +type EpochUpdatable interface { + UpdateEpoch(epoch uint64) +} diff --git a/protocol/statetracker/statetracker.go b/protocol/statetracker/statetracker.go deleted file mode 100644 index 346cad0deb..0000000000 --- a/protocol/statetracker/statetracker.go +++ /dev/null @@ -1,5 +0,0 @@ -package statetracker - -type EpochUpdatable interface { - UpdateEpoch(epoch uint64) -} diff --git a/protocol/statetracker/tx_sender.go b/protocol/statetracker/tx_sender.go index cc965beb5b..9345a6463d 100644 --- a/protocol/statetracker/tx_sender.go +++ b/protocol/statetracker/tx_sender.go @@ -5,17 +5,102 @@ import ( "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/client/tx" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/lavanet/lava/utils" conflicttypes "github.com/lavanet/lava/x/conflict/types" ) -type TxSender struct{} +const ( + defaultGasPrice = "0.000000001ulava" + defaultGasAdjustment = 1.5 +) + +type TxSender struct { + txFactory tx.Factory + clientCtx client.Context +} -func (ts *TxSender) New(ctx context.Context, txFactory tx.Factory, clientCtx client.Context) (ret *TxSender, err error) { +func NewTxSender(ctx context.Context, clientCtx client.Context, txFactory tx.Factory) (ret *TxSender, err error) { // set up the rpcClient, and factory necessary to make queries + clientCtx.SkipConfirm = true + ts := &TxSender{txFactory: txFactory, clientCtx: clientCtx} + return ts, nil +} + +func (ts *TxSender) SimulateAndBroadCastTxWithRetryOnSeqMismatch(msg sdk.Msg) error { + txf := ts.txFactory.WithGasPrices(defaultGasPrice) + txf = txf.WithGasAdjustment(defaultGasAdjustment) + if err := msg.ValidateBasic(); err != nil { + return err + } + clientCtx := ts.clientCtx + txf, err := ts.prepareFactory(txf) + if err != nil { + return err + } + + _, gasUsed, err := tx.CalculateGas(clientCtx, txf, msg) + if err != nil { + return err + } + + txf = txf.WithGas(gasUsed) + + err = tx.GenerateOrBroadcastTxWithFactory(clientCtx, txf, msg) + if err != nil { + return err + } + return nil +} + +// this function is extracted from the tx package so that we can use it locally to set the tx factory correctly +func (ts *TxSender) prepareFactory(txf tx.Factory) (tx.Factory, error) { + clientCtx := ts.clientCtx + from := clientCtx.GetFromAddress() + + if err := clientCtx.AccountRetriever.EnsureExists(clientCtx, from); err != nil { + return txf, err + } + + initNum, initSeq := txf.AccountNumber(), txf.Sequence() + if initNum == 0 || initSeq == 0 { + num, seq, err := clientCtx.AccountRetriever.GetAccountNumberSequence(clientCtx, from) + if err != nil { + return txf, err + } + + if initNum == 0 { + txf = txf.WithAccountNumber(num) + } + + if initSeq == 0 { + txf = txf.WithSequence(seq) + } + } + + return txf, nil +} + +type ConsumerTxSender struct { + *TxSender +} + +func NewConsumerTxSender(ctx context.Context, clientCtx client.Context, txFactory tx.Factory) (ret *ConsumerTxSender, err error) { + txSender, err := NewTxSender(ctx, clientCtx, txFactory) + if err != nil { + return nil, err + } + ts := &ConsumerTxSender{TxSender: txSender} return ts, nil } -func (ts *TxSender) TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) { - // TODO: send a detection tx, simulate, with retry logic for sequence number mismatch - // TODO: make sure we are not spamming the same conflicts, previous code only detecs relay by relay, it has no state trackign wether it reported already +func (ts *ConsumerTxSender) TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) error { + // TODO: retry logic for sequence number mismatch + // TODO: make sure we are not spamming the same conflicts, previous code only detecs relay by relay, it has no state tracking wether it reported already + msg := conflicttypes.NewMsgDetection(ts.clientCtx.FromAddress.String(), finalizationConflict, responseConflict, sameProviderConflict) + err := ts.SimulateAndBroadCastTxWithRetryOnSeqMismatch(msg) + if err != nil { + return utils.LavaFormatError("discrepancyChecker - SimulateAndBroadCastTx Failed", err, nil) + } + return nil } diff --git a/relayer/server.go b/relayer/server.go index c41567a545..a353b7dcf9 100644 --- a/relayer/server.go +++ b/relayer/server.go @@ -391,7 +391,7 @@ func getOrCreateSession(ctx context.Context, userAddr string, req *pairingtypes. isValidBlockHeight := validateRequestedBlockHeight(uint64(req.BlockHeight)) if !isValidBlockHeight { return nil, utils.LavaFormatError("User requested with invalid block height", err, &map[string]string{ - "req.BlockHeight": strconv.FormatInt(req.BlockHeight, 10), + "req.BlockHeight": strconv.FormatInt(req.BlockHeight, 10), "expected": strconv.FormatUint(g_sentry.GetCurrentEpochHeight(), 10), }) }