Skip to content

Commit

Permalink
feat(app): add redis cache to go server
Browse files Browse the repository at this point in the history
  • Loading branch information
ramchaik committed Aug 27, 2024
1 parent cca51c1 commit f330818
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ npm-debug.log*
yarn-debug.log*
yarn-error.log*

**/tmp/**
**/llm_env/**
5 changes: 4 additions & 1 deletion app/cmd/api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"log"

"nous/internal/cache"
"nous/internal/config"
"nous/internal/database"
"nous/internal/llmclient"
Expand All @@ -21,7 +22,9 @@ func main() {
}
defer db.Close()

llmClient := llmclient.NewClient(cfg.LLMBaseURL, nil)
redisCache := cache.NewRedisCache(cfg.RedisAddr)

llmClient := llmclient.NewClient(cfg.LLMBaseURL, nil, redisCache)

srv := server.New(cfg, db, llmClient)
if err := srv.Run(cfg.ServerAddr); err != nil {
Expand Down
3 changes: 3 additions & 0 deletions app/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ go 1.22.4
require (
github.com/bytedance/sonic v1.12.1 // indirect
github.com/bytedance/sonic/loader v0.2.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gabriel-vasile/mimetype v1.4.5 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/gin-gonic/gin v1.10.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.22.0 // indirect
github.com/go-redis/redis/v8 v8.11.5 // indirect
github.com/goccy/go-json v0.10.3 // indirect
github.com/golang-migrate/migrate/v4 v4.17.1 // indirect
github.com/google/uuid v1.6.0 // indirect
Expand Down
6 changes: 6 additions & 0 deletions app/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@ github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKz
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM=
github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/gabriel-vasile/mimetype v1.4.5 h1:J7wGKdGu33ocBOhGy0z653k/lFKLFDPJMG8Gql0kxn4=
github.com/gabriel-vasile/mimetype v1.4.5/go.mod h1:ibHel+/kbxn9x2407k1izTA1S81ku1z/DlgOW2QE0M4=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
Expand All @@ -21,6 +25,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.22.0 h1:k6HsTZ0sTnROkhS//R0O+55JgM8C4Bx7ia+JlgcnOao=
github.com/go-playground/validator/v10 v10.22.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-migrate/migrate/v4 v4.17.1 h1:4zQ6iqL6t6AiItphxJctQb3cFqWiSpMnX7wLTPnnYO4=
Expand Down
73 changes: 73 additions & 0 deletions app/internal/cache/redis.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package cache

import (
"bytes"
"compress/gzip"
"context"
"crypto/sha256"
"encoding/hex"
"io"
"time"

"github.com/go-redis/redis/v8"
)

type Cacher interface {
Get(ctx context.Context, key string) ([]byte, error)
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
GetCompressed(ctx context.Context, key string) ([]byte, error)
SetCompressed(ctx context.Context, key string, value []byte, expiration time.Duration) error
HashKey(key string) string
}

type RedisCache struct {
client *redis.Client
}

func NewRedisCache(addr string) *RedisCache {
client := redis.NewClient(&redis.Options{
Addr: addr,
})
return &RedisCache{client: client}
}

func (rc *RedisCache) Get(ctx context.Context, key string) ([]byte, error) {
return rc.client.Get(ctx, key).Bytes()
}

func (rc *RedisCache) Set(ctx context.Context, key string, value []byte, expiration time.Duration) error {
return rc.client.Set(ctx, key, value, expiration).Err()
}

func (rc *RedisCache) GetCompressed(ctx context.Context, key string) ([]byte, error) {
compressed, err := rc.Get(ctx, key)
if err != nil {
return nil, err
}

reader, err := gzip.NewReader(bytes.NewReader(compressed))
if err != nil {
return nil, err
}
defer reader.Close()

return io.ReadAll(reader)
}

func (rc *RedisCache) SetCompressed(ctx context.Context, key string, value []byte, expiration time.Duration) error {
var compressedBuf bytes.Buffer
gzipWriter := gzip.NewWriter(&compressedBuf)
_, err := gzipWriter.Write(value)
if err != nil {
return err
}
gzipWriter.Close()

return rc.Set(ctx, key, compressedBuf.Bytes(), expiration)
}

func (rc *RedisCache) HashKey(key string) string {
hasher := sha256.New()
hasher.Write([]byte(key))
return hex.EncodeToString(hasher.Sum(nil))
}
2 changes: 2 additions & 0 deletions app/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ type Config struct {
ServerAddr string
DatabasePath string
LLMBaseURL string
RedisAddr string
}

func Load() (*Config, error) {
Expand All @@ -26,5 +27,6 @@ func Load() (*Config, error) {
ServerAddr: ":8080",
DatabasePath: "./nous.db",
LLMBaseURL: "http://localhost:5000",
RedisAddr: "localhost:6379",
}, nil
}
2 changes: 1 addition & 1 deletion app/internal/handlers/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (h *ChatAPIHandler) Predict(c *gin.Context) {
return
}

predictResp, err := h.llmClient.Predict(request.Query)
predictResp, err := h.llmClient.Predict(c.Request.Context(), request.Query)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get prediction: " + err.Error()})
return
Expand Down
35 changes: 30 additions & 5 deletions app/internal/llmclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package llmclient

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"nous/internal/cache"
"time"
)

type LLMClient interface {
Predict(question string) (*PredictResponse, error)
Predict(ctx context.Context, question string) (*PredictResponse, error)
}

type HTTPClient interface {
Expand All @@ -18,7 +21,8 @@ type HTTPClient interface {

type Client struct {
BaseURL string
HTTPClient *http.Client
HTTPClient HTTPClient
Cache cache.Cacher
}

type PredictRequest struct {
Expand All @@ -30,17 +34,32 @@ type PredictResponse struct {
Steps []string `json:"steps"`
}

func NewClient(baseURL string, httpClient HTTPClient) LLMClient {
func NewClient(baseURL string, httpClient HTTPClient, cache cache.Cacher) LLMClient {
if httpClient == nil {
httpClient = &http.Client{}
}
return &Client{
BaseURL: baseURL,
HTTPClient: &http.Client{},
Cache: cache,
}
}

func (c *Client) Predict(question string) (*PredictResponse, error) {
func (c *Client) Predict(ctx context.Context, question string) (*PredictResponse, error) {
// Hash the cache key
cacheKey := c.Cache.HashKey(fmt.Sprintf("predict:%s", question))

// Try to get from cache
cachedResponse, err := c.Cache.GetCompressed(ctx, cacheKey)
if err == nil {
var predictResp PredictResponse
err = json.Unmarshal(cachedResponse, &predictResp)
if err == nil {
return &predictResp, nil
}
}

// Cache miss, proceed with API call
reqBody := PredictRequest{
Question: question,
}
Expand All @@ -49,7 +68,7 @@ func (c *Client) Predict(question string) (*PredictResponse, error) {
return nil, fmt.Errorf("error marshalling request body: %v", err)
}

req, err := http.NewRequest("POST", c.BaseURL+"/predict", bytes.NewBuffer(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", c.BaseURL+"/predict", bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
Expand All @@ -76,5 +95,11 @@ func (c *Client) Predict(question string) (*PredictResponse, error) {
return nil, fmt.Errorf("error unmarshaling response body: %v", err)
}

// Cache the response
jsonResponse, err := json.Marshal(predictResp)
if err == nil {
c.Cache.SetCompressed(ctx, cacheKey, jsonResponse, time.Hour) // Cache for 1 hour
}

return &predictResp, nil
}
4 changes: 2 additions & 2 deletions app/internal/ui/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (h *ChatUIHandler) RenderChatPage(c *gin.Context) {
return
}

predictResp, err := h.llmClient.Predict(query)
predictResp, err := h.llmClient.Predict(c.Request.Context(), query)
if err != nil {
c.HTML(http.StatusInternalServerError, "error.html", gin.H{"error": "Failed to get prediction: " + err.Error()})
return
Expand All @@ -47,7 +47,7 @@ func (h *ChatUIHandler) HandleChatMessage(c *gin.Context) {
return
}

predictResp, err := h.llmClient.Predict(query)
predictResp, err := h.llmClient.Predict(c.Request.Context(), query)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to get prediction: "+err.Error())
return
Expand Down

0 comments on commit f330818

Please sign in to comment.