diff --git a/backend/api/handler/auth.go b/backend/api/handler/auth.go index e4071fd8..59d9852e 100644 --- a/backend/api/handler/auth.go +++ b/backend/api/handler/auth.go @@ -60,7 +60,7 @@ func (h *AuthHandler) Mount(route *gin.Engine, authMiddleware gin.HandlerFunc) { func (h *AuthHandler) SignIn(c *gin.Context) error { var param parameters.LoginParam if err := c.ShouldBindQuery(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong redirect url") + return utils.ErrorBadRequest.Wrap(err, "wrong redirect url") } zap.S().Infof("redirec to %s", param.RedirectURL) buf, err := json.Marshal(parameters.OAuthParam{Action: oauth.LoginState}) @@ -94,7 +94,7 @@ func (h *AuthHandler) Callback(c *gin.Context) error { buf, err := base64.URLEncoding.DecodeString(c.Query(oauth.StateNameGoogle)) if err != nil { - return utils.InvalidInput.Wrap(err, "wrong state") + return utils.ErrorBadRequest.Wrap(err, "wrong state") } var state parameters.OAuthParam if err = json.Unmarshal(buf, &state); err != nil { @@ -169,10 +169,10 @@ func (h *AuthHandler) Callback(c *gin.Context) error { //func (h *AuthHandler) Invite(c *gin.Context, identity *security.Identity) error { // var param parameters.InviteParam // if err := c.BindJSON(¶m); err != nil { -// return utils.InvalidInput.Wrap(err, "can not parse payload") +// return utils.ErrorBadRequest.Wrap(err, "can not parse payload") // } // if err := param.Validate(); err != nil { -// return utils.InvalidInput.Wrap(err, err.Error()) +// return utils.ErrorBadRequest.Wrap(err, err.Error()) // } // // url, err := h.authBL.Invite(c.Request.Context(), identity, ¶m) @@ -195,7 +195,7 @@ func (h *AuthHandler) Callback(c *gin.Context) error { // // // //key, err := base64.URLEncoding.DecodeString(param) // //if err != nil { -// // return utils.InvalidInput.Wrap(err, "wrong state") +// // return utils.ErrorBadRequest.Wrap(err, "wrong state") // //} // ////value, err := h.storage.Pull(string(key)) // // diff --git a/backend/api/handler/chat.go b/backend/api/handler/chat.go index 37567383..bc86369b 100644 --- a/backend/api/handler/chat.go +++ b/backend/api/handler/chat.go @@ -61,7 +61,7 @@ func (h *ChatHandler) GetSessions(c *gin.Context, identity *security.Identity) e func (h *ChatHandler) GetByID(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } session, err := h.chatBL.GetSessionByID(c.Request.Context(), identity.User, id) if err != nil { @@ -82,11 +82,8 @@ func (h *ChatHandler) GetByID(c *gin.Context, identity *security.Identity) error // @Router /chats/create-chat-session [post] func (h *ChatHandler) CreateSession(c *gin.Context, identity *security.Identity) error { var param parameters.CreateChatSession - if err := c.BindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") - } - if err := param.Validate(); err != nil { - return utils.InvalidInput.Wrap(err, err.Error()) + if err := server.BindJsonAndValidate(c, ¶m); err != nil { + return err } session, err := h.chatBL.CreateSession(c.Request.Context(), identity.User, ¶m) if err != nil { @@ -107,8 +104,8 @@ func (h *ChatHandler) CreateSession(c *gin.Context, identity *security.Identity) // @Router /chats/send-message [post] func (h *ChatHandler) SendMessage(c *gin.Context, identity *security.Identity) error { var param parameters.CreateChatMessageRequest - if err := c.BindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + if err := server.BindJsonAndValidate(c, ¶m); err != nil { + return err } assistant, err := h.chatBL.SendMessage(c, identity.User, ¶m) if err != nil { @@ -138,10 +135,10 @@ func (h *ChatHandler) SendMessage(c *gin.Context, identity *security.Identity) e func (h *ChatHandler) MessageFeedback(c *gin.Context, identity *security.Identity) error { var param parameters.MessageFeedbackParam if err := c.BindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + return utils.ErrorBadRequest.Wrap(err, "wrong payload") } if err := param.Validate(); err != nil { - return utils.InvalidInput.Wrap(err, err.Error()) + return utils.ErrorBadRequest.Wrap(err, err.Error()) } feedback, err := h.chatBL.FeedbackMessage(c, identity.User, param.ID.IntPart(), param.Vote == parameters.MessageFeedbackUpvote) if err != nil { diff --git a/backend/api/handler/connector.go b/backend/api/handler/connector.go index fbb36fbc..c5e610bf 100644 --- a/backend/api/handler/connector.go +++ b/backend/api/handler/connector.go @@ -72,7 +72,7 @@ func (h *ConnectorHandler) GetById(c *gin.Context) error { } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } connectors, err := h.connectorBL.GetByID(c.Request.Context(), identity.User, id) @@ -99,10 +99,10 @@ func (h *ConnectorHandler) Create(c *gin.Context) error { } var param parameters.CreateConnectorParam if err = c.BindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + return utils.ErrorBadRequest.Wrap(err, "wrong payload") } if err = param.Validate(); err != nil { - return utils.InvalidInput.Wrap(err, err.Error()) + return utils.ErrorBadRequest.Wrap(err, err.Error()) } connector, err := h.connectorBL.Create(c.Request.Context(), identity.User, ¶m) if err != nil { @@ -129,11 +129,11 @@ func (h *ConnectorHandler) Update(c *gin.Context) error { } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } var param parameters.UpdateConnectorParam if err = c.BindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + return utils.ErrorBadRequest.Wrap(err, "wrong payload") } connector, err := h.connectorBL.Update(c.Request.Context(), id, identity.User, ¶m) if err != nil { @@ -173,11 +173,11 @@ func (h *ConnectorHandler) GetSourceTypes(c *gin.Context, identity *security.Ide func (h *ConnectorHandler) Archive(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } action := c.Param("action") if !(action == ActionRestore || action == ActionDelete) { - return utils.InvalidInput.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) + return utils.ErrorBadRequest.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) } credential, err := h.connectorBL.Archive(c.Request.Context(), identity.User, id, action == ActionRestore) if err != nil { diff --git a/backend/api/handler/credential.go b/backend/api/handler/credential.go index 98c1c11e..fca489c0 100644 --- a/backend/api/handler/credential.go +++ b/backend/api/handler/credential.go @@ -46,7 +46,7 @@ func (h *CredentialHandler) GetAll(c *gin.Context) error { } var param parameters.GetAllCredentialsParam if err = c.BindQuery(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong parameters") + return utils.ErrorBadRequest.Wrap(err, "wrong parameters") } credentials, err := h.credentialBl.GetAll(c.Request.Context(), ident.User, ¶m) if err != nil { @@ -72,7 +72,7 @@ func (h *CredentialHandler) GetByID(c *gin.Context) error { } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } credential, err := h.credentialBl.GetByID(c.Request.Context(), ident.User, id) @@ -99,10 +99,10 @@ func (h *CredentialHandler) Create(c *gin.Context) error { } var param parameters.CreateCredentialParam if err = c.BindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + return utils.ErrorBadRequest.Wrap(err, "wrong payload") } if err = param.Validate(); err != nil { - return utils.InvalidInput.Wrap(err, err.Error()) + return utils.ErrorBadRequest.Wrap(err, err.Error()) } credential, err := h.credentialBl.Create(c.Request.Context(), ident.User, ¶m) if err != nil { @@ -129,11 +129,11 @@ func (h *CredentialHandler) Update(c *gin.Context) error { } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } var param parameters.UpdateCredentialParam if err = c.BindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + return utils.ErrorBadRequest.Wrap(err, "wrong payload") } credential, err := h.credentialBl.Update(c.Request.Context(), id, ident.User, ¶m) if err != nil { @@ -156,11 +156,11 @@ func (h *CredentialHandler) Update(c *gin.Context) error { func (h *CredentialHandler) Archive(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } action := c.Param("action") if !(action == ActionRestore || action == ActionDelete) { - return utils.InvalidInput.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) + return utils.ErrorBadRequest.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) } credential, err := h.credentialBl.Archive(c.Request.Context(), identity.User, id, action == ActionRestore) if err != nil { diff --git a/backend/api/handler/document_set.go b/backend/api/handler/document_set.go index f54dccc3..a9703f9b 100644 --- a/backend/api/handler/document_set.go +++ b/backend/api/handler/document_set.go @@ -44,7 +44,7 @@ func (h *DocumentSetHandler) Mount(router *gin.Engine, authMiddleware gin.Handle func (h *DocumentSetHandler) GetAll(c *gin.Context, identity *security.Identity) error { var param parameters.ArchivedParam if err := c.ShouldBindQuery(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "failed to bind query params") + return utils.ErrorBadRequest.Wrap(err, "failed to bind query params") } documentSets, err := h.documentSetBL.GetByUser(c.Request.Context(), identity.User, ¶m) if err != nil { @@ -66,7 +66,7 @@ func (h *DocumentSetHandler) GetAll(c *gin.Context, identity *security.Identity) func (h *DocumentSetHandler) GetByID(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } documentSet, err := h.documentSetBL.GetByID(c, identity.User, id) if err != nil { @@ -89,7 +89,7 @@ func (h *DocumentSetHandler) GetByID(c *gin.Context, identity *security.Identity func (h *DocumentSetHandler) Create(c *gin.Context, identity *security.Identity) error { var param parameters.DocumentSetParam if err := c.ShouldBindJSON(¶m); err != nil { - return utils.InvalidInput.New("invalid params") + return utils.ErrorBadRequest.New("invalid params") } documentSet, err := h.documentSetBL.Create(c.Request.Context(), identity.User, ¶m) if err != nil { @@ -112,12 +112,12 @@ func (h *DocumentSetHandler) Create(c *gin.Context, identity *security.Identity) func (h *DocumentSetHandler) Update(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } var param parameters.DocumentSetParam if err = c.ShouldBindJSON(¶m); err != nil { - return utils.InvalidInput.New("invalid params") + return utils.ErrorBadRequest.New("invalid params") } documentSet, err := h.documentSetBL.Update(c.Request.Context(), identity.User, id, ¶m) if err != nil { @@ -140,11 +140,11 @@ func (h *DocumentSetHandler) Update(c *gin.Context, identity *security.Identity) func (h *DocumentSetHandler) Delete(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } action := c.Param("action") if !(action == ActionRestore || action == ActionDelete) { - return utils.InvalidInput.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) + return utils.ErrorBadRequest.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) } var documentSet *model.DocumentSet switch action { diff --git a/backend/api/handler/embedding_model.go b/backend/api/handler/embedding_model.go index 5ef17477..e2650f8f 100644 --- a/backend/api/handler/embedding_model.go +++ b/backend/api/handler/embedding_model.go @@ -43,7 +43,7 @@ func (h *EmbeddingModelHandler) Mount(router *gin.Engine, authMiddleware gin.Han func (h *EmbeddingModelHandler) GetAll(c *gin.Context, identity *security.Identity) error { var param parameters.ArchivedParam if err := c.ShouldBindQuery(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong parameters") + return utils.ErrorBadRequest.Wrap(err, "wrong parameters") } result, err := h.embeddingModelBL.GetAll(c.Request.Context(), identity.User, ¶m) if err != nil { @@ -65,7 +65,7 @@ func (h *EmbeddingModelHandler) GetAll(c *gin.Context, identity *security.Identi func (h *EmbeddingModelHandler) GetByID(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } embeddingModel, err := h.embeddingModelBL.GetByID(c.Request.Context(), identity.User, id) if err != nil { @@ -87,7 +87,7 @@ func (h *EmbeddingModelHandler) GetByID(c *gin.Context, identity *security.Ident func (h *EmbeddingModelHandler) Create(c *gin.Context, identity *security.Identity) error { var param parameters.EmbeddingModelParam if err := c.ShouldBind(¶m); err != nil { - return utils.InvalidInput.New("invalid params") + return utils.ErrorBadRequest.New("invalid params") } embeddingModel, err := h.embeddingModelBL.Create(c.Request.Context(), identity.User, ¶m) if err != nil { @@ -110,11 +110,11 @@ func (h *EmbeddingModelHandler) Create(c *gin.Context, identity *security.Identi func (h *EmbeddingModelHandler) Update(c *gin.Context, identity *security.Identity) error { var param parameters.EmbeddingModelParam if err := c.ShouldBind(¶m); err != nil { - return utils.InvalidInput.New("invalid params") + return utils.ErrorBadRequest.New("invalid params") } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } embeddingModel, err := h.embeddingModelBL.Update(c.Request.Context(), identity.User, id, ¶m) if err != nil { @@ -137,11 +137,11 @@ func (h *EmbeddingModelHandler) Update(c *gin.Context, identity *security.Identi func (h *EmbeddingModelHandler) Delete(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } action := c.Param("action") if !(action == ActionRestore || action == ActionDelete) { - return utils.InvalidInput.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) + return utils.ErrorBadRequest.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) } var embedingModel *model.EmbeddingModel switch action { diff --git a/backend/api/handler/persona.go b/backend/api/handler/persona.go index dca33b09..f03c8831 100644 --- a/backend/api/handler/persona.go +++ b/backend/api/handler/persona.go @@ -41,7 +41,7 @@ func (h *PersonaHandler) Mount(route *gin.Engine, authMiddleware gin.HandlerFunc func (h *PersonaHandler) GetAll(c *gin.Context, identity *security.Identity) error { var param parameters.ArchivedParam if err := c.BindQuery(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong parameters") + return utils.ErrorBadRequest.Wrap(err, "wrong parameters") } personas, err := h.personaBL.GetAll(c.Request.Context(), identity.User, param.Archived) if err != nil { @@ -64,7 +64,7 @@ func (h *PersonaHandler) GetAll(c *gin.Context, identity *security.Identity) err func (h *PersonaHandler) GetByID(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } persona, err := h.personaBL.GetByID(c.Request.Context(), identity.User, id) if err != nil { @@ -85,8 +85,8 @@ func (h *PersonaHandler) GetByID(c *gin.Context, identity *security.Identity) er // @Router /manage/personas [post] func (h *PersonaHandler) Create(c *gin.Context, identity *security.Identity) error { var param parameters.PersonaParam - if err := c.ShouldBindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + if err := server.BindJsonAndValidate(c, ¶m); err != nil { + return err } persona, err := h.personaBL.Create(c.Request.Context(), identity.User, ¶m) if err != nil { @@ -109,11 +109,11 @@ func (h *PersonaHandler) Create(c *gin.Context, identity *security.Identity) err func (h *PersonaHandler) Update(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } var param parameters.PersonaParam - if err = c.ShouldBindJSON(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong payload") + if err := server.BindJsonAndValidate(c, ¶m); err != nil { + return err } persona, err := h.personaBL.Update(c.Request.Context(), id, identity.User, ¶m) if err != nil { @@ -136,11 +136,11 @@ func (h *PersonaHandler) Update(c *gin.Context, identity *security.Identity) err func (h *PersonaHandler) Archive(c *gin.Context, identity *security.Identity) error { id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil || id == 0 { - return utils.InvalidInput.New("id should be presented") + return utils.ErrorBadRequest.New("id should be presented") } action := c.Param("action") if !(action == ActionRestore || action == ActionDelete) { - return utils.InvalidInput.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) + return utils.ErrorBadRequest.Newf("invalid action: should be %s or %s", ActionRestore, ActionDelete) } persona, err := h.personaBL.Archive(c.Request.Context(), identity.User, id, action == ActionRestore) if err != nil { diff --git a/backend/api/handler/tenant.go b/backend/api/handler/tenant.go index 0241761e..a2f474f0 100644 --- a/backend/api/handler/tenant.go +++ b/backend/api/handler/tenant.go @@ -66,10 +66,10 @@ func (h *TenantHandler) AddUser(c *gin.Context, identity *security.Identity) err var param parameters.AddUserParam if err := c.ShouldBind(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong parameter") + return utils.ErrorBadRequest.Wrap(err, "wrong parameter") } if err := param.Validate(); err != nil { - return utils.InvalidInput.New(err.Error()) + return utils.ErrorBadRequest.New(err.Error()) } user, err := h.tenantBL.AddUser(c.Request.Context(), identity.User, param.Email, param.Role) if err != nil { @@ -100,10 +100,10 @@ func (h *TenantHandler) EditUser(c *gin.Context, identity *security.Identity) er } var param parameters.EditUserParam if err := c.ShouldBind(¶m); err != nil { - return utils.InvalidInput.Wrap(err, "wrong parameter") + return utils.ErrorBadRequest.Wrap(err, "wrong parameter") } if err := param.Validate(); err != nil { - return utils.InvalidInput.New(err.Error()) + return utils.ErrorBadRequest.New(err.Error()) } user, err := h.tenantBL.UpdateUser(c.Request.Context(), identity.User, userID, param.Role) if err != nil { diff --git a/backend/api/model.go b/backend/api/model.go index 35aa8b74..4057cb66 100644 --- a/backend/api/model.go +++ b/backend/api/model.go @@ -5,6 +5,7 @@ import ( "cognix.ch/api/v2/core/oauth" "cognix.ch/api/v2/core/repository" "cognix.ch/api/v2/core/server" + "cognix.ch/api/v2/core/utils" "github.com/caarlos0/env/v10" "github.com/gin-gonic/gin" "go.uber.org/fx" @@ -46,6 +47,7 @@ func ReadConfig() (*Config, error) { if err != nil { return nil, err } + utils.InitLogger(cfg.Debug) return cfg, nil } diff --git a/backend/core/ai/open-ai.go b/backend/core/ai/open-ai.go index 98d45b25..2f3722f2 100644 --- a/backend/core/ai/open-ai.go +++ b/backend/core/ai/open-ai.go @@ -6,6 +6,10 @@ import ( openai "github.com/sashabaranov/go-openai" ) +var SupportedModels = map[string]bool{ + openai.GPT3Dot5Turbo: true, +} + type ( Response struct { Message string @@ -47,6 +51,7 @@ func (o *openAIClient) Request(ctx context.Context, message string) (*Response, } func NewOpenAIClient(llm *model.LLM) OpenAIClient { + return &openAIClient{ client: openai.NewClient(llm.ApiKey), modelID: llm.ModelID, diff --git a/backend/core/bll/auth.go b/backend/core/bll/auth.go index 949ff3d8..91850cbf 100644 --- a/backend/core/bll/auth.go +++ b/backend/core/bll/auth.go @@ -46,7 +46,7 @@ func (a *authBL) SignUp(ctx context.Context, identity *oauth.IdentityResponse) ( return nil, err } if exists { - return nil, utils.InvalidInput.New("user already exists") + return nil, utils.ErrorBadRequest.New("user already exists") } user := model.User{ ID: uuid.New(), @@ -73,7 +73,7 @@ func (a *authBL) SignUp(ctx context.Context, identity *oauth.IdentityResponse) ( // return "", err // } // if exists { -// return "", utils.InvalidInput.New("user already registered.") +// return "", utils.ErrorBadRequest.New("user already registered.") // } // //buf, err := json.Marshal(parameters.OAuthParam{Action: oauth.InviteState, // // Role: param.Role, diff --git a/backend/core/bll/chat.go b/backend/core/bll/chat.go index c4332cc9..a48e3369 100644 --- a/backend/core/bll/chat.go +++ b/backend/core/bll/chat.go @@ -45,7 +45,7 @@ func (b *chatBL) FeedbackMessage(ctx *gin.Context, user *model.User, id int64, v } func (b *chatBL) SendMessage(ctx *gin.Context, user *model.User, param *parameters.CreateChatMessageRequest) (*responder.Manager, error) { - chatSession, err := b.chatRepo.GetSessionByID(ctx.Request.Context(), user.ID, param.ChatSessionId.IntPart()) + chatSession, err := b.chatRepo.GetSessionByID(ctx.Request.Context(), user.ID, param.ChatSessionID.IntPart()) if err != nil { return nil, err } @@ -81,7 +81,7 @@ func (b *chatBL) CreateSession(ctx context.Context, user *model.User, param *par return nil, err } if !exists { - return nil, utils.InvalidInput.New("persona is not exists") + return nil, utils.ErrorBadRequest.New("persona is not exists") } session := model.ChatSession{ UserID: user.ID, diff --git a/backend/core/bll/connector.go b/backend/core/bll/connector.go index c72faec0..3e6f7bd1 100644 --- a/backend/core/bll/connector.go +++ b/backend/core/bll/connector.go @@ -73,7 +73,7 @@ func (c *connectorBL) Create(ctx context.Context, user *model.User, param *param return nil, err } if cred.Source != model.SourceType(param.Source) { - return nil, utils.InvalidInput.New("wrong credential source") + return nil, utils.ErrorBadRequest.New("wrong credential source") } conn.CredentialID = decimal.NewNullDecimal(cred.ID) } @@ -95,7 +95,7 @@ func (c *connectorBL) Update(ctx context.Context, id int64, user *model.User, pa return nil, err } if cred.Source != conn.Source { - return nil, utils.InvalidInput.New("wrong credential source") + return nil, utils.ErrorBadRequest.New("wrong credential source") } } conn.ConnectorSpecificConfig = param.ConnectorSpecificConfig diff --git a/backend/core/bll/credential.go b/backend/core/bll/credential.go index 322f68f1..a521d9e0 100644 --- a/backend/core/bll/credential.go +++ b/backend/core/bll/credential.go @@ -32,7 +32,7 @@ func (c *credentialBL) Archive(ctx context.Context, user *model.User, id int64, return nil, utils.ErrorPermission.New("you do not have permission") } if !restore && len(credential.Connectors) > 0 { - return nil, utils.InvalidInput.New("there are still associated connectors") + return nil, utils.ErrorBadRequest.New("there are still associated connectors") } if !restore { credential.DeletedDate = pg.NullTime{time.Now().UTC()} diff --git a/backend/core/bll/persona.go b/backend/core/bll/persona.go index ffbd7ceb..50318d13 100644 --- a/backend/core/bll/persona.go +++ b/backend/core/bll/persona.go @@ -39,7 +39,7 @@ func (b *personaBL) Archive(ctx context.Context, user *model.User, id int64, res return nil, err } if len(persona.ChatSessions) > 0 { - return nil, utils.InvalidInput.New("persona is used in chat sessions") + return nil, utils.ErrorBadRequest.New("persona is used in chat sessions") } if restore { persona.DeletedDate = pg.NullTime{} @@ -57,7 +57,7 @@ func (b *personaBL) Create(ctx context.Context, user *model.User, param *paramet starterMessages, err := json.Marshal(param.StarterMessages) if err != nil { - return nil, utils.InvalidInput.Wrap(err, "fail to marshal starter messages") + return nil, utils.ErrorBadRequest.Wrap(err, "fail to marshal starter messages") } persona := model.Persona{ Name: param.Name, @@ -101,7 +101,7 @@ func (b *personaBL) Update(ctx context.Context, id int64, user *model.User, para } starterMessages, err := json.Marshal(param.StarterMessages) if err != nil { - return nil, utils.InvalidInput.Wrap(err, "fail to marshal starter messages") + return nil, utils.ErrorBadRequest.Wrap(err, "fail to marshal starter messages") } persona.Name = param.Name persona.Description = param.Description diff --git a/backend/core/bll/tenant.go b/backend/core/bll/tenant.go index 2fa6a2a4..a8f8b51b 100644 --- a/backend/core/bll/tenant.go +++ b/backend/core/bll/tenant.go @@ -37,7 +37,7 @@ func (b *tenantBL) AddUser(ctx context.Context, user *model.User, email, role st return nil, err } if exists { - return nil, utils.InvalidInput.New("user already exists") + return nil, utils.ErrorBadRequest.New("user already exists") } newUser := &model.User{ ID: uuid.New(), diff --git a/backend/core/model/llm.go b/backend/core/model/llm.go index d7075c4c..431e9412 100644 --- a/backend/core/model/llm.go +++ b/backend/core/model/llm.go @@ -13,7 +13,7 @@ type LLM struct { Name string `json:"name,omitempty"` ModelID string `json:"model_id,omitempty"` TenantID uuid.UUID `json:"tenant_id,omitempty"` - Url string `json:"url,omitempty"` + Url string `json:"url,omitempty" pg:",use_zero"` ApiKey string `json:"api_key"` Endpoint string `json:"endpoint,omitempty"` CreatedDate time.Time `json:"created_date,omitempty"` diff --git a/backend/core/model/prompt.go b/backend/core/model/prompt.go index 3636d1b9..93a89797 100644 --- a/backend/core/model/prompt.go +++ b/backend/core/model/prompt.go @@ -13,7 +13,7 @@ type Prompt struct { PersonaID decimal.Decimal `json:"persona_id,omitempty"` UserID uuid.UUID `json:"user_id,omitempty"` Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` + Description string `json:"description,omitempty" pg:",use_zero"` SystemPrompt string `json:"system_prompt,omitempty" pg:",use_zero"` TaskPrompt string `json:"task_prompt,omitempty" pg:",use_zero"` IncludeCitations bool `json:"include_citations,omitempty" pg:",use_zero"` diff --git a/backend/core/parameters/chat.go b/backend/core/parameters/chat.go index 48e52daa..d6657d07 100644 --- a/backend/core/parameters/chat.go +++ b/backend/core/parameters/chat.go @@ -2,6 +2,7 @@ package parameters import ( "cognix.ch/api/v2/core/model" + "fmt" validation "github.com/go-ozzo/ozzo-validation/v4" "github.com/shopspring/decimal" "time" @@ -20,19 +21,31 @@ type CreateChatSession struct { func (v CreateChatSession) Validate() error { return validation.ValidateStruct(&v, - validation.Field(&v.PersonaID, validation.Required), + validation.Field(&v.PersonaID, validation.Required, + validation.By(func(value interface{}) error { + if v.PersonaID.IsZero() { + return fmt.Errorf("persona_id is zero") + } + return nil + })), ) } type CreateChatMessageRequest struct { - ChatSessionId decimal.Decimal `json:"chat_session_id,omitempty"` - ParentMessageId decimal.Decimal `json:"parent_message_id,omitempty"` - Message string `json:"message,omitempty"` - PromptId decimal.Decimal `json:"prompt_id,omitempty"` - SearchDocIds []decimal.Decimal `json:"search_doc_ids,omitempty"` - RetrievalOptions RetrievalDetails `json:"retrieval_options,omitempty"` - QueryOverride string `json:"query_override,omitempty"` - NoAiAnswer bool `json:"no_ai_answer,omitempty"` + ChatSessionID decimal.Decimal `json:"chat_session_id,omitempty"` + ParentMessageID decimal.Decimal `json:"parent_message_id,omitempty"` + Message string `json:"message,omitempty"` + PromptID decimal.Decimal `json:"prompt_id,omitempty"` + SearchDocIds []decimal.Decimal `json:"search_doc_ids,omitempty"` + //RetrievalOptions RetrievalDetails `json:"retrieval_options,omitempty"` + QueryOverride string `json:"query_override,omitempty"` + NoAiAnswer bool `json:"no_ai_answer,omitempty"` +} + +func (v CreateChatMessageRequest) Validate() error { + return validation.ValidateStruct(&v, + validation.Field(&v.ChatSessionID, validation.Required), + validation.Field(&v.Message, validation.Required)) } type RetrievalDetails struct { diff --git a/backend/core/parameters/persona.go b/backend/core/parameters/persona.go index 4eba7cfb..e708d3c9 100644 --- a/backend/core/parameters/persona.go +++ b/backend/core/parameters/persona.go @@ -1,6 +1,10 @@ package parameters -import validation "github.com/go-ozzo/ozzo-validation/v4" +import ( + "cognix.ch/api/v2/core/ai" + "fmt" + validation "github.com/go-ozzo/ozzo-validation/v4" +) type PersonaParam struct { Name string `json:"name"` @@ -22,6 +26,13 @@ type StarterMessage struct { func (v PersonaParam) Validate() error { return validation.ValidateStruct(&v, - validation.Field(&v.ModelID, validation.Required), + validation.Field(&v.Name, validation.Required), + validation.Field(&v.ModelID, validation.Required, + validation.By(func(value interface{}) error { + if _, ok := ai.SupportedModels[v.ModelID]; !ok { + return fmt.Errorf("model %s not supported", v.ModelID) + } + return nil + })), validation.Field(&v.APIKey, validation.Required)) } diff --git a/backend/core/server/auth_middleware.go b/backend/core/server/auth_middleware.go index cdbac828..0605b694 100644 --- a/backend/core/server/auth_middleware.go +++ b/backend/core/server/auth_middleware.go @@ -6,7 +6,6 @@ import ( "cognix.ch/api/v2/core/utils" "context" "github.com/gin-gonic/gin" - "net/http" "strings" "time" ) @@ -29,7 +28,7 @@ func (m *AuthMiddleware) RequireAuth(c *gin.Context) { //Get the bearer Token tokenString := c.GetHeader("Authorization") if tokenString == "" { - c.JSON(http.StatusUnauthorized, gin.H{"status": http.StatusUnauthorized, "message": "Authorization token is required"}) + handleError(c, utils.ErrorUnauthorized.New("Authorization token is required")) c.Abort() return } @@ -37,26 +36,26 @@ func (m *AuthMiddleware) RequireAuth(c *gin.Context) { extractedToken := strings.Split(tokenString, "Bearer ") if len(extractedToken) != 2 { - c.JSON(http.StatusBadRequest, gin.H{"status": http.StatusBadRequest, "message": "Incorrect format of authorization token"}) + handleError(c, utils.ErrorBadRequest.New("Incorrect format of authorization token")) c.Abort() return } claims, err := m.jwtService.ParseAndValidate(strings.TrimSpace(extractedToken[1])) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"status": http.StatusBadRequest, "message": "Token is not valid"}) + handleError(c, utils.ErrorBadRequest.New("Token is not valid")) c.Abort() return } if claims.ExpiresAt != 0 && time.Now().Unix() > claims.ExpiresAt { - c.AbortWithStatus(http.StatusUnauthorized) + handleError(c, utils.ErrorUnauthorized.New("token expired")) c.Abort() return } if claims.User, err = m.userRepo.GetByIDAndTenantID(c.Request.Context(), claims.User.ID, claims.User.TenantID); err != nil { - c.AbortWithStatus(http.StatusUnauthorized) + handleError(c, utils.ErrorUnauthorized.Wrap(err, "wrong user")) c.Abort() return } diff --git a/backend/core/server/router.go b/backend/core/server/router.go index 7fb7bad5..33a76d27 100644 --- a/backend/core/server/router.go +++ b/backend/core/server/router.go @@ -3,7 +3,8 @@ package server import ( "cognix.ch/api/v2/core/security" "cognix.ch/api/v2/core/utils" - "github.com/uptrace/opentelemetry-go-extra/otelzap" + validation "github.com/go-ozzo/ozzo-validation/v4" + "go.uber.org/zap" "net/http" "github.com/gin-gonic/gin" @@ -47,7 +48,6 @@ func handleError(c *gin.Context, err error) { ew.Code = http.StatusInternalServerError ew.Message = err.Error() } - otelzap.S().Errorf("[%s] %v", ew.Message, ew.Original) errResp := JsonErrorResponse{ Status: int(ew.Code), Error: ew.Message, @@ -55,6 +55,7 @@ func handleError(c *gin.Context, err error) { if ew.Original != nil { errResp.OriginalError = ew.Original.Error() } + zap.S().Errorf("[%s] %v", ew.Message, ew.Original) c.JSON(int(ew.Code), errResp) } } @@ -67,3 +68,15 @@ func JsonResult(c *gin.Context, status int, data interface{}) error { }) return nil } + +func BindJsonAndValidate(c *gin.Context, data interface{}) error { + if err := c.BindJSON(data); err != nil { + return utils.ErrorBadRequest.Wrap(err, "wrong payload") + } + if vl, ok := data.(validation.Validatable); ok { + if err := vl.Validate(); err != nil { + return utils.ErrorBadRequest.New(err.Error()) + } + } + return nil +} diff --git a/backend/core/utils/errors.go b/backend/core/utils/errors.go index 0279872a..e899d40f 100644 --- a/backend/core/utils/errors.go +++ b/backend/core/utils/errors.go @@ -13,10 +13,11 @@ type Errors struct { type ErrorWrap int const ( - ErrorPermission ErrorWrap = http.StatusForbidden - NotFound ErrorWrap = http.StatusNotFound - Internal ErrorWrap = http.StatusInternalServerError - InvalidInput ErrorWrap = http.StatusBadRequest + ErrorPermission ErrorWrap = http.StatusForbidden + NotFound ErrorWrap = http.StatusNotFound + Internal ErrorWrap = http.StatusInternalServerError + ErrorBadRequest ErrorWrap = http.StatusBadRequest + ErrorUnauthorized ErrorWrap = http.StatusUnauthorized ) func (e Errors) Error() string {