From 1e8a3d941ba5cfef2c478dd5bac4e4a4b4d67830 Mon Sep 17 00:00:00 2001 From: Hugues Chocart Date: Tue, 23 Jul 2024 11:15:40 +0100 Subject: [PATCH] feat: eval performances (#436) Co-authored-by: Vince Loewe --- .github/workflows/build-push-deploy.yml | 2 +- ops | 2 +- package-lock.json | 24 +++ packages/backend/package.json | 1 + packages/backend/src/api/v1/evaluator.ts | 6 +- packages/backend/src/api/v1/projects/index.ts | 99 +++++++-- packages/backend/src/api/v1/runs/index.ts | 185 +++++++++------- packages/backend/src/api/v1/views.ts | 6 +- packages/backend/src/checks/index.ts | 18 +- packages/backend/src/evaluators/assertion.ts | 24 +++ packages/backend/src/evaluators/index.ts | 4 +- packages/backend/src/evaluators/language.ts | 1 - .../evaluators/{assert.ts => old-assert.ts} | 1 + packages/backend/src/evaluators/pii.ts | 11 +- packages/backend/src/evaluators/topics.ts | 82 +++++-- packages/backend/src/evaluators/toxicity.ts | 68 +++++- .../backend/src/jobs/realtime-evaluators.ts | 4 +- packages/backend/src/utils/calcCost.ts | 14 +- packages/backend/src/utils/ingest.ts | 2 - packages/db/0026.sql | 1 + .../components/SmartViewer/HighlightPii.tsx | 51 +++++ .../components/SmartViewer/Message.tsx | 104 +++++++-- .../components/SmartViewer/MessageViewer.tsx | 2 +- .../components/SmartViewer/RenderJson.tsx | 14 +- .../components/SmartViewer/index.module.css | 13 ++ .../frontend/components/SmartViewer/index.tsx | 1 + .../frontend/components/analytics/BarList.tsx | 2 +- .../frontend/components/blocks/RunChat.tsx | 8 +- .../components/checks/ChecksInputs.tsx | 6 +- .../frontend/components/checks/Picker.tsx | 12 +- .../frontend/components/layout/Sidebar.tsx | 16 +- packages/frontend/package.json | 3 +- packages/frontend/pages/evaluations/new.tsx | 8 - .../pages/evaluations/realtime/index.tsx | 135 +++++++----- .../pages/evaluations/realtime/new.tsx | 13 +- packages/frontend/pages/logs/index.tsx | 141 ++++-------- packages/frontend/pages/settings/index.tsx | 195 +++++++++++------ packages/frontend/pages/settings/models.tsx | 14 +- packages/frontend/utils/colors.ts | 15 ++ .../frontend/utils/dataHooks/evaluators.ts | 3 +- packages/frontend/utils/dataHooks/index.ts | 51 ++++- packages/frontend/utils/dataHooks/views.ts | 10 +- packages/frontend/utils/enrichment.tsx | 187 +++++++++------- packages/frontend/utils/evaluators.ts | 200 +++--------------- packages/frontend/utils/hooks.ts | 36 ++-- packages/ml/README.md | 21 -- packages/ml/lang.py | 21 -- packages/ml/main.py | 42 ---- packages/ml/pii.py | 130 ------------ packages/ml/requirements.txt | 7 - packages/ml/toxicity.py | 35 --- packages/shared/checks/index.ts | 79 +++---- packages/shared/checks/serialize.ts | 48 +++-- packages/shared/evaluators/index.ts | 7 +- 54 files changed, 1159 insertions(+), 1026 deletions(-) create mode 100644 packages/backend/src/evaluators/assertion.ts rename packages/backend/src/evaluators/{assert.ts => old-assert.ts} (97%) create mode 100644 packages/db/0026.sql create mode 100644 packages/frontend/components/SmartViewer/HighlightPii.tsx delete mode 100644 packages/ml/README.md delete mode 100644 packages/ml/lang.py delete mode 100644 packages/ml/main.py delete mode 100644 packages/ml/pii.py delete mode 100644 packages/ml/requirements.txt delete mode 100644 packages/ml/toxicity.py diff --git a/.github/workflows/build-push-deploy.yml b/.github/workflows/build-push-deploy.yml index ceeb583e..eea55cba 100644 --- a/.github/workflows/build-push-deploy.yml +++ b/.github/workflows/build-push-deploy.yml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - service: [backend, radar, ml] + service: [backend, realtime-evaluators, ml] steps: - name: Check out the private Ops repo uses: actions/checkout@v4 diff --git a/ops b/ops index 0d1f1379..eb7857ed 160000 --- a/ops +++ b/ops @@ -1 +1 @@ -Subproject commit 0d1f1379e582bcc73eac7017b413471a069136eb +Subproject commit eb7857ed8163e45c576cb5a2b3f6bc822ff82177 diff --git a/package-lock.json b/package-lock.json index 7256f996..051028e3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -5081,6 +5081,11 @@ "node": ">= 0.4" } }, + "node_modules/highlight-words-core": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/highlight-words-core/-/highlight-words-core-1.2.2.tgz", + "integrity": "sha512-BXUKIkUuh6cmmxzi5OIbUJxrG8OAk2MqoL1DtO3Wo9D2faJg2ph5ntyuQeLqaHJmzER6H5tllCDA9ZnNe9BVGg==" + }, "node_modules/http-assert": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/http-assert/-/http-assert-1.5.0.tgz", @@ -5874,6 +5879,11 @@ "node": ">= 0.6" } }, + "node_modules/memoize-one": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/memoize-one/-/memoize-one-4.0.3.tgz", + "integrity": "sha512-QmpUu4KqDmX0plH4u+tf0riMc1KHE1+lw95cMrLlXQAFOx/xnBtwhZ52XJxd9X2O6kwKBqX32kmhbhlobD0cuw==" + }, "node_modules/merge-stream": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/merge-stream/-/merge-stream-2.0.0.tgz", @@ -7263,6 +7273,19 @@ "react": "^18.3.1" } }, + "node_modules/react-highlight-words": { + "version": "0.20.0", + "resolved": "https://registry.npmjs.org/react-highlight-words/-/react-highlight-words-0.20.0.tgz", + "integrity": "sha512-asCxy+jCehDVhusNmCBoxDf2mm1AJ//D+EzDx1m5K7EqsMBIHdZ5G4LdwbSEXqZq1Ros0G0UySWmAtntSph7XA==", + "dependencies": { + "highlight-words-core": "^1.2.0", + "memoize-one": "^4.0.0", + "prop-types": "^15.5.8" + }, + "peerDependencies": { + "react": "^0.14.0 || ^15.0.0 || ^16.0.0-0 || ^17.0.0-0 || ^18.0.0-0" + } + }, "node_modules/react-is": { "version": "16.13.1", "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", @@ -9423,6 +9446,7 @@ "react": "^18.3.1", "react-confetti": "^6.1.0", "react-dom": "^18.3.1", + "react-highlight-words": "^0.20.0", "react-json-view-lite": "^1.2.1", "recharts": "^2.12.7", "shared": "*", diff --git a/packages/backend/package.json b/packages/backend/package.json index e72e8d15..078a437e 100644 --- a/packages/backend/package.json +++ b/packages/backend/package.json @@ -4,6 +4,7 @@ "scripts": { "start": "npm run migrate:db && tsx src/index.ts", "start:radar": "tsx src/radar.ts", + "start:realtime-evaluators": "tsx src/realtime-evaluators.ts", "dev:realtime-evaluators": "tsx --env-file=.env --watch src/realtime-evaluators.ts", "migrate:db": "tsx src/migrate.ts", "build": "tsup src/index.ts --format esm", diff --git a/packages/backend/src/api/v1/evaluator.ts b/packages/backend/src/api/v1/evaluator.ts index 395c0818..84230c91 100644 --- a/packages/backend/src/api/v1/evaluator.ts +++ b/packages/backend/src/api/v1/evaluator.ts @@ -44,7 +44,6 @@ evaluators.get("/:id", async (ctx: Context) => { }) evaluators.post("/", async (ctx: Context) => { - console.log(ctx.request.body) const requestBody = z.object({ ownerId: z.string().optional(), name: z.string(), @@ -53,7 +52,7 @@ evaluators.post("/", async (ctx: Context) => { type: z.string(), mode: z.string(), params: z.record(z.any()), - filters: z.string(), + filters: z.array(z.any()), }) const { projectId } = ctx.state @@ -63,7 +62,6 @@ evaluators.post("/", async (ctx: Context) => { const [insertedEvaluator] = await sql` insert into evaluator ${sql({ ...evaluator, - filters: deserializeLogic(evaluator.filters), projectId, })} returning * @@ -79,7 +77,7 @@ evaluators.patch("/:id", async (ctx: Context) => { type: z.string(), mode: z.string(), params: z.record(z.any()), - filters: z.record(z.any()), + filters: z.array(z.any()), }) const { projectId } = ctx.state diff --git a/packages/backend/src/api/v1/projects/index.ts b/packages/backend/src/api/v1/projects/index.ts index 0092e233..d8f2743f 100644 --- a/packages/backend/src/api/v1/projects/index.ts +++ b/packages/backend/src/api/v1/projects/index.ts @@ -164,40 +164,93 @@ projects.patch( async (ctx: Context) => { const bodySchema = z.object({ name: z.string().optional(), - filters: z.array(z.any()).optional(), - }); - const { name, filters } = bodySchema.parse(ctx.request.body); - const { projectId } = ctx.params; - const { userId } = ctx.state; + }) + const { name } = bodySchema.parse(ctx.request.body) + const { projectId } = ctx.params + const { userId } = ctx.state // TODO: this should be in a middleware - const hasProjectAccess = await checkProjectAccess(projectId, userId); + const hasProjectAccess = await checkProjectAccess(projectId, userId) if (!hasProjectAccess) { - ctx.throw(401, "Unauthorized"); + ctx.throw(401, "Unauthorized") } - if (name) { await sql` - update project - set name = ${name} - where id = ${projectId} - `; + update project set name = ${name} where id = ${projectId} + ` } - if (filters) { - await sql` - insert into ingestion_rule (project_id, type, filters) - values (${projectId}, 'filtering', ${filters}) - on conflict (project_id, type) - do update set filters = excluded.filters - `; + ctx.status = 200 + ctx.body = {} + }, +) + +projects.get( + "/:projectId/rules", + checkAccess("projects", "update"), + async (ctx: Context) => { + const { projectId } = ctx.params + const { userId } = ctx.state + + const hasProjectAccess = await checkProjectAccess(projectId, userId) + if (!hasProjectAccess) { + ctx.throw(401, "Unauthorized") } - ctx.status = 200; - ctx.body = {}; - } -); + const rules = + await sql`select * from ingestion_rule where project_id = ${projectId}` + + ctx.body = rules + }, +) + +projects.post( + "/:projectId/rules", + checkAccess("projects", "update"), + async (ctx: Context) => { + const bodySchema = z.object({ + type: z.enum(["filtering", "masking"]).default("filtering"), + filters: z.array(z.any()).optional(), + }) + const { type, filters } = bodySchema.parse(ctx.request.body) + const { projectId } = ctx.params + const { userId } = ctx.state + + const hasProjectAccess = await checkProjectAccess(projectId, userId) + if (!hasProjectAccess) { + ctx.throw(401, "Unauthorized") + } + await sql` + insert into ingestion_rule (project_id, type, filters) + values (${projectId}, ${type}, ${filters}) + on conflict (project_id, type) + do update set filters = excluded.filters + ` + + ctx.status = 200 + ctx.body = {} + }, +) + +projects.delete( + "/:projectId/rules", + checkAccess("projects", "update"), + async (ctx: Context) => { + const { projectId } = ctx.params + const { userId } = ctx.state + + const hasProjectAccess = await checkProjectAccess(projectId, userId) + if (!hasProjectAccess) { + ctx.throw(401, "Unauthorized") + } + + await sql`delete from ingestion_rule where project_id = ${projectId}` + + ctx.status = 200 + ctx.body = {} + }, +) export default projects diff --git a/packages/backend/src/api/v1/runs/index.ts b/packages/backend/src/api/v1/runs/index.ts index 9474c0dd..9c1a44f4 100644 --- a/packages/backend/src/api/v1/runs/index.ts +++ b/packages/backend/src/api/v1/runs/index.ts @@ -99,7 +99,6 @@ function formatRun(run: any) { status: run.status, siblingRunId: run.siblingRunId, params: processParams(run.params), - metadata: run.metadata, user: run.externalUserId && { id: run.externalUserId, @@ -110,81 +109,95 @@ function formatRun(run: any) { }, } - // TODO: c'est horrible - // const evaluationResults = run.evaluationResults.find( - // (result) => result.evaluatorType === "language", - // ) - // const languageDetections = evaluationResults?.result - // if ( - // languageDetections?.input && - // languageDetections?.output && - // languageDetections?.error - // ) { - // if (Array.isArray(formattedRun.input)) { - // for (let i = 0; i < formattedRun.input.length; i++) { - // if ( - // typeof formattedRun.input[i] === "object" && - // languageDetections.input - // ) { - // formattedRun.input[i].languageDetection = languageDetections.input[i] - // } - // } - // } else if (formattedRun.input && typeof formattedRun.input === "object") { - // formattedRun.input.languageDetection = languageDetections.input[0] - // } - - // if (Array.isArray(formattedRun.output)) { - // for (let i = 0; i < run.output.length; i++) { - // if (typeof formattedRun.output[i] === "object") { - // formattedRun.output[i].languageDetection = - // languageDetections.output[i] - // } - // } - // } else if (formattedRun.output && typeof formattedRun.input === "object") { - // formattedRun.output.languageDetection = languageDetections.output[0] - // } - - // if (formattedRun.error && typeof formattedRun.input === "object") { - // formattedRun.error.languageDetection = languageDetections.error[0] - // } - // } - - // const sentimentEvaluationResults = run.evaluationResults.find( - // (result) => result.evaluatorType === "sentiment", - // ) - // const sentimentAnalyses = sentimentEvaluationResults?.result - // if ( - // sentimentAnalyses?.input && - // sentimentAnalyses?.output && - // sentimentAnalyses?.error - // ) { - // if (Array.isArray(formattedRun.input)) { - // for (let i = 0; i < formattedRun.input.length; i++) { - // if (typeof formattedRun.input[i] === "object") { - // formattedRun.input[i].sentimentAnalysis = sentimentAnalyses.input[i] - // } - // } - // } else if (formattedRun.input && typeof formattedRun.input === "object") { - // formattedRun.input.sentimentAnalysis = sentimentAnalyses.input[0] - // } - // if (Array.isArray(formattedRun.output)) { - // for (let i = 0; i < run.output.length; i++) { - // if (formattedRun.output && typeof formattedRun.output[i] === "object") { - // formattedRun.output[i].sentimentAnalysis = sentimentAnalyses.output[i] - // } - // } - // } else if (formattedRun.output && typeof formattedRun.input === "object") { - // formattedRun.output.sentimentAnalysis = sentimentAnalyses.output[0] - // } - // if (formattedRun.error && typeof formattedRun.input === "object") { - // formattedRun.error.sentimentAnalysis = sentimentAnalyses.error[0] - // } - // } - - // for (let evaluationResult of run.evaluationResults || []) { - // formattedRun[`enrichment-${evaluationResult.evaluatorId}`] = - // evaluationResult - // } + try { + // TODO: put in process input function + if (Array.isArray(formattedRun.input)) { + for (const message of formattedRun.input) { + message.enrichments = [] + } + } else if (typeof formattedRun.input === "object") { + formattedRun.input.enrichments = [] + } + + if (Array.isArray(formattedRun.output)) { + for (const message of formattedRun.output) { + message.enrichments = [] + } + } else if (formattedRun.output && typeof formattedRun.output === "object") { + formattedRun.output.enrichments = [] + } + + if (formattedRun.error && typeof formattedRun.error === "object") { + formattedRun.error.enrichments = [] + } + + for (const { + result, + evaluatorType, + evaluatorId, + } of run.evaluationResults) { + if (!result?.input || !result?.output || !result?.error) { + continue + } + + if (Array.isArray(formattedRun.input)) { + for (let i = 0; i < formattedRun.input.length; i++) { + const message = formattedRun.input[i] + if (typeof message === "object") { + message.enrichments.push({ + result: result.input[i], + type: evaluatorType, + id: evaluatorId, + }) + } + } + } else if (formattedRun.input && typeof formattedRun.input === "object") { + formattedRun.input.enrichments.push({ + result: result.input[0], + type: evaluatorType, + id: evaluatorId, + }) + } + + if (Array.isArray(formattedRun.output)) { + for (let i = 0; i < formattedRun.output.length; i++) { + const message = formattedRun.output[i] + if (typeof message === "object") { + message.enrichments.push({ + result: result.output[i], + type: evaluatorType, + id: evaluatorId, + }) + } + } + } else if ( + formattedRun.output && + typeof formattedRun.output === "object" + ) { + formattedRun.output.enrichments.push({ + result: result.output[0], + type: evaluatorType, + id: evaluatorId, + }) + } + + if (formattedRun.error && typeof formattedRun.error === "object") { + formattedRun.error.enrichments.push({ + result: result.error[0], + type: evaluatorType, + id: evaluatorId, + }) + } + } + } catch (error) { + console.error(error) + } + + // TODO: put in an array nammed enrichment instead + for (let evaluationResult of run.evaluationResults || []) { + formattedRun[`enrichment-${evaluationResult.evaluatorId}`] = + evaluationResult + } return formattedRun } @@ -222,13 +235,26 @@ runs.get("/", async (ctx: Context) => { eu.last_seen as user_last_seen, eu.props as user_props, t.slug as template_slug, - rpfc.feedback as parent_feedback + rpfc.feedback as parent_feedback, + coalesce(array_agg( + jsonb_build_object( + 'evaluatorName', e.name, + 'evaluatorSlug', e.slug, + 'evaluatorType', e.type, + 'evaluatorId', e.id, + 'result', er.result, + 'createdAt', er.created_at, + 'updatedAt', er.updated_at + ) + ) filter (where er.run_id is not null), '{}') as evaluation_results from public.run r left join external_user eu on r.external_user_id = eu.id left join run_parent_feedback_cache rpfc on r.id = rpfc.id left join template_version tv on r.template_version_id = tv.id left join template t on tv.template_id = t.id + left join evaluation_result_v2 er on r.id = er.run_id + left join evaluator e on er.evaluator_id = e.id where r.project_id = ${projectId} ${parentRunCheck} @@ -258,6 +284,8 @@ runs.get("/", async (ctx: Context) => { left join run_parent_feedback_cache rpfc on r.id = rpfc.id left join template_version tv on r.template_version_id = tv.id left join template t on tv.template_id = t.id + left join evaluation_result_v2 er on r.id = er.run_id + left join evaluator e on er.evaluator_id = e.id where r.project_id = ${projectId} ${parentRunCheck} @@ -417,6 +445,9 @@ runs.get("/:id", async (ctx) => { from run r left join external_user eu on r.external_user_id = eu.id + left join run_parent_feedback_cache rpfc on r.id = rpfc.id + left join template_version tv on r.template_version_id = tv.id + left join template t on tv.template_id = t.id left join evaluation_result_v2 er on r.id = er.run_id left join evaluator e on er.evaluator_id = e.id where diff --git a/packages/backend/src/api/v1/views.ts b/packages/backend/src/api/v1/views.ts index e7677979..4256894a 100644 --- a/packages/backend/src/api/v1/views.ts +++ b/packages/backend/src/api/v1/views.ts @@ -3,7 +3,7 @@ import sql from "@/src/utils/db" import { clearUndefined } from "@/src/utils/ingest" import Context from "@/src/utils/koa" import Router from "koa-router" -import { CheckLogic } from "shared" + import { z } from "zod" const views = new Router({ @@ -15,6 +15,7 @@ const ViewSchema = z.object({ data: z.any(), columns: z.any(), icon: z.string().optional(), + type: z.enum(["llm", "thread", "trace"]), }) views.get("/", checkAccess("logs", "list"), async (ctx: Context) => { @@ -39,7 +40,7 @@ views.post("/", async (ctx: Context) => { const { projectId, userId } = ctx.state const validatedData = ViewSchema.parse(ctx.request.body) - const { name, data, columns, icon } = validatedData + const { name, data, columns, icon, type } = validatedData const [insertedCheck] = await sql` insert into view ${sql({ @@ -49,6 +50,7 @@ views.post("/", async (ctx: Context) => { data, columns, icon, + type, })} returning * ` diff --git a/packages/backend/src/checks/index.ts b/packages/backend/src/checks/index.ts index fb678dea..407616c7 100644 --- a/packages/backend/src/checks/index.ts +++ b/packages/backend/src/checks/index.ts @@ -182,17 +182,21 @@ export const CHECK_RUNNERS: CheckRunner[] = [ sql`e.type = 'pii'`, or( types.map((type: string) => { - const jsonSql = [{ type }] - return sql`( - er.result::jsonb -> 'input' @> ${sql.json(jsonSql)} - OR - er.result::jsonb -> 'output' @> ${sql.json(jsonSql)} + return sql`EXISTS ( + SELECT 1 + FROM jsonb_array_elements(er.result -> 'input') as input_array + WHERE input_array @> ${sql.json([{ type }])} + ) OR EXISTS ( + SELECT 1 + FROM jsonb_array_elements(er.result -> 'output') as output_array + WHERE output_array @> ${sql.json([{ type }])} )` }), ), ]) }, }, + { id: "sentiment", sql: ({ sentiment }) => { @@ -435,12 +439,14 @@ export const CHECK_RUNNERS: CheckRunner[] = [ id: "tokens", // sum completion_tokens and prompt_tokens if field is total sql: ({ field, operator, tokens }) => { + if (!tokens) return sql`true` + if (field === "total") { return sql`completion_tokens + prompt_tokens ${postgresOperators( operator, )} ${tokens}` } else { - return sql`${sql(field + "_tokens")} ${postgresOperators( + return sql`${sql(field + "_tokens")} ${postgresisOperators( operator, )} ${tokens}` } diff --git a/packages/backend/src/evaluators/assertion.ts b/packages/backend/src/evaluators/assertion.ts new file mode 100644 index 00000000..409f48e6 --- /dev/null +++ b/packages/backend/src/evaluators/assertion.ts @@ -0,0 +1,24 @@ +import { callML } from "@/src/utils/ml" +import { Run } from "shared" + +interface Params { + statement: string + model: string +} + +export async function evaluate(run: Run, params: Params) { + try { + const { statement, model } = params + console.log(statement) + const result = await callML("assertion", { + input: run.input, + output: run.output, + statement, + model, + }) + console.log(result) + return result + } catch (error) { + console.error(error) + } +} diff --git a/packages/backend/src/evaluators/index.ts b/packages/backend/src/evaluators/index.ts index a2f423f2..b7d5ad3d 100644 --- a/packages/backend/src/evaluators/index.ts +++ b/packages/backend/src/evaluators/index.ts @@ -1,6 +1,6 @@ import * as pii from "./pii" import * as language from "./language" -import * as assert from "./assert" +import * as assertion from "./assertion" import * as tone from "./tone" import * as topics from "./topics" import * as toxicity from "./toxicity" @@ -11,7 +11,7 @@ import * as replies from "./replies" const evaluators = { pii, language, - assert, + assertion, tone, topics, toxicity, diff --git a/packages/backend/src/evaluators/language.ts b/packages/backend/src/evaluators/language.ts index 0a513a34..f8cead40 100644 --- a/packages/backend/src/evaluators/language.ts +++ b/packages/backend/src/evaluators/language.ts @@ -132,7 +132,6 @@ export async function evaluate(run: Run, params: unknown) { } // TODO: zod for languages, SHOLUD NOT INGEST IN DB IF NOT CORRECT FORMAT - return languages } diff --git a/packages/backend/src/evaluators/assert.ts b/packages/backend/src/evaluators/old-assert.ts similarity index 97% rename from packages/backend/src/evaluators/assert.ts rename to packages/backend/src/evaluators/old-assert.ts index 6482d502..747d60c4 100644 --- a/packages/backend/src/evaluators/assert.ts +++ b/packages/backend/src/evaluators/old-assert.ts @@ -7,6 +7,7 @@ interface AssertParams { conditions: string[] } +// Used in playground export async function evaluate(run: Run, params: AssertParams) { const { conditions } = params diff --git a/packages/backend/src/evaluators/pii.ts b/packages/backend/src/evaluators/pii.ts index e23e766e..bfe7f3ca 100644 --- a/packages/backend/src/evaluators/pii.ts +++ b/packages/backend/src/evaluators/pii.ts @@ -58,20 +58,23 @@ export async function evaluate(run: Run, params: Params) { error: errorPIIs, } - // TODO: zod for languages, SHOLUD NOT INGEST IN DB IF NOT CORRECT FORMAT return PIIs } -// TODO: type async function detectPIIs( texts: string[], entityTypes: string[] = [], customRegexes: string[] = [], - excludedEntities: string[] = [] + excludedEntities: string[] = [], ): Promise { try { - return callML("pii", { texts, entityTypes, customRegexes, excludedEntities }) + return callML("pii", { + texts, + entityTypes, + customRegexes, + excludedEntities, + }) } catch (error) { console.error(error) console.log(texts) diff --git a/packages/backend/src/evaluators/topics.ts b/packages/backend/src/evaluators/topics.ts index a104a420..954c490f 100644 --- a/packages/backend/src/evaluators/topics.ts +++ b/packages/backend/src/evaluators/topics.ts @@ -1,38 +1,78 @@ import { Run } from "shared" -import openai from "@/src/utils/openai" -import lunary from "lunary" -import { lastMsg } from "../checks" +import { callML } from "../utils/ml" interface TopicsParams { topics: string[] } -export async function evaluate(run: Run, params: TopicsParams) { - const { topics } = params +// TOOD: refacto this with all the other parsing function already in use +function parseMessages(messages: unknown) { + if (!messages) { + return [""] + } + if (typeof messages === "string" && messages.length) { + return [messages] + } - const input = lastMsg(run.input) + `\n\n` + lastMsg(run.output) + if (messages === "__NOT_INGESTED__") { + return [""] + } - const topicsList = topics.join("\n") + if (Array.isArray(messages)) { + let contentArray = [] + for (const message of messages) { + if (message?.type === "system") { + continue + } + let content = message.content || message.text + if (typeof content === "string" && content.length) { + contentArray.push(content) + } else { + contentArray.push(JSON.stringify(message)) + } + } + return contentArray + } - const template = await lunary.renderTemplate("topics", { - input, - topics: topicsList, - }) + if (typeof messages === "object") { + return [JSON.stringify(messages)] + } - const res = await openai.chat.completions.create(template) + return [""] +} - const output = res.choices[0]?.message?.content +export async function evaluate(run: Run, params: TopicsParams) { + const { topics } = params + const input = parseMessages(run.input) + const output = parseMessages(run.output) + const error = parseMessages(run.error) - if (!output) return [] + const [inputTopics, outputTopics, errorTopics] = await Promise.all([ + detectTopics(input, topics), + detectTopics(output, topics), + detectTopics(error, topics), + ]) - // if the first line is 'None' as instructed in the prompt, return an empty array - if (output.split("\n")[0].toLowerCase().includes("none")) { - return [] + const result = { + input: inputTopics, + output: outputTopics, + error: errorTopics, } - const results = output - .split("\n") - .map((line: string) => line.toLowerCase().replace(".", "").trim()) + return result +} - return results +async function detectTopics( + texts: string[], + topics: string[] = [], +): Promise { + try { + return callML("topic", { + texts, + topics, + }) + } catch (error) { + console.error(error) + console.log(texts) + } } diff --git a/packages/backend/src/evaluators/toxicity.ts b/packages/backend/src/evaluators/toxicity.ts index 726f74ac..b877962f 100644 --- a/packages/backend/src/evaluators/toxicity.ts +++ b/packages/backend/src/evaluators/toxicity.ts @@ -1,17 +1,65 @@ +import { callML } from "@/src/utils/ml" import { Run } from "shared" -import { callML } from "../utils/ml" -import { lastMsg } from "../checks" + +// TOOD: refacto this with all the other parsing function already in use +function parseMessages(messages: unknown) { + if (!messages) { + return [""] + } + if (typeof messages === "string" && messages.length) { + return [messages] + } + + if (messages === "__NOT_INGESTED__") { + return [""] + } + + if (Array.isArray(messages)) { + let contentArray = [] + for (const message of messages) { + let content = message.content || message.text + if (typeof content === "string" && content.length) { + contentArray.push(content) + } else { + contentArray.push(JSON.stringify(message)) + } + } + return contentArray + } + + if (typeof messages === "object") { + return [JSON.stringify(messages)] + } + + return [""] +} export async function evaluate(run: Run) { - const text = lastMsg(run.input) + lastMsg(run.output) - if (!text.length) { - return null + const input = parseMessages(run.input) + const output = parseMessages(run.output) + const error = parseMessages(run.error) + + const [inputToxicity, outputToxicity] = await Promise.all([ + detectToxicity(input), + detectToxicity(output), + ]) + + const toxicity = { + input: inputToxicity, + output: outputToxicity, + error: error.map((e) => null), } - const toxicityLabels = await callML("toxicity", { - text, - }) + // TODO: zod for languages, SHOLUD NOT INGEST IN DB IF NOT CORRECT FORMAT + return toxicity +} - // format: ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult' ...] - return toxicityLabels +// TODO: type +async function detectToxicity(texts: string[]): Promise { + try { + return callML("toxicity", { texts }) + } catch (error) { + console.error(error) + console.log(texts) + } } diff --git a/packages/backend/src/jobs/realtime-evaluators.ts b/packages/backend/src/jobs/realtime-evaluators.ts index d7fddc2c..5d3a8ae3 100644 --- a/packages/backend/src/jobs/realtime-evaluators.ts +++ b/packages/backend/src/jobs/realtime-evaluators.ts @@ -6,7 +6,7 @@ import { RealtimeEvaluator } from "shared/evaluators" import { sleep } from "../utils/misc" import evaluators from "../evaluators" -const RUNS_BATCH_SIZE = 30 +const RUNS_BATCH_SIZE = 10 async function runEvaluator(evaluator: RealtimeEvaluator, run: Run) { try { @@ -71,7 +71,7 @@ async function evaluatorJob() { from evaluator e where - mode = 'realtime' + mode = 'realtime' and project_id = 'befa0759-fbf1-4c5e-a51c-dd2dbe70f053' order by random() ` diff --git a/packages/backend/src/utils/calcCost.ts b/packages/backend/src/utils/calcCost.ts index d82bd545..651d3689 100644 --- a/packages/backend/src/utils/calcCost.ts +++ b/packages/backend/src/utils/calcCost.ts @@ -227,12 +227,20 @@ export async function calcRunCost(run: any) { let inputUnits = 0 let outputUnits = 0 + let inputCost = 0 + let outputCost = 0 + if (mapping.unit === "TOKENS") { inputUnits = run.promptTokens || 0 outputUnits = run.completionTokens || 0 + + inputCost = (inputCost * inputUnits) / 1_000_000 + outputCost = (outputCost * outputUnits) / 1_000_000 } else if (mapping.unit === "MILLISECONDS") { inputUnits = run.duration || 0 outputUnits = 0 + + inputCost = inputCost * inputUnits } else if (mapping.unit === "CHARACTERS") { inputUnits = (typeof run.input === "string" ? run.input : JSON.stringify(run.input)) @@ -242,10 +250,10 @@ export async function calcRunCost(run: any) { ? run.output : JSON.stringify(run.output) ).length || 0 - } - const inputCost = (mapping.inputCost * inputUnits) / 1_000_000 - const outputCost = (mapping.outputCost * outputUnits) / 1_000_000 + inputCost = (inputCost * inputUnits) / 1_000_000 + outputCost = (outputCost * outputUnits) / 1_000_000 + } const finalCost = Number((inputCost + outputCost).toFixed(5)) diff --git a/packages/backend/src/utils/ingest.ts b/packages/backend/src/utils/ingest.ts index 12486d8e..ed54325a 100644 --- a/packages/backend/src/utils/ingest.ts +++ b/packages/backend/src/utils/ingest.ts @@ -1,5 +1,3 @@ -// import { completeRunUsage } from "@/lib/countTokens" - import { completeRunUsage } from "./countToken" import sql from "./db" diff --git a/packages/db/0026.sql b/packages/db/0026.sql new file mode 100644 index 00000000..0e5ec7f7 --- /dev/null +++ b/packages/db/0026.sql @@ -0,0 +1 @@ +alter table view add column type text not null default 'llm'; \ No newline at end of file diff --git a/packages/frontend/components/SmartViewer/HighlightPii.tsx b/packages/frontend/components/SmartViewer/HighlightPii.tsx new file mode 100644 index 00000000..83aa77eb --- /dev/null +++ b/packages/frontend/components/SmartViewer/HighlightPii.tsx @@ -0,0 +1,51 @@ +import Highlighter from "react-highlight-words" +import { useProject, useProjectRules } from "@/utils/dataHooks" +import { Tooltip } from "@mantine/core" +import { getPIIColor } from "@/utils/colors" +import classes from "./index.module.css" + +export default function HighlightPii({ + text, + piiDetection, +}: { + text: string + piiDetection: { type: string; entity: string }[] // Contains the detected PII +}) { + if (!piiDetection || piiDetection.length === 0) { + return <>{text} + } + + const { maskingRule } = useProjectRules() + + const HighlightBadge = ({ children }) => { + const piiType = piiDetection.find((pii) => pii.entity === children)?.type + const bgColor = `light-dark(var(--mantine-color-${getPIIColor(piiType)}-2), var(--mantine-color-${getPIIColor(piiType)}-9))` + const length = children.length + return ( + + + {maskingRule ? "x".repeat(length) : children} + + + ) + } + + return ( + pii.entity)} + autoEscape={true} + caseSensitive={true} + textToHighlight={text} + /> + ) +} diff --git a/packages/frontend/components/SmartViewer/Message.tsx b/packages/frontend/components/SmartViewer/Message.tsx index cefec41c..891938df 100644 --- a/packages/frontend/components/SmartViewer/Message.tsx +++ b/packages/frontend/components/SmartViewer/Message.tsx @@ -1,4 +1,4 @@ -import { getColorForRole } from "@/utils/colors" +import { getColorForRole, getPIIColor } from "@/utils/colors" import { ActionIcon, Box, @@ -29,11 +29,12 @@ import ProtectedText from "../blocks/ProtectedText" import { RenderJson } from "./RenderJson" import classes from "./index.module.css" -import { useEffect } from "react" +import { useEffect, useMemo } from "react" import { openConfirmModal } from "@mantine/modals" import { getFlagEmoji, getLanguageName } from "@/utils/format" import { renderSentimentEnrichment } from "@/utils/enrichment" +import HighlightPii from "./HighlightPii" const ghostTextAreaStyles = { variant: "unstyled", @@ -47,7 +48,14 @@ const ghostTextAreaStyles = { width: "100%", } -function RenderFunction({ color, editable, onChange, compact, data }) { +function RenderFunction({ + color, + editable, + onChange, + compact, + data, + piiDetection, +}) { return ( ) : (
-          
+          
         
)}
) } -function FunctionCallMessage({ data, color, compact }) { - return +function FunctionCallMessage({ data, color, compact, piiDetection }) { + return ( + + ) } -function ToolCallsMessage({ toolCalls, editable, onChange, color, compact }) { +function ToolCallsMessage({ + toolCalls, + editable, + onChange, + color, + compact, + piiDetection, +}) { return ( <> {toolCalls.map((toolCall, index) => ( @@ -138,6 +164,7 @@ function ToolCallsMessage({ toolCalls, editable, onChange, color, compact }) { { const newToolCalls = [...toolCalls] newToolCalls[index].function = newData @@ -178,7 +205,13 @@ function ToolCallsMessage({ toolCalls, editable, onChange, color, compact }) { ) } -function TextMessage({ data, compact, onChange = () => {}, editable = false }) { +function TextMessage({ + data, + compact, + onChange = () => {}, + piiDetection, + editable = false, +}) { const text = data.content || data.text return ( @@ -192,10 +225,15 @@ function TextMessage({ data, compact, onChange = () => {}, editable = false }) { onChange={(e) => onChange({ ...data, content: e.target.value })} {...ghostTextAreaStyles} /> - ) : compact ? ( - text?.substring(0, 150) // truncate text to render less ) : ( - text + )} @@ -265,7 +303,7 @@ function ChatMessageContent({ data, color, compact, - + piiDetection, onChange, editable, }) { @@ -303,6 +341,7 @@ function ChatMessageContent({ @@ -322,6 +361,7 @@ function ChatMessageContent({ onChange({ ...data, toolCalls })} compact={compact} @@ -397,8 +437,6 @@ export function ChatMessage({ const color = getColorForRole(data?.role) - const codeBg = `light-dark(rgba(255,255,255,0.5), rgba(0,0,0,0.6))` - // Add/remove the 'id' and 'name' props required on tool calls useEffect(() => { if (!data || !editable) return @@ -434,6 +472,23 @@ export function ChatMessage({ } }, [data, editable]) + const sentiment = useMemo(() => { + return data?.enrichments?.find( + (enrichment) => enrichment.type === "sentiment", + )?.result + }, [data?.enrichments]) + + const piiDetection = useMemo(() => { + return data?.enrichments?.find((enrichment) => enrichment.type === "pii") + ?.result + }, [data?.enrichments]) + + const language = useMemo(() => { + return data?.enrichments?.find( + (enrichment) => enrichment.type === "language", + )?.result + }, [data?.enrichments]) + return ( )} - {/* {renderSentimentEnrichment(data?.sentimentAnalysis?.score)} */} - {/* {data?.languageDetection && ( + {renderSentimentEnrichment(sentiment?.score)} + {language && ( - {getFlagEmoji(data.languageDetection.isoCode)} + {getFlagEmoji(language.isoCode)} - )} */} + )} )} { + return enrichments?.find((enrichment) => enrichment.type === "pii")?.result + }, [enrichments]) + return ( <> - {content} + + + {extra} diff --git a/packages/frontend/components/SmartViewer/MessageViewer.tsx b/packages/frontend/components/SmartViewer/MessageViewer.tsx index 4ba57cda..39dedfb8 100644 --- a/packages/frontend/components/SmartViewer/MessageViewer.tsx +++ b/packages/frontend/components/SmartViewer/MessageViewer.tsx @@ -10,7 +10,7 @@ function getLastMessage(messages) { return messages } -export default function MessageViewer({ data, compact }) { +export default function MessageViewer({ data, compact, piiDetection }) { const obj = Array.isArray(data) ? data : [data] return compact ? ( diff --git a/packages/frontend/components/SmartViewer/RenderJson.tsx b/packages/frontend/components/SmartViewer/RenderJson.tsx index 8504b346..f72fb478 100644 --- a/packages/frontend/components/SmartViewer/RenderJson.tsx +++ b/packages/frontend/components/SmartViewer/RenderJson.tsx @@ -3,8 +3,9 @@ import ProtectedText from "../blocks/ProtectedText" // import { JsonView, defaultStyles } from "react-json-view-lite" // import errorHandler from "@/utils/errors" import ErrorBoundary from "../blocks/ErrorBoundary" +import HighlightPii from "./HighlightPii" -export const Json = ({ data, compact }) => { +export const Json = ({ data, compact, piiDetection }) => { if (!data) return null const parsed = useMemo(() => { @@ -31,13 +32,18 @@ export const Json = ({ data, compact }) => { return ( - {compact ? JSON.stringify(parsed) : JSON.stringify(parsed, null, 2)} + ) } -export const RenderJson = ({ data, compact }) => ( +export const RenderJson = ({ data, compact, piiDetection }) => ( - + ) diff --git a/packages/frontend/components/SmartViewer/index.module.css b/packages/frontend/components/SmartViewer/index.module.css index 16157227..73c6b825 100644 --- a/packages/frontend/components/SmartViewer/index.module.css +++ b/packages/frontend/components/SmartViewer/index.module.css @@ -16,6 +16,19 @@ padding: var(--mantine-spacing-xs) !important; } +.piiBadge { + padding: 2px 4px; + border-radius: 4px; + /* background-color: var(--mantine-color-red-1); */ + /* color: var(--mantine-color-red-9); */ + font-size: 12px; + font-weight: 500; + + &.blurred { + filter: blur(3px); + } +} + .paper { padding: 12px; padding-top: 8px; diff --git a/packages/frontend/components/SmartViewer/index.tsx b/packages/frontend/components/SmartViewer/index.tsx index 1c9d4d30..8f33d307 100644 --- a/packages/frontend/components/SmartViewer/index.tsx +++ b/packages/frontend/components/SmartViewer/index.tsx @@ -15,6 +15,7 @@ import { ChatMessage } from "./Message" import MessageViewer from "./MessageViewer" import { RenderJson } from "./RenderJson" import classes from "./index.module.css" +import HighlightPii from "./HighlightPii" const checkIsMessage = (obj) => { return ( diff --git a/packages/frontend/components/analytics/BarList.tsx b/packages/frontend/components/analytics/BarList.tsx index 90c91bce..22cf287e 100644 --- a/packages/frontend/components/analytics/BarList.tsx +++ b/packages/frontend/components/analytics/BarList.tsx @@ -24,7 +24,7 @@ type BarListProps = { function BarList({ data, columns, filterZero = true }: BarListProps) { const dataColumns = columns.filter((col) => !col.bar && col.key) const main = dataColumns.find((col) => col.main) || dataColumns[0] - const mainTotal = data.reduce((acc, item) => acc + (item[main.key] || 0), 0) + const mainTotal = data?.reduce((acc, item) => acc + (item[main.key] || 0), 0) const scheme = useComputedColorScheme() if (!data) return <>No data. diff --git a/packages/frontend/components/blocks/RunChat.tsx b/packages/frontend/components/blocks/RunChat.tsx index 52d2c3d0..4cbcae0f 100644 --- a/packages/frontend/components/blocks/RunChat.tsx +++ b/packages/frontend/components/blocks/RunChat.tsx @@ -42,6 +42,8 @@ function parseMessageFromRun(run) { ), id: run.id, feedback: run.feedback, + enrichments: msg.enrichments, + ...(siblingRunId && { siblingRunId }), ...(OUTPUT_ROLES.includes(role) && { took: @@ -71,18 +73,20 @@ function Message({ const runId = router?.query?.selected const { updateFeedback } = useRun(msg.id) const { data: relatedRuns } = useProjectSWR(runId && `/runs/${runId}/related`) + return ( <> - {!!msg.took && ( + {/* {!!msg.took && ( {msg.took}ms - )} + )} */} {msg.role !== "user" && ( ) }, - text: ({ placeholder, width, value, onChange }) => { + text: ({ placeholder, width, value, minimal, onChange }) => { return ( onChange(e.currentTarget.value)} diff --git a/packages/frontend/components/checks/Picker.tsx b/packages/frontend/components/checks/Picker.tsx index e780533c..33f6d15e 100644 --- a/packages/frontend/components/checks/Picker.tsx +++ b/packages/frontend/components/checks/Picker.tsx @@ -14,6 +14,7 @@ export function RenderCheckNode({ node, disabled, checks, + showAndOr, setNode, removeNode, }: { @@ -21,6 +22,7 @@ export function RenderCheckNode({ node: CheckLogic checks: Check[] disabled?: boolean + showAndOr?: boolean setNode: (node: CheckLogic | LogicData) => void removeNode: () => void }) { @@ -35,6 +37,7 @@ export function RenderCheckNode({ key={i} checks={checks} disabled={disabled} + showAndOr={showAndOr} node={n as CheckLogic} removeNode={() => { const newNode = [...node] @@ -50,7 +53,8 @@ export function RenderCheckNode({ ) return node.map((n, i) => { - const showOperator = i !== 0 && i !== node.length - 1 && !minimal + const showOperator = + i !== 0 && i !== node.length - 1 && (!minimal || showAndOr) return showOperator ? ( {showCheckNode(n, i)} @@ -58,7 +62,7 @@ export function RenderCheckNode({