diff --git a/core/node/auth/auth_impl.go b/core/node/auth/auth_impl.go index 4f494ead93..8901d1e171 100644 --- a/core/node/auth/auth_impl.go +++ b/core/node/auth/auth_impl.go @@ -76,13 +76,6 @@ type ChainAuthArgs struct { linkedWallets string // a serialized list of linked wallets to comply with the cache key constraints } -// Replaces principal with given wallet and returns new copy of args. -func (args *ChainAuthArgs) withWallet(wallet common.Address) *ChainAuthArgs { - ret := *args - ret.principal = wallet - return &ret -} - func (args *ChainAuthArgs) withLinkedWallets(linkedWallets []common.Address) *ChainAuthArgs { ret := *args var builder strings.Builder @@ -216,7 +209,11 @@ func (ca *chainAuth) IsEntitled(ctx context.Context, cfg *config.Config, args *C return nil } -func (ca *chainAuth) isWalletEntitled(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (bool, error) { +func (ca *chainAuth) areLinkedWalletsEntitled( + ctx context.Context, + cfg *config.Config, + args *ChainAuthArgs, +) (bool, error) { log := dlog.FromCtx(ctx) if args.kind == chainAuthKindSpace { log.Debug("isWalletEntitled", "kind", "space", "args", args) @@ -232,7 +229,11 @@ func (ca *chainAuth) isWalletEntitled(ctx context.Context, cfg *config.Config, a } } -func (ca *chainAuth) isSpaceEnabledUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) { +func (ca *chainAuth) isSpaceEnabledUncached( + ctx context.Context, + cfg *config.Config, + args *ChainAuthArgs, +) (CacheResult, error) { // This is awkward as we want enabled to be cached for 15 minutes, but the API returns the inverse isDisabled, err := ca.spaceContract.IsSpaceDisabled(ctx, args.spaceId) return &boolCacheResult{allowed: !isDisabled}, err @@ -261,7 +262,11 @@ func (ca *chainAuth) checkSpaceEnabled(ctx context.Context, cfg *config.Config, } } -func (ca *chainAuth) isChannelEnabledUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) { +func (ca *chainAuth) isChannelEnabledUncached( + ctx context.Context, + cfg *config.Config, + args *ChainAuthArgs, +) (CacheResult, error) { // This is awkward as we want enabled to be cached for 15 minutes, but the API returns the inverse isDisabled, err := ca.spaceContract.IsChannelDisabled(ctx, args.spaceId, args.channelId) return &boolCacheResult{allowed: !isDisabled}, err @@ -343,7 +348,11 @@ func deserializeWallets(serialized string) []common.Address { return linkedWallets } -func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) { +func (ca *chainAuth) isEntitledToSpaceUncached( + ctx context.Context, + cfg *config.Config, + args *ChainAuthArgs, +) (CacheResult, error) { log := dlog.FromCtx(ctx) log.Debug("isEntitledToSpaceUncached", "args", args) result, cacheHit, err := ca.entitlementManagerCache.executeUsingCache( @@ -369,9 +378,21 @@ func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, cfg *config. temp := (result.(*timestampedCacheValue).Result()) - if args.principal == temp.(*entitlementCacheResult).owner { - log.Debug("owner is entitled to space", "spaceId", args.spaceId, "userId", args.principal) - return &boolCacheResult{allowed: true}, nil + wallets := deserializeWallets(args.linkedWallets) + + for _, wallet := range wallets { + if wallet == temp.(*entitlementCacheResult).owner { + log.Debug( + "owner is entitled to space", + "spaceId", + args.spaceId, + "userId", + wallet, + "principal", + args.principal, + ) + return &boolCacheResult{allowed: true}, nil + } } entitlementData := temp.(*entitlementCacheResult) // Assuming result is of *entitlementCacheResult type @@ -381,8 +402,7 @@ func (ca *chainAuth) isEntitledToSpaceUncached(ctx context.Context, cfg *config. if ent.entitlementType == "RuleEntitlement" { re := ent.ruleEntitlement log.Debug("RuleEntitlement", "ruleEntitlement", re) - result, err := entitlement.EvaluateRuleData(ctx, cfg, deserializeWallets(args.linkedWallets), re) - + result, err := entitlement.EvaluateRuleData(ctx, cfg, wallets, re) if err != nil { return &boolCacheResult{allowed: false}, AsRiverError(err).Func("isEntitledToSpace") } @@ -429,15 +449,28 @@ func (ca *chainAuth) isEntitledToSpace(ctx context.Context, cfg *config.Config, return isEntitled.IsAllowed(), nil } -func (ca *chainAuth) isEntitledToChannelUncached(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) { - allowed, err := ca.spaceContract.IsEntitledToChannel( - ctx, - args.spaceId, - args.channelId, - args.principal, - args.permission, - ) - return &boolCacheResult{allowed: allowed}, err +func (ca *chainAuth) isEntitledToChannelUncached( + ctx context.Context, + cfg *config.Config, + args *ChainAuthArgs, +) (CacheResult, error) { + wallets := deserializeWallets(args.linkedWallets) + for _, wallet := range wallets { + allowed, err := ca.spaceContract.IsEntitledToChannel( + ctx, + args.spaceId, + args.channelId, + wallet, + args.permission, + ) + if err != nil { + return &boolCacheResult{allowed: false}, err + } + if allowed { + return &boolCacheResult{allowed: true}, nil + } + } + return &boolCacheResult{allowed: false}, nil } func (ca *chainAuth) isEntitledToChannel(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (bool, error) { @@ -458,11 +491,6 @@ func (ca *chainAuth) isEntitledToChannel(ctx context.Context, cfg *config.Config return isEntitled.IsAllowed(), nil } -type entitlementCheckResult struct { - allowed bool - err error -} - func (ca *chainAuth) getLinkedWallets(ctx context.Context, rootKey common.Address) ([]common.Address, error) { log := dlog.FromCtx(ctx) @@ -508,7 +536,11 @@ func (ca *chainAuth) checkMembership( * If any of the operations fail before getting positive result, the whole operation fails. * A prerequisite for this function is that one of the linked wallets is a member of the space. */ -func (ca *chainAuth) checkEntitlement(ctx context.Context, cfg *config.Config, args *ChainAuthArgs) (CacheResult, error) { +func (ca *chainAuth) checkEntitlement( + ctx context.Context, + cfg *config.Config, + args *ChainAuthArgs, +) (CacheResult, error) { log := dlog.FromCtx(ctx) ctx, cancel := context.WithTimeout(ctx, time.Millisecond*time.Duration(ca.contractCallsTimeoutMs)) @@ -570,55 +602,20 @@ func (ca *chainAuth) checkEntitlement(ctx context.Context, cfg *config.Config, a } // Now that we know the user is a member of the space, we can check entitlements. - resultsChan := make(chan entitlementCheckResult, len(wallets)) - var wg sync.WaitGroup - - // Get linked wallets and check them in parallel. - wg.Add(1) - go func() { - // defer here is essential since we are (mis)using WaitGroup here. - // It is ok to increment the WaitGroup once it is being waited on as long as the counter is not zero - // (see https://pkg.go.dev/sync#WaitGroup) - // We are adding new goroutines to the WaitGroup in the loop below, so we need to make sure that the counter is always > 0. - defer wg.Done() - if len(wallets) > ca.linkedWalletsLimit { - log.Error("too many wallets linked to the root key", "rootKey", args.principal, "wallets", len(wallets)) - resultsChan <- entitlementCheckResult{allowed: false, err: fmt.Errorf("too many wallets linked to the root key: %d", len(wallets)-1)} - return - } - // Check all wallets in parallel. - for _, wallet := range wallets { - wg.Add(1) - go func(address common.Address) { - defer wg.Done() - result, err := ca.isWalletEntitled(ctx, cfg, args.withWallet(address)) - resultsChan <- entitlementCheckResult{allowed: result, err: err} - }(wallet) - } - }() - - go func() { - wg.Wait() - close(resultsChan) - }() - - for opResult := range resultsChan { - if opResult.err != nil { - // we don't check for context cancellation error here because - // * if it is a timeout it has to propagate - // * the explicit cancel happens only here, so it is not possible. - - // Cancel all inflight requests. - cancel() - // Any error is a failure. - return &boolCacheResult{allowed: false}, opResult.err - } - if opResult.allowed { - // We have the result we need, cancel all inflight requests. - cancel() + if len(wallets) > ca.linkedWalletsLimit { + log.Error("too many wallets linked to the root key", "rootKey", args.principal, "wallets", len(wallets)) + return &boolCacheResult{ + allowed: false, + }, fmt.Errorf( + "too many wallets linked to the root key: %d", + len(wallets)-1, + ) + } - return &boolCacheResult{allowed: true}, nil - } + result, err := ca.areLinkedWalletsEntitled(ctx, cfg, args) + if err != nil { + return &boolCacheResult{allowed: false}, err } - return &boolCacheResult{allowed: false}, nil + + return &boolCacheResult{allowed: result}, nil }