diff --git a/.changeset/perfect-zebras-live.md b/.changeset/perfect-zebras-live.md new file mode 100644 index 0000000000..c15d9f3e4d --- /dev/null +++ b/.changeset/perfect-zebras-live.md @@ -0,0 +1,6 @@ +--- +"@llamaindex/unit-test": patch +"@llamaindex/workflow": patch +--- + +feat(workflow): allow send event with no output diff --git a/packages/workflow/src/workflow-context.ts b/packages/workflow/src/workflow-context.ts index 624916f8b2..763e5d18dc 100644 --- a/packages/workflow/src/workflow-context.ts +++ b/packages/workflow/src/workflow-context.ts @@ -13,19 +13,18 @@ export type StepHandler< AnyWorkflowEventConstructor | StartEventConstructor, ...(AnyWorkflowEventConstructor | StopEventConstructor)[], ] = [AnyWorkflowEventConstructor | StartEventConstructor], - Out extends [ - AnyWorkflowEventConstructor | StartEventConstructor, - ...(AnyWorkflowEventConstructor | StopEventConstructor)[], - ] = [AnyWorkflowEventConstructor | StopEventConstructor], + Out extends (AnyWorkflowEventConstructor | StopEventConstructor)[] = [], > = ( context: HandlerContext, ...events: { [K in keyof Inputs]: InstanceType; } ) => Promise< - { - [K in keyof Out]: InstanceType; - }[number] + Out extends [] + ? void + : { + [K in keyof Out]: InstanceType; + }[number] >; export type ReadonlyStepMap = ReadonlyMap< @@ -275,7 +274,7 @@ export class WorkflowContext */ #createStreamEvents(): AsyncIterableIterator> { const isPendingEvents = new WeakSet>(); - const pendingTasks = new Set>>(); + const pendingTasks = new Set | void>>(); const enqueuedEvents = new Set>(); const stream = new ReadableStream>({ start: async (controller) => { @@ -325,102 +324,104 @@ export class WorkflowContext } const [steps, inputsMap, outputsMap] = this.#getStepFunction(event); - const nextEventPromises: Promise>[] = [ - ...steps, - ] - .map((step) => { - const inputs = [...(inputsMap.get(step) ?? [])]; - const acceptableInputs: WorkflowEvent[] = - this.#pendingInputQueue.filter((event) => - inputs.some((input) => event instanceof input), - ); - const events: WorkflowEvent[] = flattenEvents( - inputs, - [event, ...acceptableInputs], - ); - // remove the event from the queue, in case of infinite loop - events.forEach((event) => { - const protocolIdx = this.#queue.findIndex( - (protocol) => - protocol.type === "event" && - protocol.event === event, + const nextEventPromises: Promise | void>[] = + [...steps] + .map((step) => { + const inputs = [...(inputsMap.get(step) ?? [])]; + const acceptableInputs: WorkflowEvent[] = + this.#pendingInputQueue.filter((event) => + inputs.some((input) => event instanceof input), + ); + const events: WorkflowEvent[] = flattenEvents( + inputs, + [event, ...acceptableInputs], ); - if (protocolIdx !== -1) { - this.#queue.splice(protocolIdx, 1); + // remove the event from the queue, in case of infinite loop + events.forEach((event) => { + const protocolIdx = this.#queue.findIndex( + (protocol) => + protocol.type === "event" && + protocol.event === event, + ); + if (protocolIdx !== -1) { + this.#queue.splice(protocolIdx, 1); + } + }); + if (events.length !== inputs.length) { + if (this.#verbose) { + console.log( + `Not enough inputs for step ${step.name}, waiting for more events`, + ); + } + // not enough to run the step, push back to the queue + this.#sendEvent(event); + isPendingEvents.add(event); + return null; + } + if (isPendingEvents.has(event)) { + isPendingEvents.delete(event); } - }); - if (events.length !== inputs.length) { if (this.#verbose) { console.log( - `Not enough inputs for step ${step.name}, waiting for more events`, + `Running step ${step.name} with inputs ${events}`, ); } - // not enough to run the step, push back to the queue - this.#sendEvent(event); - isPendingEvents.add(event); - return null; - } - if (isPendingEvents.has(event)) { - isPendingEvents.delete(event); - } - if (this.#verbose) { - console.log( - `Running step ${step.name} with inputs ${events}`, - ); - } - const data = this.data; - return (step as StepHandler) - .call( - null, - { - get data() { - return data; + const data = this.data; + return (step as StepHandler) + .call( + null, + { + get data() { + return data; + }, + sendEvent: this.#sendEvent, + requireEvent: this.#requireEvent, }, - sendEvent: this.#sendEvent, - requireEvent: this.#requireEvent, - }, - // @ts-expect-error IDK why - ...events.sort((a, b) => { - const aIndex = inputs.indexOf( - a.constructor as AnyWorkflowEventConstructor, - ); - const bIndex = inputs.indexOf( - b.constructor as AnyWorkflowEventConstructor, - ); - return aIndex - bIndex; - }), - ) - .then((nextEvent) => { - if (this.#verbose) { - console.log( - `Step ${step.name} completed, next event is ${nextEvent}`, - ); - } - const outputs = outputsMap.get(step) ?? []; - if ( - !outputs.some( - (output) => nextEvent.constructor === output, - ) - ) { - if (this.#strict) { - const error = Error( - `Step ${step.name} returned an unexpected output event ${nextEvent}`, + // @ts-expect-error IDK why + ...events.sort((a, b) => { + const aIndex = inputs.indexOf( + a.constructor as AnyWorkflowEventConstructor, ); - controller.error(error); - } else { - console.warn( - `Step ${step.name} returned an unexpected output event ${nextEvent}`, + const bIndex = inputs.indexOf( + b.constructor as AnyWorkflowEventConstructor, ); + return aIndex - bIndex; + }), + ) + .then((nextEvent: void | WorkflowEvent) => { + if (nextEvent === undefined) { + return; } - } - if (!(nextEvent instanceof StopEvent)) { - this.#pendingInputQueue.unshift(nextEvent); - this.#sendEvent(nextEvent); - } - return nextEvent; - }); - }) - .filter((promise) => promise !== null); + if (this.#verbose) { + console.log( + `Step ${step.name} completed, next event is ${nextEvent}`, + ); + } + const outputs = outputsMap.get(step) ?? []; + if ( + !outputs.some( + (output) => nextEvent.constructor === output, + ) + ) { + if (this.#strict) { + const error = Error( + `Step ${step.name} returned an unexpected output event ${nextEvent}`, + ); + controller.error(error); + } else { + console.warn( + `Step ${step.name} returned an unexpected output event ${nextEvent}`, + ); + } + } + if (!(nextEvent instanceof StopEvent)) { + this.#pendingInputQueue.unshift(nextEvent); + this.#sendEvent(nextEvent); + } + return nextEvent; + }); + }) + .filter((promise) => promise !== null); nextEventPromises.forEach((promise) => { pendingTasks.add(promise); promise @@ -433,6 +434,9 @@ export class WorkflowContext }); Promise.race(nextEventPromises) .then((fastestNextEvent) => { + if (fastestNextEvent === undefined) { + return; + } if (!enqueuedEvents.has(fastestNextEvent)) { controller.enqueue(fastestNextEvent); enqueuedEvents.add(fastestNextEvent); @@ -441,7 +445,10 @@ export class WorkflowContext }) .then(async (fastestNextEvent) => Promise.all(nextEventPromises).then((nextEvents) => { - for (const nextEvent of nextEvents) { + const events = nextEvents.filter( + (event) => event !== undefined, + ); + for (const nextEvent of events) { // do not enqueue the same event twice if (fastestNextEvent !== nextEvent) { if (!enqueuedEvents.has(nextEvent)) { diff --git a/packages/workflow/src/workflow.ts b/packages/workflow/src/workflow.ts index 1e2abab2d3..0ec6f4dc20 100644 --- a/packages/workflow/src/workflow.ts +++ b/packages/workflow/src/workflow.ts @@ -57,10 +57,7 @@ export class Workflow { AnyWorkflowEventConstructor | StartEventConstructor, ...(AnyWorkflowEventConstructor | StopEventConstructor)[], ], - const Out extends [ - AnyWorkflowEventConstructor | StopEventConstructor, - ...(AnyWorkflowEventConstructor | StopEventConstructor)[], - ], + const Out extends (AnyWorkflowEventConstructor | StopEventConstructor)[], >( parameters: StepParameters, stepFn: ( @@ -69,9 +66,11 @@ export class Workflow { [K in keyof In]: InstanceType; } ) => Promise< - { - [K in keyof Out]: InstanceType; - }[number] + Out extends [] + ? void + : { + [K in keyof Out]: InstanceType; + }[number] >, ): this { const { inputs, outputs } = parameters; diff --git a/unit/workflow/workflow.test.ts b/unit/workflow/workflow.test.ts index 4d2104ef47..522bc3a376 100644 --- a/unit/workflow/workflow.test.ts +++ b/unit/workflow/workflow.test.ts @@ -531,6 +531,21 @@ describe("workflow basic", () => { const result = await myWorkflow.run("start"); expect(result.data).toBe("query result"); }); + + test("allow output with send event", async () => { + const myFlow = new Workflow({ verbose: true }); + myFlow.addStep( + { + inputs: [StartEvent], + outputs: [], + }, + async (context, ev) => { + context.sendEvent(new StopEvent(`Hello ${ev.data}!`)); + }, + ); + const result = myFlow.run("world"); + expect((await result).data).toBe("Hello world!"); + }); }); describe("workflow event loop", () => {