Skip to content

Commit

Permalink
refactor: schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Feb 25, 2025
1 parent 3226aed commit 66490f3
Show file tree
Hide file tree
Showing 7 changed files with 654 additions and 565 deletions.
106 changes: 24 additions & 82 deletions src/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {evaluateAnswer, evaluateQuestion} from "./tools/evaluator";
import {analyzeSteps} from "./tools/error-analyzer";
import {TokenTracker} from "./utils/token-tracker";
import {ActionTracker} from "./utils/action-tracker";
import {StepAction, AnswerAction, KnowledgeItem, EvaluationCriteria, SearchResult} from "./types";
import {StepAction, AnswerAction, KnowledgeItem, SearchResult, EvaluationType} from "./types";
import {TrackerContext} from "./types";
import {search} from "./tools/jina-search";
// import {grounding} from "./tools/grounding";
Expand All @@ -21,73 +21,14 @@ import {CodeSandbox} from "./tools/code-sandbox";
import {serperSearch} from './tools/serper-search';
import {getUnvisitedURLs, normalizeUrl} from "./utils/url-tools";
import {buildMdFromAnswer, chooseK, removeExtraLineBreaks, removeHTMLtags} from "./utils/text-tools";
import {MAX_QUERIES_PER_STEP, MAX_REFLECT_PER_STEP, MAX_URLS_PER_STEP, Schemas} from "./utils/schemas";

async function sleep(ms: number) {
const seconds = Math.ceil(ms / 1000);
console.log(`Waiting ${seconds}s...`);
return new Promise(resolve => setTimeout(resolve, ms));
}

const MAX_URLS_PER_STEP = 2
const MAX_QUERIES_PER_STEP = 5
const MAX_REFLECT_PER_STEP = 3

function getSchema(allowReflect: boolean, allowRead: boolean, allowAnswer: boolean, allowSearch: boolean, allowCoding: boolean, languageStyle: string = 'same language as the question') {
const actions: string[] = [];
const properties: Record<string, z.ZodTypeAny> = {
action: z.enum(['placeholder']), // Will update later with actual actions
think: z.string().describe(`Explain why choose this action, what's the chain-of-thought behind choosing this action, use the first-person narrative.`).max(500)
};

if (allowSearch) {
actions.push("search");
properties.searchRequests = z.array(
z.string().max(30)
.describe(`A natual language search request in ${languageStyle}. Based on the deep intention behind the original question and the expected answer format.`)).describe(`Required when action='search'. Always prefer a single request, only add another request if the original question covers multiple aspects or elements and one search request is definitely not enough, each request focus on one specific aspect of the original question. Minimize mutual information between each request. Maximum ${MAX_QUERIES_PER_STEP} search requests.`).max(MAX_QUERIES_PER_STEP);
}

if (allowCoding) {
actions.push("coding");
properties.codingIssue = z.string().max(500)
.describe("Required when action='coding'. Describe what issue to solve with coding, format like a github issue ticket. Specify the input value when it is short.").optional();
}

if (allowAnswer) {
actions.push("answer");
properties.references = z.array(
z.object({
exactQuote: z.string().describe("Exact relevant quote from the document, must be a soundbite, short and to the point, no fluff").max(30),
url: z.string().describe("source URL; must be directly from the context")
}).required()
).describe("Required when action='answer'. Must be an array of references that support the answer, each reference must contain an exact quote and the URL of the document").optional();
properties.answer = z.string()
.describe(`Required when action='answer'. Must be definitive, no ambiguity, uncertainty, or disclaimers. Must in ${languageStyle} and confident. Use markdown footnote syntax like [^1], [^2] to refer the corresponding reference item`).optional();
}

if (allowReflect) {
actions.push("reflect");
properties.questionsToAnswer = z.array(
z.string().describe("each question must be a single line, Questions must be: Original (not variations of existing questions); Focused on single concepts; Under 20 words; Non-compound/non-complex")
).max(MAX_REFLECT_PER_STEP)
.describe(`Required when action='reflect'. List of most important questions to fill the knowledge gaps of finding the answer to the original question. Maximum provide ${MAX_REFLECT_PER_STEP} reflect questions.`).optional();
}

if (allowRead) {
actions.push("visit");
properties.URLTargets = z.array(z.string())
.max(MAX_URLS_PER_STEP)
.describe(`Required when action='visit'. Must be an array of URLs, choose up the most relevant ${MAX_URLS_PER_STEP} URLs to visit`).optional();
}

// Update the enum values after collecting all actions
properties.action = z.enum(actions as [string, ...string[]])
.describe("Must match exactly one action type");

return z.object(properties);

}



function getPrompt(
context?: string[],
Expand Down Expand Up @@ -192,7 +133,7 @@ ${learnedStrategy}
if (allURLs && allURLs.length > 0) {
urlList = allURLs
.filter(r => 'url' in r)
.map(r => ` + "${r.url}": "${r.title}"`)
.map(r => ` + "${r.url}": "${r.title}"`)
.join('\n');
}

Expand Down Expand Up @@ -290,37 +231,38 @@ ${actionSections.join('\n\n')}
}



const allContext: StepAction[] = []; // all steps in the current session, including those leads to wrong results

function updateContext(step: any) {
allContext.push(step)
}





export async function getResponse(question?: string,
tokenBudget: number = 1_000_000,
maxBadAttempts: number = 3,
existingContext?: Partial<TrackerContext>,
messages?: Array<CoreAssistantMessage | CoreUserMessage>
): Promise<{ result: StepAction; context: TrackerContext; visitedURLs: string[], readURLs: string[] }> {
const context: TrackerContext = {
tokenTracker: existingContext?.tokenTracker || new TokenTracker(tokenBudget),
actionTracker: existingContext?.actionTracker || new ActionTracker()
};

let step = 0;
let totalStep = 0;
let badAttempts = 0;
let schema: ZodObject<any> = getSchema(true, true, true, true, true)

question = question?.trim() as string;
if (messages && messages.length > 0) {
question = (messages[messages.length - 1]?.content as string).trim();
} else {
messages = [{role: 'user', content: question.trim()}]
}

const SchemaGen = new Schemas(question);
const context: TrackerContext = {
tokenTracker: existingContext?.tokenTracker || new TokenTracker(tokenBudget),
actionTracker: existingContext?.actionTracker || new ActionTracker()
};

let schema: ZodObject<any> = SchemaGen.getAgentSchema(true, true, true, true, true)
const gaps: string[] = [question]; // All questions to be answered including the orginal question
const allQuestions = [question];
const allKeywords = [];
Expand All @@ -338,7 +280,7 @@ export async function getResponse(question?: string,

const allURLs: Record<string, SearchResult> = {};
const visitedURLs: string[] = [];
const evaluationMetrics: Record<string, EvaluationCriteria> = {};
const evaluationMetrics: Record<string, EvaluationType[]> = {};
while (context.tokenTracker.getTotalUsage().totalTokens < tokenBudget && badAttempts <= maxBadAttempts) {
// add 1s delay to avoid rate limiting
step++;
Expand All @@ -349,7 +291,8 @@ export async function getResponse(question?: string,
allowReflect = allowReflect && (gaps.length <= 1);
const currentQuestion: string = gaps.length > 0 ? gaps.shift()! : question
if (!evaluationMetrics[currentQuestion]) {
evaluationMetrics[currentQuestion] = await evaluateQuestion(currentQuestion, context)
evaluationMetrics[currentQuestion] =
await evaluateQuestion(currentQuestion, context, SchemaGen)
}

// update all urls with buildURLMap
Expand All @@ -371,8 +314,7 @@ export async function getResponse(question?: string,
getUnvisitedURLs(allURLs, visitedURLs),
false,
);
schema = getSchema(allowReflect, allowRead, allowAnswer, allowSearch, allowCoding,
evaluationMetrics[currentQuestion].languageStyle)
schema = SchemaGen.getAgentSchema(allowReflect, allowRead, allowAnswer, allowSearch, allowCoding)
const generator = new ObjectGeneratorSafe(context.tokenTracker);
const result = await generator.generateObject({
model: 'agent',
Expand Down Expand Up @@ -420,10 +362,11 @@ export async function getResponse(question?: string,

context.actionTracker.trackThink(`But wait, let me evaluate the answer first.`)

const {response: evaluation} = await evaluateAnswer(currentQuestion, thisStep,
const evaluation = await evaluateAnswer(currentQuestion, thisStep,
evaluationMetrics[currentQuestion],
context,
visitedURLs
visitedURLs,
SchemaGen
);

if (currentQuestion.trim() === question) {
Expand Down Expand Up @@ -462,7 +405,7 @@ The evaluator thinks your answer is bad because:
${evaluation.think}
`);
// store the bad context and reset the diary context
const {response: errorAnalysis} = await analyzeSteps(diaryContext, context);
const errorAnalysis = await analyzeSteps(diaryContext, context, SchemaGen);

allKnowledge.push({
question: currentQuestion,
Expand Down Expand Up @@ -554,7 +497,7 @@ But then you realized you have asked them before. You decided to to think out of
thisStep.searchRequests = chooseK((await dedupQueries(thisStep.searchRequests, [], context.tokenTracker)).unique_queries, MAX_QUERIES_PER_STEP);

// rewrite queries
let {queries: keywordsQueries} = await rewriteQuery(thisStep, context);
let {queries: keywordsQueries} = await rewriteQuery(thisStep, context, SchemaGen);
// avoid exisitng searched queries
keywordsQueries = chooseK((await dedupQueries(keywordsQueries, allKeywords, context.tokenTracker)).unique_queries, MAX_QUERIES_PER_STEP);

Expand Down Expand Up @@ -717,7 +660,7 @@ You decided to think out of the box or cut from a completely different angle.`);
allowRead = false;
}
} else if (thisStep.action === 'coding' && thisStep.codingIssue) {
const sandbox = new CodeSandbox({allContext, visitedURLs, allURLs, allKnowledge}, context);
const sandbox = new CodeSandbox({allContext, visitedURLs, allURLs, allKnowledge}, context, SchemaGen);
try {
const result = await sandbox.solve(thisStep.codingIssue);
allKnowledge.push({
Expand Down Expand Up @@ -778,8 +721,7 @@ But unfortunately, you failed to solve the issue. You need to think out of the b
true,
);

schema = getSchema(false, false, true, false, false,
evaluationMetrics[question]?.languageStyle || 'same language as the question');
schema = SchemaGen.getAgentSchema(false, false, true, false, false);
const generator = new ObjectGeneratorSafe(context.tokenTracker);
const result = await generator.generateObject({
model: 'agentBeastMode',
Expand Down
117 changes: 55 additions & 62 deletions src/tools/code-sandbox.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
import { z } from 'zod';
import { ObjectGeneratorSafe } from "../utils/safe-generator";
import {TrackerContext} from "../types";

// Define the response schema for code generation
const codeGenerationSchema = z.object({
think: z.string().describe('Short explain or comments on the thought process behind the code, in first person.').max(200),
code: z.string().describe('The JavaScript code that solves the problem and always use \'return\' statement to return the result. Focus on solving the core problem; No need for error handling or try-catch blocks or code comments. No need to declare variables that are already available, especially big long strings or arrays.'),
});

// Define the types
interface CodeGenerationResponse {
code: string;
}
import {ObjectGeneratorSafe} from "../utils/safe-generator";
import {CodeGenResponse, TrackerContext} from "../types";
import {Schemas} from "../utils/schemas";


interface SandboxResult {
success: boolean;
Expand Down Expand Up @@ -72,33 +62,36 @@ export class CodeSandbox {
private generator: ObjectGeneratorSafe;
private maxAttempts: number;
private context: Record<string, any>;
private schemaGen: Schemas;

constructor(
context: any = {},
trackers?: TrackerContext,
maxAttempts: number = 3
trackers: TrackerContext,
schemaGen: Schemas,
maxAttempts: number = 3,
) {
this.trackers = trackers;
this.generator = new ObjectGeneratorSafe(trackers?.tokenTracker);
this.maxAttempts = maxAttempts;
this.context = context;
this.schemaGen = schemaGen;
}

private async generateCode(
problem: string,
previousAttempts: Array<{ code: string; error?: string }> = []
): Promise<CodeGenerationResponse> {
): Promise<CodeGenResponse> {
const prompt = getPrompt(problem, analyzeStructure(this.context), previousAttempts);

const result = await this.generator.generateObject({
model: 'coder',
schema: codeGenerationSchema,
schema: this.schemaGen.getCodeGeneratorSchema(),
prompt,
});

this.trackers?.actionTracker.trackThink(result.object.think);

return result.object;
return result.object as CodeGenResponse;
}

private evaluateCode(code: string): SandboxResult {
Expand Down Expand Up @@ -143,7 +136,7 @@ export class CodeSandbox {
for (let i = 0; i < this.maxAttempts; i++) {
// Generate code
const generation = await this.generateCode(problem, attempts);
const { code } = generation;
const {code} = generation;

console.log(`Coding attempt ${i + 1}:`, code);
// Evaluate the code
Expand Down Expand Up @@ -180,61 +173,61 @@ export class CodeSandbox {
}

function formatValue(value: any): string {
if (value === null) return 'null';
if (value === undefined) return 'undefined';
if (value === null) return 'null';
if (value === undefined) return 'undefined';

const type = typeof value;
const type = typeof value;

if (type === 'string') {
// Clean and truncate string value
const cleaned = value.replace(/\n/g, ' ').replace(/\s+/g, ' ').trim();
return cleaned.length > 50 ?
`"${cleaned.slice(0, 47)}..."` :
`"${cleaned}"`;
}
if (type === 'string') {
// Clean and truncate string value
const cleaned = value.replace(/\n/g, ' ').replace(/\s+/g, ' ').trim();
return cleaned.length > 50 ?
`"${cleaned.slice(0, 47)}..."` :
`"${cleaned}"`;
}

if (type === 'number' || type === 'boolean') {
return String(value);
}
if (type === 'number' || type === 'boolean') {
return String(value);
}

if (value instanceof Date) {
return `"${value.toISOString()}"`;
}
if (value instanceof Date) {
return `"${value.toISOString()}"`;
}

return '';
return '';
}

export function analyzeStructure(value: any, indent = ''): string {
if (value === null) return 'null';
if (value === undefined) return 'undefined';
if (value === null) return 'null';
if (value === undefined) return 'undefined';

const type = typeof value;
const type = typeof value;

if (type === 'function') {
return 'Function';
}
if (type === 'function') {
return 'Function';
}

// Handle atomic types with example values
if (type !== 'object' || value instanceof Date) {
const formattedValue = formatValue(value);
return `${type}${formattedValue ? ` (example: ${formattedValue})` : ''}`;
}
// Handle atomic types with example values
if (type !== 'object' || value instanceof Date) {
const formattedValue = formatValue(value);
return `${type}${formattedValue ? ` (example: ${formattedValue})` : ''}`;
}

if (Array.isArray(value)) {
if (value.length === 0) return 'Array<unknown>';
const sampleItem = value[0];
return `Array<${analyzeStructure(sampleItem, indent + ' ')}>`;
}
if (Array.isArray(value)) {
if (value.length === 0) return 'Array<unknown>';
const sampleItem = value[0];
return `Array<${analyzeStructure(sampleItem, indent + ' ')}>`;
}

const entries = Object.entries(value);
if (entries.length === 0) return '{}';
const entries = Object.entries(value);
if (entries.length === 0) return '{}';

const properties = entries
.map(([key, val]) => {
const analyzed = analyzeStructure(val, indent + ' ');
return `${indent} "${key}": ${analyzed}`;
})
.join(',\n');
const properties = entries
.map(([key, val]) => {
const analyzed = analyzeStructure(val, indent + ' ');
return `${indent} "${key}": ${analyzed}`;
})
.join(',\n');

return `{\n${properties}\n${indent}}`;
return `{\n${properties}\n${indent}}`;
}
Loading

0 comments on commit 66490f3

Please sign in to comment.