diff --git a/src/curry/memo.ts b/src/curry/memo.ts index d44c8060..c86d5529 100644 --- a/src/curry/memo.ts +++ b/src/curry/memo.ts @@ -1,16 +1,22 @@ -import type { NoInfer } from 'radashi' +import { type NoInfer, isArray, selectFirst, sift } from 'radashi' +type KeyOrKeys = string | (string | undefined)[] type Cache = Record function memoize( cache: Cache, func: (...args: TArgs) => TResult, - keyFunc: ((...args: TArgs) => string) | null, + getKeyFunc: ((...args: TArgs) => KeyOrKeys) | null, + setKeyFunc: ((...args: TArgs) => KeyOrKeys) | null, ttl: number | null, ) { return function callWithMemo(...args: any): TResult { - const key = keyFunc ? keyFunc(...args) : JSON.stringify({ args }) - const existing = cache[key] + const keyOrKeys = getKeyFunc + ? getKeyFunc(...args) + : JSON.stringify({ args }) + const keys = isArray(keyOrKeys) ? sift(keyOrKeys) : [keyOrKeys] + + const existing = selectFirst(keys, key => cache[key]) if (existing !== undefined) { if (!existing.exp) { return existing.value @@ -20,16 +26,22 @@ function memoize( } } const result = func(...args) - cache[key] = { - exp: ttl ? new Date().getTime() + ttl : null, - value: result, + + const setKeyOrKeys = setKeyFunc ? setKeyFunc(...args) : keys + const setKeys = isArray(setKeyOrKeys) ? sift(setKeyOrKeys) : [setKeyOrKeys] + for (const key of setKeys) { + cache[key] = { + exp: ttl ? new Date().getTime() + ttl : null, + value: result, + } } return result } } export interface MemoOptions { - key?: (...args: TArgs) => string + key?: (...args: TArgs) => KeyOrKeys + setKey?: (...args: TArgs) => KeyOrKeys ttl?: number } @@ -61,5 +73,11 @@ export function memo( func: (...args: TArgs) => TResult, options: MemoOptions> = {}, ): (...args: TArgs) => TResult { - return memoize({}, func, options.key ?? null, options.ttl ?? null) + return memoize( + {}, + func, + options.key ?? null, + options.setKey ?? null, + options.ttl ?? null, + ) } diff --git a/tests/curry/memo.test.ts b/tests/curry/memo.test.ts index e2f8d11f..f050aac1 100644 --- a/tests/curry/memo.test.ts +++ b/tests/curry/memo.test.ts @@ -7,6 +7,7 @@ describe('memo', () => { const resultB = func() expect(resultA).toBe(resultB) }) + test('uses key to identify unique calls', () => { const func = _.memo( (arg: { user: { id: string } }) => { @@ -23,6 +24,32 @@ describe('memo', () => { expect(resultA).toBe(resultA2) expect(resultB).not.toBe(resultA) }) + + test('uses multiple keys to identify unique calls', () => { + const rawFn = vi.fn((arg: { id: string; withAdditionalStuff: boolean }) => { + if (arg.withAdditionalStuff) { + // do stuff + } + + return arg.id + }) + + const func = _.memo(rawFn, { + key: arg => + arg.withAdditionalStuff + ? `${arg.id}_withAdditionalStuff` + : [`${arg.id}`, `${arg.id}_withAdditionalStuff`], // we also look for the shared key + setKey: arg => + arg.withAdditionalStuff + ? [`${arg.id}`, `${arg.id}_withAdditionalStuff`] // we also set the shared key + : `${arg.id}`, + }) + + func({ id: '1', withAdditionalStuff: true }) + func({ id: '1', withAdditionalStuff: false }) + expect(rawFn).toHaveBeenCalledTimes(1) + }) + test('calls function again when first value expires', async () => { vi.useFakeTimers() const func = _.memo(() => new Date().getTime(), { @@ -33,6 +60,7 @@ describe('memo', () => { const resultB = func() expect(resultA).not.toBe(resultB) }) + test('does not call function again when first value has not expired', async () => { vi.useFakeTimers() const func = _.memo(() => new Date().getTime(), {