Skip to content

Commit

Permalink
PRT-300 consumer state tracker (#285)
Browse files Browse the repository at this point in the history
* finished consumer state tracker

* lint

* fix some bugs

* rpcconsumer fixes

* fix missing VrfSk

* fixed review changes

---------

Co-authored-by: Aleksao998 <[email protected]>
  • Loading branch information
omerlavanet and 0xAleksaOpacic authored Feb 5, 2023
1 parent c94cc89 commit c62c1fe
Show file tree
Hide file tree
Showing 15 changed files with 459 additions and 132 deletions.
8 changes: 4 additions & 4 deletions cmd/lavad/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ func main() {
utils.LavaFormatInfo("cache service connected", &map[string]string{"address": cacheAddr})
}
}
rpcConsumer.Start(ctx, txFactory, clientCtx, rpcEndpoints, requiredResponses, vrf_sk, cache)
return nil
err = rpcConsumer.Start(ctx, txFactory, clientCtx, rpcEndpoints, requiredResponses, vrf_sk, cache)
return err
},
}

Expand Down Expand Up @@ -425,8 +425,8 @@ func main() {
if err != nil {
utils.LavaFormatFatal("error fetching chainproxy.ParallelConnectionsFlag", err, nil)
}
rpcProvider.Start(ctx, txFactory, clientCtx, rpcProviderEndpoints, cache, numberOfNodeParallelConnections)
return nil
err = rpcProvider.Start(ctx, txFactory, clientCtx, rpcProviderEndpoints, cache, numberOfNodeParallelConnections)
return err
},
}

Expand Down
31 changes: 31 additions & 0 deletions protocol/chainlib/chain_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ package chainlib
import (
"context"
"fmt"

"github.com/cosmos/cosmos-sdk/client"
)

const (
TendermintStatusQuery = "status"
)

type ChainFetcher struct {
Expand All @@ -22,3 +28,28 @@ func NewChainFetcher(ctx context.Context, chainProxy ChainProxy /*here needs som
cf := &ChainFetcher{chainProxy: chainProxy}
return cf
}

type LavaChainFetcher struct {
clientCtx client.Context
}

func (lcf *LavaChainFetcher) FetchLatestBlockNum(ctx context.Context) (int64, error) {
resultStatus, err := lcf.clientCtx.Client.Status(ctx)
if err != nil {
return 0, err
}
return resultStatus.SyncInfo.LatestBlockHeight, nil
}

func (lcf *LavaChainFetcher) FetchBlockHashByNum(ctx context.Context, blockNum int64) (string, error) {
resultStatus, err := lcf.clientCtx.Client.Status(ctx)
if err != nil {
return "", err
}
return resultStatus.SyncInfo.LatestBlockHash.String(), nil
}

func NewLavaChainFetcher(ctx context.Context, clientCtx client.Context) *LavaChainFetcher {
lcf := &LavaChainFetcher{clientCtx: clientCtx}
return lcf
}
14 changes: 10 additions & 4 deletions protocol/chaintracker/chain_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,10 +335,16 @@ func (ct *ChainTracker) serve(ctx context.Context, listenAddr string) error {
return nil
}

func New(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker) {
config.validate()
func New(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker, err error) {
err = config.validate()
if err != nil {
return nil, err
}
chainTracker = &ChainTracker{forkCallback: config.ForkCallback, newLatestCallback: config.NewLatestCallback, blocksToSave: config.BlocksToSave, chainFetcher: chainFetcher, latestBlockNum: 0, serverBlockMemory: config.ServerBlockMemory}
chainTracker.start(ctx, config.AverageBlockTime)
chainTracker.serve(ctx, config.ServerAddress)
err = chainTracker.start(ctx, config.AverageBlockTime)
if err != nil {
return nil, err
}
err = chainTracker.serve(ctx, config.ServerAddress)
return
}
16 changes: 8 additions & 8 deletions protocol/chaintracker/chain_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ func TestChainTracker(t *testing.T) {
currentLatestBlockInMock := mockChainFetcher.AdvanceBlock()

chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)}
chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)

chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
for _, advancement := range tt.advancements {
for i := 0; i < int(advancement); i++ {
currentLatestBlockInMock = mockChainFetcher.AdvanceBlock()
Expand Down Expand Up @@ -176,8 +176,8 @@ func TestChainTrackerRangeOnly(t *testing.T) {
currentLatestBlockInMock := mockChainFetcher.AdvanceBlock()

chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)}
chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)

chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
for _, advancement := range tt.advancements {
for i := 0; i < int(advancement); i++ {
currentLatestBlockInMock = mockChainFetcher.AdvanceBlock()
Expand Down Expand Up @@ -256,8 +256,8 @@ func TestChainTrackerCallbacks(t *testing.T) {
callbackCalledNewLatest = true
}
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback, NewLatestCallback: newBlockCallback}
chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)

chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
t.Run("one long test", func(t *testing.T) {
for _, tt := range tests {
utils.LavaFormatInfo("started test "+tt.name, nil)
Expand Down Expand Up @@ -338,8 +338,8 @@ func TestChainTrackerMaintainMemory(t *testing.T) {
callbackCalledFork = true
}
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback}
chainTracker := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)

chainTracker, err := chaintracker.New(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
t.Run("one long test", func(t *testing.T) {
for _, tt := range tests {
utils.LavaFormatInfo("started test "+tt.name, nil)
Expand Down
14 changes: 9 additions & 5 deletions protocol/rpcconsumer/rpcconsumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ var (

type ConsumerStateTrackerInf interface {
RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager)
RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser)
RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser, chainID string) error
RegisterFinalizationConsensusForUpdates(context.Context, *lavaprotocol.FinalizationConsensus)
TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict)
TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) error
}

type RPCConsumer struct {
Expand All @@ -46,11 +46,12 @@ type RPCConsumer struct {
// spawns a new RPCConsumer server with all it's processes and internals ready for communications
func (rpcc *RPCConsumer) Start(ctx context.Context, txFactory tx.Factory, clientCtx client.Context, rpcEndpoints []*lavasession.RPCEndpoint, requiredResponses int, vrf_sk vrf.PrivateKey, cache *performance.Cache) (err error) {
// spawn up ConsumerStateTracker
consumerStateTracker := statetracker.ConsumerStateTracker{}
rpcc.consumerStateTracker, err = consumerStateTracker.New(ctx, txFactory, clientCtx)
lavaChainFetcher := chainlib.NewLavaChainFetcher(ctx, clientCtx)
consumerStateTracker, err := statetracker.NewConsumerStateTracker(ctx, txFactory, clientCtx, lavaChainFetcher)
if err != nil {
return err
}
rpcc.consumerStateTracker = consumerStateTracker
rpcc.rpcConsumerServers = make(map[string]*RPCConsumerServer, len(rpcEndpoints))

keyName, err := sigs.GetKeyName(clientCtx)
Expand Down Expand Up @@ -78,7 +79,10 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, txFactory tx.Factory, client
if err != nil {
return err
}
consumerStateTracker.RegisterChainParserForSpecUpdates(ctx, chainParser)
err = consumerStateTracker.RegisterChainParserForSpecUpdates(ctx, chainParser, rpcEndpoint.ChainID)
if err != nil {
return err
}
finalizationConsensus := &lavaprotocol.FinalizationConsensus{}
consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus)
rpcc.rpcConsumerServers[key] = &RPCConsumerServer{}
Expand Down
17 changes: 13 additions & 4 deletions protocol/rpcconsumer/rpcconsumer_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type RPCConsumerServer struct {
}

type ConsumerTxSender interface {
TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict)
TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) error
}

func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndpoint *lavasession.RPCEndpoint,
Expand All @@ -58,6 +58,7 @@ func (rpccs *RPCConsumerServer) ServeRPCRequests(ctx context.Context, listenEndp
rpccs.cache = cache
rpccs.consumerTxSender = consumerStateTracker
rpccs.requiredResponses = requiredResponses
rpccs.VrfSk = vrfSk
pLogs, err := common.NewRPCConsumerLogs()
if err != nil {
utils.LavaFormatFatal("failed creating RPCConsumer logs", err, nil)
Expand Down Expand Up @@ -217,11 +218,16 @@ func (rpccs *RPCConsumerServer) sendRelayToProvider(
}
// get here only if performed a regular relay successfully
expectedBH, numOfProviders := rpccs.finalizationConsensus.ExpectedBlockHeight(rpccs.chainParser)
err = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, epoch, reply.LatestBlock, chainMessage.GetServiceApi().ComputeUnits, relayLatency, expectedBH, numOfProviders, rpccs.consumerSessionManager.GetAtomicPairingAddressesLength()) // session done successfully
pairingAddressesLen := rpccs.consumerSessionManager.GetAtomicPairingAddressesLength()
latestBlock := relayResult.Reply.LatestBlock
err = rpccs.consumerSessionManager.OnSessionDone(singleConsumerSession, epoch, latestBlock, chainMessage.GetServiceApi().ComputeUnits, relayLatency, expectedBH, numOfProviders, pairingAddressesLen) // session done successfully

// set cache in a non blocking call
go func() {
err2 := rpccs.cache.SetEntry(ctx, relayRequest, chainMessage.GetInterface().Interface, nil, chainID, dappID, reply, relayResult.Finalized) // caching in the portal doesn't care about hashes
new_ctx := context.Background()
new_ctx, cancel := context.WithTimeout(new_ctx, lavaprotocol.DataReliabilityTimeoutIncrease)
defer cancel()
err2 := rpccs.cache.SetEntry(new_ctx, relayRequest, chainMessage.GetInterface().Interface, nil, chainID, dappID, relayResult.Reply, relayResult.Finalized) // caching in the portal doesn't care about hashes
if err2 != nil && !performance.NotInitialisedError.Is(err2) {
utils.LavaFormatWarning("error updating cache with new entry", err2, nil)
}
Expand Down Expand Up @@ -399,7 +405,10 @@ func (rpccs *RPCConsumerServer) sendDataReliabilityRelayIfApplicable(ctx context
report, conflicts := lavaprotocol.VerifyReliabilityResults(relayResult, dataReliabilityVerifications, numberOfReliabilitySessions)
if report {
for _, conflict := range conflicts {
rpccs.consumerTxSender.TxConflictDetection(ctx, nil, conflict, nil)
err := rpccs.consumerTxSender.TxConflictDetection(ctx, nil, conflict, nil)
if err != nil {
utils.LavaFormatError("could not send detection Transaction", err, &map[string]string{"conflict": fmt.Sprintf("%+v", conflict)})
}
}
}
// detectionMessage = conflicttypes.NewMsgDetection(consumerAddress, nil, &responseConflict, nil)
Expand Down
8 changes: 5 additions & 3 deletions protocol/rpcprovider/rpcprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ func (rpcp *RPCProvider) Start(ctx context.Context, txFactory tx.Factory, client
_, avergaeBlockTime, blocksToFinalization, blocksInFinalizationData := chainParser.ChainBlockStats()
blocksToSaveChainTracker := uint64(blocksToFinalization + blocksInFinalizationData)
chainTrackerConfig := chaintracker.ChainTrackerConfig{
ServerAddress: rpcProviderEndpoint.NodeUrl,
BlocksToSave: blocksToSaveChainTracker,
AverageBlockTime: avergaeBlockTime, // divide here to make the querying more often so we don't miss block changes by that much
AverageBlockTime: avergaeBlockTime,
ServerBlockMemory: ChainTrackerDefaultMemory + blocksToSaveChainTracker,
}
chainFetcher := chainlib.NewChainFetcher(ctx, chainProxy)
chainTracker := chaintracker.New(ctx, chainFetcher, chainTrackerConfig)
chainTracker, err := chaintracker.New(ctx, chainFetcher, chainTrackerConfig)
if err != nil {
utils.LavaFormatFatal("failed creating chain tracker", err, &map[string]string{"chainTrackerConfig": fmt.Sprintf("%+v", chainTrackerConfig)})
}
reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker)
providerStateTracker.RegisterReliabilityManagerForVoteUpdates(ctx, reliabilityManager)

Expand Down
95 changes: 28 additions & 67 deletions protocol/statetracker/consumer_state_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package statetracker
import (
"context"
"fmt"
"sync"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/tx"
Expand All @@ -14,100 +13,62 @@ import (
"github.com/lavanet/lava/protocol/lavasession"
"github.com/lavanet/lava/utils"
conflicttypes "github.com/lavanet/lava/x/conflict/types"
spectypes "github.com/lavanet/lava/x/spec/types"
)

// ConsumerStateTracker CSTis a class for tracking consumer data from the lava blockchain, such as epoch changes.
// it allows also to query specific data form the blockchain and acts as a single place to send transactions
type ConsumerStateTracker struct {
consumerAddress sdk.AccAddress
chainTracker *chaintracker.ChainTracker
stateQuery *StateQuery
txSender *TxSender
registrationLock sync.RWMutex
newLavaBlockUpdaters map[string]Updater
consumerAddress sdk.AccAddress
stateQuery *ConsumerStateQuery
txSender *ConsumerTxSender
*StateTracker
}

type Updater interface {
Update(int64)
UpdaterKey() string
}

func (cst *ConsumerStateTracker) New(ctx context.Context, txFactory tx.Factory, clientCtx client.Context) (ret *ConsumerStateTracker, err error) {
// set up StateQuery
// Spin up chain tracker on the lava node, its address is in the --node flag (or its default), on new block call to newLavaBlock
// use StateQuery to get the lava spec and spin up the chain tracker with the right params
// set up txSender the same way

stateQuery := StateQuery{}
cst.stateQuery, err = stateQuery.New(ctx, clientCtx)
func NewConsumerStateTracker(ctx context.Context, txFactory tx.Factory, clientCtx client.Context, chainFetcher chaintracker.ChainFetcher) (ret *ConsumerStateTracker, err error) {
stateTrackerBase, err := NewStateTracker(ctx, txFactory, clientCtx, chainFetcher)
if err != nil {
return nil, err
}

txSender := TxSender{}
cst.txSender, err = txSender.New(ctx, txFactory, clientCtx)
txSender, err := NewConsumerTxSender(ctx, clientCtx, txFactory)
if err != nil {
return nil, err
}
cst.consumerAddress = clientCtx.FromAddress
cst := &ConsumerStateTracker{StateTracker: stateTrackerBase, stateQuery: NewConsumerStateQuery(ctx, clientCtx), txSender: txSender}
return cst, nil
}

func (cst *ConsumerStateTracker) newLavaBlock(latestBlock int64) {
// go over the registered updaters and trigger update
cst.registrationLock.RLock()
defer cst.registrationLock.RUnlock()
for _, updater := range cst.newLavaBlockUpdaters {
updater.Update(latestBlock)
}
}

func (cst *ConsumerStateTracker) RegisterConsumerSessionManagerForPairingUpdates(ctx context.Context, consumerSessionManager *lavasession.ConsumerSessionManager) {
// register this CSM to get the updated pairing list when a new epoch starts
cst.registrationLock.Lock()
defer cst.registrationLock.Unlock()
// make sure new lava block exists as a callback in stateTracker
// add updatePairingForRegistered as a callback on a new block

var pairingUpdater *PairingUpdater = nil // UpdaterKey is nil safe
pairingUpdater_raw, ok := cst.newLavaBlockUpdaters[pairingUpdater.UpdaterKey()]
pairingUpdater := NewPairingUpdater(cst.consumerAddress, cst.stateQuery)
pairingUpdaterRaw := cst.StateTracker.RegisterForUpdates(ctx, pairingUpdater)
pairingUpdater, ok := pairingUpdaterRaw.(*PairingUpdater)
if !ok {
pairingUpdater = NewPairingUpdater(cst.consumerAddress, cst.stateQuery)
cst.newLavaBlockUpdaters[pairingUpdater.UpdaterKey()] = pairingUpdater
utils.LavaFormatFatal("invalid updater type returned from RegisterForUpdates", nil, &map[string]string{"updater": fmt.Sprintf("%+v", pairingUpdaterRaw)})
}
pairingUpdater, ok = pairingUpdater_raw.(*PairingUpdater)
if !ok {
utils.LavaFormatFatal("invalid_updater_key in RegisterConsumerSessionManagerForPairingUpdates", nil, &map[string]string{"updaters_map": fmt.Sprintf("%+v", cst.newLavaBlockUpdaters)})
}
pairingUpdater.RegisterPairing(consumerSessionManager)
pairingUpdater.RegisterPairing(ctx, consumerSessionManager)
}

func (cst *ConsumerStateTracker) RegisterFinalizationConsensusForUpdates(ctx context.Context, finalizationConsensus *lavaprotocol.FinalizationConsensus) {
cst.registrationLock.Lock()
defer cst.registrationLock.Unlock()

var finalizationConsensusUpdater *FinalizationConsensusUpdater = nil // UpdaterKey is nil safe
finalizationConsensusUpdater_raw, ok := cst.newLavaBlockUpdaters[finalizationConsensusUpdater.UpdaterKey()]
finalizationConsensusUpdater := NewFinalizationConsensusUpdater(cst.consumerAddress, cst.stateQuery)
finalizationConsensusUpdaterRaw := cst.StateTracker.RegisterForUpdates(ctx, finalizationConsensusUpdater)
finalizationConsensusUpdater, ok := finalizationConsensusUpdaterRaw.(*FinalizationConsensusUpdater)
if !ok {
finalizationConsensusUpdater = NewFinalizationConsensusUpdater(cst.consumerAddress, cst.stateQuery)
cst.newLavaBlockUpdaters[finalizationConsensusUpdater.UpdaterKey()] = finalizationConsensusUpdater
}
finalizationConsensusUpdater, ok = finalizationConsensusUpdater_raw.(*FinalizationConsensusUpdater)
if !ok {
utils.LavaFormatFatal("invalid_updater_key in RegisterFinalizationConsensusForUpdates", nil, &map[string]string{"updaters_map": fmt.Sprintf("%+v", cst.newLavaBlockUpdaters)})
utils.LavaFormatFatal("invalid updater type returned from RegisterForUpdates", nil, &map[string]string{"updater": fmt.Sprintf("%+v", finalizationConsensusUpdaterRaw)})
}
finalizationConsensusUpdater.RegisterFinalizationConsensus(finalizationConsensus)
}

func (cst *ConsumerStateTracker) RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser) {
// register this chainParser for spec updates
// currently just set the first one, and have a TODO to handle spec changes
// get the spec and set it into the chainParser
spec := spectypes.Spec{}
chainParser.SetSpec(spec)
func (cst *ConsumerStateTracker) RegisterChainParserForSpecUpdates(ctx context.Context, chainParser chainlib.ChainParser, chainID string) error {
// TODO: handle spec changes
spec, err := cst.stateQuery.GetSpec(ctx, chainID)
if err != nil {
return err
}
chainParser.SetSpec(*spec)
return nil
}

func (cst *ConsumerStateTracker) TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) {
cst.txSender.TxConflictDetection(ctx, finalizationConflict, responseConflict, sameProviderConflict)
func (cst *ConsumerStateTracker) TxConflictDetection(ctx context.Context, finalizationConflict *conflicttypes.FinalizationConflict, responseConflict *conflicttypes.ResponseConflict, sameProviderConflict *conflicttypes.FinalizationConflict) error {
err := cst.txSender.TxConflictDetection(ctx, finalizationConflict, responseConflict, sameProviderConflict)
return err
}
Loading

0 comments on commit c62c1fe

Please sign in to comment.