Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proof] Implement scalable proof validation #1031

Merged
merged 14 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
762 changes: 726 additions & 36 deletions api/poktroll/proof/event.pulsar.go

Large diffs are not rendered by default.

193 changes: 157 additions & 36 deletions api/poktroll/proof/types.pulsar.go

Large diffs are not rendered by default.

28 changes: 14 additions & 14 deletions pkg/crypto/protocol/proof_path.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,8 @@ import (
"github.com/pokt-network/smt"
)

// SMT specification used for the proof verification.
var (
newHasher = sha256.New
SmtSpec smt.TrieSpec
)

func init() {
// Use a spec that does not prehash values in the smst. This returns a nil value
// hasher for the proof verification in order to avoid hashing the value twice.
SmtSpec = smt.NewTrieSpec(
newHasher(), true,
smt.WithValueHasher(nil),
)
}
// newHasher is the hash function used by the SMT specification.
var newHasher = sha256.New
red-0ne marked this conversation as resolved.
Show resolved Hide resolved

// GetPathForProof computes the path to be used for proof validation by hashing
// the block hash and session id.
Expand All @@ -31,3 +19,15 @@ func GetPathForProof(blockHash []byte, sessionId string) []byte {

return hasher.Sum(nil)
}

// NewSMTSpec returns the SMT specification used for the proof verification.
// It uses a new hasher at every call to avoid concurrency issues that could be
// caused by a shared hasher.
red-0ne marked this conversation as resolved.
Show resolved Hide resolved
func NewSMTSpec() *smt.TrieSpec {
trieSpec := smt.NewTrieSpec(
newHasher(), true,
smt.WithValueHasher(nil),
red-0ne marked this conversation as resolved.
Show resolved Hide resolved
)

return &trieSpec
}
11 changes: 11 additions & 0 deletions proto/poktroll/proof/event.proto
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,14 @@ message EventProofUpdated {
uint64 num_estimated_compute_units = 5 [(gogoproto.jsontag) = "num_estimated_compute_units"];
cosmos.base.v1beta1.Coin claimed_upokt = 6 [(gogoproto.jsontag) = "claimed_upokt"];
}

// Event emitted after a proof has been checked for validity in the proof module's
// EndBlocker.
message EventProofValidityChecked {
Olshansk marked this conversation as resolved.
Show resolved Hide resolved
poktroll.proof.Proof proof = 1 [(gogoproto.jsontag) = "proof"];
uint64 block_height = 2 [(gogoproto.jsontag) = "block_height"];
poktroll.proof.ClaimProofStatus proof_status = 3 [(gogoproto.jsontag) = "proof_status"];
// reason is the string representation of the error that led to the proof being
// marked as invalid (e.g. "invalid closest merkle proof", "invalid relay request signature")
string reason = 4 [(gogoproto.jsontag) = "reason"];
red-0ne marked this conversation as resolved.
Show resolved Hide resolved
red-0ne marked this conversation as resolved.
Show resolved Hide resolved
}
18 changes: 16 additions & 2 deletions proto/poktroll/proof/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@ message Proof {

// Claim is the serialized object stored onchain for claims pending to be proven
message Claim {
// Address of the supplier's operator that submitted this claim.
string supplier_operator_address = 1 [(cosmos_proto.scalar) = "cosmos.AddressString"]; // the address of the supplier's operator that submitted this claim
// The session header of the session that this claim is for.

// Session header this claim is for.
poktroll.session.SessionHeader session_header = 2;
// Root hash returned from smt.SMST#Root().

// Root hash from smt.SMST#Root().
bytes root_hash = 3;

// Important: This field MUST only be set by proofKeeper#EnsureValidProofSignaturesAndClosestPath
ClaimProofStatus proof_validation_status = 4;
}
red-0ne marked this conversation as resolved.
Show resolved Hide resolved

enum ProofRequirementReason {
Expand All @@ -43,3 +49,11 @@ enum ClaimProofStage {
SETTLED = 2;
EXPIRED = 3;
}

// Status of proof validation for a claim
// Default is PENDING_VALIDATION regardless of proof requirement
enum ClaimProofStatus {
PENDING_VALIDATION = 0;
VALIDATED = 1;
INVALID = 2;
}
1 change: 1 addition & 0 deletions testutil/testtree/tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,6 @@ func NewClaim(
SupplierOperatorAddress: supplierOperatorAddr,
SessionHeader: sessionHeader,
RootHash: rootHash,
ProofValidationStatus: prooftypes.ClaimProofStatus_PENDING_VALIDATION,
}
}
134 changes: 79 additions & 55 deletions x/proof/keeper/msg_server_submit_proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,84 +19,87 @@ import (
sharedtypes "github.com/pokt-network/poktroll/x/shared/types"
)

// SubmitProof is the server handler to submit and store a proof onchain.
// A proof that's stored onchain is what leads to rewards (i.e. inflation)
// downstream, making this a critical part of the protocol.
// SubmitProof is the server message handler that stores a valid
// proof onchain, enabling downstream reward distribution.
//
// Note that the validation of the proof is done in `EnsureValidProof`. However,
// preliminary checks are done in the handler to prevent sybil or DoS attacks on
// full nodes because storing and validating proofs is expensive.
// IMPORTANT: Full proof validation occurs in EnsureValidProofSignaturesAndClosestPath.
// This handler performs preliminary validation to prevent sybil/DoS attacks.
//
// We are playing a balance of security and efficiency here, where enough validation
// is done on proof submission, and exhaustive validation is done during session
// settlement.
// There is a security & performance balance and tradeoff between the handler and end blocker:
// - Basic validation on submission (here)
// - Exhaustive validation in endblocker (EnsureValidProofSignaturesAndClosestPath)
//
// The entity sending the SubmitProof messages does not necessarily need
// to correspond to the supplier signing the proof. For example, a single entity
// could (theoretically) batch multiple proofs (signed by the corresponding supplier)
// into one transaction to save on transaction fees.
// Note: Proof submitter may differ from supplier signer, allowing batched submissions
// to optimize transaction fees.
func (k msgServer) SubmitProof(
ctx context.Context,
msg *types.MsgSubmitProof,
) (_ *types.MsgSubmitProofResponse, err error) {
sdkCtx := cosmostypes.UnwrapSDKContext(ctx)

// Declare claim to reference in telemetry.
var (
claim = new(types.Claim)
claim *types.Claim
isExistingProof bool
numRelays uint64
numClaimComputeUnits uint64
sessionHeader *sessiontypes.SessionHeader
)

logger := k.Logger().With("method", "SubmitProof")
sdkCtx := cosmostypes.UnwrapSDKContext(ctx)
logger.Info("About to start submitting proof")

// Basic validation of the SubmitProof message.
if err = msg.ValidateBasic(); err != nil {
logger.Error("failed to validate the submitProof message")
return nil, status.Error(codes.InvalidArgument, err.Error())
}
logger.Info("validated the submitProof message")

// Compare msg session header w/ onchain session header.
session, err := k.queryAndValidateSessionHeader(ctx, msg.GetSessionHeader(), msg.GetSupplierOperatorAddress())
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
sessionHeader = msg.GetSessionHeader()
supplierOperatorAddress := msg.GetSupplierOperatorAddress()

// Defer telemetry calls so that they reference the final values the relevant variables.
defer k.finalizeSubmitProofTelemetry(session, msg, isExistingProof, numRelays, numClaimComputeUnits, err)
logger = logger.With(
"session_id", sessionHeader.GetSessionId(),
"application_address", sessionHeader.GetApplicationAddress(),
"service_id", sessionHeader.GetServiceId(),
"session_end_height", sessionHeader.GetSessionEndBlockHeight(),
"supplier_operator_address", supplierOperatorAddress,
)
logger.Info("validated the submitProof message")

if err = k.deductProofSubmissionFee(ctx, msg.GetSupplierOperatorAddress()); err != nil {
logger.Error(fmt.Sprintf("failed to deduct proof submission fee: %v", err))
// Defer telemetry calls so that they reference the final values the relevant variables.
defer k.finalizeSubmitProofTelemetry(sessionHeader, msg, isExistingProof, numRelays, numClaimComputeUnits, err)

// Construct the proof from the message.
red-0ne marked this conversation as resolved.
Show resolved Hide resolved
proof := newProofFromMsg(msg)

// EnsureWellFormedProof ensures proper proof formation by verifying:
// - Proof structure
// - Associated claim
// - Relay session headers
// - Submission timing within required window
if err = k.EnsureWellFormedProof(ctx, proof); err != nil {
logger.Error(fmt.Sprintf("failed to ensure well-formed proof: %v", err))
return nil, status.Error(codes.FailedPrecondition, err.Error())
red-0ne marked this conversation as resolved.
Show resolved Hide resolved
}

// Construct the proof
proof := types.Proof{
SupplierOperatorAddress: msg.GetSupplierOperatorAddress(),
SessionHeader: session.GetHeader(),
ClosestMerkleProof: msg.GetProof(),
logger.Info("ensured the proof is well-formed")

// Retrieve the claim associated with the proof.
// The claim should ALWAYS exist since the proof validation in EnsureWellFormedProof
// retrieves and validates the associated claim.
foundClaim, claimFound := k.GetClaim(ctx, sessionHeader.GetSessionId(), supplierOperatorAddress)
if !claimFound {
logger.Error("failed to find the claim associated with the proof")
return nil, status.Error(codes.FailedPrecondition, types.ErrProofClaimNotFound.Error())
}

// Helpers for logging the same metadata throughout this function calls
logger = logger.With(
"session_id", proof.SessionHeader.SessionId,
"session_end_height", proof.SessionHeader.SessionEndBlockHeight,
"supplier_operator_address", proof.SupplierOperatorAddress)
claim = &foundClaim

// Validate proof message commit height is within the respective session's
// proof submission window using the onchain session header.
if err = k.validateProofWindow(ctx, proof.SessionHeader, proof.SupplierOperatorAddress); err != nil {
if err = k.deductProofSubmissionFee(ctx, supplierOperatorAddress); err != nil {
logger.Error(fmt.Sprintf("failed to deduct proof submission fee: %v", err))
return nil, status.Error(codes.FailedPrecondition, err.Error())
}

// Retrieve the corresponding claim for the proof submitted so it can be
// used in the proof validation below.
claim, err = k.queryAndValidateClaimForProof(ctx, proof.SessionHeader, proof.SupplierOperatorAddress)
if err != nil {
return nil, status.Error(codes.Internal, types.ErrProofClaimNotFound.Wrap(err.Error()).Error())
}

// Check if a proof is required for the claim.
proofRequirement, err := k.ProofRequirementForClaim(ctx, claim)
if err != nil {
Expand All @@ -120,7 +123,7 @@ func (k msgServer) SubmitProof(
}

// Get the service ID relayMiningDifficulty to calculate the claimed uPOKT.
serviceId := session.GetHeader().GetServiceId()
serviceId := sessionHeader.GetServiceId()
sharedParams := k.sharedKeeper.GetParams(ctx)
relayMiningDifficulty, _ := k.serviceKeeper.GetRelayMiningDifficulty(ctx, serviceId)

Expand All @@ -131,7 +134,7 @@ func (k msgServer) SubmitProof(
_, isExistingProof = k.GetProof(ctx, proof.SessionHeader.SessionId, proof.SupplierOperatorAddress)

// Upsert the proof
k.UpsertProof(ctx, proof)
k.UpsertProof(ctx, *proof)
logger.Info("successfully upserted the proof")

// Emit the appropriate event based on whether the claim was created or updated.
Expand All @@ -141,7 +144,7 @@ func (k msgServer) SubmitProof(
proofUpsertEvent = proto.Message(
&types.EventProofUpdated{
Claim: claim,
Proof: &proof,
Proof: proof,
NumRelays: numRelays,
NumClaimedComputeUnits: numClaimComputeUnits,
NumEstimatedComputeUnits: numEstimatedComputUnits,
Expand All @@ -152,14 +155,15 @@ func (k msgServer) SubmitProof(
proofUpsertEvent = proto.Message(
&types.EventProofSubmitted{
Claim: claim,
Proof: &proof,
Proof: proof,
NumRelays: numRelays,
NumClaimedComputeUnits: numClaimComputeUnits,
NumEstimatedComputeUnits: numEstimatedComputUnits,
ClaimedUpokt: &claimedUPOKT,
},
)
}

if err = sdkCtx.EventManager().EmitTypedEvent(proofUpsertEvent); err != nil {
return nil, status.Error(
codes.Internal,
Expand All @@ -172,7 +176,7 @@ func (k msgServer) SubmitProof(
}

return &types.MsgSubmitProofResponse{
Proof: &proof,
Proof: proof,
}, nil
}

Expand Down Expand Up @@ -322,10 +326,17 @@ func (k Keeper) getProofRequirementSeedBlockHash(

// finalizeSubmitProofTelemetry finalizes telemetry updates for SubmitProof, incrementing counters as needed.
// Meant to run deferred.
func (k msgServer) finalizeSubmitProofTelemetry(session *sessiontypes.Session, msg *types.MsgSubmitProof, isExistingProof bool, numRelays, numClaimComputeUnits uint64, err error) {
func (k msgServer) finalizeSubmitProofTelemetry(
sessionHeader *sessiontypes.SessionHeader,
msg *types.MsgSubmitProof,
isExistingProof bool,
numRelays,
numClaimComputeUnits uint64,
err error,
) {
if !isExistingProof {
serviceId := session.Header.ServiceId
applicationAddress := session.Header.ApplicationAddress
serviceId := sessionHeader.ServiceId
applicationAddress := sessionHeader.ApplicationAddress
supplierOperatorAddress := msg.GetSupplierOperatorAddress()
claimProofStage := types.ClaimProofStage_PROVEN.String()

Expand All @@ -337,7 +348,11 @@ func (k msgServer) finalizeSubmitProofTelemetry(session *sessiontypes.Session, m

// finalizeProofRequirementTelemetry finalizes telemetry updates for proof requirements.
// Meant to run deferred.
func (k Keeper) finalizeProofRequirementTelemetry(requirementReason types.ProofRequirementReason, claim *types.Claim, err error) {
func (k Keeper) finalizeProofRequirementTelemetry(
requirementReason types.ProofRequirementReason,
claim *types.Claim,
err error,
) {
telemetry.ProofRequirementCounter(
requirementReason.String(),
claim.SessionHeader.ServiceId,
Expand All @@ -346,3 +361,12 @@ func (k Keeper) finalizeProofRequirementTelemetry(requirementReason types.ProofR
err,
)
}

// newProofFromMsg creates a new proof from a MsgSubmitProof message.
func newProofFromMsg(msg *types.MsgSubmitProof) *types.Proof {
return &types.Proof{
SupplierOperatorAddress: msg.GetSupplierOperatorAddress(),
SessionHeader: msg.GetSessionHeader(),
ClosestMerkleProof: msg.GetProof(),
}
}
4 changes: 2 additions & 2 deletions x/proof/keeper/msg_server_submit_proof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ func TestMsgServer_SubmitProof_Error(t *testing.T) {
},
msgSubmitProofToExpectedErrorFn: func(msgSubmitProof *prooftypes.MsgSubmitProof) error {
return status.Error(
codes.InvalidArgument,
codes.FailedPrecondition,
prooftypes.ErrProofInvalidSessionId.Wrapf(
"session ID does not match onchain session ID; expected %q, got %q",
validSessionHeader.GetSessionId(),
Expand All @@ -652,7 +652,7 @@ func TestMsgServer_SubmitProof_Error(t *testing.T) {
},
msgSubmitProofToExpectedErrorFn: func(msgSubmitProof *prooftypes.MsgSubmitProof) error {
return status.Error(
codes.InvalidArgument,
codes.FailedPrecondition,
prooftypes.ErrProofNotFound.Wrapf(
"supplier operator address %q not found in session ID %q",
wrongSupplierOperatorAddr,
Expand Down
11 changes: 8 additions & 3 deletions x/proof/keeper/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,16 @@ func (k Keeper) RemoveProof(ctx context.Context, sessionId, supplierOperatorAddr
)
}

// GetAllProofs returns all proof
func (k Keeper) GetAllProofs(ctx context.Context) (proofs []types.Proof) {
// GetAllProofsIterator returns an iterator for all proofs in the store
func (k Keeper) GetAllProofsIterator(ctx context.Context) storetypes.Iterator {
storeAdapter := runtime.KVStoreAdapter(k.storeService.OpenKVStore(ctx))
primaryStore := prefix.NewStore(storeAdapter, types.KeyPrefix(types.ProofPrimaryKeyPrefix))
iterator := storetypes.KVStorePrefixIterator(primaryStore, []byte{})
return storetypes.KVStorePrefixIterator(primaryStore, []byte{})
}

// GetAllProofs returns all proofs in the store
func (k Keeper) GetAllProofs(ctx context.Context) (proofs []types.Proof) {
iterator := k.GetAllProofsIterator(ctx)
red-0ne marked this conversation as resolved.
Show resolved Hide resolved

defer iterator.Close()

Expand Down
Loading