From 72cc306818a38f2abef68b89d2131936de75633b Mon Sep 17 00:00:00 2001 From: Waleed Gadelkareem Date: Mon, 20 Apr 2020 20:55:39 +0200 Subject: [PATCH] add cache tagging --- README.md | 79 ++++++++++++++++++++++++----------- cache.go | 12 ++++++ cache_test.go | 41 ++++++++++++++++++- example_test.go | 41 +++++++++++++++---- file.go | 107 +++++++++++++++++++++++++++++++++++++++--------- file_test.go | 12 +++++- go.mod | 2 + memory.go | 37 +++++++++++++++++ memory_test.go | 9 ++++ redis.go | 55 ++++++++++++++++++++++--- redis_test.go | 12 +++++- sql.go | 83 +++++++++++++++++++++++++++++++++++++ sql_test.go | 14 ++++++- 13 files changed, 441 insertions(+), 63 deletions(-) diff --git a/README.md b/README.md index 87bffb9..56b6909 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Cachita is a golang file, memory, SQL, Redis cache library - In memory file cache index to avoid unneeded I/O. - [Msgpack](https://msgpack.org/index.html) based binary serialization using [msgpack](https://github.com/vmihailenco/msgpack) library for file caching. - [radix](https://github.com/mediocregopher/radix) Redis client. +- Tag cache and invalidate cache keys based on tags, check in the [examples](https://godoc.org/github.com/gadelkareem/cachita#pkg-examples). API docs: https://godoc.org/github.com/gadelkareem/cachita. @@ -34,7 +35,7 @@ func ExampleCache() { } if cache.Exists("cache_key") { - //do something + // do something } var holder string @@ -43,14 +44,14 @@ func ExampleCache() { panic(err) } - fmt.Printf("%s", holder) //prints "some data" + fmt.Printf("%s", holder) // prints "some data" err = cache.Invalidate("cache_key") if err != nil { panic(err) } - //Output: some data + // Output: some data } @@ -60,29 +61,59 @@ func ExampleCache() { ``` > go test -v -bench=. -benchmem -BenchmarkMemoryCacheWithInt-8 1000000 1218 ns/op 120 B/op 6 allocs/op -BenchmarkMemoryCacheWithString-8 1000000 1234 ns/op 136 B/op 6 allocs/op -BenchmarkMemoryCacheWithMapInterface-8 1000000 1445 ns/op 536 B/op 10 allocs/op -BenchmarkMemoryCacheWithStruct-8 1000000 1588 ns/op 680 B/op 11 allocs/op -BenchmarkMemory_Incr-8 500000 2389 ns/op 192 B/op 10 allocs/op -BenchmarkFileCacheWithInt-8 10000 110629 ns/op 2946 B/op 34 allocs/op -BenchmarkFileCacheWithString-8 10000 117502 ns/op 2968 B/op 35 allocs/op -BenchmarkFileCacheWithMapInterface-8 10000 121150 ns/op 4998 B/op 58 allocs/op -BenchmarkFileCacheWithStruct-8 10000 120383 ns/op 5909 B/op 63 allocs/op -BenchmarkFile_Incr-8 10000 188167 ns/op 7095 B/op 74 allocs/op -BenchmarkRedisCacheWithInt-8 5000 331572 ns/op 703 B/op 25 allocs/op -BenchmarkRedisCacheWithString-8 5000 351982 ns/op 1202 B/op 35 allocs/op -BenchmarkRedisCacheWithMapInterface-8 5000 331931 ns/op 3284 B/op 59 allocs/op -BenchmarkRedisCacheWithStruct-8 5000 336453 ns/op 4184 B/op 64 allocs/op -BenchmarkRedis_Incr-8 2000 774163 ns/op 1598 B/op 45 allocs/op -BenchmarkSqlCacheWithInt-8 1000 2468703 ns/op 5168 B/op 143 allocs/op -BenchmarkSqlCacheWithString-8 1000 2121222 ns/op 5121 B/op 135 allocs/op -BenchmarkSqlCacheWithMapInterface-8 1000 2838557 ns/op 11137 B/op 373 allocs/op -BenchmarkSqlCacheWithStruct-8 1000 1903278 ns/op 13880 B/op 450 allocs/op -BenchmarkSql_Incr-8 500 3175832 ns/op 9693 B/op 268 allocs/op +BenchmarkFileCacheWithInt +BenchmarkFileCacheWithInt-8 10000 116118 ns/op 2447 B/op 31 allocs/op +BenchmarkFileCacheWithString +BenchmarkFileCacheWithString-8 10909 123491 ns/op 2470 B/op 32 allocs/op +BenchmarkFileCacheWithMapInterface +BenchmarkFileCacheWithMapInterface-8 9862 124641 ns/op 4499 B/op 55 allocs/op +BenchmarkFileCacheWithStruct +BenchmarkFileCacheWithStruct-8 9356 130355 ns/op 5404 B/op 60 allocs/op +BenchmarkFile_Incr +BenchmarkFile_Incr-8 6331 192199 ns/op 3113 B/op 44 allocs/op +BenchmarkFile_Tag +BenchmarkFile_Tag-8 3885 286273 ns/op 2720 B/op 47 allocs/op +BenchmarkMemoryCacheWithInt +BenchmarkMemoryCacheWithInt-8 870573 1288 ns/op 120 B/op 6 allocs/op +BenchmarkMemoryCacheWithString +BenchmarkMemoryCacheWithString-8 938899 1161 ns/op 136 B/op 6 allocs/op +BenchmarkMemoryCacheWithMapInterface +BenchmarkMemoryCacheWithMapInterface-8 835402 1618 ns/op 536 B/op 10 allocs/op +BenchmarkMemoryCacheWithStruct +BenchmarkMemoryCacheWithStruct-8 771076 1591 ns/op 680 B/op 11 allocs/op +BenchmarkMemory_Incr +BenchmarkMemory_Incr-8 649772 1784 ns/op 184 B/op 9 allocs/op +BenchmarkMemory_Tag +BenchmarkMemory_Tag-8 361974 3458 ns/op 439 B/op 14 allocs/op +BenchmarkRedisCacheWithInt +BenchmarkRedisCacheWithInt-8 1404 787836 ns/op 492 B/op 21 allocs/op +BenchmarkRedisCacheWithString +BenchmarkRedisCacheWithString-8 1573 775092 ns/op 995 B/op 32 allocs/op +BenchmarkRedisCacheWithMapInterface +BenchmarkRedisCacheWithMapInterface-8 1506 709349 ns/op 3074 B/op 55 allocs/op +BenchmarkRedisCacheWithStruct +BenchmarkRedisCacheWithStruct-8 1714 872728 ns/op 3969 B/op 61 allocs/op +BenchmarkRedis_Incr +BenchmarkRedis_Incr-8 1153 1096139 ns/op 1235 B/op 32 allocs/op +BenchmarkRedis_Tag +BenchmarkRedis_Tag-8 379 3356175 ns/op 8325 B/op 201 allocs/op +BenchmarkSqlCacheWithInt +BenchmarkSqlCacheWithInt-8 277 3960950 ns/op 4741 B/op 115 allocs/op +BenchmarkSqlCacheWithString +BenchmarkSqlCacheWithString-8 280 3979248 ns/op 4679 B/op 106 allocs/op +BenchmarkSqlCacheWithMapInterface +BenchmarkSqlCacheWithMapInterface-8 282 4816726 ns/op 11444 B/op 352 allocs/op +BenchmarkSqlCacheWithStruct +BenchmarkSqlCacheWithStruct-8 230 4375050 ns/op 13730 B/op 425 allocs/op +BenchmarkSql_Incr +BenchmarkSql_Incr-8 199 6042507 ns/op 6220 B/op 154 allocs/op +BenchmarkSql_Tag +BenchmarkSql_Tag-8 57 35618536 ns/op 836763 B/op 967 allocs/op +PASS +ok github.com/gadelkareem/cachita 40.188s ``` -## Howto +## How to Please go through [examples](https://godoc.org/github.com/gadelkareem/cachita#pkg-examples) to get an idea how to use this package. diff --git a/cache.go b/cache.go index 524887b..3f3fa6b 100644 --- a/cache.go +++ b/cache.go @@ -15,8 +15,11 @@ type ( Get(key string, i interface{}) error Put(key string, i interface{}, ttl time.Duration) error // ttl 0:default ttl, -1: keep forever Incr(key string, ttl time.Duration) (int64, error) + Tag(key string, tags ...string) error Exists(key string) bool Invalidate(key string) error + InvalidateMulti(keys ...string) error + InvalidateTags(tags ...string) error } record struct { Data interface{} @@ -64,6 +67,15 @@ func runEvery(ttl time.Duration, f func()) { }() } +func inArr(a []string, x string) bool { + for _, n := range a { + if x == n { + return true + } + } + return false +} + func TypeAssert(source, target interface{}) (err error) { if source == nil { return nil diff --git a/cache_test.go b/cache_test.go index c40e14b..0a36d03 100644 --- a/cache_test.go +++ b/cache_test.go @@ -2,10 +2,11 @@ package cachita import ( "fmt" - "github.com/stretchr/testify/assert" "math/rand" "testing" "time" + + "github.com/stretchr/testify/assert" ) func newCache(c Cache, t *testing.T) { @@ -31,7 +32,7 @@ func cacheExpires(c Cache, t *testing.T, ttl, tts time.Duration) { } } -func test(c Cache, k string, s, d interface{}, t assert.TestingT, f ... func(t assert.TestingT, s, d interface{})) { +func test(c Cache, k string, s, d interface{}, t assert.TestingT, f ...func(t assert.TestingT, s, d interface{})) { k = fmt.Sprintf("%s%d", k, rand.Int()) disableAssert := isBenchmark(t) @@ -89,6 +90,30 @@ func testIncr(c Cache, k string, t assert.TestingT) { } } +func testTag(c Cache, k string, t assert.TestingT) { + k = fmt.Sprintf("%s%d", k, rand.Int()) + k2 := fmt.Sprintf("%s%d", k, rand.Int()) + disableAssert := isBenchmark(t) + + err := c.Put(k, "test", 0) + isError(err, t) + err = c.Put(k2, "test2", 0) + isError(err, t) + + tags := []string{"t1", "t2"} + err = c.Tag(k, tags[0]) + isError(err, t) + err = c.Tag(k2, tags...) + isError(err, t) + + err = c.InvalidateTags(tags...) + isError(err, t) + if !disableAssert { + assert.False(t, c.Exists(k)) + assert.False(t, c.Exists(k2)) + } +} + func benchmarkCacheWithInt(c Cache, b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -191,6 +216,18 @@ func benchmarkCacheIncr(c Cache, b *testing.B) { }) } +func cacheTag(c Cache, t assert.TestingT) { + testTag(c, "x", t) +} + +func benchmarkCacheTag(c Cache, b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cacheTag(c, b) + } + }) +} + func compareMap(t assert.TestingT, s1, d1 interface{}) { s := *s1.(*map[string]interface{}) d := *d1.(*map[string]interface{}) diff --git a/example_test.go b/example_test.go index 366374d..ed1ee6b 100644 --- a/example_test.go +++ b/example_test.go @@ -16,7 +16,7 @@ func ExampleCache() { } if cache.Exists("cache_key") { - //do something + // do something } var holder string @@ -25,14 +25,14 @@ func ExampleCache() { panic(err) } - fmt.Printf("%s", holder) //prints "some data" + fmt.Printf("%s", holder) // prints "some data" err = cache.Invalidate("cache_key") if err != nil { panic(err) } - //Output: some data + // Output: some data } @@ -53,7 +53,7 @@ func ExampleMemory() { } fmt.Printf("%+v", cacheObj) - //Output: map[test:data] + // Output: map[test:data] } @@ -74,16 +74,16 @@ func ExampleFile() { panic(err) } - fmt.Printf("%s", holder) //prints "some data" + fmt.Printf("%s", holder) // prints "some data" - //Output: some data + // Output: some data } func ExampleNewMemoryCache() { - cache := cachita.NewMemoryCache(1*time.Millisecond, 1*time.Minute) //default ttl 1 millisecond + cache := cachita.NewMemoryCache(1*time.Millisecond, 1*time.Minute) // default ttl 1 millisecond - err := cache.Put("cache_key", "some data", 0) //ttl = 0 means use default + err := cache.Put("cache_key", "some data", 0) // ttl = 0 means use default if err != nil { panic(err) } @@ -91,6 +91,29 @@ func ExampleNewMemoryCache() { time.Sleep(2 * time.Millisecond) fmt.Printf("%t", cache.Exists("cache_key")) - //Output: false + // Output: false + +} + +func ExampleTaggedCache() { + cache := cachita.NewMemoryCache(1*time.Millisecond, 1*time.Minute) // default ttl 1 millisecond + + err := cache.Put("cache_key", "some data", 0) // ttl = 0 means use default + if err != nil { + panic(err) + } + + err = cache.Tag("cache_key", "tag1", "tag2") + if err != nil { + panic(err) + } + + err = cache.InvalidateTags("tag2") + if err != nil { + panic(err) + } + + fmt.Printf("%t", cache.Exists("cache_key")) + // Output: false } diff --git a/file.go b/file.go index d200230..241319a 100644 --- a/file.go +++ b/file.go @@ -23,9 +23,11 @@ type file struct { } type fileIndex struct { - sync.RWMutex - records map[string]time.Time - path string + recordsMu sync.RWMutex + records map[string]time.Time + tagsMu sync.Mutex + tags map[string][]string + path string } func File() (Cache, error) { @@ -46,8 +48,9 @@ func File() (Cache, error) { func NewFileCache(dir string, ttl, tickerTtl time.Duration) (Cache, error) { var ( err error + i *fileIndex ) - i, err := newIndex(dir, ttl) + i, err = newIndex(dir, ttl) if err != nil { return nil, err } @@ -112,15 +115,47 @@ func (c *file) path(id string) string { func (c *file) deleteExpired() { expired := c.i.expiredRecords() for _, id := range expired { - os.Remove(c.path(id)) + _ = os.Remove(c.path(id)) } } -//----------------------- fileIndex +func (c *file) InvalidateMulti(keys ...string) (err error) { + var ids []string + for _, key := range keys { + id := Id(key) + err = os.Remove(c.path(id)) + if err != nil && !isNotFound(err) { + return + } + } + c.i.removeMulti(ids...) + return +} + +// tags are only managed via the index +func (c *file) Tag(key string, tags ...string) error { + c.i.tag(Id(key), tags...) + return nil +} + +func (c *file) InvalidateTags(tags ...string) (err error) { + ids := c.i.removeTags(tags...) + for _, id := range ids { + err = os.Remove(c.path(id)) + if err != nil && !isNotFound(err) { + return + } + } + c.i.removeMulti(ids...) + return nil +} + +// ----------------------- fileIndex func newIndex(dir string, ttl time.Duration) (i *fileIndex, err error) { i = &fileIndex{path: filepath.Join(dir, Id(FileIndex))} i.records = make(map[string]time.Time) + i.tags = make(map[string][]string) err = readData(i.path, &i.records) if err != nil && err != ErrNotFound { @@ -131,8 +166,8 @@ func newIndex(dir string, ttl time.Duration) (i *fileIndex, err error) { currentDir string files []os.FileInfo ) - i.Lock() - defer i.Unlock() + i.recordsMu.Lock() + defer i.recordsMu.Unlock() characters := "0123456789abcdef" for _, char1 := range characters { for _, char2 := range characters { @@ -170,8 +205,8 @@ func newIndex(dir string, ttl time.Duration) (i *fileIndex, err error) { } func (i *fileIndex) check(id string) error { - i.RLock() - defer i.RUnlock() + i.recordsMu.RLock() + defer i.recordsMu.RUnlock() expiredAt, exists := i.records[id] if !exists { return ErrNotFound @@ -183,8 +218,8 @@ func (i *fileIndex) check(id string) error { } func (i *fileIndex) expiredRecords() []string { - i.Lock() - defer i.Unlock() + i.recordsMu.Lock() + defer i.recordsMu.Unlock() var ( expired []string records = make(map[string]time.Time) @@ -205,26 +240,55 @@ func (i *fileIndex) expiredRecords() []string { } func (i *fileIndex) add(id string, expiredAt time.Time) { - i.Lock() - defer i.Unlock() + i.recordsMu.Lock() + defer i.recordsMu.Unlock() i.records[id] = expiredAt return } func (i *fileIndex) remove(id string) { - i.Lock() - defer i.Unlock() + i.recordsMu.Lock() + defer i.recordsMu.Unlock() delete(i.records, id) } -//-------------------- +func (i *fileIndex) removeMulti(ids ...string) { + i.recordsMu.Lock() + defer i.recordsMu.Unlock() + for _, id := range ids { + delete(i.records, id) + } +} + +func (i *fileIndex) tag(id string, tags ...string) { + i.tagsMu.Lock() + defer i.tagsMu.Unlock() + for _, t := range tags { + if inArr(i.tags[t], id) { + continue + } + i.tags[t] = append(i.tags[t], id) + } +} + +func (i *fileIndex) removeTags(tags ...string) (ids []string) { + i.tagsMu.Lock() + for _, t := range tags { + ids = append(ids, i.tags[t]...) + delete(i.tags, t) + } + i.tagsMu.Unlock() + return +} + +// -------------------- func exists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { return true, nil } - if os.IsNotExist(err) { + if isNotFound(err) { return false, nil } return false, err @@ -233,7 +297,7 @@ func exists(path string) (bool, error) { func readData(path string, i interface{}) error { data, err := ioutil.ReadFile(path) if err != nil { - if os.IsNotExist(err) || err == io.EOF { + if isNotFound(err) { return ErrNotFound } return err @@ -247,6 +311,7 @@ func readData(path string, i interface{}) error { } return nil } + func writeData(path string, i interface{}) error { data, err := msgpack.Marshal(i) if err != nil { @@ -254,3 +319,7 @@ func writeData(path string, i interface{}) error { } return ioutil.WriteFile(path, data, 0666) } + +func isNotFound(e error) bool { + return os.IsNotExist(e) || e == io.EOF +} diff --git a/file_test.go b/file_test.go index 3f4000c..c47e8aa 100644 --- a/file_test.go +++ b/file_test.go @@ -1,11 +1,12 @@ package cachita import ( - "github.com/stretchr/testify/assert" "os" "path/filepath" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestNewFileCache(t *testing.T) { @@ -99,3 +100,12 @@ func fc(t assert.TestingT) (c Cache) { isError(err, t) return } + +func TestFile_Tag(t *testing.T) { + t.Parallel() + cacheTag(fc(t), t) +} + +func BenchmarkFile_Tag(b *testing.B) { + benchmarkCacheTag(fc(b), b) +} diff --git a/go.mod b/go.mod index 8d0f39d..86380e4 100644 --- a/go.mod +++ b/go.mod @@ -13,3 +13,5 @@ require ( google.golang.org/appengine v1.4.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) + +go 1.13 diff --git a/memory.go b/memory.go index 669b39a..cc98e53 100644 --- a/memory.go +++ b/memory.go @@ -10,6 +10,8 @@ var mCache Cache type memory struct { recordsMu sync.RWMutex records map[string]*record + tagsMu sync.Mutex + tags map[string][]string ttl time.Duration } @@ -23,6 +25,7 @@ func Memory() Cache { func NewMemoryCache(ttl, tickerTtl time.Duration) Cache { c := &memory{ records: make(map[string]*record), + tags: make(map[string][]string), ttl: ttl, } @@ -90,3 +93,37 @@ func (c *memory) deleteExpired() { } c.records = records } + +func (c *memory) InvalidateMulti(keys ...string) error { + c.recordsMu.Lock() + defer c.recordsMu.Unlock() + for _, key := range keys { + delete(c.records, key) + } + return nil +} + +func (c *memory) Tag(key string, tags ...string) error { + c.tagsMu.Lock() + defer c.tagsMu.Unlock() + for _, t := range tags { + if inArr(c.tags[t], key) { + continue + } + c.tags[t] = append(c.tags[t], key) + } + + return nil +} + +func (c *memory) InvalidateTags(tags ...string) error { + c.tagsMu.Lock() + var keys []string + for _, t := range tags { + keys = append(keys, c.tags[t]...) + delete(c.tags, t) + } + c.tagsMu.Unlock() + + return c.InvalidateMulti(keys...) +} diff --git a/memory_test.go b/memory_test.go index 3d28449..c2c1ed5 100644 --- a/memory_test.go +++ b/memory_test.go @@ -59,3 +59,12 @@ func TestMemory_Incr(t *testing.T) { func BenchmarkMemory_Incr(b *testing.B) { benchmarkCacheIncr(Memory(), b) } + +func TestMemory_Tag(t *testing.T) { + t.Parallel() + cacheTag(Memory(), t) +} + +func BenchmarkMemory_Tag(b *testing.B) { + benchmarkCacheTag(Memory(), b) +} diff --git a/redis.go b/redis.go index 7a9ec96..7e35e98 100644 --- a/redis.go +++ b/redis.go @@ -2,9 +2,10 @@ package cachita import ( "fmt" + "time" + "github.com/mediocregopher/radix/v3" "github.com/vmihailenco/msgpack" - "time" ) var rCache Cache @@ -93,17 +94,21 @@ func (c *redis) Incr(key string, ttl time.Duration) (int64, error) { } func (c *redis) Invalidate(key string) error { - return c.pool.Do(radix.FlatCmd(nil, "DEL", c.k(key))) + return c.pool.Do(radix.Cmd(nil, "DEL", c.k(key))) } func (c *redis) Exists(key string) bool { var b bool - c.pool.Do(radix.FlatCmd(&b, "EXISTS", c.k(key))) - return b + err := c.pool.Do(radix.Cmd(&b, "EXISTS", c.k(key))) + return err == nil && b } func (c *redis) k(key string) string { - return fmt.Sprintf("%s:%s", c.prefix, key) + return fmt.Sprintf("%s:keys::%s", c.prefix, key) +} + +func (c *redis) t(tag string) string { + return fmt.Sprintf("%s:tags::%s", c.prefix, tag) } func isInt(i interface{}) bool { @@ -123,3 +128,43 @@ func isInt(i interface{}) bool { } return true } + +func (c *redis) InvalidateMulti(keys ...string) error { + var cmds []radix.CmdAction + for _, k := range keys { + cmds = append(cmds, radix.FlatCmd(nil, "DEL", c.k(k))) + } + return c.pool.Do(radix.Pipeline(cmds...)) +} + +func (c *redis) Tag(key string, tags ...string) (err error) { + rKey := c.k(key) + var cmds []radix.CmdAction + for _, t := range tags { + cmds = append(cmds, radix.FlatCmd(nil, "SADD", c.t(t), rKey)) + } + return c.pool.Do(radix.Pipeline(cmds...)) +} + +func (c *redis) InvalidateTags(tags ...string) error { + var rKeys, rTags []string + for _, t := range tags { + var keys []string + t = c.t(t) + err := c.pool.Do(radix.Cmd(&keys, "SMEMBERS", t)) + if err != nil { + return err + } + rKeys = append(rKeys, keys...) + rTags = append(rTags, t) + } + + var cmds []radix.CmdAction + for _, k := range rKeys { + cmds = append(cmds, radix.FlatCmd(nil, "DEL", k)) + for _, t := range rTags { + cmds = append(cmds, radix.FlatCmd(nil, "SREM", t, k)) + } + } + return c.pool.Do(radix.Pipeline(cmds...)) +} diff --git a/redis_test.go b/redis_test.go index 7ae34f2..8e7d3e7 100644 --- a/redis_test.go +++ b/redis_test.go @@ -1,9 +1,10 @@ package cachita import ( - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestNewRedisCache(t *testing.T) { @@ -65,3 +66,12 @@ func rc(t assert.TestingT) (c Cache) { isError(err, t) return } + +func TestRedis_Tag(t *testing.T) { + t.Parallel() + cacheTag(rc(t), t) +} + +func BenchmarkRedis_Tag(b *testing.B) { + benchmarkCacheTag(rc(b), b) +} diff --git a/sql.go b/sql.go index 84b21a1..08efa1e 100644 --- a/sql.go +++ b/sql.go @@ -2,6 +2,7 @@ package cachita import ( "database/sql" + "fmt" "strconv" "strings" "time" @@ -24,6 +25,11 @@ type row struct { ExpiredAt int64 } +type tagRow struct { + Id string + Keys string +} + func Sql(driverName, dataSourceName string) (Cache, error) { if sCache == nil { sqlDriver, err := sql.Open(driverName, dataSourceName) @@ -165,6 +171,10 @@ func (c *sqlCache) createTable() error { if err != nil { return err } + _, err = c.db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s_tags (id CHAR(32) NOT NULL PRIMARY KEY, keys TEXT NOT NULL)", c.tableName)) + if err != nil { + return err + } return nil } @@ -174,3 +184,76 @@ func (c *sqlCache) placeholder(index int) string { } return "?" } + +func (c *sqlCache) InvalidateMulti(keys ...string) error { + var ids []string + for _, key := range keys { + ids = append(ids, Id(key)) + } + _, err := c.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", c.tableName, c.placeholder(1)), ids) + return err +} + +func (c *sqlCache) Tag(key string, tags ...string) (err error) { + id := Id(key) + var r *tagRow + for _, t := range tags { + r, err = c.tagRow(Id(t)) + if err != nil && err != sql.ErrNoRows { + return + } + r.Keys += fmt.Sprintf(",%s", id) + var query string + if err == sql.ErrNoRows { + query = fmt.Sprintf("INSERT INTO %s_tags (keys, id) VALUES(%s, %s)", c.tableName, c.placeholder(1), c.placeholder(2)) + } else { + query = fmt.Sprintf("UPDATE %s_tags SET keys = %s WHERE id = %s ", c.tableName, c.placeholder(1), c.placeholder(2)) + } + _, err = c.db.Exec(query, r.Keys, r.Id) + if err != nil { + return + } + } + return nil +} + +func (c *sqlCache) tagRow(id string) (r *tagRow, err error) { + r = new(tagRow) + r.Id = id + query := fmt.Sprintf("SELECT keys FROM %s_tags WHERE id = %s", c.tableName, c.placeholder(1)) + err = c.db.QueryRow(query, r.Id).Scan(&r.Keys) + return +} + +func (c *sqlCache) InvalidateTags(tags ...string) (err error) { + var keys string + var r *tagRow + for _, t := range tags { + r, err = c.tagRow(Id(t)) + if err != nil && err != sql.ErrNoRows { + return + } + if err == sql.ErrNoRows || r == nil { + continue + } + keys += fmt.Sprintf(",%s", r.Keys) + } + s := sqlKeys(keys) + _, err = c.db.Exec(fmt.Sprintf("DELETE FROM %s WHERE id IN (%s)", c.tableName, s)) + + return +} +func sqlKeys(s string) string { + ids := strings.Split(s, ",") + var ( + l []string + r string + ) + for _, v := range ids { + if !inArr(l, v) && v != "," && v != "" { + l = append(l, v) + r += fmt.Sprintf("'%s',", v) + } + } + return strings.TrimRight(r, ",") +} diff --git a/sql_test.go b/sql_test.go index d5b45ee..f7503ef 100644 --- a/sql_test.go +++ b/sql_test.go @@ -2,10 +2,11 @@ package cachita import ( "database/sql" - _ "github.com/lib/pq" - "github.com/stretchr/testify/assert" "testing" "time" + + _ "github.com/lib/pq" + "github.com/stretchr/testify/assert" ) func TestNewSqlCache(t *testing.T) { @@ -70,3 +71,12 @@ func sc(t assert.TestingT) (c Cache) { isError(err, t) return } + +func TestSql_Tag(t *testing.T) { + t.Parallel() + cacheTag(sc(t), t) +} + +func BenchmarkSql_Tag(b *testing.B) { + benchmarkCacheTag(sc(b), b) +}