Skip to content

Commit

Permalink
feat: make systemPrompt sync and optional (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdjastrzebski authored Dec 11, 2024
1 parent c9731e5 commit 287c02b
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 45 deletions.
5 changes: 5 additions & 0 deletions .changeset/four-spies-mate.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@callstack/byorg-core': minor
---

core: make systemPrompt both optional and sync-only
2 changes: 1 addition & 1 deletion docs/src/docs/core/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export type RequestContext = {
resolvedEntities: EntityInfo;

/** Function for generating a system prompt */
systemPrompt: () => Promise<string> | string;
systemPrompt: () => string | null;

/**
* Received partial response update with response streaming.
Expand Down
12 changes: 2 additions & 10 deletions docs/src/docs/core/usage.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,13 @@ Here is step by step how to start with byorg.
languageModel: openAiModel,
});
```
### Prepare a system prompt function
```js
import { RequestContext } from '@callstack/byorg-core';

const systemPrompt = (context: RequestContext): Promise<string> | string => {
return "You are a helpful AI assistant named Byorg";
};
```

### Create an Application instance
```js
import { VercelChatModelAdapter } from '@callstack/byorg-core';

const app = createApp({
chatModel,
systemPrompt,
chatModel
});
```
### Process user messages
Expand Down
4 changes: 2 additions & 2 deletions examples/bare/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ const chatModel = new VercelChatModelAdapter({
// import { createChatMockModel } from '@callstack/byorg-core';
// const chatModel = createChatMockModel();

const SYSTEM_PROMPT = 'Your name is Byorg. You are an AI Assistant.';
const SYSTEM_PROMPT = 'Your name is Byorg. You are a helpful AI Assistant.';

const app = createApp({
chatModel,
systemPrompt: () => SYSTEM_PROMPT,
systemPrompt: SYSTEM_PROMPT,
});

// Create a readline interface for user input
Expand Down
6 changes: 2 additions & 4 deletions examples/discord/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ const chatModel = new VercelChatModelAdapter({
// import { createChatMockModel } from '@callstack/byorg-core';
// const chatModel = createChatMockModel();

const systemPrompt = () => {
return 'Your name is Byorg. You are a helpful AI Assistant.';
};
const SYSTEM_PROMPT = 'Your name is Byorg. You are a helpful AI Assistant.';

const app = createApp({
chatModel,
systemPrompt,
systemPrompt: SYSTEM_PROMPT,
});

const discord = await createDiscordApp({ app });
Expand Down
4 changes: 2 additions & 2 deletions examples/slack/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ const chatModel = new VercelChatModelAdapter({
// import { createChatMockModel } from '@callstack/byorg-core';
// const chatModel = createChatMockModel();

const SYSTEM_PROMPT = 'Your name is Byorg. You are an AI Assistant.';
const SYSTEM_PROMPT = 'Your name is Byorg. You are a helpful AI Assistant.';

const app = createApp({
chatModel,
systemPrompt: () => SYSTEM_PROMPT,
systemPrompt: SYSTEM_PROMPT,
});

const slack = createSlackApp({
Expand Down
2 changes: 0 additions & 2 deletions packages/core/src/__tests__/application.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ test('basic non-streaming test', async () => {
const testModel = createMockChatModel({ delay: 0, seed: 3 });
const app = createApp({
chatModel: testModel,
systemPrompt: () => '',
});

const result = await app.processMessages(messages);
Expand All @@ -32,7 +31,6 @@ test('basic streaming test', async () => {
const testModel = createMockChatModel({ delay: 0, seed: 3 });
const app = createApp({
chatModel: testModel,
systemPrompt: () => '',
});

const onPartialResponse = vitest.fn();
Expand Down
63 changes: 63 additions & 0 deletions packages/core/src/ai/__tests__/vercel.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { expect, test } from 'vitest';
import { createMockVercelModel } from '../../mock/vercel-mock-model.js';
import { createApp } from '../../application.js';
import { Message } from '../../domain.js';
import { VercelChatModelAdapter } from '../vercel.js';

const messages: Message[] = [{ role: 'user', content: 'Hello' }];

test('sends system messages when system prompt is provided', async () => {
const { languageModel, calls } = createMockVercelModel({ text: 'Hello, world!' });
const app = createApp({
chatModel: new VercelChatModelAdapter({ languageModel }),
systemPrompt: 'This is system prompt',
});

await app.processMessages(messages);
expect(calls.length).toBe(1);
expect(calls[0].prompt).toContainEqual({
role: 'system',
content: 'This is system prompt',
});
});

test('does not send system message when system prompt is not provided', async () => {
const { languageModel, calls } = createMockVercelModel({ text: 'Hello, world!' });
const app = createApp({
chatModel: new VercelChatModelAdapter({ languageModel }),
});

await app.processMessages(messages);
expect(calls.length).toBe(1);
expect(calls[0].prompt).not.toContainEqual({
role: 'system',
});
});

test('does not send system message when system prompt returns null', async () => {
const { languageModel, calls } = createMockVercelModel({ text: 'Hello, world!' });
const app = createApp({
chatModel: new VercelChatModelAdapter({ languageModel }),
systemPrompt: () => null,
});

await app.processMessages(messages);
expect(calls.length).toBe(1);
expect(calls[0].prompt).not.toContainEqual({
role: 'system',
});
});

test('does not send system message when system prompt returns empty string', async () => {
const { languageModel, calls } = createMockVercelModel({ text: 'Hello, world!' });
const app = createApp({
chatModel: new VercelChatModelAdapter({ languageModel }),
systemPrompt: () => '',
});

await app.processMessages(messages);
expect(calls.length).toBe(1);
expect(calls[0].prompt).not.toContainEqual({
role: 'system',
});
});
37 changes: 20 additions & 17 deletions packages/core/src/ai/vercel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,20 @@ export class VercelChatModelAdapter implements ChatModel {
constructor(private readonly _options: VercelChatModelAdapterOptions) {}

async generateResponse(context: RequestContext): Promise<AssistantResponse> {
const messages = context.messages;
let systemPrompt = context.systemPrompt();
if (systemPrompt) {
const entitiesPrompt = formatResolvedEntities(context.resolvedEntities);
if (entitiesPrompt) {
systemPrompt += '\n\n' + entitiesPrompt;
}
}

const systemPrompt = await context.systemPrompt();
const entitiesPrompt = formatResolvedEntities(context.resolvedEntities);
const finalSystemPrompt = [systemPrompt, entitiesPrompt].join('\n\n');
const messages: CoreMessage[] = [];
if (systemPrompt) {
messages.push({ role: 'system', content: systemPrompt });
}

// TODO: Use userId in anonymous case
const resolvedMessages: CoreMessage[] = [
{ role: 'system' as const, content: finalSystemPrompt },
...messages.map(toMessageParam),
];
messages.push(...context.messages.map(toMessageParam));

const getRunToolFunction =
<TParams extends z.ZodSchema, TOutput>(
Expand All @@ -94,7 +97,7 @@ export class VercelChatModelAdapter implements ChatModel {

const executionContext: AiExecutionContext = {
tools,
messages: resolvedMessages,
messages: messages,
};

const executionResult = context.onPartialResponse
Expand Down Expand Up @@ -229,14 +232,14 @@ function toNumberOrZero(n: number): number {
return Number.isNaN(n) ? 0 : n;
}

function formatResolvedEntities(entities: Record<string, unknown>): string {
function formatResolvedEntities(entities: Record<string, unknown>): string | null {
if (Object.keys(entities).length === 0) {
return '';
return null;
}

return `ENTITY DICTIONARY: \n
${Object.entries(entities)
.map(([key, value]) => `'${key}' is '${JSON.stringify(value)}'`)
.join('\n')}
`;
return `### ENTITY DICTIONARY ###\n
${Object.entries(entities)
.map(([key, value]) => `'${key}' is '${JSON.stringify(value)}'`)
.join('\n')}
`;
}
12 changes: 6 additions & 6 deletions packages/core/src/application.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ export type ErrorHandler = (
) => Promise<MessageResponse> | MessageResponse;

export type ApplicationConfig = {
systemPrompt: (context: RequestContext) => Promise<string> | string;
chatModel: ChatModel | ((context: RequestContext) => ChatModel);
systemPrompt?: ((context: RequestContext) => string | null) | string;
plugins?: ApplicationPlugin[];
errorHandler?: ErrorHandler;
};
Expand Down Expand Up @@ -89,8 +89,6 @@ export function createApp(config: ApplicationConfig): Application {
const performance = new PerformanceTimeline();
performance.markStart(PerformanceMarks.processMessages);

const onPartialResponse = options?.onPartialResponse;

const context: RequestContext = {
messages,
get lastMessage() {
Expand All @@ -102,14 +100,16 @@ export function createApp(config: ApplicationConfig): Application {

return lastMessage;
},

systemPrompt: () =>
typeof config.systemPrompt === 'function'
? config.systemPrompt(context)
: (config.systemPrompt ?? null),
onPartialResponse: options?.onPartialResponse,
tools,
references: getReferenceStorage(),
resolvedEntities: {},
onPartialResponse,
extras: options?.extras ?? {},
performance,
systemPrompt: () => config.systemPrompt(context),
};

const handler = async () => {
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/domain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export type RequestContext = {
tools: ApplicationTool[];
references: ReferenceStorage;
resolvedEntities: EntityInfo;
systemPrompt: () => Promise<string> | string;
systemPrompt: () => string | null;
onPartialResponse?: (text: string) => void;
extras: MessageRequestExtras;
performance: PerformanceTimeline;
Expand Down
24 changes: 24 additions & 0 deletions packages/core/src/mock/vercel-mock-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { LanguageModelV1CallOptions } from 'ai';
import { MockLanguageModelV1 } from 'ai/test';

export type MockVercelModelOptions = {
text: string;
};

export function createMockVercelModel({ text }: MockVercelModelOptions) {
const calls: LanguageModelV1CallOptions[] = [];
const languageModel = new MockLanguageModelV1({
doGenerate: (options) => {
calls.push(options);
return Promise.resolve({
rawCall: { rawPrompt: null, rawSettings: {} },
finishReason: 'stop',
usage: { promptTokens: 10, completionTokens: 20 },
text,
});
},
// TODO: add doStream when needed
});

return { languageModel, calls };
}

0 comments on commit 287c02b

Please sign in to comment.