From d9b566e050126e018f226f25f14b8e77b2487a30 Mon Sep 17 00:00:00 2001 From: aPaladiychuk Date: Mon, 6 May 2024 17:47:49 +0300 Subject: [PATCH] invalidate llm model after user update it --- backend/api/handler/persona.go | 14 +++++++++++--- backend/core/ai/builder.go | 13 ++++++++++++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/backend/api/handler/persona.go b/backend/api/handler/persona.go index f03c8831..2986ed90 100644 --- a/backend/api/handler/persona.go +++ b/backend/api/handler/persona.go @@ -1,6 +1,7 @@ package handler import ( + "cognix.ch/api/v2/core/ai" "cognix.ch/api/v2/core/bll" "cognix.ch/api/v2/core/parameters" "cognix.ch/api/v2/core/security" @@ -13,10 +14,11 @@ import ( type PersonaHandler struct { personaBL bll.PersonaBL + aiBuilder *ai.Builder } -func NewPersonaHandler(personaBL bll.PersonaBL) *PersonaHandler { - return &PersonaHandler{personaBL: personaBL} +func NewPersonaHandler(personaBL bll.PersonaBL, aiBuilder *ai.Builder) *PersonaHandler { + return &PersonaHandler{personaBL: personaBL, aiBuilder: aiBuilder} } func (h *PersonaHandler) Mount(route *gin.Engine, authMiddleware gin.HandlerFunc) { handler := route.Group("/api/manage/personas") @@ -112,13 +114,16 @@ func (h *PersonaHandler) Update(c *gin.Context, identity *security.Identity) err return utils.ErrorBadRequest.New("id should be presented") } var param parameters.PersonaParam - if err := server.BindJsonAndValidate(c, ¶m); err != nil { + if err = server.BindJsonAndValidate(c, ¶m); err != nil { return err } persona, err := h.personaBL.Update(c.Request.Context(), id, identity.User, ¶m) if err != nil { return err } + if persona.LLM != nil { + h.aiBuilder.Invalidate(persona.LLM) + } return server.JsonResult(c, http.StatusOK, persona) } @@ -146,5 +151,8 @@ func (h *PersonaHandler) Archive(c *gin.Context, identity *security.Identity) er if err != nil { return err } + if persona.LLM != nil { + h.aiBuilder.Invalidate(persona.LLM) + } return server.JsonResult(c, http.StatusOK, persona) } diff --git a/backend/core/ai/builder.go b/backend/core/ai/builder.go index 68214a12..df601f78 100644 --- a/backend/core/ai/builder.go +++ b/backend/core/ai/builder.go @@ -1,9 +1,13 @@ package ai -import "cognix.ch/api/v2/core/model" +import ( + "cognix.ch/api/v2/core/model" + "sync" +) type Builder struct { clients map[int64]OpenAIClient + mx sync.Mutex } func NewBuilder() *Builder { @@ -11,6 +15,8 @@ func NewBuilder() *Builder { } func (b *Builder) New(llm *model.LLM) OpenAIClient { + b.mx.Lock() + defer b.mx.Unlock() if client, ok := b.clients[llm.ID.IntPart()]; ok { return client } @@ -18,3 +24,8 @@ func (b *Builder) New(llm *model.LLM) OpenAIClient { b.clients[llm.ID.IntPart()] = client return client } +func (b *Builder) Invalidate(llm *model.LLM) { + b.mx.Lock() + delete(b.clients, llm.ID.IntPart()) + b.mx.Unlock() +}