Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use sync.Pool for Relation slices. #60

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package cache
import (
"sync"

"github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/model"
"github.com/aserto-dev/azm/model/diff"
stts "github.com/aserto-dev/azm/stats"
Expand All @@ -17,15 +19,17 @@ type (
)

type Cache struct {
model *model.Model
mtx sync.RWMutex
model *model.Model
mtx sync.RWMutex
relsPool *graph.RelationsPool
}

// New, create new model cache instance.
func New(m *model.Model) *Cache {
return &Cache{
model: m,
mtx: sync.RWMutex{},
model: m,
mtx: sync.RWMutex{},
relsPool: mempool.NewSlicePool[*dsc.Relation](),
}
}

Expand Down
12 changes: 9 additions & 3 deletions cache/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
)

func (c *Cache) Check(req *dsr.CheckRequest, relReader graph.RelationReader) (*dsr.CheckResponse, error) {
checker := graph.NewCheck(c.model, req, relReader)
c.mtx.RLock()
defer c.mtx.RUnlock()

checker := graph.NewCheck(c.model, req, relReader, c.relsPool)

ctx := pb.NewStruct()

Expand All @@ -26,15 +29,18 @@ type graphSearch interface {
}

func (c *Cache) GetGraph(req *dsr.GetGraphRequest, relReader graph.RelationReader) (*dsr.GetGraphResponse, error) {
c.mtx.RLock()
defer c.mtx.RUnlock()

var (
search graphSearch
err error
)

if req.ObjectId == "" {
search, err = graph.NewObjectSearch(c.model, req, relReader)
search, err = graph.NewObjectSearch(c.model, req, relReader, c.relsPool)
} else {
search, err = graph.NewSubjectSearch(c.model, req, relReader)
search, err = graph.NewSubjectSearch(c.model, req, relReader, c.relsPool)
}

if err != nil {
Expand Down
37 changes: 27 additions & 10 deletions graph/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ type Checker struct {
getRels RelationReader

memo *checkMemo
pool *RelationsPool
}

func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader) *Checker {
func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader, pool *RelationsPool) *Checker {
return &Checker{
m: m,
params: &relation{
Expand All @@ -29,6 +30,7 @@ func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader) *Che
},
getRels: reader,
memo: newCheckMemo(req.Trace),
pool: pool,
}
}

Expand Down Expand Up @@ -88,7 +90,16 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) {
r := c.m.Objects[params.ot].Relations[params.rel]
steps := c.m.StepRelation(r, params.st)

// Reuse the same slice in all steps.
relsPtr := c.pool.Get()
defer func() {
*relsPtr = (*relsPtr)[:0]
c.pool.Put(relsPtr)
}()

for _, step := range steps {
*relsPtr = (*relsPtr)[:0]

req := &dsc.Relation{
ObjectType: params.ot.String(),
ObjectId: params.oid.String(),
Expand All @@ -103,27 +114,26 @@ func (c *Checker) checkRelation(params *relation) (checkStatus, error) {
req.SubjectRelation = step.Relation.String()
}

rels, err := c.getRels(req)
if err != nil {
if err := c.getRels(req, relsPtr); err != nil {
return checkStatusFalse, err
}

switch {
case step.IsDirect():
for _, rel := range rels {
for _, rel := range *relsPtr {
if rel.SubjectId == params.sid.String() {
return checkStatusTrue, nil
}
}

case step.IsWildcard():
if len(rels) > 0 {
if len(*relsPtr) > 0 {
// We have a wildcard match.
return checkStatusTrue, nil
}

case step.IsSubject():
for _, rel := range rels {
for _, rel := range *relsPtr {
if status, err := c.check(&relation{
ot: step.Object,
oid: ObjectID(rel.SubjectId),
Expand Down Expand Up @@ -190,17 +200,21 @@ func (c *Checker) checkPermission(params *relation) (checkStatus, error) {

func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relations, error) {
if pt.IsArrow() {
// Resolve the base of the arrow.
rels, err := c.getRels(&dsc.Relation{
query := &dsc.Relation{
ObjectType: params.ot.String(),
ObjectId: params.oid.String(),
Relation: pt.Base.String(),
})
}

relsPtr := c.pool.Get()

// Resolve the base of the arrow.
err := c.getRels(query, relsPtr)
if err != nil {
return relations{}, err
}

expanded := lo.Map(rels, func(rel *dsc.Relation, _ int) *relation {
expanded := lo.Map(*relsPtr, func(rel *dsc.Relation, _ int) *relation {
return &relation{
ot: model.ObjectName(rel.SubjectType),
oid: ObjectID(rel.SubjectId),
Expand All @@ -210,6 +224,9 @@ func (c *Checker) expandTerm(pt *model.PermissionTerm, params *relation) (relati
}
})

*relsPtr = (*relsPtr)[:0]
c.pool.Put(relsPtr)

return expanded, nil
}

Expand Down
6 changes: 5 additions & 1 deletion graph/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (
"testing"

azmgraph "github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
v3 "github.com/aserto-dev/azm/v3"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -67,11 +69,13 @@ func TestCheck(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, m)

pool := mempool.NewSlicePool[*dsc.Relation]()

for _, test := range tests {
t.Run(test.check, func(tt *testing.T) {
assert := assert.New(tt)

checker := azmgraph.NewCheck(m, checkReq(test.check), rels.GetRelations)
checker := azmgraph.NewCheck(m, checkReq(test.check), rels.GetRelations, pool)

res, err := checker.Check()
assert.NoError(err)
Expand Down
20 changes: 12 additions & 8 deletions graph/objects.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ObjectSearch struct {
wildcardSearch *SubjectSearch
}

func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader) (*ObjectSearch, error) {
func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader, pool *RelationsPool) (*ObjectSearch, error) {
params := searchParams(req)
if err := validate(m, params); err != nil {
return nil, err
Expand All @@ -40,13 +40,15 @@ func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationRe
getRels: invertedRelationReader(im, reader),
memo: newSearchMemo(req.Trace),
explain: req.Explain,
pool: pool,
}},
wildcardSearch: &SubjectSearch{graphSearch{
m: im,
params: wildcardParams(iParams),
getRels: invertedRelationReader(im, reader),
memo: newSearchMemo(req.Trace),
explain: req.Explain,
pool: pool,
}},
}, nil
}
Expand Down Expand Up @@ -125,22 +127,24 @@ func wildcardParams(params *relation) *relation {
}

func invertedRelationReader(m *model.Model, reader RelationReader) RelationReader {
return func(r *dsc.Relation) ([]*dsc.Relation, error) {
return func(r *dsc.Relation, out *Relations) error {
ir := uninvertRelation(m, relationFromProto(r))
res, err := reader(ir.asProto())
if err != nil {
return nil, err
if err := reader(ir.asProto(), out); err != nil {
return err
}

return lo.Map(res, func(r *dsc.Relation, _ int) *dsc.Relation {
return &dsc.Relation{
res := *out
for i, r := range res {
res[i] = &dsc.Relation{
ObjectType: r.SubjectType,
ObjectId: r.SubjectId,
Relation: r.Relation,
SubjectType: r.ObjectType,
SubjectId: r.ObjectId,
}
}), nil
}

return nil
}
}

Expand Down
5 changes: 4 additions & 1 deletion graph/objects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/model"
v3 "github.com/aserto-dev/azm/v3"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
Expand Down Expand Up @@ -35,11 +36,13 @@ func TestSearchObjects(t *testing.T) {
im.Validate(model.SkipNameValidation, model.AllowPermissionInArrowBase),
)

pool := mempool.NewSlicePool[*dsc.Relation]()

for _, test := range searchObjectsTests {
t.Run(test.search, func(tt *testing.T) {
assert := assert.New(tt)

objSearch, err := graph.NewObjectSearch(m, graphReq(test.search), rels.GetRelations)
objSearch, err := graph.NewObjectSearch(m, graphReq(test.search), rels.GetRelations, pool)
assert.NoError(err)

res, err := objSearch.Search()
Expand Down
30 changes: 19 additions & 11 deletions graph/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"strings"

"github.com/aserto-dev/azm/internal/mempool"
"github.com/aserto-dev/azm/model"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3"
Expand All @@ -12,21 +13,27 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

type ObjectID = model.ObjectID
type (
ObjectID = model.ObjectID

// RelationReader retrieves relations that match the given filter.
type RelationReader func(*dsc.Relation) ([]*dsc.Relation, error)
Relations = []*dsc.Relation

type searchPath relations
// RelationReader retrieves relations that match the given filter.
RelationReader func(*dsc.Relation, *Relations) error

type object struct {
Type model.ObjectName
ID ObjectID
}
RelationsPool = mempool.Pool[*Relations]

searchPath relations

// The results of a search is a map where the key is a matching relations
// and the value is a list of paths that connect the search object and subject.
type searchResults map[object][]searchPath
object struct {
Type model.ObjectName
ID ObjectID
}

// The results of a search is a map where the key is a matching relations
// and the value is a list of paths that connect the search object and subject.
searchResults map[object][]searchPath
)

// Objects returns the objects from the search results.
func (r searchResults) Objects() []*dsc.ObjectIdentifier {
Expand Down Expand Up @@ -92,6 +99,7 @@ type graphSearch struct {

memo *searchMemo
explain bool
pool *RelationsPool
}

func validate(m *model.Model, params *relation) error {
Expand Down
Loading
Loading