Skip to content

Commit

Permalink
Deduplicate repeated checks of linked wallets for entitlements (#46)
Browse files Browse the repository at this point in the history
Current implementation checked entitlements over all wallets once for
each wallet. This fix evaluates entitlement rules for all wallets a
single time.
  • Loading branch information
clemire authored May 23, 2024
1 parent 446e726 commit 04c6cdb
Showing 1 changed file with 77 additions and 80 deletions.
157 changes: 77 additions & 80 deletions core/node/auth/auth_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@ type ChainAuthArgs struct {
linkedWallets string // a serialized list of linked wallets to comply with the cache key constraints
}

// Replaces principal with given wallet and returns new copy of args.
func (args *ChainAuthArgs) withWallet(wallet common.Address) *ChainAuthArgs {
ret := *args
ret.principal = wallet
return &ret
}

func (args *ChainAuthArgs) withLinkedWallets(linkedWallets []common.Address) *ChainAuthArgs {
ret := *args
var builder strings.Builder
Expand Down Expand Up @@ -216,7 +209,11 @@ func (ca *chainAuth) IsEntitled(ctx context.Context, cfg *config.Config, args *C
return nil
}

func (ca *chainAuth) isWalletEntitled(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (bool, error) {
func (ca *chainAuth) areLinkedWalletsEntitled(
ctx context.Context,
cfg *config.Config,
args *ChainAuthArgs,
) (bool, error) {
log := dlog.FromCtx(ctx)
if args.kind == chainAuthKindSpace {
log.Debug("isWalletEntitled", "kind", "space", "args", args)
Expand All @@ -232,7 +229,11 @@ func (ca *chainAuth) isWalletEntitled(ctx context.Context, cfg *config.Config, a
}
}

func (ca *chainAuth) isSpaceEnabledUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) isSpaceEnabledUncached(
ctx context.Context,
cfg *config.Config,
args *ChainAuthArgs,
) (CacheResult, error) {
// This is awkward as we want enabled to be cached for 15 minutes, but the API returns the inverse
isDisabled, err := ca.spaceContract.IsSpaceDisabled(ctx, args.spaceId)
return &boolCacheResult{allowed: !isDisabled}, err
Expand Down Expand Up @@ -261,7 +262,11 @@ func (ca *chainAuth) checkSpaceEnabled(ctx context.Context, cfg *config.Config,
}
}

func (ca *chainAuth) isChannelEnabledUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) isChannelEnabledUncached(
ctx context.Context,
cfg *config.Config,
args *ChainAuthArgs,
) (CacheResult, error) {
// This is awkward as we want enabled to be cached for 15 minutes, but the API returns the inverse
isDisabled, err := ca.spaceContract.IsChannelDisabled(ctx, args.spaceId, args.channelId)
return &boolCacheResult{allowed: !isDisabled}, err
Expand Down Expand Up @@ -343,7 +348,11 @@ func deserializeWallets(serialized string) []common.Address {
return linkedWallets
}

func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) isEntitledToSpaceUncached(
ctx context.Context,
cfg *config.Config,
args *ChainAuthArgs,
) (CacheResult, error) {
log := dlog.FromCtx(ctx)
log.Debug("isEntitledToSpaceUncached", "args", args)
result, cacheHit, err := ca.entitlementManagerCache.executeUsingCache(
Expand All @@ -369,9 +378,21 @@ func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, cfg *config.

temp := (result.(*timestampedCacheValue).Result())

if args.principal == temp.(*entitlementCacheResult).owner {
log.Debug("owner is entitled to space", "spaceId", args.spaceId, "userId", args.principal)
return &boolCacheResult{allowed: true}, nil
wallets := deserializeWallets(args.linkedWallets)

for _, wallet := range wallets {
if wallet == temp.(*entitlementCacheResult).owner {
log.Debug(
"owner is entitled to space",
"spaceId",
args.spaceId,
"userId",
wallet,
"principal",
args.principal,
)
return &boolCacheResult{allowed: true}, nil
}
}

entitlementData := temp.(*entitlementCacheResult) // Assuming result is of *entitlementCacheResult type
Expand All @@ -381,8 +402,7 @@ func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, cfg *config.
if ent.entitlementType == "RuleEntitlement" {
re := ent.ruleEntitlement
log.Debug("RuleEntitlement", "ruleEntitlement", re)
result, err := entitlement.EvaluateRuleData(ctx, cfg, deserializeWallets(args.linkedWallets), re)

result, err := entitlement.EvaluateRuleData(ctx, cfg, wallets, re)
if err != nil {
return &boolCacheResult{allowed: false}, AsRiverError(err).Func("isEntitledToSpace")
}
Expand Down Expand Up @@ -429,15 +449,28 @@ func (ca *chainAuth) isEntitledToSpace(ctx context.Context, cfg *config.Config,
return isEntitled.IsAllowed(), nil
}

func (ca *chainAuth) isEntitledToChannelUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
allowed, err := ca.spaceContract.IsEntitledToChannel(
ctx,
args.spaceId,
args.channelId,
args.principal,
args.permission,
)
return &boolCacheResult{allowed: allowed}, err
func (ca *chainAuth) isEntitledToChannelUncached(
ctx context.Context,
cfg *config.Config,
args *ChainAuthArgs,
) (CacheResult, error) {
wallets := deserializeWallets(args.linkedWallets)
for _, wallet := range wallets {
allowed, err := ca.spaceContract.IsEntitledToChannel(
ctx,
args.spaceId,
args.channelId,
wallet,
args.permission,
)
if err != nil {
return &boolCacheResult{allowed: false}, err
}
if allowed {
return &boolCacheResult{allowed: true}, nil
}
}
return &boolCacheResult{allowed: false}, nil
}

func (ca *chainAuth) isEntitledToChannel(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (bool, error) {
Expand All @@ -458,11 +491,6 @@ func (ca *chainAuth) isEntitledToChannel(ctx context.Context, cfg *config.Config
return isEntitled.IsAllowed(), nil
}

type entitlementCheckResult struct {
allowed bool
err error
}

func (ca *chainAuth) getLinkedWallets(ctx context.Context, rootKey common.Address) ([]common.Address, error) {
log := dlog.FromCtx(ctx)

Expand Down Expand Up @@ -508,7 +536,11 @@ func (ca *chainAuth) checkMembership(
* If any of the operations fail before getting positive result, the whole operation fails.
* A prerequisite for this function is that one of the linked wallets is a member of the space.
*/
func (ca *chainAuth) checkEntitlement(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) {
func (ca *chainAuth) checkEntitlement(
ctx context.Context,
cfg *config.Config,
args *ChainAuthArgs,
) (CacheResult, error) {
log := dlog.FromCtx(ctx)

ctx, cancel := context.WithTimeout(ctx, time.Millisecond*time.Duration(ca.contractCallsTimeoutMs))
Expand Down Expand Up @@ -570,55 +602,20 @@ func (ca *chainAuth) checkEntitlement(ctx context.Context, cfg *config.Config, a
}

// Now that we know the user is a member of the space, we can check entitlements.
resultsChan := make(chan entitlementCheckResult, len(wallets))
var wg sync.WaitGroup

// Get linked wallets and check them in parallel.
wg.Add(1)
go func() {
// defer here is essential since we are (mis)using WaitGroup here.
// It is ok to increment the WaitGroup once it is being waited on as long as the counter is not zero
// (see https://pkg.go.dev/sync#WaitGroup)
// We are adding new goroutines to the WaitGroup in the loop below, so we need to make sure that the counter is always > 0.
defer wg.Done()
if len(wallets) > ca.linkedWalletsLimit {
log.Error("too many wallets linked to the root key", "rootKey", args.principal, "wallets", len(wallets))
resultsChan <- entitlementCheckResult{allowed: false, err: fmt.Errorf("too many wallets linked to the root key: %d", len(wallets)-1)}
return
}
// Check all wallets in parallel.
for _, wallet := range wallets {
wg.Add(1)
go func(address common.Address) {
defer wg.Done()
result, err := ca.isWalletEntitled(ctx, cfg, args.withWallet(address))
resultsChan <- entitlementCheckResult{allowed: result, err: err}
}(wallet)
}
}()

go func() {
wg.Wait()
close(resultsChan)
}()

for opResult := range resultsChan {
if opResult.err != nil {
// we don't check for context cancellation error here because
// * if it is a timeout it has to propagate
// * the explicit cancel happens only here, so it is not possible.

// Cancel all inflight requests.
cancel()
// Any error is a failure.
return &boolCacheResult{allowed: false}, opResult.err
}
if opResult.allowed {
// We have the result we need, cancel all inflight requests.
cancel()
if len(wallets) > ca.linkedWalletsLimit {
log.Error("too many wallets linked to the root key", "rootKey", args.principal, "wallets", len(wallets))
return &boolCacheResult{
allowed: false,
}, fmt.Errorf(
"too many wallets linked to the root key: %d",
len(wallets)-1,
)
}

return &boolCacheResult{allowed: true}, nil
}
result, err := ca.areLinkedWalletsEntitled(ctx, cfg, args)
if err != nil {
return &boolCacheResult{allowed: false}, err
}
return &boolCacheResult{allowed: false}, nil

return &boolCacheResult{allowed: result}, nil
}

0 comments on commit 04c6cdb

Please sign in to comment.