diff --git a/index.ts b/index.ts index dea3d14..27bb69f 100644 --- a/index.ts +++ b/index.ts @@ -1,22 +1,28 @@ /// -import { getFunctionAddress } from "convex/server"; import { DataModelFromSchemaDefinition, + DefaultFunctionArgs, DocumentByName, FunctionReference, FunctionReturnType, + GenericActionCtx, GenericDataModel, GenericDocument, GenericMutationCtx, GenericSchema, HttpRouter, OptionalRestArgs, + PublicHttpAction, + RegisteredAction, + RegisteredMutation, + RegisteredQuery, SchemaDefinition, StorageActionWriter, SystemDataModel, UserIdentity, actionGeneric, + getFunctionAddress, httpActionGeneric, makeFunctionReference, mutationGeneric, @@ -124,9 +130,7 @@ class DatabaseFake { // - When the top-level mutation commits, the writes are applied to the // database. // - When a mutation is rolled back, last level of writes is discarded. - private _writes: Array< - Record - > = []; + private _writes: Array> = []; private _schema: { schemaValidation: boolean; @@ -206,10 +210,7 @@ class DatabaseFake { return this._storage[storageId]; } - private _addWrite( - id: DocumentId, - newValue: StoredDocument | null, - ) { + private _addWrite(id: DocumentId, newValue: StoredDocument | null) { if (this._writes.length === 0) { throw new Error(`Write outside of transaction ${id}`); } @@ -614,16 +615,20 @@ function evaluateFilter( filter: any, ): Value | undefined { if (filter.$eq !== undefined) { - return compareValues( - evaluateFilter(document, filter.$eq[0]), - evaluateFilter(document, filter.$eq[1]) - ) === 0; + return ( + compareValues( + evaluateFilter(document, filter.$eq[0]), + evaluateFilter(document, filter.$eq[1]), + ) === 0 + ); } if (filter.$neq !== undefined) { - return compareValues( - evaluateFilter(document, filter.$neq[0]), - evaluateFilter(document, filter.$neq[1]) - ) !== 0; + return ( + compareValues( + evaluateFilter(document, filter.$neq[0]), + evaluateFilter(document, filter.$neq[1]), + ) !== 0 + ); } if (filter.$and !== undefined) { return filter.$and.every((child: any) => evaluateFilter(document, child)); @@ -1103,13 +1108,23 @@ function asyncSyscallImpl() { }); if (udfType === "query") { return JSON.stringify( - convexToJson(await withAuth().queryFromPath(functionPath, /* isNested */ true, udfArgs)), + convexToJson( + await withAuth().queryFromPath( + functionPath, + /* isNested */ true, + udfArgs, + ), + ), ); } if (udfType === "mutation") { return JSON.stringify( convexToJson( - await withAuth().mutationFromPath(functionPath, /* isNested */ true, udfArgs), + await withAuth().mutationFromPath( + functionPath, + /* isNested */ true, + udfArgs, + ), ), ); } @@ -1118,11 +1133,7 @@ function asyncSyscallImpl() { ); } case "1.0/createFunctionHandle": { - const { - name, - reference, - functionHandle, - } = args; + const { name, reference, functionHandle } = args; const functionPath = await getFunctionPathFromAddress({ name, reference, @@ -1159,19 +1170,22 @@ function asyncSyscallImpl() { const componentPath = getCurrentComponentPath(); setTimeout( (async () => { - const canceled = await withAuth().runInComponent(componentPath, async () => { - const job = db.get(jobId) as ScheduledFunction; - if (job.state.kind === "canceled") { - return true; - } - if (job.state.kind !== "pending") { - throw new Error( - `\`convexTest\` invariant error: Unexpected scheduled function state when starting it: ${job.state.kind}`, - ); - } - db.patch(jobId, { state: { kind: "inProgress" } }); - return false; - }); + const canceled = await withAuth().runInComponent( + componentPath, + async () => { + const job = db.get(jobId) as ScheduledFunction; + if (job.state.kind === "canceled") { + return true; + } + if (job.state.kind !== "pending") { + throw new Error( + `\`convexTest\` invariant error: Unexpected scheduled function state when starting it: ${job.state.kind}`, + ); + } + db.patch(jobId, { state: { kind: "inProgress" } }); + return false; + }, + ); if (canceled) { return; } @@ -1315,11 +1329,14 @@ async function blobSha(blob: Blob) { async function waitForInProgressScheduledFunctions(): Promise { let hadScheduledFunctions = false; for (const componentPath of Object.keys(getConvexGlobal().components)) { - const inProgressJobs = (await withAuth().runInComponent(componentPath, async (ctx) => { - return (await ctx.db.system.query("_scheduled_functions").collect()).filter( - (job: ScheduledFunction) => job.state.kind === "inProgress", - ); - })) as ScheduledFunction[]; + const inProgressJobs = (await withAuth().runInComponent( + componentPath, + async (ctx) => { + return ( + await ctx.db.system.query("_scheduled_functions").collect() + ).filter((job: ScheduledFunction) => job.state.kind === "inProgress"); + }, + )) as ScheduledFunction[]; let numRemaining = inProgressJobs.length; if (numRemaining === 0) { continue; @@ -1455,7 +1472,9 @@ export type TestConvexForDataModelAndIdentity< function getComponentInfo(componentPath: string): ComponentInfo { const convex = getConvexGlobal(); if (convex.components[componentPath] === undefined) { - throw new Error(`Component "${componentPath}" is not registered. Call "t.registerComponent".`); + throw new Error( + `Component "${componentPath}" is not registered. Call "t.registerComponent".`, + ); } return convex.components[componentPath]; } @@ -1474,8 +1493,7 @@ function getTransactionManager() { function getCurrentComponentPath() { const functionStack = getTransactionManager().functionStack; - const currentFunctionPath = - functionStack[functionStack.length - 1]; + const currentFunctionPath = functionStack[functionStack.length - 1]; return currentFunctionPath?.componentPath ?? ROOT_COMPONENT_PATH; } @@ -1578,10 +1596,7 @@ class TransactionManager { private _markTransactionDone: (() => void) | null = null; public functionStack: FunctionPath[] = []; - async begin( - functionPath: FunctionPath, - isNested: boolean, - ) { + async begin(functionPath: FunctionPath, isNested: boolean) { // Take a lock only for the top-level of each transaction. // Nested transactions are not isolated so if you `Promise.all` on multiple // `ctx.runMutation` or `ctx.runQuery` calls, they won't be serialized. @@ -1732,13 +1747,17 @@ function withAuth(auth: AuthFake = new AuthFake()) { }; const byTypeWithPath = { - queryFromPath: async (functionPath: FunctionPath, isNested: boolean, args: any) => { + queryFromPath: async ( + functionPath: FunctionPath, + isNested: boolean, + args: any, + ) => { const func = await getFunctionFromPath(functionPath, "query"); - validateValidator(JSON.parse(func.exportArgs()), args ?? {}); + validateValidator(JSON.parse((func as any).exportArgs()), args ?? {}); const q = queryGeneric({ handler: (ctx: any, a: any) => { const testCtx = { ...ctx, auth }; - return func(testCtx, a); + return getHandler(func)(testCtx, a); }, }); const transactionManager = getTransactionManager(); @@ -1759,14 +1778,20 @@ function withAuth(auth: AuthFake = new AuthFake()) { args: any, ): Promise => { const func = await getFunctionFromPath(functionPath, "mutation"); - validateValidator(JSON.parse(func.exportArgs()), args ?? {}); - - return await runTransaction(func, args, {}, functionPath, isNested); + validateValidator(JSON.parse((func as any).exportArgs()), args ?? {}); + + return await runTransaction( + getHandler(func), + args, + {}, + functionPath, + isNested, + ); }, actionFromPath: async (functionPath: FunctionPath, args: any) => { const func = await getFunctionFromPath(functionPath, "action"); - validateValidator(JSON.parse(func.exportArgs()), args ?? {}); + validateValidator(JSON.parse((func as any).exportArgs()), args ?? {}); const a = actionGeneric({ handler: (ctx: any, a: any) => { @@ -1777,7 +1802,7 @@ function withAuth(auth: AuthFake = new AuthFake()) { runAction: byType.action, auth, }; - return func(testCtx, a); + return getHandler(func)(testCtx, a); }, }); getTransactionManager().beginAction(functionPath); @@ -1798,21 +1823,35 @@ function withAuth(auth: AuthFake = new AuthFake()) { const byType = { query: async (functionReference: any, args: any) => { - const functionPath = await getFunctionPathFromReference(functionReference); - return await byTypeWithPath.queryFromPath(functionPath, /* isNested */ false, args); + const functionPath = + await getFunctionPathFromReference(functionReference); + return await byTypeWithPath.queryFromPath( + functionPath, + /* isNested */ false, + args, + ); }, mutation: async (functionReference: any, args: any): Promise => { - const functionPath = await getFunctionPathFromReference(functionReference); - return await byTypeWithPath.mutationFromPath(functionPath, /* isNested */ false, args); + const functionPath = + await getFunctionPathFromReference(functionReference); + return await byTypeWithPath.mutationFromPath( + functionPath, + /* isNested */ false, + args, + ); }, action: async (functionReference: any, args: any) => { - const functionPath = await getFunctionPathFromReference(functionReference); + const functionPath = + await getFunctionPathFromReference(functionReference); return await byTypeWithPath.actionFromPath(functionPath, args); }, }; - const run = async (componentPath: string, handler: (ctx: any) => T): Promise => { + const run = async ( + componentPath: string, + handler: (ctx: any) => T, + ): Promise => { // Grab StorageActionWriter from action ctx const a = actionGeneric({ handler: async ({ storage }: any) => { @@ -1850,10 +1889,18 @@ function withAuth(auth: AuthFake = new AuthFake()) { fun: async (functionPath: FunctionPath, args: any) => { const func = await getFunctionFromPath(functionPath, "any"); if (func.isQuery) { - return await byTypeWithPath.queryFromPath(functionPath, /* isNested */ false, args); + return await byTypeWithPath.queryFromPath( + functionPath, + /* isNested */ false, + args, + ); } if (func.isMutation) { - return await byTypeWithPath.mutationFromPath(functionPath, /* isNested */ false, args); + return await byTypeWithPath.mutationFromPath( + functionPath, + /* isNested */ false, + args, + ); } if (func.isAction) { return await byTypeWithPath.actionFromPath(functionPath, args); @@ -1881,8 +1928,7 @@ function withAuth(auth: AuthFake = new AuthFake()) { runAction: byType.action, auth, }; - // TODO: Remove `any`, it's needed because of a bug in Convex types - return func(testCtx, a) as any; + return getHandler(func)(testCtx, a); }); const response = await ( a as unknown as { @@ -1908,14 +1954,17 @@ function withAuth(auth: AuthFake = new AuthFake()) { // Stop after a fixed number of iterations to avoid infinite loops. for (let i = 0; i < maxIterations; i++) { advanceTimers(); - const hadScheduledFunctions = await waitForInProgressScheduledFunctions(); + const hadScheduledFunctions = + await waitForInProgressScheduledFunctions(); if (!hadScheduledFunctions) { return; } } - throw new Error("finishAllScheduledFunctions: too many iterations. " - + "Check for infinitely recursive scheduled functions, " - + "or increase maxIterations."); + throw new Error( + "finishAllScheduledFunctions: too many iterations. " + + "Check for infinitely recursive scheduled functions, " + + "or increase maxIterations.", + ); }, }; } @@ -1993,10 +2042,35 @@ async function getFunctionPathFromReference( return await getFunctionPathFromAddress(functionAddress); } -async function getFunctionFromPath( +type RegisteredFunctions = { + query: RegisteredQuery; + mutation: RegisteredMutation; + action: RegisteredAction; + any: any; +}; + +type RegisteredFunctionKind = keyof RegisteredFunctions; + +function getHandler( + func: RegisteredAction, +): (ctx: GenericActionCtx, args: Args) => Promise; +function getHandler( + func: RegisteredMutation, +): (ctx: GenericActionCtx, args: Args) => Promise; +function getHandler( + func: RegisteredQuery, +): (ctx: GenericActionCtx, args: Args) => Promise; +function getHandler( + func: PublicHttpAction, +): (ctx: GenericActionCtx, args: Request) => Promise; +function getHandler(func: any): (ctx: any, args: any) => any { + return "_handler" in func ? func["_handler"] : func; +} + +async function getFunctionFromPath( functionPath: FunctionPath, - type: "query" | "mutation" | "action" | "any", -) { + type: T, +): Promise { // "queries/messages:list" -> ["queries/messages", "list"] const [modulePath, maybeExportName] = functionPath.udfPath.split(":"); const exportName = @@ -2012,7 +2086,7 @@ async function getFunctionFromPath( `Expected a Convex function exported from module "${modulePath}" as \`${exportName}\`, but there is no such export.`, ); } - if (typeof func !== "function") { + if (typeof getHandler(func) !== "function") { throw new Error( `Expected a Convex function exported from module "${modulePath}" as \`${exportName}\`, but got: ${func}`, );