Skip to content

Commit

Permalink
fix: playground tools (#544)
Browse files Browse the repository at this point in the history
Co-authored-by: Vince Loewe <[email protected]>
  • Loading branch information
hughcrt and vincelwt authored Sep 5, 2024
1 parent 1db2df3 commit 5811705
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 49 deletions.
2 changes: 1 addition & 1 deletion packages/backend/src/api/v1/orgs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ orgs.post("/playground", async (ctx: Context) => {
const requestBodySchema = z.object({
content: z.array(z.any()).or(z.string()),
extra: z.any(),
variables: z.record(z.string()).optional().default({}),
variables: z.record(z.string()).nullable().optional().default({}),
})
const { content, extra, variables } = requestBodySchema.parse(
ctx.request.body,
Expand Down
17 changes: 9 additions & 8 deletions packages/backend/src/api/v1/template-versions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ const versions = new Router({
// Otherwise it returns stuff like maxTokens instead of max_tokens and OpenAI breaks
const unCameledSql = postgres(process.env.DATABASE_URL!)

export function unCamelExtras(version: any) {
version.extra = unCamelObject(version.extra)
return version
}

//Warning: Route used by SDK to fetch the latest version of a template
versions.get("/latest", async (ctx: Context) => {
const { projectId } = ctx.state
Expand Down Expand Up @@ -47,8 +52,6 @@ versions.get("/latest", async (ctx: Context) => {
ctx.throw("Template not found, is the project ID correct?", 404)
}

latestVersion.extra = unCamelObject(latestVersion.extra)

// This makes sure OpenAI messages are not camel cased as used in the app
// For example: message.toolCallId instead of message.tool_call_id
if (typeof latestVersion.content !== "string") {
Expand All @@ -57,7 +60,7 @@ versions.get("/latest", async (ctx: Context) => {
)
}

ctx.body = latestVersion
ctx.body = unCamelExtras(latestVersion)
})

versions.get("/:id", async (ctx: Context) => {
Expand All @@ -78,13 +81,11 @@ versions.get("/:id", async (ctx: Context) => {
ctx.throw(401, "You do not have access to this ressource.")
}

version.extra = unCamelObject(version.extra)

const [template] = await sql`
select * from template where project_id = ${projectId} and id = ${version.templateId}
`

ctx.body = { ...version, template }
ctx.body = { ...unCamelExtras(version), template }
})

versions.patch(
Expand Down Expand Up @@ -120,7 +121,7 @@ versions.patch(
}

const [updatedTemplateVersion] = await sql`
update template_version
update template_version
set ${sql(
clearUndefined({
content: sql.json(content),
Expand All @@ -136,7 +137,7 @@ versions.patch(
returning *
`

ctx.body = updatedTemplateVersion
ctx.body = unCamelExtras(updatedTemplateVersion)
},
)

Expand Down
21 changes: 9 additions & 12 deletions packages/backend/src/api/v1/templates.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Context from "@/src/utils/koa"
import { unCamelObject } from "@/src/utils/misc"
import Router from "koa-router"
import { z } from "zod"
import { unCamelExtras } from "./template-versions"

const templates = new Router({
prefix: "/templates",
Expand Down Expand Up @@ -34,9 +35,7 @@ templates.get("/", async (ctx: Context) => {

// uncamel each template's versions' extras' keys
for (const template of templates) {
for (const version of template.versions) {
version.extra = unCamelObject(version.extra)
}
template.versions = template.versions.map(unCamelExtras)
}

ctx.body = templates
Expand All @@ -62,11 +61,7 @@ templates.get("/latest", async (ctx: Context) => {
tv.created_at desc;
`

for (const version of templateVersions) {
version.extra = unCamelObject(version.extra)
}

ctx.body = templateVersions
ctx.body = templateVersions.map(unCamelExtras)
})

// insert template + a first version, and return the template with versions
Expand All @@ -93,13 +88,15 @@ templates.post("/", checkAccess("prompts", "create"), async (ctx: Context) => {
})} returning *
`

delete extra.stop

const [templateVersion] = await sql`
insert into template_version ${sql(
clearUndefined({
templateId: template.id,
content: sql.json(content),
extra: sql.json(unCamelObject(extra)),
testValues: sql.json(testValues),
extra: sql.json(unCamelObject(clearUndefined(extra))),
testValues: testValues ? sql.json(testValues) : undefined,
isDraft: isDraft,
notes,
}),
Expand All @@ -108,7 +105,7 @@ templates.post("/", checkAccess("prompts", "create"), async (ctx: Context) => {

ctx.body = {
...template,
versions: [templateVersion],
versions: [unCamelExtras(templateVersion)],
}
})

Expand Down Expand Up @@ -207,7 +204,7 @@ templates.post(
returning *
`

ctx.body = templateVersion
ctx.body = unCamelExtras(templateVersion)
},
)

Expand Down
9 changes: 8 additions & 1 deletion packages/backend/src/utils/misc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ export async function setDefaultBody(ctx: Context, next: Next) {
}

export function unCamelObject(obj: any): any {
if (Array.isArray(obj)) {
return obj.map(unCamelObject)
}
if (typeof obj !== "object" || obj === null) {
return obj
}
const newObj: any = {}
for (const key in obj) {
newObj[key.replace(/([A-Z])/g, "_$1").toLowerCase()] = obj[key]
const newKey = key.replace(/([A-Z])/g, "_$1").toLowerCase()
newObj[newKey] = unCamelObject(obj[key])
}
return newObj
}
Expand Down
32 changes: 19 additions & 13 deletions packages/frontend/components/SmartViewer/Message.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import {
useComputedColorScheme,
} from "@mantine/core"
import {
IconCircleMinus,
IconInfoCircle,
IconRobot,
IconTool,
IconTrash,
IconUser,
} from "@tabler/icons-react"
import Image from "next/image"
Expand Down Expand Up @@ -178,8 +178,12 @@ function ToolCallsMessage({
{editable && (
<ActionIcon
color="red"
variant="transparent"
className={classes.toolCallActionIcon}
size={22}
size="sm"
pos="absolute"
top="35px"
right="2px"
onClick={() => {
openConfirmModal({
title: "Are you sure?",
Expand All @@ -196,7 +200,7 @@ function ToolCallsMessage({
})
}}
>
<IconTrash size={16} />
<IconCircleMinus size="14" />
</ActionIcon>
)}
</Box>
Expand Down Expand Up @@ -525,16 +529,18 @@ export function ChatMessage({
{data.role}
</Text>
)}
<Group>
{sentiment && <SentimentEnrichment2 score={sentiment?.score} />}
{language && (
<Tooltip
label={`${getLanguageName(language.isoCode)} (${Number(language.confidence.toFixed(3))})`}
>
<Box>{getFlagEmoji(language.isoCode)}</Box>
</Tooltip>
)}
</Group>
{!editable && (
<Group>
{sentiment && <SentimentEnrichment2 score={sentiment?.score} />}
{language && (
<Tooltip
label={`${getLanguageName(language.isoCode)} (${Number(language.confidence.toFixed(3))})`}
>
<Box>{getFlagEmoji(language.isoCode)}</Box>
</Tooltip>
)}
</Group>
)}
</Group>
)}
<ChatMessageContent
Expand Down
2 changes: 1 addition & 1 deletion packages/frontend/components/prompts/PromptEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ export function PromptEditor({
onChange(newContent)
}}
>
<IconCircleMinus size="12" />
<IconCircleMinus size="14" />
</ActionIcon>
</Box>
))}
Expand Down
100 changes: 87 additions & 13 deletions packages/frontend/components/prompts/Provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,52 @@ import { useState } from "react"
import Link from "next/link"
import { IconInfoCircle, IconTools } from "@tabler/icons-react"

function convertOpenAIToolsToAnthropic(openAITools) {
return openAITools.map((openAITool) => {
const openAIFunction = openAITool.function

if (!openAIFunction) {
return openAITool
}

const anthropicTool = {
name: openAIFunction.name,
description: openAIFunction.description,
input_schema: {
type: "object",
properties: {},
required: openAIFunction.parameters.required || [],
},
}

for (const [key, value] of Object.entries(
openAIFunction.parameters.properties,
)) {
anthropicTool.input_schema.properties[key] = {
type: value.type,
description: value.description,
}
}

return anthropicTool
})
}

function convertAnthropicToolsToOpenAI(anthropicTools) {
return anthropicTools.map((anthropicTool) => ({
type: "function",
function: {
name: anthropicTool.name,
description: anthropicTool.description,
parameters: {
type: "object",
properties: anthropicTool.input_schema.properties,
required: anthropicTool.input_schema.required,
},
},
}))
}

export const ParamItem = ({ name, value, description }) => (
<Group justify="space-between">
<Group gap={5}>
Expand All @@ -40,6 +86,19 @@ export const ParamItem = ({ name, value, description }) => (
</Group>
)

const validateToolCalls = (toolCalls: any[]) => {
if (!Array.isArray(toolCalls)) return false

const isNameDescriptionFormat = toolCalls.every(
(t) => t.name && t.description && t.input_schema,
)
const isFunctionTypeFormat = toolCalls.every(
(t) => t.type === "function" && t.function?.name,
)

return isNameDescriptionFormat || isFunctionTypeFormat
}

const isNullishButNotZero = (val: any) =>
val === undefined || val === null || val === ""

Expand Down Expand Up @@ -78,19 +137,6 @@ export default function ProviderEditor({
},
})

const validateToolCalls = (toolCalls: any[]) => {
if (!Array.isArray(toolCalls)) return false

const isNameDescriptionFormat = toolCalls.every(
(t) => t.name && t.description && t.input_schema,
)
const isFunctionTypeFormat = toolCalls.every(
(t) => t.type === "function" && t.function?.name,
)

return isNameDescriptionFormat || isFunctionTypeFormat
}

return (
<>
<ParamItem
Expand All @@ -107,9 +153,35 @@ export default function ProviderEditor({
inputMode="search"
value={value?.model}
onChange={(model) => {
if (!model || !value.model) {
return
}
// Handle conversion between OpenAI and Anthropic tools format
const isPreviousProviderOpenAI =
value.model.startsWith("gpt") || value.model.includes("mistral")
const isNewProviderOpenAI =
model.startsWith("gpt") || model.includes("mistral")

const isPreviousProviderAnthropic =
value.model.startsWith("claude")

const isNewProviderAnthropic = model.startsWith("claude")

let updatedTools = value.config.tools

if (isPreviousProviderOpenAI && isNewProviderAnthropic) {
updatedTools = convertOpenAIToolsToAnthropic(value.config.tools)
} else if (isPreviousProviderAnthropic && isNewProviderOpenAI) {
updatedTools = convertAnthropicToolsToOpenAI(value.config.tools)
}

onChange({
...value,
model,
config: {
...value.config,
tools: updatedTools,
},
})
}}
/>
Expand Down Expand Up @@ -258,6 +330,8 @@ export default function ProviderEditor({
? undefined
: JSON.parse(jsonrepair(tempJSON.trim()))

console.log(empty, repaired)

if (!empty && !validateToolCalls(repaired)) {
throw new Error("Invalid tool calls format")
}
Expand Down
6 changes: 6 additions & 0 deletions packages/frontend/pages/prompts/[[...id]].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ function Playground() {
const run = await fetcher.get(`/runs/${clone}?projectId=${project?.id}`)

if (run?.input) {
if (Array.isArray(run.input)) {
for (const input of run.input) {
delete input.enrichments
}
}

setTemplateVersion({
// ...templateVersion,
content: run.input,
Expand Down

0 comments on commit 5811705

Please sign in to comment.