Skip to content

Commit

Permalink
Add memory pool for timestamppb.Timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
ronenh committed Dec 8, 2024
1 parent b2198ea commit 6161ae1
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 40 deletions.
5 changes: 2 additions & 3 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cache
import (
"sync"

"github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/mempool"
"github.com/aserto-dev/azm/model"
"github.com/aserto-dev/azm/model/diff"
Expand All @@ -21,15 +20,15 @@ type (
type Cache struct {
model *model.Model
mtx sync.RWMutex
relsPool *graph.RelationsPool
relsPool *mempool.RelationsPool
}

// New, create new model cache instance.
func New(m *model.Model) *Cache {
return &Cache{
model: m,
mtx: sync.RWMutex{},
relsPool: mempool.NewCollectionPool[dsc.Relation, *dsc.Relation](),
relsPool: mempool.NewRelationsPool(),
}
}

Expand Down
5 changes: 3 additions & 2 deletions graph/check.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graph

import (
"github.com/aserto-dev/azm/mempool"
"github.com/aserto-dev/azm/model"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3"
Expand All @@ -15,10 +16,10 @@ type Checker struct {
getRels RelationReader

memo *checkMemo
pool *RelationsPool
pool *mempool.RelationsPool
}

func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader, pool *RelationsPool) *Checker {
func NewCheck(m *model.Model, req *dsr.CheckRequest, reader RelationReader, pool *mempool.RelationsPool) *Checker {
return &Checker{
m: m,
params: &relation{
Expand Down
5 changes: 2 additions & 3 deletions graph/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
azmgraph "github.com/aserto-dev/azm/graph"
"github.com/aserto-dev/azm/mempool"
v3 "github.com/aserto-dev/azm/v3"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -70,7 +69,7 @@ func TestCheck(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, m)

pool := mempool.NewCollectionPool[dsc.Relation]()
pool := mempool.NewRelationsPool()

for _, test := range tests {
t.Run(test.check, func(tt *testing.T) {
Expand All @@ -94,7 +93,7 @@ func BenchmarkCheck(b *testing.B) {
b.Fatalf("failed to load model: %s", err)
}

pool := mempool.NewCollectionPool[dsc.Relation]()
pool := mempool.NewRelationsPool()

b.ResetTimer()
for _, test := range tests {
Expand Down
5 changes: 3 additions & 2 deletions graph/objects.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package graph
import (
"strings"

"github.com/aserto-dev/azm/mempool"
"github.com/aserto-dev/azm/model"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3"
Expand All @@ -16,7 +17,7 @@ type ObjectSearch struct {
wildcardSearch *SubjectSearch
}

func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader, pool *RelationsPool) (*ObjectSearch, error) {
func NewObjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader, pool *mempool.RelationsPool) (*ObjectSearch, error) {
params := searchParams(req)
if err := validate(m, params); err != nil {
return nil, err
Expand Down Expand Up @@ -127,7 +128,7 @@ func wildcardParams(params *relation) *relation {
}

func invertedRelationReader(m *model.Model, reader RelationReader) RelationReader {
return func(r *dsc.Relation, relPool MessagePool[dsc.Relation, *dsc.Relation], out *Relations) error {
return func(r *dsc.Relation, relPool MessagePool[*dsc.Relation], out *Relations) error {
ir := uninvertRelation(m, relationFromProto(r))
if err := reader(ir.asProto(), relPool, out); err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion graph/objects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestSearchObjects(t *testing.T) {
im.Validate(model.SkipNameValidation, model.AllowPermissionInArrowBase),
)

pool := mempool.NewCollectionPool[dsc.Relation]()
pool := mempool.NewRelationsPool()

for _, test := range searchObjectsTests {
t.Run(test.search, func(tt *testing.T) {
Expand Down
11 changes: 5 additions & 6 deletions graph/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ import (
)

type (
ObjectID = model.ObjectID
Relations = []*dsc.Relation
RelationsPool = mempool.CollectionPool[dsc.Relation, *dsc.Relation]
ObjectID = model.ObjectID
Relations = []*dsc.Relation

searchPath relations

Expand All @@ -30,12 +29,12 @@ type (
searchResults map[object][]searchPath
)

type MessagePool[M any, T mempool.Resetable[M]] interface {
type MessagePool[T any] interface {
Get() T
Put(T)
}

type RelationPool = MessagePool[dsc.Relation, *dsc.Relation]
type RelationPool = MessagePool[*dsc.Relation]

// RelationReader retrieves relations that match the given filter.
type RelationReader func(*dsc.Relation, RelationPool, *Relations) error
Expand Down Expand Up @@ -104,7 +103,7 @@ type graphSearch struct {

memo *searchMemo
explain bool
pool *RelationsPool
pool *mempool.RelationsPool
}

func validate(m *model.Model, params *relation) error {
Expand Down
8 changes: 7 additions & 1 deletion graph/subjects.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package graph

import (
"github.com/aserto-dev/azm/mempool"
"github.com/aserto-dev/azm/model"
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
dsr "github.com/aserto-dev/go-directory/aserto/directory/reader/v3"
Expand All @@ -12,7 +13,12 @@ type SubjectSearch struct {
graphSearch
}

func NewSubjectSearch(m *model.Model, req *dsr.GetGraphRequest, reader RelationReader, pool *RelationsPool) (*SubjectSearch, error) {
func NewSubjectSearch(
m *model.Model,
req *dsr.GetGraphRequest,
reader RelationReader,
pool *mempool.RelationsPool,
) (*SubjectSearch, error) {
params := searchParams(req)
if err := validate(m, params); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion graph/subjects_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func TestSearchSubjects(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, m)

pool := mempool.NewCollectionPool[dsc.Relation]()
pool := mempool.NewRelationsPool()

for _, test := range searchSubjectsTests {
t.Run(test.search, func(tt *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions graph/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func parseRelation(r string) *relation {
return rel
}

func (r *relation) proto(pool graph.MessagePool[dsc.Relation, *dsc.Relation]) *dsc.Relation {
func (r *relation) proto(pool graph.MessagePool[*dsc.Relation]) *dsc.Relation {
rel := pool.Get()
rel.ObjectType = r.ObjectType.String()
rel.ObjectId = r.ObjectID.String()
Expand Down Expand Up @@ -97,7 +97,7 @@ func NewRelationsReader(rels ...string) RelationsReader {
})
}

func (r RelationsReader) GetRelations(req *dsc.Relation, pool graph.MessagePool[dsc.Relation, *dsc.Relation], out *graph.Relations) error {
func (r RelationsReader) GetRelations(req *dsc.Relation, pool graph.MessagePool[*dsc.Relation], out *graph.Relations) error {
ot := model.ObjectName(req.ObjectType)
oid := model.ObjectID(req.ObjectId)
rn := model.RelationName(req.Relation)
Expand Down
40 changes: 21 additions & 19 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package mempool

import "sync"
import (
"sync"
)

const defaultSliceCapacity = 128
const defaultSliceCapacity = 2048

type Pool[T any] struct {
sync.Pool
Expand Down Expand Up @@ -33,45 +35,45 @@ func NewSlicePool[T any]() *Pool[*[]T] {
})
}

type Resetable[T any] interface {
Reset()
*T
type Allocator[T any] interface {
New() T
Reset(T)
}

type CollectionPool[M any, T Resetable[M]] struct {
type CollectionPool[T any] struct {
slicePool *Pool[*[]T]
msgPool *Pool[T]
alloc Allocator[T]
}

func NewCollectionPool[M any, T Resetable[M]]() *CollectionPool[M, T] {
return &CollectionPool[M, T]{
func NewCollectionPool[T any](alloc Allocator[T]) *CollectionPool[T] {
return &CollectionPool[T]{
slicePool: NewSlicePool[T](),
alloc: alloc,
msgPool: NewPool(func() T {
return new(M)
return alloc.New()
}),
}
}

func (p CollectionPool[M, T]) GetSlice() *[]T {
// return p.slicePool.Get()
return p.slicePool.New().(*[]T)
func (p CollectionPool[T]) GetSlice() *[]T {
return p.slicePool.Get()
}

func (p *CollectionPool[M, T]) PutSlice(s *[]T) {
func (p *CollectionPool[T]) PutSlice(s *[]T) {
for _, item := range *s {
item.Reset()
p.alloc.Reset(item)
p.msgPool.Put(item)
}

*s = (*s)[:0]
p.slicePool.Put(s)
}

func (p *CollectionPool[M, T]) Get() T {
// return p.msgPool.Get()
return p.msgPool.New().(T)
func (p *CollectionPool[T]) Get() T {
return p.msgPool.Get()
}

func (p *CollectionPool[M, T]) Put(m T) {
p.msgPool.Put(m)
func (p *CollectionPool[T]) Put(t T) {
p.msgPool.Put(t)
}
47 changes: 47 additions & 0 deletions mempool/relation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package mempool

import (
dsc "github.com/aserto-dev/go-directory/aserto/directory/common/v3"
"google.golang.org/protobuf/types/known/timestamppb"
)

type RelationsPool = CollectionPool[*dsc.Relation]

func NewRelationsPool() *RelationsPool {
return NewCollectionPool[*dsc.Relation](NewRelationAllocator())
}

type RelationAllocator struct {
tsPool *Pool[*timestamppb.Timestamp]
}

func NewRelationAllocator() *RelationAllocator {
return &RelationAllocator{
tsPool: NewPool[*timestamppb.Timestamp](
func() *timestamppb.Timestamp {
return new(timestamppb.Timestamp)
}),
}
}

func (ra *RelationAllocator) New() *dsc.Relation {
rel := new(dsc.Relation)
rel.CreatedAt = ra.tsPool.Get()
rel.UpdatedAt = ra.tsPool.Get()
return rel
}

func (ra *RelationAllocator) Reset(rel *dsc.Relation) {
if rel.CreatedAt != nil {
rel.CreatedAt.Reset()
ra.tsPool.Put(rel.CreatedAt)
}
if rel.UpdatedAt != nil {
rel.UpdatedAt.Reset()
ra.tsPool.Put(rel.UpdatedAt)
}

rel.Reset()
rel.CreatedAt = ra.tsPool.Get()
rel.UpdatedAt = ra.tsPool.Get()
}

0 comments on commit 6161ae1

Please sign in to comment.