diff --git a/packages/sdk/src/client.ts b/packages/sdk/src/client.ts index c9e0341641..7352711d1c 100644 --- a/packages/sdk/src/client.ts +++ b/packages/sdk/src/client.ts @@ -154,7 +154,12 @@ import { SignerContext } from './signerContext' import { decryptAESGCM, deriveKeyAndIV, encryptAESGCM, uint8ArrayToBase64 } from './crypto_utils' import { makeTags, makeTipTags } from './tags' import { TipEventObject } from '@river-build/generated/dev/typings/ITipping' -import { extractMlsExternalGroup, ExtractMlsExternalGroupResult } from './mls/utils/mlsutils' +import { + extractMlsExternalGroup, + ExtractMlsExternalGroupResult, + mlsCommitsFromStreamView, +} from './mls/utils/mlsutils' +import { MlsMessage } from '@river-build/mls-rs-wasm' export type ClientEvents = StreamEvents & DecryptionEvents @@ -2512,6 +2517,46 @@ export class Client ) } + // helper to return all commits from a specific epoch and forward + // may contain a few additional commits < than the requested epoch + async getMlsCommits(streamId: string, fromSnapshotContainingEpoch: bigint) { + let streamView = this.stream(streamId)?.view + let commits: Uint8Array[] = [] + if (!streamView || !streamView.isInitialized) { + streamView = await this.getStream(streamId) + } + commits = mlsCommitsFromStreamView(streamView) + check(isDefined(streamView), `stream not found: ${streamId}`) + let miniblockNum = streamView.miniblockInfo?.min + check(isDefined(miniblockNum), `miniblockNum not found: ${streamId}`) + + function checkDone() { + for (const commit of commits) { + try { + const message = MlsMessage.fromBytes(commit) + if (message.epoch) { + return message.epoch <= fromSnapshotContainingEpoch + } + } catch { + // ignore + } + } + return false + } + + while (!checkDone()) { + const header = await this.getMiniblockHeader(streamId, miniblockNum) + const mls = header.snapshot?.members?.mls + check(isDefined(mls), `mls snapshot not found: ${streamId}`) + commits = mls.commitsSinceLastSnapshot.concat(commits) + if (miniblockNum === 0n) { + break + } + miniblockNum = header.prevSnapshotMiniblockNum + } + return commits + } + // Used during testing userDeviceKey(): UserDevice { return { diff --git a/packages/sdk/src/mls/utils/mlsutils.ts b/packages/sdk/src/mls/utils/mlsutils.ts index 121a48edcd..079e75ce3b 100644 --- a/packages/sdk/src/mls/utils/mlsutils.ts +++ b/packages/sdk/src/mls/utils/mlsutils.ts @@ -1,12 +1,40 @@ import { check } from '@river-build/dlog' import { IStreamStateView } from '../../streamStateView' +import { StreamTimelineEvent } from '../../types' + export type ExtractMlsExternalGroupResult = { externalGroupSnapshot: Uint8Array groupInfoMessage: Uint8Array commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] } +function commitFromEvent( + event: StreamTimelineEvent, +): { commit: Uint8Array; groupInfoMessage: Uint8Array } | undefined { + const payload = event.remoteEvent?.event.payload + if (payload?.case !== 'memberPayload') { + return undefined + } + if (payload?.value?.content.case !== 'mls') { + return undefined + } + + const mlsPayload = payload.value.content.value + switch (mlsPayload.content.case) { + case 'externalJoin': + case 'welcomeMessage': + return { + commit: mlsPayload.content.value.commit, + groupInfoMessage: mlsPayload.content.value.groupInfoMessage, + } + case undefined: + return undefined + default: + return undefined + } +} + export function extractMlsExternalGroup( streamView: IStreamStateView, ): ExtractMlsExternalGroupResult | undefined { @@ -33,30 +61,26 @@ export function extractMlsExternalGroup( check(groupInfoMessage !== undefined, 'no groupInfoMessage found') const commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] = [] for (let i = indexOfLastSnapshot; i < streamView.timeline.length; i++) { - const event = streamView.timeline[i] - const payload = event.remoteEvent?.event.payload - if (payload?.case !== 'memberPayload') { - continue + const commit = commitFromEvent(streamView.timeline[i]) + if (commit) { + commits.push(commit) } - if (payload?.value?.content.case !== 'mls') { + } + return { externalGroupSnapshot, groupInfoMessage, commits: commits } +} + +export function mlsCommitsFromStreamView(streamView: IStreamStateView): Uint8Array[] { + const commits: Uint8Array[] = [] + const firstMiniblockNum = streamView.miniblockInfo?.min ?? 0n + for (let i = 0; i < streamView.timeline.length; i++) { + // the events in the first miniblock (with snapshot) are already accounted for + if (streamView.timeline[i].miniblockNum === firstMiniblockNum) { continue } - - const mlsPayload = payload.value.content.value - switch (mlsPayload.content.case) { - case 'externalJoin': - case 'welcomeMessage': - commits.push({ - commit: mlsPayload.content.value.commit, - groupInfoMessage: mlsPayload.content.value.groupInfoMessage, - }) - break - - case undefined: - break - default: - break + const commit = commitFromEvent(streamView.timeline[i]) + if (commit) { + commits.push(commit.commit) } } - return { externalGroupSnapshot, groupInfoMessage, commits: commits } + return commits } diff --git a/packages/sdk/src/tests/multi_ne/mls.test.ts b/packages/sdk/src/tests/multi_ne/mls.test.ts index 11cb645b94..19ece0da82 100644 --- a/packages/sdk/src/tests/multi_ne/mls.test.ts +++ b/packages/sdk/src/tests/multi_ne/mls.test.ts @@ -36,6 +36,7 @@ describe('mlsTests', () => { let bobClient: Client let bobMlsGroup: MlsGroup let aliceClient: Client + let observerClient: Client let bobMlsClient: MlsClient let aliceMlsClient: MlsClient let aliceMlsClient2: MlsClient @@ -49,8 +50,12 @@ describe('mlsTests', () => { const commits: Uint8Array[] = [] beforeAll(async () => { - bobClient = await makeInitAndStartClient() - aliceClient = await makeInitAndStartClient() + ;[bobClient, aliceClient, observerClient] = await Promise.all([ + makeInitAndStartClient(), + makeInitAndStartClient(), + makeInitAndStartClient(), + ]) + const clientOptions: MlsClientOptions = { withAllowExternalCommit: true, withRatchetTreeExtension: false, @@ -782,4 +787,22 @@ describe('mlsTests', () => { welcomeMessage.signaturePublicKeys.find((val) => bin_equal(val, signature)), ).toBeDefined() }) + + test('fetch all commits from epoch 0n (from streamview)', async () => { + const allCommits = await bobClient.getMlsCommits(streamId, 0n) + expect(allCommits.length).toBe(commits.length) + expect(allCommits.length).toBeGreaterThan(2) + for (let i = 0; i < allCommits.length; i++) { + expect(bin_equal(allCommits[i], commits[i])).toBe(true) + } + }) + + test('fetch all commits from epoch 0n (from get stream)', async () => { + const allCommits = await observerClient.getMlsCommits(streamId, 0n) + expect(allCommits.length).toBe(commits.length) + expect(allCommits.length).toBeGreaterThan(2) + for (let i = 0; i < allCommits.length; i++) { + expect(bin_equal(allCommits[i], commits[i])).toBe(true) + } + }) }) diff --git a/packages/sdk/vitest.config.multi_legacy.ts b/packages/sdk/vitest.config.multi_legacy.ts index 4ee1a8db13..8525fb5944 100644 --- a/packages/sdk/vitest.config.multi_legacy.ts +++ b/packages/sdk/vitest.config.multi_legacy.ts @@ -1,8 +1,8 @@ import { defineConfig, mergeConfig } from 'vitest/config' -import { rootConfig } from '../../vitest.config.mjs' +import { sdkRootConfig } from './vitest.sdk.rootConfig' export default mergeConfig( - rootConfig, + sdkRootConfig, defineConfig({ test: { environment: 'happy-dom', diff --git a/packages/sdk/vitest.config.multi_ne.ts b/packages/sdk/vitest.config.multi_ne.ts index c283791d20..78eb8e6115 100644 --- a/packages/sdk/vitest.config.multi_ne.ts +++ b/packages/sdk/vitest.config.multi_ne.ts @@ -1,9 +1,8 @@ import { defineConfig, mergeConfig } from 'vitest/config' -import { rootConfig } from '../../vitest.config.mjs' -import wasm from 'vite-plugin-wasm' +import { sdkRootConfig } from './vitest.sdk.rootConfig' export default mergeConfig( - rootConfig, + sdkRootConfig, defineConfig({ test: { environment: 'happy-dom', @@ -14,12 +13,6 @@ export default mergeConfig( hookTimeout: 120_000, testTimeout: 120_000, setupFiles: './vitest.setup.ts', - server: { - deps: { - inline: ['@river-build/mls-rs-wasm'], - }, - }, }, - plugins: [wasm()], }), ) diff --git a/packages/sdk/vitest.config.unit.ts b/packages/sdk/vitest.config.unit.ts index d77a2d6505..e8b8b9ead6 100644 --- a/packages/sdk/vitest.config.unit.ts +++ b/packages/sdk/vitest.config.unit.ts @@ -1,9 +1,8 @@ import { defineConfig, mergeConfig } from 'vitest/config' -import wasm from 'vite-plugin-wasm' -import { rootConfig } from '../../vitest.config.mjs' +import { sdkRootConfig } from './vitest.sdk.rootConfig' export default mergeConfig( - rootConfig, + sdkRootConfig, defineConfig({ test: { environment: 'happy-dom', @@ -11,12 +10,6 @@ export default mergeConfig( hookTimeout: 120_000, testTimeout: 120_000, setupFiles: './vitest.setup.ts', - server: { - deps: { - inline: ['@river-build/mls-rs-wasm'], - }, - }, }, - plugins: [wasm()], }), ) diff --git a/packages/sdk/vitest.sdk.rootConfig.ts b/packages/sdk/vitest.sdk.rootConfig.ts new file mode 100644 index 0000000000..2775640147 --- /dev/null +++ b/packages/sdk/vitest.sdk.rootConfig.ts @@ -0,0 +1,17 @@ +import { defineConfig, mergeConfig } from 'vitest/config' +import { rootConfig } from '../../vitest.config.mjs' +import wasm from 'vite-plugin-wasm' + +export const sdkRootConfig = mergeConfig( + rootConfig, + defineConfig({ + test: { + server: { + deps: { + inline: ['@river-build/mls-rs-wasm'], + }, + }, + }, + plugins: [wasm()], + }), +)