Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP close commit gap #2044

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion packages/sdk/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,12 @@ import { SignerContext } from './signerContext'
import { decryptAESGCM, deriveKeyAndIV, encryptAESGCM, uint8ArrayToBase64 } from './crypto_utils'
import { makeTags, makeTipTags } from './tags'
import { TipEventObject } from '@river-build/generated/dev/typings/ITipping'
import { extractMlsExternalGroup, ExtractMlsExternalGroupResult } from './mls/utils/mlsutils'
import {
extractMlsExternalGroup,
ExtractMlsExternalGroupResult,
mlsCommitsFromStreamView,
} from './mls/utils/mlsutils'
import { MlsMessage } from '@river-build/mls-rs-wasm'

export type ClientEvents = StreamEvents & DecryptionEvents

Expand Down Expand Up @@ -2512,6 +2517,44 @@ export class Client
)
}

async getMlsCommits(streamId: string, fromEpoch: bigint) {
let streamView = this.stream(streamId)?.view
let commits: Uint8Array[] = []
if (!streamView || !streamView.isInitialized) {
streamView = await this.getStream(streamId)
}
commits = mlsCommitsFromStreamView(streamView)
check(isDefined(streamView), `stream not found: ${streamId}`)
let miniblockNum = streamView.miniblockInfo?.min
check(isDefined(miniblockNum), `miniblockNum not found: ${streamId}`)

function checkDone(commits: Uint8Array[]) {
erikolsson marked this conversation as resolved.
Show resolved Hide resolved
for (const commit of commits) {
try {
const message = MlsMessage.fromBytes(commit)
if (message.epoch) {
return message.epoch <= fromEpoch
}
} catch {
// ignore
}
}
return false
}

while (!checkDone(commits)) {
const header = await this.getMiniblockHeader(streamId, miniblockNum)
const mls = header.snapshot?.members?.mls
check(isDefined(mls), `mls snapshot not found: ${streamId}`)
commits = mls.commitsSinceLastSnapshot.concat(commits)
if (miniblockNum === 0n) {
break
}
miniblockNum = header.prevSnapshotMiniblockNum
}
return commits
}

// Used during testing
userDeviceKey(): UserDevice {
return {
Expand Down
67 changes: 46 additions & 21 deletions packages/sdk/src/mls/utils/mlsutils.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
import { check } from '@river-build/dlog'
import { IStreamStateView } from '../../streamStateView'

import { StreamTimelineEvent } from '../../types'

export type ExtractMlsExternalGroupResult = {
externalGroupSnapshot: Uint8Array
groupInfoMessage: Uint8Array
commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[]
}

function commitFromEvent(
event: StreamTimelineEvent,
): { commit: Uint8Array; groupInfoMessage: Uint8Array } | undefined {
const payload = event.remoteEvent?.event.payload
if (payload?.case !== 'memberPayload') {
return undefined
}
if (payload?.value?.content.case !== 'mls') {
return undefined
}

const mlsPayload = payload.value.content.value
switch (mlsPayload.content.case) {
case 'externalJoin':
erikolsson marked this conversation as resolved.
Show resolved Hide resolved
case 'welcomeMessage':
return {
commit: mlsPayload.content.value.commit,
groupInfoMessage: mlsPayload.content.value.groupInfoMessage,
}
case undefined:
return undefined
default:
return undefined
}
}

export function extractMlsExternalGroup(
streamView: IStreamStateView,
): ExtractMlsExternalGroupResult | undefined {
Expand All @@ -33,30 +61,27 @@ export function extractMlsExternalGroup(
check(groupInfoMessage !== undefined, 'no groupInfoMessage found')
const commits: { commit: Uint8Array; groupInfoMessage: Uint8Array }[] = []
for (let i = indexOfLastSnapshot; i < streamView.timeline.length; i++) {
const event = streamView.timeline[i]
const payload = event.remoteEvent?.event.payload
if (payload?.case !== 'memberPayload') {
continue
const commit = commitFromEvent(streamView.timeline[i])
if (commit) {
commits.push(commit)
}
if (payload?.value?.content.case !== 'mls') {
}
return { externalGroupSnapshot, groupInfoMessage, commits: commits }
}

export function mlsCommitsFromStreamView(streamView: IStreamStateView): Uint8Array[] {
const commits: Uint8Array[] = []
const firstMiniblockNum = streamView.miniblockInfo?.min ?? 0n
for (let i = 0; i < streamView.timeline.length; i++) {
if (streamView.timeline[i].miniblockNum == firstMiniblockNum) {
continue
}

const mlsPayload = payload.value.content.value
switch (mlsPayload.content.case) {
case 'externalJoin':
case 'welcomeMessage':
commits.push({
commit: mlsPayload.content.value.commit,
groupInfoMessage: mlsPayload.content.value.groupInfoMessage,
})
break

case undefined:
break
default:
break
const commit = commitFromEvent(streamView.timeline[i])
if (commit) {
commits.push(commit.commit)
}
}
return { externalGroupSnapshot, groupInfoMessage, commits: commits }
return commits
}

// export function mlsCommitsFromMiniblockHeader(miniblockHeader: MiniblockHeader) {}
27 changes: 25 additions & 2 deletions packages/sdk/src/tests/multi_ne/mls.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ describe('mlsTests', () => {
let bobClient: Client
let bobMlsGroup: MlsGroup
let aliceClient: Client
let observerClient: Client
let bobMlsClient: MlsClient
let aliceMlsClient: MlsClient
let aliceMlsClient2: MlsClient
Expand All @@ -49,8 +50,12 @@ describe('mlsTests', () => {
const commits: Uint8Array[] = []

beforeAll(async () => {
bobClient = await makeInitAndStartClient()
aliceClient = await makeInitAndStartClient()
;[bobClient, aliceClient, observerClient] = await Promise.all([
makeInitAndStartClient(),
makeInitAndStartClient(),
makeInitAndStartClient(),
])

const clientOptions: MlsClientOptions = {
withAllowExternalCommit: true,
withRatchetTreeExtension: false,
Expand Down Expand Up @@ -782,4 +787,22 @@ describe('mlsTests', () => {
welcomeMessage.signaturePublicKeys.find((val) => bin_equal(val, signature)),
).toBeDefined()
})

test('fetch all commits from epoch 0n (from streamview)', async () => {
const allCommits = await bobClient.getMlsCommits(streamId, 0n)
expect(allCommits.length).toBe(commits.length)
expect(allCommits.length).toBeGreaterThan(2)
for (let i = 0; i < allCommits.length; i++) {
expect(bin_equal(allCommits[i], commits[i])).toBe(true)
}
})

test('fetch all commits from epoch 0n (from get stream)', async () => {
const allCommits = await observerClient.getMlsCommits(streamId, 0n)
expect(allCommits.length).toBe(commits.length)
expect(allCommits.length).toBeGreaterThan(2)
for (let i = 0; i < allCommits.length; i++) {
expect(bin_equal(allCommits[i], commits[i])).toBe(true)
}
})
})
Loading