Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add host key to account token #257

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/add_host_public_key_to_accounttoken.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
default: major
---

# Add host public key to AccountToken
2 changes: 2 additions & 0 deletions rhp/v4/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ func (a Account) EncodeTo(e *types.Encoder) { e.Write(a[:]) }
func (a *Account) DecodeFrom(d *types.Decoder) { d.Read(a[:]) }

func (at AccountToken) encodeTo(e *types.Encoder) {
at.HostKey.EncodeTo(e)
at.Account.EncodeTo(e)
e.WriteTime(at.ValidUntil)
at.Signature.EncodeTo(e)
}

func (at *AccountToken) decodeFrom(d *types.Decoder) {
at.HostKey.DecodeFrom(d)
at.Account.DecodeFrom(d)
at.ValidUntil = d.ReadTime()
at.Signature.DecodeFrom(d)
Expand Down
49 changes: 49 additions & 0 deletions rhp/v4/encoding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package rhp

import (
"bytes"
"math"
"reflect"
"testing"
"time"

"go.sia.tech/core/types"
"lukechampine.com/frand"
)

type rhpEncodable[T any] interface {
*T
encodeTo(*types.Encoder)
decodeFrom(*types.Decoder)
}

func testRoundtrip[T any, PT rhpEncodable[T]](a PT) func(t *testing.T) {
return func(t *testing.T) {
buf := bytes.NewBuffer(nil)
enc := types.NewEncoder(buf)

a.encodeTo(enc)
if err := enc.Flush(); err != nil {
t.Fatal(err)
}

b := new(T)
dec := types.NewBufDecoder(buf.Bytes())
PT(b).decodeFrom(dec)

if !reflect.DeepEqual(a, b) {
t.Log(a)
t.Log(reflect.ValueOf(b).Elem())
t.Fatal("expected rountrip to match")
}
}
}

func TestEncodingRoundtrip(t *testing.T) {
t.Run("AccountToken", testRoundtrip(&AccountToken{
HostKey: frand.Entropy256(),
Account: frand.Entropy256(),
ValidUntil: time.Unix(int64(frand.Intn(math.MaxInt)), 0),
Signature: types.Signature(frand.Bytes(64)),
}))
}
25 changes: 2 additions & 23 deletions rhp/v4/rhp.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,6 @@ func (hp HostPrices) SigHash() types.Hash256 {
return h.Sum()
}

// Validate checks the host prices for validity. It returns an error if the
// prices have expired or the signature is invalid.
func (hp *HostPrices) Validate(pk types.PublicKey) error {
if time.Until(hp.ValidUntil) <= 0 {
return ErrPricesExpired
}
if !pk.VerifyHash(hp.SigHash(), hp.Signature) {
return ErrInvalidSignature
}
return nil
}

// HostSettings specify the settings of a host.
type HostSettings struct {
ProtocolVersion [3]uint8 `json:"protocolVersion"`
Expand Down Expand Up @@ -208,6 +196,7 @@ func (a *Account) UnmarshalText(b []byte) error {

// An AccountToken authorizes an account action.
type AccountToken struct {
HostKey types.PublicKey `json:"hostKey"`
Account Account `json:"account"`
ValidUntil time.Time `json:"validUntil"`
Signature types.Signature `json:"signature"`
Expand All @@ -216,22 +205,12 @@ type AccountToken struct {
// SigHash returns the hash of the account token used for signing.
func (at *AccountToken) SigHash() types.Hash256 {
h := types.NewHasher()
at.HostKey.EncodeTo(h.E)
at.Account.EncodeTo(h.E)
h.E.WriteTime(at.ValidUntil)
return h.Sum()
}

// Validate verifies the account token is valid for use. It returns an error if
// the token has expired or the signature is invalid.
func (at AccountToken) Validate() error {
if time.Now().After(at.ValidUntil) {
return NewRPCError(ErrorCodeBadRequest, "account token expired")
} else if !types.PublicKey(at.Account).VerifyHash(at.SigHash(), at.Signature) {
return ErrInvalidSignature
}
return nil
}

// GenerateAccount generates a pair of private key and Account from a secure
// entropy source.
func GenerateAccount() (types.PrivateKey, Account) {
Expand Down
45 changes: 36 additions & 9 deletions rhp/v4/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,42 @@ package rhp
import (
"errors"
"fmt"
"time"

"go.sia.tech/core/types"
)

// Validate checks the host prices for validity. It returns an error if the
// prices have expired or the signature is invalid.
func (hp *HostPrices) Validate(pk types.PublicKey) error {
if time.Until(hp.ValidUntil) <= 0 {
return ErrPricesExpired
}
if !pk.VerifyHash(hp.SigHash(), hp.Signature) {
return ErrInvalidSignature
}
return nil
}

// Validate verifies the account token is valid for use. It returns an error if
// the token has expired or the signature is invalid.
func (at AccountToken) Validate(hostKey types.PublicKey) error {
switch {
case at.HostKey != hostKey:
return NewRPCError(ErrorCodeBadRequest, "host key mismatch")
case time.Now().After(at.ValidUntil):
return NewRPCError(ErrorCodeBadRequest, "account token expired")
case !types.PublicKey(at.Account).VerifyHash(at.SigHash(), at.Signature):
return ErrInvalidSignature
}
return nil
}

// Validate validates a read sector request.
func (req *RPCReadSectorRequest) Validate(pk types.PublicKey) error {
if err := req.Prices.Validate(pk); err != nil {
func (req *RPCReadSectorRequest) Validate(hostKey types.PublicKey) error {
if err := req.Prices.Validate(hostKey); err != nil {
return fmt.Errorf("prices are invalid: %w", err)
} else if err := req.Token.Validate(); err != nil {
} else if err := req.Token.Validate(hostKey); err != nil {
return fmt.Errorf("token is invalid: %w", err)
}
switch {
Expand All @@ -26,10 +53,10 @@ func (req *RPCReadSectorRequest) Validate(pk types.PublicKey) error {
}

// Validate validates a write sector request.
func (req *RPCWriteSectorRequest) Validate(pk types.PublicKey) error {
if err := req.Prices.Validate(pk); err != nil {
func (req *RPCWriteSectorRequest) Validate(hostKey types.PublicKey) error {
if err := req.Prices.Validate(hostKey); err != nil {
return fmt.Errorf("prices are invalid: %w", err)
} else if err := req.Token.Validate(); err != nil {
} else if err := req.Token.Validate(hostKey); err != nil {
return fmt.Errorf("token is invalid: %w", err)
}
switch {
Expand Down Expand Up @@ -200,10 +227,10 @@ func (req *RPCRefreshContractRequest) Validate(pk types.PublicKey, existingTotal
}

// Validate checks that the request is valid
func (req *RPCVerifySectorRequest) Validate(pk types.PublicKey) error {
if err := req.Prices.Validate(pk); err != nil {
func (req *RPCVerifySectorRequest) Validate(hostKey types.PublicKey) error {
if err := req.Prices.Validate(hostKey); err != nil {
return fmt.Errorf("prices are invalid: %w", err)
} else if err := req.Token.Validate(); err != nil {
} else if err := req.Token.Validate(hostKey); err != nil {
return fmt.Errorf("token is invalid: %w", err)
} else if req.LeafIndex >= LeavesPerSector {
return fmt.Errorf("leaf index must be less than %d", LeavesPerSector)
Expand Down
40 changes: 40 additions & 0 deletions rhp/v4/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package rhp

import (
"errors"
"strings"
"testing"
"time"

"go.sia.tech/core/types"
"lukechampine.com/frand"
)

func TestValidateAccountToken(t *testing.T) {
hostKey := types.GeneratePrivateKey().PublicKey()
renterKey := types.GeneratePrivateKey()
account := Account(renterKey.PublicKey())

ac := AccountToken{
HostKey: hostKey,
Account: account,
ValidUntil: time.Now().Add(-time.Minute),
}

if err := ac.Validate(frand.Entropy256()); !strings.Contains(err.Error(), "host key mismatch") {
t.Fatalf("expected host key mismatch, got %v", err)
} else if err := ac.Validate(hostKey); !strings.Contains(err.Error(), "token expired") {
t.Fatalf("expected token expired, got %v", err)
}

ac.ValidUntil = time.Now().Add(time.Minute)
if err := ac.Validate(hostKey); !errors.Is(err, ErrInvalidSignature) {
t.Fatalf("expected ErrInvalidSignature, got %v", err)
}

ac.Signature = renterKey.SignHash(ac.SigHash())

if err := ac.Validate(hostKey); err != nil {
t.Fatal(err)
}
}
Loading