From bc6114ce37aba2d4768eb14244442e3e974d4c75 Mon Sep 17 00:00:00 2001 From: Alula <6276139+alula@users.noreply.github.com> Date: Wed, 13 Nov 2024 08:17:12 +0100 Subject: [PATCH 1/6] Update udpqueue-api --- ext-udpqueue/build.gradle | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ext-udpqueue/build.gradle b/ext-udpqueue/build.gradle index 1daaa62..a730e69 100644 --- a/ext-udpqueue/build.gradle +++ b/ext-udpqueue/build.gradle @@ -1,8 +1,6 @@ -def udpqueueVersion = '0.2.7' - dependencies { compileOnly project(':core') implementation 'dev.arbjerg:lava-common:2.2.1' - implementation group: 'club.minnced', name: 'udpqueue-api', version: '0.1.1' + implementation group: 'club.minnced', name: 'udpqueue-api', version: '0.2.9' compileOnly group: 'org.jetbrains', name: 'annotations', version: '13.0' } From 07c391c9bfda40b697f66a0c3cf1630037576a60 Mon Sep 17 00:00:00 2001 From: Alula <6276139+alula@users.noreply.github.com> Date: Wed, 13 Nov 2024 08:17:54 +0100 Subject: [PATCH 2/6] Fix Koe TestBot and update libraries --- testbot/build.gradle | 6 +++--- .../java/moe/kyokobot/koe/testbot/TestBot.java | 15 ++++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/testbot/build.gradle b/testbot/build.gradle index f1361fc..e543afe 100644 --- a/testbot/build.gradle +++ b/testbot/build.gradle @@ -1,8 +1,8 @@ dependencies { implementation project(':core') implementation project(':ext-udpqueue') - implementation group: 'net.dv8tion', name: 'JDA', version: '5.0.2' - implementation group: 'dev.arbjerg', name: 'lavaplayer', version: '2.2.1' - implementation group: 'com.github.lavalink-devs', name: 'lavaplayer-youtube-source', version: '1.0.5' + implementation group: 'net.dv8tion', name: 'JDA', version: '5.2.1' + implementation group: 'dev.arbjerg', name: 'lavaplayer', version: '2.2.2' + implementation group: 'dev.lavalink.youtube', name: 'v2', version: '1.8.3' implementation group: 'ch.qos.logback', name: 'logback-classic', version: '1.4.14' } diff --git a/testbot/src/main/java/moe/kyokobot/koe/testbot/TestBot.java b/testbot/src/main/java/moe/kyokobot/koe/testbot/TestBot.java index 80d078b..c259412 100644 --- a/testbot/src/main/java/moe/kyokobot/koe/testbot/TestBot.java +++ b/testbot/src/main/java/moe/kyokobot/koe/testbot/TestBot.java @@ -11,6 +11,9 @@ import com.sedmelluq.discord.lavaplayer.track.AudioTrack; import com.sedmelluq.discord.lavaplayer.track.playback.MutableAudioFrame; import dev.lavalink.youtube.YoutubeAudioSourceManager; +import dev.lavalink.youtube.clients.AndroidMusic; +import dev.lavalink.youtube.clients.AndroidTestsuite; +import dev.lavalink.youtube.clients.WebEmbedded; import io.netty.buffer.ByteBuf; import moe.kyokobot.koe.*; import moe.kyokobot.koe.media.OpusAudioFrameProvider; @@ -25,6 +28,7 @@ import net.dv8tion.jda.api.hooks.ListenerAdapter; import net.dv8tion.jda.api.hooks.VoiceDispatchInterceptor; import net.dv8tion.jda.api.requests.GatewayIntent; +import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,7 +54,7 @@ public class TestBot extends ListenerAdapter implements VoiceDispatchInterceptor private Koe koe; private KoeClient koeClient; private AudioPlayerManager playerManager; - private Map playerMap = new ConcurrentHashMap<>(); + private final Map playerMap = new ConcurrentHashMap<>(); public TestBot(String token) { this.token = token; @@ -66,7 +70,8 @@ public void start() { public void stop() { try { logger.info("Shutting down..."); - koeClient.close(); + if (koeClient != null) + koeClient.close(); Thread.sleep(250); jda.shutdownNow(); Thread.sleep(500); @@ -89,14 +94,14 @@ public JDA createJDA() { public AudioPlayerManager createAudioPlayerManager() { var manager = new DefaultAudioPlayerManager(); - manager.registerSourceManager(new YoutubeAudioSourceManager()); + manager.registerSourceManager(new YoutubeAudioSourceManager(new AndroidMusic(), new AndroidTestsuite(), new WebEmbedded())); manager.registerSourceManager(SoundCloudAudioSourceManager.createDefault()); manager.registerSourceManager(new HttpAudioSourceManager()); return manager; } @Override - public void onReady(ReadyEvent event) { + public void onReady(@NotNull ReadyEvent event) { koeClient = koe.newClient(jda.getSelfUser().getIdLong()); } @@ -114,7 +119,7 @@ public void onVoiceServerUpdate(VoiceServerUpdate voiceServerUpdate) { @Override public boolean onVoiceStateUpdate(VoiceStateUpdate voiceStateUpdate) { - if (voiceStateUpdate.getVoiceState().getIdLong() == jda.getSelfUser().getIdLong() && voiceStateUpdate.getChannel().getIdLong() == 0) { + if (voiceStateUpdate.getVoiceState().getIdLong() == jda.getSelfUser().getIdLong() && voiceStateUpdate.getChannel() == null) { koeClient.destroyConnection(voiceStateUpdate.getGuildIdLong()); } return true; From e24ad2d7f9188f5cdd44e4b4b170f1a148c2e2c1 Mon Sep 17 00:00:00 2001 From: Alula <6276139+alula@users.noreply.github.com> Date: Wed, 13 Nov 2024 08:37:15 +0100 Subject: [PATCH 3/6] Make V8 the default gateway version, update opcodes --- .../moe/kyokobot/koe/KoeOptionsBuilder.java | 2 +- .../koe/gateway/MediaGatewayV4Connection.java | 4 ++-- .../koe/gateway/MediaGatewayV5Connection.java | 6 +++--- .../koe/gateway/MediaGatewayV8Connection.java | 6 +++--- .../java/moe/kyokobot/koe/gateway/Op.java | 21 ++++++++++++++++--- 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/core/src/main/java/moe/kyokobot/koe/KoeOptionsBuilder.java b/core/src/main/java/moe/kyokobot/koe/KoeOptionsBuilder.java index 2e33740..3b3bbea 100644 --- a/core/src/main/java/moe/kyokobot/koe/KoeOptionsBuilder.java +++ b/core/src/main/java/moe/kyokobot/koe/KoeOptionsBuilder.java @@ -39,7 +39,7 @@ public class KoeOptionsBuilder { : NioDatagramChannel.class; this.byteBufAllocator = new PooledByteBufAllocator(); - this.gatewayVersion = GatewayVersion.V4; + this.gatewayVersion = GatewayVersion.V8; this.framePollerFactory = new NettyFramePollerFactory(); this.highPacketPriority = true; } diff --git a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV4Connection.java b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV4Connection.java index 50ba7ec..23ce0f4 100644 --- a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV4Connection.java +++ b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV4Connection.java @@ -116,7 +116,7 @@ protected void handlePayload(JsonObject object) { logger.debug("Resumed successfully"); break; } - case Op.CLIENT_CONNECT: { + case Op.VIDEO: { var data = object.getObject("d"); var user = data.getString("user_id"); var audioSsrc = data.getInt("audio_ssrc", 0); @@ -195,7 +195,7 @@ private void selectProtocol(String protocol) { .add("data", udpInfo) .combine(udpInfo)); - sendInternalPayload(Op.CLIENT_CONNECT, new JsonObject() + sendInternalPayload(Op.VIDEO, new JsonObject() .add("audio_ssrc", ssrc) .add("video_ssrc", 0) .add("rtx_ssrc", 0)); diff --git a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV5Connection.java b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV5Connection.java index 0bf6deb..3eb65c8 100644 --- a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV5Connection.java +++ b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV5Connection.java @@ -114,7 +114,7 @@ protected void handlePayload(JsonObject object) { logger.debug("Resumed successfully"); break; } - case Op.CLIENT_CONNECT: { + case Op.VIDEO: { var data = object.getObject("d"); var user = data.getString("user_id"); var audioSsrc = data.getInt("audio_ssrc", 0); @@ -129,7 +129,7 @@ protected void handlePayload(JsonObject object) { connection.getDispatcher().userDisconnected(user); break; } - case Op.VIDEO_SINK_WANTS: { + case Op.MEDIA_SINK_WANTS: { // Sent only if `video` flag was true while identifying. At time of writing this comment Discord forces // it to false on bots (so.. user bot time? /s) due to voice server bug that broke clients or something. // After receiving this opcode client can send op 12 with ssrcs for video (audio + 1) @@ -211,7 +211,7 @@ private void selectProtocol(String protocol) { this.updateSpeaking(0); - sendInternalPayload(Op.CLIENT_CONNECT, new JsonObject() + sendInternalPayload(Op.VIDEO, new JsonObject() .add("audio_ssrc", ssrc) .add("video_ssrc", 0) .add("rtx_ssrc", 0)); diff --git a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java index 0f03c8d..04ae51e 100644 --- a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java +++ b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java @@ -120,7 +120,7 @@ protected void handlePayload(JsonObject object) { logger.debug("Resumed successfully"); break; } - case Op.CLIENT_CONNECT: { + case Op.VIDEO: { var data = object.getObject("d"); var user = data.getString("user_id"); var audioSsrc = data.getInt("audio_ssrc", 0); @@ -135,7 +135,7 @@ protected void handlePayload(JsonObject object) { connection.getDispatcher().userDisconnected(user); break; } - case Op.VIDEO_SINK_WANTS: { + case Op.MEDIA_SINK_WANTS: { // Sent only if `video` flag was true while identifying. At time of writing this comment Discord forces // it to false on bots (so.. user bot time? /s) due to voice server bug that broke clients or something. // After receiving this opcode client can send op 12 with ssrcs for video (audio + 1) @@ -219,7 +219,7 @@ private void selectProtocol(String protocol) { this.updateSpeaking(0); - sendInternalPayload(Op.CLIENT_CONNECT, new JsonObject() + sendInternalPayload(Op.VIDEO, new JsonObject() .add("audio_ssrc", ssrc) .add("video_ssrc", 0) .add("rtx_ssrc", 0)); diff --git a/core/src/main/java/moe/kyokobot/koe/gateway/Op.java b/core/src/main/java/moe/kyokobot/koe/gateway/Op.java index d9f0b9a..ea31ed9 100644 --- a/core/src/main/java/moe/kyokobot/koe/gateway/Op.java +++ b/core/src/main/java/moe/kyokobot/koe/gateway/Op.java @@ -16,9 +16,24 @@ private Op() { public static final int HELLO = 8; public static final int RESUMED = 9; // public static final int DUNNO = 10; - // public static final int DUNNO = 11; - public static final int CLIENT_CONNECT = 12; // thx b1nzy + public static final int CLIENT_CONNECT = 11; + public static final int VIDEO = 12; public static final int CLIENT_DISCONNECT = 13; public static final int CODECS = 14; - public static final int VIDEO_SINK_WANTS = 15; + public static final int MEDIA_SINK_WANTS = 15; + public static final int VOICE_BACKEND_VERSION = 16; + public static final int CHANNEL_OPTIONS_UPDATE = 17; + public static final int CLIENT_FLAGS = 18; + public static final int SPEED_TEST = 19; + public static final int PLATFORM = 20; + public static final int SECURE_FRAMES_PREPARE_PROTOCOL_TRANSITION = 21; + public static final int SECURE_FRAMES_EXECUTE_TRANSITION = 22; + public static final int SECURE_FRAMES_READY_FOR_TRANSITION = 23; + public static final int SECURE_FRAMES_PREPARE_EPOCH = 24; + public static final int MLS_EXTERNAL_SENDER_PACKAGE = 25; + public static final int MLS_KEY_PACKAGE = 26; + public static final int MLS_PROPOSALS = 27; + public static final int MLS_COMMIT_WELCOME = 28; + public static final int MLS_PREPARE_COMMIT_TRANSITION = 29; + public static final int MLS_WELCOME = 30; } From 1a37ac19f622e1acb0116abfc39a2a7a8636f080 Mon Sep 17 00:00:00 2001 From: Alula <6276139+alula@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:22:39 +0100 Subject: [PATCH 4/6] Implement handling of binary websocket payloads --- .../AbstractMediaGatewayConnection.java | 31 +++++++++++++++++-- .../koe/gateway/MediaGatewayV8Connection.java | 12 ++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/moe/kyokobot/koe/gateway/AbstractMediaGatewayConnection.java b/core/src/main/java/moe/kyokobot/koe/gateway/AbstractMediaGatewayConnection.java index ead302c..d36d946 100644 --- a/core/src/main/java/moe/kyokobot/koe/gateway/AbstractMediaGatewayConnection.java +++ b/core/src/main/java/moe/kyokobot/koe/gateway/AbstractMediaGatewayConnection.java @@ -1,6 +1,7 @@ package moe.kyokobot.koe.gateway; import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; @@ -16,8 +17,8 @@ import io.netty.handler.ssl.SslHandler; import io.netty.util.concurrent.EventExecutor; import moe.kyokobot.koe.VoiceServerInfo; -import moe.kyokobot.koe.internal.NettyBootstrapFactory; import moe.kyokobot.koe.internal.MediaConnectionImpl; +import moe.kyokobot.koe.internal.NettyBootstrapFactory; import moe.kyokobot.koe.internal.json.JsonObject; import moe.kyokobot.koe.internal.json.JsonParser; import moe.kyokobot.koe.internal.util.NettyFutureWrapper; @@ -116,6 +117,10 @@ public boolean isOpen() { protected abstract void handlePayload(JsonObject object); + protected void handleBinaryPayload(ByteBuf buffer) { + // no-op + } + protected void onClose(int code, @Nullable String reason, boolean remote) { if (!closed) { closed = true; @@ -148,14 +153,29 @@ public void sendInternalPayload(int op, Object d) { sendRaw(new JsonObject().add("op", op).add("d", d)); } + public void sendBinaryInternalPayload(char op, ByteBuf buffer) { + var frame = channel.alloc().buffer(1 + buffer.readableBytes()); + frame.writeByte(op); + frame.writeBytes(buffer); + sendRaw(frame); + buffer.release(); + } + protected void sendRaw(JsonObject object) { if (channel != null && channel.isOpen()) { var data = object.toString(); - logger.trace("<- {}", data); + logger.trace("<-T {}", data); channel.writeAndFlush(new TextWebSocketFrame(data)); } } + protected void sendRaw(ByteBuf buffer) { + if (channel != null && channel.isOpen()) { + logger.trace("<-B {}", buffer); + channel.writeAndFlush(new BinaryWebSocketFrame(buffer)); + } + } + private class WebSocketClientHandler extends SimpleChannelInboundHandler { private final WebSocketClientHandshaker handshaker; @@ -210,9 +230,14 @@ protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Except if (msg instanceof TextWebSocketFrame) { var frame = (TextWebSocketFrame) msg; var object = JsonParser.object().from(frame.content()); - logger.trace("-> {}", object); + logger.trace("->T {}", object); frame.release(); handlePayload(object); + } else if (msg instanceof BinaryWebSocketFrame) { + var frame = (BinaryWebSocketFrame) msg; + logger.trace("->B {}", frame.content()); + handleBinaryPayload(frame.content()); + frame.release(); } else if (msg instanceof CloseWebSocketFrame) { var frame = (CloseWebSocketFrame) msg; if (logger.isDebugEnabled()) { diff --git a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java index 04ae51e..d38d08d 100644 --- a/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java +++ b/core/src/main/java/moe/kyokobot/koe/gateway/MediaGatewayV8Connection.java @@ -1,5 +1,6 @@ package moe.kyokobot.koe.gateway; +import io.netty.buffer.ByteBuf; import moe.kyokobot.koe.VoiceServerInfo; import moe.kyokobot.koe.codec.Codec; import moe.kyokobot.koe.codec.DefaultCodecs; @@ -46,7 +47,9 @@ protected void identify() { .addAsString("user_id", connection.getClient().getClientId()) .add("session_id", voiceServerInfo.getSessionId()) .add("token", voiceServerInfo.getToken()) - .add("video", true)); + .add("video", true) + .add("max_dave_protocol_version", 0) + .add("streams", new JsonArray())); } @Override @@ -151,6 +154,13 @@ protected void handlePayload(JsonObject object) { } } + @Override + protected void handleBinaryPayload(ByteBuf buffer) { + sequence = buffer.readUnsignedShort(); + + var op = buffer.readUnsignedByte(); + } + @Override protected void onClose(int code, @Nullable String reason, boolean remote) { super.onClose(code, reason, remote); From 36726d4fa191922ce36cb8884eb1b332743006e9 Mon Sep 17 00:00:00 2001 From: Alula <6276139+alula@users.noreply.github.com> Date: Sun, 17 Nov 2024 13:47:56 +0100 Subject: [PATCH 5/6] Add our BouncyCastle MLS fork to the tree --- README.md | 2 + build.gradle | 8 + core/build.gradle | 12 +- dave/build.gradle | 6 + .../moe/kyokobot/koe/dave/mls/MLSSession.java | 665 +++++ .../moe/kyokobot/koe/dave/util/MLSUtil.java | 46 + ext-udpqueue/build.gradle | 2 +- mls/LICENSE | 27 + mls/README.md | 31 + mls/build.gradle | 6 + .../moe/kyokobot/koe/mls/GroupKeySet.java | 308 +++ .../moe/kyokobot/koe/mls/KeyGeneration.java | 43 + .../kyokobot/koe/mls/KeyScheduleEpoch.java | 486 ++++ .../moe/kyokobot/koe/mls/TranscriptHash.java | 88 + .../kyokobot/koe/mls/TreeKEM/LeafIndex.java | 116 + .../kyokobot/koe/mls/TreeKEM/LeafNode.java | 330 +++ .../koe/mls/TreeKEM/LeafNodeHashInput.java | 35 + .../koe/mls/TreeKEM/LeafNodeSource.java | 29 + .../kyokobot/koe/mls/TreeKEM/LifeTime.java | 52 + .../moe/kyokobot/koe/mls/TreeKEM/Node.java | 84 + .../kyokobot/koe/mls/TreeKEM/NodeIndex.java | 181 ++ .../koe/mls/TreeKEM/OptionalNode.java | 62 + .../koe/mls/TreeKEM/ParentHashInput.java | 39 + .../kyokobot/koe/mls/TreeKEM/ParentNode.java | 42 + .../koe/mls/TreeKEM/ParentNodeHashInput.java | 41 + .../koe/mls/TreeKEM/TreeHashInput.java | 64 + .../koe/mls/TreeKEM/TreeKEMPrivateKey.java | 318 +++ .../koe/mls/TreeKEM/TreeKEMPublicKey.java | 871 ++++++ .../moe/kyokobot/koe/mls/TreeKEM/Utils.java | 16 + .../java/moe/kyokobot/koe/mls/TreeSize.java | 35 + .../koe/mls/codec/AuthenticatedContent.java | 168 ++ .../kyokobot/koe/mls/codec/Capabilities.java | 76 + .../kyokobot/koe/mls/codec/Certificate.java | 22 + .../moe/kyokobot/koe/mls/codec/Commit.java | 87 + .../kyokobot/koe/mls/codec/ContentType.java | 33 + .../kyokobot/koe/mls/codec/Credential.java | 84 + .../koe/mls/codec/CredentialType.java | 51 + .../koe/mls/codec/EncryptedGroupSecrets.java | 33 + .../moe/kyokobot/koe/mls/codec/Extension.java | 102 + .../kyokobot/koe/mls/codec/ExtensionType.java | 62 + .../koe/mls/codec/ExternalSender.java | 46 + .../kyokobot/koe/mls/codec/FramedContent.java | 163 ++ .../moe/kyokobot/koe/mls/codec/Grease.java | 55 + .../kyokobot/koe/mls/codec/GroupContext.java | 76 + .../moe/kyokobot/koe/mls/codec/GroupInfo.java | 133 + .../kyokobot/koe/mls/codec/GroupSecrets.java | 40 + .../koe/mls/codec/HPKECiphertext.java | 42 + .../kyokobot/koe/mls/codec/KeyPackage.java | 116 + .../koe/mls/codec/MLSInputStream.java | 246 ++ .../kyokobot/koe/mls/codec/MLSMessage.java | 512 ++++ .../koe/mls/codec/MLSOutputStream.java | 128 + .../moe/kyokobot/koe/mls/codec/NodeType.java | 32 + .../moe/kyokobot/koe/mls/codec/PSKType.java | 32 + .../kyokobot/koe/mls/codec/PathSecret.java | 33 + .../koe/mls/codec/PreSharedKeyID.java | 141 + .../koe/mls/codec/PrivateMessage.java | 234 ++ .../moe/kyokobot/koe/mls/codec/Proposal.java | 418 +++ .../kyokobot/koe/mls/codec/ProposalOrRef.java | 77 + .../koe/mls/codec/ProposalOrRefType.java | 32 + .../kyokobot/koe/mls/codec/ProposalType.java | 58 + .../koe/mls/codec/ProtocolVersion.java | 27 + .../kyokobot/koe/mls/codec/PublicMessage.java | 119 + .../koe/mls/codec/ResumptionPSKUsage.java | 33 + .../moe/kyokobot/koe/mls/codec/Sender.java | 94 + .../kyokobot/koe/mls/codec/SenderType.java | 27 + .../kyokobot/koe/mls/codec/UpdatePath.java | 53 + .../koe/mls/codec/UpdatePathNode.java | 46 + .../koe/mls/codec/ValidatedContent.java | 13 + .../moe/kyokobot/koe/mls/codec/Varint.java | 67 + .../moe/kyokobot/koe/mls/codec/Welcome.java | 144 + .../kyokobot/koe/mls/codec/WireFormat.java | 28 + .../moe/kyokobot/koe/mls/crypto/MlsAead.java | 16 + .../koe/mls/crypto/MlsCipherSuite.java | 227 ++ .../moe/kyokobot/koe/mls/crypto/MlsKdf.java | 20 + .../kyokobot/koe/mls/crypto/MlsSigner.java | 32 + .../moe/kyokobot/koe/mls/crypto/Secret.java | 141 + .../kyokobot/koe/mls/crypto/bc/BcMlsAead.java | 104 + .../kyokobot/koe/mls/crypto/bc/BcMlsKdf.java | 64 + .../koe/mls/crypto/bc/BcMlsSigner.java | 261 ++ .../koe/mls/protocol/CachedProposal.java | 18 + .../koe/mls/protocol/CachedUpdate.java | 22 + .../moe/kyokobot/koe/mls/protocol/Group.java | 2415 +++++++++++++++++ settings.gradle | 2 + 83 files changed, 11343 insertions(+), 7 deletions(-) create mode 100644 dave/build.gradle create mode 100644 dave/src/main/java/moe/kyokobot/koe/dave/mls/MLSSession.java create mode 100644 dave/src/main/java/moe/kyokobot/koe/dave/util/MLSUtil.java create mode 100644 mls/LICENSE create mode 100644 mls/README.md create mode 100644 mls/build.gradle create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/GroupKeySet.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/KeyGeneration.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/KeyScheduleEpoch.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TranscriptHash.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/LeafIndex.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/LeafNode.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/LeafNodeHashInput.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/LeafNodeSource.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/LifeTime.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/Node.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/NodeIndex.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/OptionalNode.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/ParentHashInput.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/ParentNode.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/ParentNodeHashInput.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/TreeHashInput.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/TreeKEMPrivateKey.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/TreeKEMPublicKey.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeKEM/Utils.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/TreeSize.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/AuthenticatedContent.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Capabilities.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Certificate.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Commit.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ContentType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Credential.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/CredentialType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/EncryptedGroupSecrets.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Extension.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ExtensionType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ExternalSender.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/FramedContent.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Grease.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/GroupContext.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/GroupInfo.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/GroupSecrets.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/HPKECiphertext.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/KeyPackage.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/MLSInputStream.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/MLSMessage.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/MLSOutputStream.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/NodeType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/PSKType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/PathSecret.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/PreSharedKeyID.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/PrivateMessage.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Proposal.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ProposalOrRef.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ProposalOrRefType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ProposalType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ProtocolVersion.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/PublicMessage.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ResumptionPSKUsage.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Sender.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/SenderType.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/UpdatePath.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/UpdatePathNode.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/ValidatedContent.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Varint.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/Welcome.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/codec/WireFormat.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/MlsAead.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/MlsCipherSuite.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/MlsKdf.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/MlsSigner.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/Secret.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/bc/BcMlsAead.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/bc/BcMlsKdf.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/crypto/bc/BcMlsSigner.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/protocol/CachedProposal.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/protocol/CachedUpdate.java create mode 100644 mls/src/main/java/moe/kyokobot/koe/mls/protocol/Group.java diff --git a/README.md b/README.md index 30b822c..dcd56ca 100644 --- a/README.md +++ b/README.md @@ -49,3 +49,5 @@ Koe includes modified/stripped-down parts based on following open-source project - [tweetnacl-java](https://github.com/InstantWebP2P/tweetnacl-java) (Poly1305, SecretBox) - [nanojson](https://github.com/mmastrac/nanojson) (modified for bytebuf support, changed the API a bit and etc.) +- [BouncyCastle MLS](https://github.com/bcgit/bc-java/tree/1.79) (see README.md in mls module for more info) +- [libdave](https://github.com/discord/libdave) (Koe's DAVE implementation is heavily based on reference C++ implementation) diff --git a/build.gradle b/build.gradle index 6c204e3..17876d2 100644 --- a/build.gradle +++ b/build.gradle @@ -21,6 +21,14 @@ def getGitVersion() { return [versionStr.toString().trim(), true] } +ext { + nettyVersion = '4.1.112.Final' + slf4jVersion = '1.8.0-beta4' + tinkVersion = '1.15.0' + bouncyCastleVersion = '1.79' + jetbrainsAnnotationsVersion = '13.0' +} + subprojects { apply plugin: 'maven-publish' apply plugin: 'java-library' diff --git a/core/build.gradle b/core/build.gradle index 9b757c1..354155c 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -1,8 +1,8 @@ dependencies { - api group: 'io.netty', name: 'netty-transport', version: '4.1.112.Final' - implementation group: 'io.netty', name: 'netty-codec-http', version: '4.1.112.Final' - implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: '4.1.112.Final', classifier: 'linux-x86_64' - implementation group: 'org.slf4j', name: 'slf4j-api', version: '1.8.0-beta4' - implementation group: 'com.google.crypto.tink', name: 'tink', version: '1.14.1' - compileOnly group: 'org.jetbrains', name: 'annotations', version: '13.0' + api group: 'io.netty', name: 'netty-transport', version: "$nettyVersion" + implementation group: 'io.netty', name: 'netty-codec-http', version: "$nettyVersion" + implementation group: 'io.netty', name: 'netty-transport-native-epoll', version: "$nettyVersion", classifier: 'linux-x86_64' + implementation group: 'org.slf4j', name: 'slf4j-api', version: "$slf4jVersion" + implementation group: 'com.google.crypto.tink', name: 'tink', version: "$tinkVersion" + compileOnly group: 'org.jetbrains', name: 'annotations', version: "$jetbrainsAnnotationsVersion" } diff --git a/dave/build.gradle b/dave/build.gradle new file mode 100644 index 0000000..47d0327 --- /dev/null +++ b/dave/build.gradle @@ -0,0 +1,6 @@ +dependencies { + implementation project(':mls') + implementation group: 'org.slf4j', name: 'slf4j-api', version: "$slf4jVersion" + compileOnly group: 'org.jetbrains', name: 'annotations', version: "$jetbrainsAnnotationsVersion" +} + diff --git a/dave/src/main/java/moe/kyokobot/koe/dave/mls/MLSSession.java b/dave/src/main/java/moe/kyokobot/koe/dave/mls/MLSSession.java new file mode 100644 index 0000000..b163016 --- /dev/null +++ b/dave/src/main/java/moe/kyokobot/koe/dave/mls/MLSSession.java @@ -0,0 +1,665 @@ +package moe.kyokobot.koe.dave.mls; + +import moe.kyokobot.koe.dave.DAVEException; +import moe.kyokobot.koe.dave.KeyRatchet; +import moe.kyokobot.koe.dave.PersistentKeyManager; +import moe.kyokobot.koe.dave.util.MLSUtil; +import moe.kyokobot.koe.mls.TreeKEM.LeafNode; +import moe.kyokobot.koe.mls.TreeKEM.LifeTime; +import moe.kyokobot.koe.mls.codec.*; +import moe.kyokobot.koe.mls.crypto.Secret; +import moe.kyokobot.koe.mls.protocol.Group; +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.generators.SCrypt; +import org.bouncycastle.util.encoders.Hex; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.security.SecureRandom; +import java.util.*; +import java.util.function.Consumer; + +public class MLSSession { + private static final Logger logger = LoggerFactory.getLogger(MLSSession.class); + + @Nullable + private final PersistentKeyManager persistentKeyManager; + private int protocolVersion; + private byte[] groupId = new byte[0]; + private final String signingKeyId; + private String selfUserId = ""; + private LeafNode selfLeafNode; + private AsymmetricCipherKeyPair selfSigPrivateKey; + private AsymmetricCipherKeyPair selfHPKEPrivateKey; + + private AsymmetricCipherKeyPair joinInitPrivateKey; + private KeyPackage joinKeyPackage; + + private ExternalSender externalSender; + + private Group pendingGroupState; + private MLSMessage pendingGroupCommit; + + private Group outboundCachedGroupState; + + private Group currentState; + private Map roster; + + private Group stateWithProposals; + private final Deque proposalQueue = new LinkedList<>(); + + /** + * Creates a new MLS session. + * + * @param signingKeyId ID of a persistent key to use for this session. + * @param persistentKeyManager Optional instance of {@link PersistentKeyManager} if you want to use persistent keys. + */ + public MLSSession(@Nullable String signingKeyId, @Nullable PersistentKeyManager persistentKeyManager) { + this.persistentKeyManager = persistentKeyManager; + this.signingKeyId = signingKeyId == null ? "" : signingKeyId; + } + + public void init(int protocolVersion, long groupId, String selfUserId, AsymmetricCipherKeyPair transientKey) throws DAVEException { + this.selfUserId = selfUserId; + this.protocolVersion = protocolVersion; + this.groupId = MLSUtil.bigEndianBytesFrom(groupId); + + logger.info("Initializing MLS session with protocol version {} and group ID {}", protocolVersion, groupId); + + initLeafNode(selfUserId, transientKey); + createPendingGroup(); + } + + public void reset() { + logger.info("Resetting MLS session"); + + clearPendingState(); + + currentState = null; + outboundCachedGroupState = null; + + protocolVersion = 0; + groupId = null; + } + + public void setProtocolVersion(int version) { + if (protocolVersion != version) { + // when we need to retain backwards compatibility + // there may be some changes to the MLS objects required here + // until then we can just update the stored version + protocolVersion = version; + } + } + + public int getProtocolVersion() { + return protocolVersion; + } + + public byte[] getLastEpochAuthenticator() throws DAVEException { + if (currentState == null) { + throw new DAVEException("Cannot get epoch authenticator without an established MLS group"); + } + + return currentState.getEpochAuthenticator(); + } + + public void setExternalSender(byte[] marshalledExternalSender) throws DAVEException { + if (currentState != null) { + throw new DAVEException("Cannot set external sender after joining/creating an MLS group"); + } + + logger.info("Unmarshalling MLS external sender"); + logger.info("Sender: {}", Hex.encode(marshalledExternalSender)); + + try { + externalSender = new ExternalSender(new MLSInputStream(marshalledExternalSender)); + + if (groupId != null && groupId.length > 0) { + createPendingGroup(); + } + } catch (Exception e) { + throw new DAVEException("Failed to unmarshal external sender", e); + } + } + + @Nullable + public byte[] processProposals(byte[] proposals, Set recognizedUserIDs) throws DAVEException { + try { + if (pendingGroupState == null && currentState == null) { + throw new DAVEException("Cannot process proposals without any pending or established MLS group state"); + } + + if (stateWithProposals == null) { + stateWithProposals = (pendingGroupState != null ? pendingGroupState : currentState); + } + + logger.info("Processing MLS proposals message of {} bytes", proposals.length); + logger.info("Proposals: {}", Hex.toHexString(proposals)); + + MLSInputStream inStream = new MLSInputStream(proposals); + + boolean isRevoke = (boolean) inStream.read(boolean.class); + logger.info("Revoking: {}", isRevoke); + + var suite = stateWithProposals.getSuite(); + + if (isRevoke) { + var refs = new ArrayList(); + inStream.readList(refs, byte[].class); + + for (byte[] ref : refs) { + boolean found = false; + for (var it = proposalQueue.iterator(); it.hasNext(); ) { + var prop = it.next(); + if (Arrays.equals(prop.ref, ref)) { + found = true; + it.remove(); + break; + } + } + + if (!found) { + throw new DAVEException("Cannot revoke unrecognized proposal ref"); + } + } + + stateWithProposals = (pendingGroupState != null ? pendingGroupState : currentState); + + for (QueuedProposal prop : proposalQueue) { + stateWithProposals.handle(prop.content, null, null); + } + } else { + var messages = new ArrayList(); + inStream.readList(messages, MLSMessage.class); + + for (var proposalMessage : messages) { + var validatedMessage = stateWithProposals.unwrap(proposalMessage); + + if (!validateProposalMessage(validatedMessage.getAuthenticatedContent(), stateWithProposals, recognizedUserIDs)) { + return null; + } + + stateWithProposals.handle(validatedMessage.getAuthenticatedContent(), null, null); + + var ref = suite.refHash(MLSOutputStream.encode(validatedMessage.getAuthenticatedContent()), "MLS 1.0 Proposal Reference"); + + proposalQueue.add(new QueuedProposal(validatedMessage.getAuthenticatedContent(), ref)); + } + } + + // generate a commit + var commitSecret = new byte[suite.getKDF().getHashLength()]; + SecureRandom random = new SecureRandom(); + random.nextBytes(commitSecret); + + // Commit the Add and PSK proposals + var commitOpts = new Group.CommitOptions( + List.of(), // no extra proposals + true, // inline tree in welcome + false, // do not force path + new Group.LeafNodeOptions() + ); + + var commit = stateWithProposals.commit(new Secret(commitSecret), commitOpts, new Group.MessageOptions(), + new Group.CommitParameters(Group.NORMAL_COMMIT_PARAMS)); + + logger.info("Prepared commit/welcome/next state for MLS group from received proposals"); + + var outStream = new MLSOutputStream(); + outStream.write(commit.message); + + pendingGroupCommit = commit.message; + + if (!commit.message.welcome.getSecrets().isEmpty()) { + outStream.write(commit.message.welcome); + } + + outboundCachedGroupState = commit.group; + + logger.info("Output: {}", Hex.toHexString(outStream.toByteArray())); + + return outStream.toByteArray(); + } catch (Exception e) { + throw new DAVEException("Failed to parse MLS proposals", e); + } + } + + private boolean isRecognizedUserID(Credential cred, Set recognizedUserIDs) { + String uid = UserCredential.userCredentialToString(cred, protocolVersion); + if (uid.isEmpty()) { + logger.error("Attempted to verify credential of unexpected type"); + return false; + } + + if (!recognizedUserIDs.contains(uid)) { + logger.error("Attempted to verify credential for unrecognized user ID: {}", uid); + return false; + } + + return true; + } + + private boolean validateProposalMessage(AuthenticatedContent message, Group targetGroup, Set recognizedUserIds) { + if (message.getWireFormat() != WireFormat.mls_private_message) { + logger.error("MLS proposal message must be PublicMessage"); + return false; + } + + if (message.getContent().getEpoch() != targetGroup.getEpoch()) { + logger.error("MLS proposal message must be for current epoch ({}) != {}", message.getContent().getEpoch(), targetGroup.getEpoch()); + return false; + } + + if (message.getContent().getContentType() != ContentType.PROPOSAL) { + logger.error("ProcessProposals called with non-proposal message"); + return false; + } + + if (message.getContent().getSender().getSenderType() != SenderType.EXTERNAL) { + logger.error("MLS proposal must be from external sender"); + return false; + } + + var proposal = message.getContent().getProposal(); + switch (proposal.getProposalType()) { + case ADD: + var credential = proposal.getAdd().keyPackage.getLeafNode().getCredential(); + if (!isRecognizedUserID(credential, recognizedUserIds)) { + logger.error("MLS add proposal must be for recognized user"); + return false; + } + break; + case REMOVE: + // Remove proposals are always allowed (mlspp will validate that it's a recognized user) + break; + default: + logger.error("MLS proposal must be add or remove"); + return false; + } + + return true; + } + + private boolean canProcessCommit(MLSMessage commit) { + if (stateWithProposals == null) { + return false; + } + + if (!Arrays.equals(commit.getGroupId(), groupId)) { + logger.error("MLS commit message was for unexpected group"); + return false; + } + + return true; + } + + @NotNull + public RosterVariant processCommit(byte[] commit) { + try { + logger.info("Processing commit"); + logger.info("Commit: {}", Hex.toHexString(commit)); + + var commitMessage = (MLSMessage) MLSInputStream.decode(commit, MLSMessage.class); + + if (!canProcessCommit(commitMessage)) { + logger.error("ProcessCommit called with unprocessable MLS commit"); + return new RosterVariant.Ignored(); + } + + Group optionalCachedGroup = null; + if (outboundCachedGroupState != null) { + optionalCachedGroup = outboundCachedGroupState; + } + + var newState = stateWithProposals.handle(stateWithProposals.unwrap(commitMessage).getAuthenticatedContent(), + optionalCachedGroup, null); + + if (newState == null) { + logger.error("MLS commit handling did not produce a new state"); + return new RosterVariant.Failed(); + } + + logger.info("Successfully processed MLS commit, updating state; our leaf index is {}; current epoch is {}", + newState.getIndex().value(), newState.getEpoch()); + + var ret = replaceState(newState); + outboundCachedGroupState = null; + clearPendingState(); + + return ret; + } catch (Exception e) { + logger.error("Failed to process MLS commit: {}", e.getMessage()); + return new RosterVariant.Failed(); + } + } + + @Nullable + public RosterVariant.RosterMap processWelcome(byte[] welcome, Set recognizedUserIDs) { + try { + if (!hasCryptographicStateForWelcome()) { + logger.error("Missing local crypto state necessary to process MLS welcome"); + return null; + } + + if (externalSender == null) { + logger.error("Cannot process MLS welcome without an external sender"); + return null; + } + + if (currentState != null) { + logger.error("Cannot process MLS welcome after joining/creating an MLS group"); + return null; + } + + logger.info("Processing welcome: {}", Hex.toHexString(welcome)); + + var unmarshalledWelcome = (MLSMessage) MLSInputStream.decode(welcome, MLSMessage.class); + var suite = unmarshalledWelcome.welcome.getSuite(); + + // TODO: BC MLS does redundant serialization ;w; + var newState = new Group(suite.getHPKE().serializePrivateKey(joinInitPrivateKey.getPrivate()), + selfHPKEPrivateKey, + suite.getHPKE().serializePrivateKey(selfSigPrivateKey.getPrivate()), + joinKeyPackage, + unmarshalledWelcome.welcome, + null, + new HashMap<>(), + new HashMap<>()); + + if (!verifyWelcomeState(newState, recognizedUserIDs)) { + logger.error("Group received in MLS welcome is not valid"); + return null; + } + + logger.info("Successfully welcomed to MLS Group, our leaf index is {}; current epoch is {}", + newState.getIndex().value(), newState.getEpoch()); + + var ret = replaceState(newState); + + clearPendingState(); + + return ret; + } catch (Exception e) { + logger.error("Failed to create group state from MLS welcome: {}", e.getMessage()); + return null; + } + } + + private RosterVariant.RosterMap replaceState(Group state) { + var newRoster = new HashMap(); + for (var node : state.roster()) { + if (node.getCredential().getCredentialType() == CredentialType.basic) { + newRoster.put(MLSUtil.fromBigEndianBytes(node.getCredential().getIdentity()), node.getSignatureKey()); + } + } + + var changeMap = new HashMap(); + + newRoster.forEach((key, value) -> { + if (!roster.containsKey(key)) { + changeMap.put(key, new byte[0]); + } + }); + + roster.forEach((key, value) -> { + if (!newRoster.containsKey(key)) { + changeMap.put(key, new byte[0]); + } + }); + + roster = newRoster; + currentState = state; + + return new RosterVariant.RosterMap(changeMap); + } + + private boolean hasCryptographicStateForWelcome() { + return joinKeyPackage != null && joinInitPrivateKey != null && selfSigPrivateKey != null && selfHPKEPrivateKey != null; + } + + private boolean verifyWelcomeState(Group state, Set recognizedUserIDs) { + if (externalSender == null) { + logger.error("Cannot verify MLS welcome without an external sender"); + return false; + } + + var ext = state.getExtensions().stream().filter(extension -> extension.extensionType == ExtensionType.EXTERNAL_SENDERS).findFirst(); + if (ext.isEmpty()) { + logger.error("MLS welcome missing external senders extension"); + return false; + } + + List senders; + try { + senders = ext.get().getSenders(); + } catch (IOException e) { + logger.error("Failed to read external senders", e); + return false; + } + + if (senders.size() != 1) { + logger.error("MLS welcome lists unexpected number of external senders: {}", senders.size()); + return false; + } + + if (!senders.get(0).equals(externalSender)) { + logger.error("MLS welcome lists unexpected external sender"); + return false; + } + + for (var leaf : state.roster()) { + if (!isRecognizedUserID(leaf.getCredential(), recognizedUserIDs)) { + logger.error("MLS welcome lists unrecognized user ID"); + // TRACK_MLS_ERROR("Welcome message lists unrecognized user ID"); + // return false; + } + } + + return true; + } + + private void initLeafNode(String selfUserId, @Nullable AsymmetricCipherKeyPair transientKey) throws DAVEException { + var cipherSuite = Parameters.ciphersuiteForProtocolVersion(protocolVersion); + + if (transientKey == null) { + if (signingKeyId.isEmpty()) { + //Generate a new key pair + transientKey = cipherSuite.generateSignatureKeyPair(); + } else if (persistentKeyManager != null) { + transientKey = persistentKeyManager.getKeyPair(signingKeyId, protocolVersion); + } else { + throw new DAVEException("Did not receive MLS signature private key!"); + } + } + + selfSigPrivateKey = transientKey; + + var selfCredential = UserCredential.createUserCredential(selfUserId, protocolVersion); + + selfHPKEPrivateKey = cipherSuite.getHPKE().generatePrivateKey(); + + try { + Objects.requireNonNull(selfSigPrivateKey); + Objects.requireNonNull(selfHPKEPrivateKey); + + selfLeafNode = new LeafNode(cipherSuite, + cipherSuite.getHPKE().serializePublicKey(selfHPKEPrivateKey.getPublic()), + cipherSuite.serializeSignaturePublicKey(selfSigPrivateKey.getPublic()), + selfCredential, + Parameters.leafNodeCapabilitiesForProtocolVersion(protocolVersion), + new LifeTime(), + Parameters.leafNodeExtensionsForProtocolVersion(protocolVersion), + cipherSuite.serializeSignaturePrivateKey(selfSigPrivateKey.getPrivate())); + + logger.info("Created MLS leaf node"); + } catch (Exception e) { + throw new DAVEException("Failed to initialize MLS leaf node", e); + } + } + + private void resetJoinKeyPackage() throws DAVEException { + try { + if (selfLeafNode == null) { + throw new DAVEException("Cannot initialize join key package without a leaf node"); + } + + var cipherSuite = Parameters.ciphersuiteForProtocolVersion(protocolVersion); + + joinInitPrivateKey = cipherSuite.getHPKE().generatePrivateKey(); + + joinKeyPackage = new KeyPackage(cipherSuite, + cipherSuite.getHPKE().serializePublicKey(joinInitPrivateKey.getPublic()), + selfLeafNode, + Parameters.leafNodeExtensionsForProtocolVersion(protocolVersion), + cipherSuite.serializeSignaturePrivateKey(selfSigPrivateKey.getPrivate())); + + logger.info("Generated key package: {}", Hex.toHexString(MLSOutputStream.encode(joinKeyPackage))); + } catch (Exception e) { + throw new DAVEException("Failed to initialize join key package", e); + } + } + + private void createPendingGroup() throws DAVEException { + try { + if (groupId == null || groupId.length == 0) { + throw new DAVEException("Cannot create MLS group without a group ID"); + } + + if (externalSender == null) { + throw new DAVEException("Cannot create MLS group without ExternalSender"); + } + + if (selfLeafNode == null) { + throw new DAVEException("Cannot create MLS group without self leaf node"); + } + + logger.info("Creating a pending MLS group"); + + var cipherSuite = Parameters.ciphersuiteForProtocolVersion(protocolVersion); + + pendingGroupState = new Group(groupId, cipherSuite, selfHPKEPrivateKey, + cipherSuite.serializeSignaturePrivateKey(selfSigPrivateKey.getPrivate()), + selfLeafNode, + Parameters.groupExtensionsForProtocolVersion(protocolVersion, externalSender)); + + logger.info("Created a pending MLS group"); + } catch (Exception e) { + throw new DAVEException("Failed to create MLS group", e); + } + } + + public byte[] getMarshalledKeyPackage() throws DAVEException { + try { + // key packages are not meant to be re-used + // so every time the client asks for a key package we create a new one + resetJoinKeyPackage(); + + if (joinKeyPackage == null) { + throw new DAVEException("Cannot marshal an uninitialized key package"); + } + + return MLSOutputStream.encode(joinKeyPackage); + } catch (Exception e) { + throw new DAVEException("Failed to marshal join key package", e); + } + } + + public KeyRatchet getKeyRatchet(String userId) throws DAVEException { + try { + if (currentState == null) { + throw new DAVEException("Cannot get key ratchet without an established MLS group"); + } + + // change the string user ID to a little endian 64 bit user ID + byte[] userIdBytes = MLSUtil.littleEndianBytesFromString(userId); + + // generate the base secret for the hash ratchet + byte[] baseSecret = currentState.getKeySchedule().MLSExporter("User Media Key Base Label", userIdBytes, 16); + + // this assumes the MLS ciphersuite produces a 16 byte key + // would need to be updated to a different ciphersuite if there's a future mismatch + return new MLSKeyRatchet(currentState.getSuite(), baseSecret); + } catch (Exception e) { + throw new DAVEException("Failed to get key ratchet", e); + } + } + + private static final byte[] SALT = Hex.decode("24cab17a7af8ec2b82b412b92dab192e"); + + public void getPairwiseFingerprint(int version, String userId, Consumer callback) throws DAVEException { + try { + if (currentState == null || selfSigPrivateKey == null) { + throw new DAVEException("Cannot get pairwise fingerprint without an established MLS group"); + } + + long u64RemoteUserId = Long.parseLong(userId); + long u64SelfUserId = Long.parseLong(selfUserId); + + byte[] remoteUserIdBytes = MLSUtil.bigEndianBytesFrom(u64RemoteUserId); + byte[] selfUserIdBytes = MLSUtil.bigEndianBytesFrom(u64SelfUserId); + + MLSOutputStream toHash1 = new MLSOutputStream(); + MLSOutputStream toHash2 = new MLSOutputStream(); + + toHash1.write(version); + toHash1.write(remoteUserIdBytes); + + toHash2.write(version); + toHash2.write(currentState.getSuite().serializeSignaturePublicKey(selfSigPrivateKey.getPublic())); + toHash2.write(selfUserIdBytes); + + var keyData = new ArrayList(); + keyData.add(toHash1.toByteArray()); + keyData.add(toHash2.toByteArray()); + keyData.sort(Arrays::compare); + + var data = new byte[keyData.get(0).length + keyData.get(1).length]; + System.arraycopy(keyData.get(0), 0, data, 0, keyData.get(0).length); + System.arraycopy(keyData.get(1), 0, data, keyData.get(0).length, keyData.get(1).length); + + int N = 16384, r = 8, p = 2, maxMem = 32 * 1024 * 1024; + int hashLen = 64; + + var out = SCrypt.generate(data, SALT, N, r, p, hashLen); + callback.accept(out); + } catch (Exception e) { + throw new DAVEException("Failed to generate pairwise fingerprint", e); + } + } + + private void clearPendingState() { + pendingGroupState = null; + pendingGroupCommit = null; + + joinInitPrivateKey = null; + joinKeyPackage = null; + + selfHPKEPrivateKey = null; + + selfLeafNode = null; + + stateWithProposals = null; + proposalQueue.clear(); + } + + static class QueuedProposal { + private final AuthenticatedContent content; + private final byte[] ref; + + public QueuedProposal(AuthenticatedContent content, byte[] ref) { + this.content = content; + this.ref = ref; + } + + public AuthenticatedContent getContent() { + return content; + } + + public byte[] getRef() { + return ref; + } + } +} diff --git a/dave/src/main/java/moe/kyokobot/koe/dave/util/MLSUtil.java b/dave/src/main/java/moe/kyokobot/koe/dave/util/MLSUtil.java new file mode 100644 index 0000000..f96852a --- /dev/null +++ b/dave/src/main/java/moe/kyokobot/koe/dave/util/MLSUtil.java @@ -0,0 +1,46 @@ +package moe.kyokobot.koe.dave.util; + +public class MLSUtil { + /** + * Turns a long into a byte array containing a big endian u64. + * @param value The long to convert. + * @return The byte array. + */ + public static byte[] bigEndianBytesFrom(long value) { + byte[] bytes = new byte[8]; + for (int i = 7; i >= 0; i--) { + bytes[i] = (byte) (value & 0xFF); + value >>= 8; + } + return bytes; + } + + /** + * Turns a byte array containing a big endian u64 into a long. + * @param bytes The byte array to convert. + * @return The long. + */ + public static long fromBigEndianBytes(byte[] bytes) { + long value = 0; + for (int i = 0; i < 8; i++) { + value <<= 8; + value |= bytes[i] & 0xFF; + } + return value; + } + + /** + * Turns a String containing an unsigned long into a byte array containing a little endian u64. + * @param value The string to convert. + * @return The byte array. + */ + public static byte[] littleEndianBytesFromString(String value) { + long longValue = Long.parseUnsignedLong(value); + byte[] bytes = new byte[8]; + for (int i = 0; i < 8; i++) { + bytes[i] = (byte) (longValue & 0xFF); + longValue >>= 8; + } + return bytes; + } +} diff --git a/ext-udpqueue/build.gradle b/ext-udpqueue/build.gradle index a730e69..a184267 100644 --- a/ext-udpqueue/build.gradle +++ b/ext-udpqueue/build.gradle @@ -2,5 +2,5 @@ dependencies { compileOnly project(':core') implementation 'dev.arbjerg:lava-common:2.2.1' implementation group: 'club.minnced', name: 'udpqueue-api', version: '0.2.9' - compileOnly group: 'org.jetbrains', name: 'annotations', version: '13.0' + compileOnly group: 'org.jetbrains', name: 'annotations', version: "$jetbrainsAnnotationsVersion" } diff --git a/mls/LICENSE b/mls/LICENSE new file mode 100644 index 0000000..e91af8e --- /dev/null +++ b/mls/LICENSE @@ -0,0 +1,27 @@ +Please note the Bouncy Caste License should be read in the same way as the MIT license. + +Please also note this licensing model is made possible through funding from donations and the sale of support contracts. + +Bouncy Castle License + +Copyright (c) 2000 - 2024 The Legion of the Bouncy Castle Inc. (https://www.bouncycastle.org) +Copyright (c) 2019 Alula + + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/mls/README.md b/mls/README.md new file mode 100644 index 0000000..dea6a2c --- /dev/null +++ b/mls/README.md @@ -0,0 +1,31 @@ +# BouncyCastle MLS (Koe fork) + +This is a fork of [BouncyCastle's](https://www.bouncycastle.org/) MLS library. + +### Personal notes / changes + +- BC MLS lib seems to be a direct port of Cisco's MLS++ library (so what libdave uses) and most things map 1:1. + - https://github.com/bcgit/bc-java/issues/1317 +- When I noticed first roadblocks and realized I have to modify the source code, I first tried to port it to Tink. + - However Tink seems to be pretty badly documented. I had to heavily rely on the source code. + - It's missing certain HKDF functionality, so I had to either reimplement it using code from `internal` package + (which does not sound future-proof) or completely reinvent the wheel. + - I decided to give up and just stick with BouncyCastle and maybe replace Tink with it in Koe to reduce amount of + dependencies that do the same thing. +- I've removed dependency on gRPC and Protobuf + - Why does BC MLS ship with these? + - It's only purpose seems to be interop testing apparently: https://github.com/bcgit/bc-java/pull/1565 +- `Capabilities` class doesn't expose any of it's fields, so I couldn't change extensions and supported suites to what + DAVE expects without parsing a serialized message... +- `Group` was missing an equivalent of `State::unwrap` method. +- `Welcome` has no clean way to access `secrets`, `EncryptedGroupSecrets` were private. +- The library has many methods that only accept serialized messages and even does redundant re-serialization internally + in some places :( + - most notably key handling? (ctrl+f `serializePrivateKey`) +- Added `TreeKEMPublicKey.allLeaves` method. +- Added `equals` to `ExternalSender`. + +### Possible replacements + +- https://github.com/Traderjoe95/mls-kotlin (written in Kotlin, but requires Java 21) +- write our own? diff --git a/mls/build.gradle b/mls/build.gradle new file mode 100644 index 0000000..975bbae --- /dev/null +++ b/mls/build.gradle @@ -0,0 +1,6 @@ +dependencies { + implementation group: 'org.slf4j', name: 'slf4j-api', version: "$slf4jVersion" + implementation group: 'org.bouncycastle', name: 'bcutil-jdk18on', version: "$bouncyCastleVersion" + api group: 'org.bouncycastle', name: 'bcprov-jdk18on', version: "$bouncyCastleVersion" +} + diff --git a/mls/src/main/java/moe/kyokobot/koe/mls/GroupKeySet.java b/mls/src/main/java/moe/kyokobot/koe/mls/GroupKeySet.java new file mode 100644 index 0000000..81188c7 --- /dev/null +++ b/mls/src/main/java/moe/kyokobot/koe/mls/GroupKeySet.java @@ -0,0 +1,308 @@ +package moe.kyokobot.koe.mls; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.InvalidParameterException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import moe.kyokobot.koe.mls.codec.ContentType; +import moe.kyokobot.koe.mls.crypto.MlsCipherSuite; +import moe.kyokobot.koe.mls.TreeKEM.LeafIndex; +import moe.kyokobot.koe.mls.TreeKEM.NodeIndex; +import moe.kyokobot.koe.mls.crypto.Secret; + +public class GroupKeySet +{ + final MlsCipherSuite suite; + final int secretSize; + // We store a commitment to the encryption secret that was used to create this structure, so that we can compare + // for purposes of equivalence checking without violating forward secrecy. + final Secret encryptionSecretCommit; + + public SecretTree secretTree; + Map handshakeRatchets; + Map applicationRatchets; + + + public GroupKeySet(MlsCipherSuite suite, TreeSize treeSize, Secret encryptionSecret) + throws IOException, IllegalAccessException + { + this.suite = suite; + this.secretSize = suite.getKDF().getHashLength(); + this.encryptionSecretCommit = encryptionSecret.deriveSecret(suite, "commitment"); + this.secretTree = new SecretTree(treeSize, encryptionSecret); + this.handshakeRatchets = new HashMap(); + this.applicationRatchets = new HashMap(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) + { + return true; + } + if (o == null || getClass() != o.getClass()) + { + return false; + } + GroupKeySet that = (GroupKeySet)o; + return secretSize == that.secretSize && suite.equals(that.suite) && encryptionSecretCommit.equals(that.encryptionSecretCommit); + } + + void initRatchets(LeafIndex sender) + throws IOException, IllegalAccessException + { + Secret leafSecret = secretTree.get(sender); + + Secret handshakeRatchetSecret = leafSecret.expandWithLabel(suite, "handshake", new byte[]{}, secretSize); + Secret applicationRatchetSecret = leafSecret.expandWithLabel(suite, "application", new byte[]{}, secretSize); + + HashRatchet handshakeRatchet = new HashRatchet(handshakeRatchetSecret); + HashRatchet applicationRatchet = new HashRatchet(applicationRatchetSecret); + + handshakeRatchets.put(sender, handshakeRatchet); + applicationRatchets.put(sender, applicationRatchet); + } + + public KeyGeneration get(ContentType contentType, LeafIndex sender, int generation, byte[] reuseGuard) + throws IOException, IllegalAccessException + { + HashRatchet chain; + + switch (contentType) + { + case APPLICATION: + chain = applicationRatchet(sender); + break; + case PROPOSAL: + case COMMIT: + chain = handshakeRatchet(sender); + break; + default: + return null; + } + + KeyGeneration keys = chain.get(generation); + ApplyReuseGuard(reuseGuard, keys.nonce); + return keys; + } + + public KeyGeneration get(ContentType contentType, LeafIndex sender, byte[] reuseGuard) + throws IOException, IllegalAccessException + { + HashRatchet chain; + + switch (contentType) + { + case APPLICATION: + chain = applicationRatchet(sender); + break; + case PROPOSAL: + case COMMIT: + chain = handshakeRatchet(sender); + break; + default: + return null; + } + + KeyGeneration keys = chain.next(); + ApplyReuseGuard(reuseGuard, keys.nonce); + return keys; + } + + private void ApplyReuseGuard(byte[] guard, byte[] nonce) + { + for (int i = 0; i < guard.length; i++) + { + nonce[i] ^= guard[i]; + } + } + + public void erase(ContentType contentType, LeafIndex sender, int generation) + throws IOException, IllegalAccessException + { + switch (contentType) + { + + case APPLICATION: + applicationRatchet(sender).erase(generation); + break; + case PROPOSAL: + case COMMIT: + handshakeRatchet(sender).erase(generation); + break; + } + } + + public HashRatchet handshakeRatchet(LeafIndex sender) + throws IOException, IllegalAccessException + { + if (!handshakeRatchets.containsKey(sender)) + { + initRatchets(sender); + } + return handshakeRatchets.get(sender); + } + + public HashRatchet applicationRatchet(LeafIndex sender) + throws IOException, IllegalAccessException + { + if (!applicationRatchets.containsKey(sender)) + { + initRatchets(sender); + } + return applicationRatchets.get(sender); + } + + public boolean hasLeaf(LeafIndex sender) + { + return secretTree.hasLeaf(sender); + } + + public class SecretTree + { + final TreeSize treeSize; + public Map secrets; + + public SecretTree(TreeSize treeSizeIn, Secret encryptionSecret) + { + treeSize = treeSizeIn; + secrets = new HashMap(); + secrets.put(NodeIndex.root(treeSize), encryptionSecret); + } + + protected boolean hasLeaf(LeafIndex sender) + { + return sender.value() < treeSize.leafCount(); + } + + public Secret get(LeafIndex leaf) + throws IOException, IllegalAccessException + { + + final byte[] leftLabel = "left".getBytes(StandardCharsets.UTF_8); + final byte[] rightLabel = "right".getBytes(StandardCharsets.UTF_8); + + NodeIndex rootNode = NodeIndex.root(treeSize); + NodeIndex leafNode = new NodeIndex(leaf); + + // Find an ancestor that is populated + List dirpath = leaf.directPath(treeSize); + dirpath.add(0, leafNode); + dirpath.add(rootNode); + int curr = 0; + for (; curr < dirpath.size(); curr++) + { + if (secrets.containsKey(dirpath.get(curr))) + { + break; + } + } + + if (curr > dirpath.size()) + { + throw new InvalidParameterException("No secret found to derive leaf key"); + } + + // Derive down + for (; curr > 0; curr--) + { + NodeIndex currNode = dirpath.get(curr); + NodeIndex left = currNode.left(); + NodeIndex right = currNode.right(); + + Secret secret = secrets.get(currNode); + secrets.put(left, secret.expandWithLabel(suite, "tree", leftLabel, secretSize)); + secrets.put(right, secret.expandWithLabel(suite, "tree", rightLabel, secretSize)); + } + + // Get the leaf secret + Secret leafSecret = secrets.get(leafNode); + + // Forget the secrets along the direct path + for (NodeIndex i : dirpath) + { + if (i.equals(leafNode)) + { + continue; + } + + if (secrets.containsKey(i)) + { + secrets.get(i).consume(); + secrets.remove(i); + } + } + + return leafSecret; + } + } + + public class HashRatchet + { + final int keySize; + final int nonceSize; + Secret nextSecret; + int nextGeneration; + Map cache; + + HashRatchet(Secret baseSecret) + { + keySize = suite.getAEAD().getKeySize(); + nonceSize = suite.getAEAD().getNonceSize(); + nextGeneration = 0; + nextSecret = baseSecret; + cache = new HashMap(); + } + + public KeyGeneration next() + throws IOException, IllegalAccessException + { + Secret key = nextSecret.deriveTreeSecret(suite, "key", nextGeneration, keySize); + Secret nonce = nextSecret.deriveTreeSecret(suite, "nonce", nextGeneration, nonceSize); + Secret secret = nextSecret.deriveTreeSecret(suite, "secret", nextGeneration, secretSize); + + KeyGeneration generation = new KeyGeneration(nextGeneration, key, nonce); + + nextGeneration += 1; + nextSecret.consume(); + nextSecret = secret; + + cache.put(generation.generation, generation); + return generation; + } + + public KeyGeneration get(int generation) + throws IOException, IllegalAccessException + { + if (cache.containsKey(generation)) + { + return cache.get(generation); + } + + if (nextGeneration > generation) + { + throw new InvalidParameterException("Request for expired key"); + } + + while (nextGeneration < generation) + { + next(); + } + + return next(); + } + + public void erase(int generation) + { + if (cache.containsKey(generation)) + { + cache.get(generation).consume(); + cache.remove(generation); + } + } + } +} diff --git a/mls/src/main/java/moe/kyokobot/koe/mls/KeyGeneration.java b/mls/src/main/java/moe/kyokobot/koe/mls/KeyGeneration.java new file mode 100644 index 0000000..bddec7e --- /dev/null +++ b/mls/src/main/java/moe/kyokobot/koe/mls/KeyGeneration.java @@ -0,0 +1,43 @@ +package moe.kyokobot.koe.mls; + +import java.util.Arrays; + +import moe.kyokobot.koe.mls.crypto.Secret; + +public class KeyGeneration +{ + public final int generation; + public final byte[] key; + public final byte[] nonce; + + public KeyGeneration(int generation, Secret key, Secret nonce) + { + this.generation = generation; + this.key = key.value().clone(); + this.nonce = nonce.value().clone(); + + key.consume(); + nonce.consume(); + } + + @Override + public boolean equals(Object o) + { + if (this == o) + { + return true; + } + if (o == null || getClass() != o.getClass()) + { + return false; + } + KeyGeneration that = (KeyGeneration)o; + return generation == that.generation && Arrays.equals(key, that.key) && Arrays.equals(nonce, that.nonce); + } + + void consume() + { + Arrays.fill(key, (byte)0); + Arrays.fill(nonce, (byte)0); + } +} diff --git a/mls/src/main/java/moe/kyokobot/koe/mls/KeyScheduleEpoch.java b/mls/src/main/java/moe/kyokobot/koe/mls/KeyScheduleEpoch.java new file mode 100644 index 0000000..f5b953c --- /dev/null +++ b/mls/src/main/java/moe/kyokobot/koe/mls/KeyScheduleEpoch.java @@ -0,0 +1,486 @@ +package moe.kyokobot.koe.mls; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.List; + +import moe.kyokobot.koe.mls.codec.MLSOutputStream; +import moe.kyokobot.koe.mls.codec.PreSharedKeyID; +import moe.kyokobot.koe.mls.crypto.MlsCipherSuite; +import org.bouncycastle.crypto.AsymmetricCipherKeyPair; +import org.bouncycastle.crypto.hpke.HPKEContext; +import org.bouncycastle.crypto.hpke.HPKEContextWithEncapsulation; +import org.bouncycastle.crypto.params.AsymmetricKeyParameter; +import moe.kyokobot.koe.mls.crypto.Secret; +import org.bouncycastle.util.Arrays; + +public class KeyScheduleEpoch +{ + public static class PSKWithSecret + { + public PreSharedKeyID id; + public Secret secret; + + public PSKWithSecret(PreSharedKeyID id, Secret secret) + { + this.id = id; + this.secret = secret; + } + } + + public byte[] receiveExternalInit(byte[] kemOut) + throws IOException + { + // do export (private key) + final int L = suite.getKDF().getHashLength(); + final byte[] labelData = "MLS 1.0 external init secret".getBytes(StandardCharsets.UTF_8); + HPKEContext ctx = suite.getHPKE().setupBaseR(kemOut, externalKeyPair, new byte[0]); + + return ctx.export(labelData, L); + } + + public static class JoinSecrets + { + // Cached values + private final MlsCipherSuite suite; + // Public values + public final Secret joinerSecret; + + public Secret welcomeSecret; + public Secret welcomeKey; + public Secret welcomeNonce; + private Secret memberSecret; // Held to derive further secrets + + static class PSKLabel + implements MLSOutputStream.Writable + { + PreSharedKeyID id; + short index; + short count; + + public PSKLabel(PreSharedKeyID id, short index, short count) + { + this.id = id; + this.index = index; + this.count = count; + } + + @Override + public void writeTo(MLSOutputStream stream) + throws IOException + { + stream.write(id); + stream.write(index); + stream.write(count); + } + } + + /* + 0 0 = psk_secret_[0] + | | + V V + psk_[0] --> Extract --> ExpandWithLabel --> Extract = psk_secret_[1] + | + 0 | + | | + V V + psk_[1] --> Extract --> ExpandWithLabel --> Extract = psk_secret_[2] + | + 0 ... + | | + V V + psk_[n-1] --> Extract --> ExpandWithLabel --> Extract = psk_secret_[n] + */ + public static Secret pskSecret(MlsCipherSuite suite, List psks) + throws IOException + { + Secret pskSecret = Secret.zero(suite); + if (psks == null || psks.isEmpty()) + { + return pskSecret; + } + + short index = 0; + short count = (short)psks.size(); + for (PSKWithSecret psk : psks) + { + PSKLabel label = new PSKLabel(psk.id, index, count); + byte[] pskLabel = MLSOutputStream.encode(label); + index += 1; + + Secret pskExtracted = Secret.extract(suite, Secret.zero(suite), psk.secret); + Secret pskInput = pskExtracted.expandWithLabel(suite, "derived psk", pskLabel, suite.getKDF().getHashLength()); + pskSecret = Secret.extract(suite, pskInput, pskSecret); + } + + return pskSecret; + } + + /* + init_secret_[n-1] + | + | + V + commit_secret --> KDF.Extract + | + | + V + ExpandWithLabel(., "joiner", GroupContext_[n], KDF.Nh) + | + | + V + joiner_secret + + */ + public static JoinSecrets forMember(MlsCipherSuite suite, Secret initSecret, Secret commitSecret, Secret pskSecret, byte[] context) + throws IOException + { + Secret preJoinerSecret = Secret.extract(suite, initSecret, commitSecret); + Secret joinerSecret = preJoinerSecret.expandWithLabel(suite, "joiner", context, suite.getKDF().getHashLength()); + return new JoinSecrets(suite, joinerSecret, pskSecret); + } + + /* + joiner_secret + | + | + V +psk_secret (or 0) --> KDF.Extract + | + | + +--> DeriveSecret(., "welcome") + | = welcome_secret + | + V + ExpandWithLabel(., "epoch", GroupContext_[n], KDF.Nh) + | + | + V + epoch_secret + */ + public JoinSecrets(MlsCipherSuite suite, Secret joinerSecret, List psks) + throws IOException + { + this.suite = suite; + this.joinerSecret = joinerSecret; + this.memberSecret = Secret.extract(suite, joinerSecret, pskSecret(suite, psks)); + // Carry-forward values + // Held to avoid consuming joinerSecret + this.welcomeSecret = memberSecret.deriveSecret(suite, "welcome"); + this.welcomeKey = welcomeSecret.expand(suite, "key", suite.getAEAD().getKeySize()); + this.welcomeNonce = welcomeSecret.expand(suite, "nonce", suite.getAEAD().getNonceSize()); + } + + public JoinSecrets(MlsCipherSuite suite, Secret joinerSecret, Secret pskSecret) + throws IOException + { + this.suite = suite; + this.joinerSecret = joinerSecret; + this.memberSecret = Secret.extract(suite, joinerSecret, pskSecret); + // Carry-forward values + // Held to avoid consuming joinerSecret + this.welcomeSecret = memberSecret.deriveSecret(suite, "welcome"); + this.welcomeKey = welcomeSecret.expand(suite, "key", suite.getAEAD().getKeySize()); + this.welcomeNonce = welcomeSecret.expand(suite, "nonce", suite.getAEAD().getNonceSize()); + } + + public void injectPskSecret(Secret pskSecret) + throws IOException + { + this.memberSecret = Secret.extract(suite, joinerSecret, pskSecret); + this.welcomeSecret = memberSecret.deriveSecret(suite, "welcome"); + this.welcomeKey = welcomeSecret.expand(suite, "key", suite.getAEAD().getKeySize()); + this.welcomeNonce = welcomeSecret.expand(suite, "nonce", suite.getAEAD().getNonceSize()); + } + + + public KeyScheduleEpoch complete(TreeSize treeSize, byte[] context) + throws IOException, IllegalAccessException + { +// memberSecret = new Secret(suite.getKDF().extract(joinerSecret.value(), pskSecret(suite, null).value())); + Secret epochSecret = memberSecret.expandWithLabel(suite, "epoch", context, suite.getKDF().getHashLength()); + KeyScheduleEpoch keySchedule = new KeyScheduleEpoch(suite, treeSize, epochSecret); + keySchedule.setJoinerSecret(joinerSecret); + return keySchedule; + } + } + + public static class ExternalInitParams + { + public byte[] kemOutput; + public Secret initSecret; + + public ExternalInitParams(MlsCipherSuite suite, AsymmetricKeyParameter externalPub) + { + final byte[] exportContext = "MLS 1.0 external init secret".getBytes(StandardCharsets.UTF_8); + final int L = suite.getKDF().getHashLength(); + + HPKEContextWithEncapsulation ctx = suite.getHPKE().setupBaseS(externalPub, null); + kemOutput = ctx.getEncapsulation(); + initSecret = new Secret(ctx.export(exportContext, L)); + } + + public Secret getInitSecret() + { + return initSecret; + } + + public byte[] getKEMOutput() + { + return kemOutput; + } + } + + + final MlsCipherSuite suite; + + // Secrets derived from the epoch secret + public final Secret initSecret; + public Secret senderDataSecret; + public final Secret exporterSecret; + public final Secret confirmationKey; + public Secret membershipKey; + public final Secret resumptionPSK; + public final Secret epochAuthenticator; + public final Secret encryptionSecret; + public final Secret externalSecret; + + // Further dervied products + final AsymmetricCipherKeyPair externalKeyPair; + public final GroupKeySet groupKeySet; + public Secret joinerSecret; + + public Secret getJoinerSecret() + { + return joinerSecret; + } + + public void setJoinerSecret(Secret joinerSecret) + { + this.joinerSecret = joinerSecret; + } + + public GroupKeySet getEncryptionKeys(TreeSize size) + throws IOException, IllegalAccessException + { + return new GroupKeySet(suite, size, encryptionSecret); + } + + public static KeyGeneration senderDataKeys(MlsCipherSuite suite, byte[] senderDataSecretBytes, byte[] ciphertext) + throws IOException + { + Secret senderDataSecret = new Secret(senderDataSecretBytes); + int sampleSize = suite.getKDF().getHashLength(); + byte[] sample = Arrays.copyOf(ciphertext, sampleSize); + int keySize = suite.getAEAD().getKeySize(); + int nonceSize = suite.getAEAD().getNonceSize(); + Secret key = senderDataSecret.expandWithLabel(suite, "key", sample, keySize); + Secret nonce = senderDataSecret.expandWithLabel(suite, "nonce", sample, nonceSize); + return new KeyGeneration(0, key, nonce); + } + + + public static Secret welcomeSecret(MlsCipherSuite suite, byte[] joinerSecret, List psk) + throws IOException + { + Secret pskSecret = JoinSecrets.pskSecret(suite, psk); + Secret extract = new Secret(suite.getKDF().extract(joinerSecret, pskSecret.value())); + return extract.deriveSecret(suite, "welcome"); + } + + public static KeyScheduleEpoch forCreator(MlsCipherSuite suite, byte[] groupContext) + throws IOException, IllegalAccessException + { + SecureRandom random = new SecureRandom(); + byte[] initSecret = new byte[suite.getKDF().getHashLength()]; + random.nextBytes(initSecret); + + JoinSecrets joinerSecret = JoinSecrets.forMember(suite, new Secret(initSecret), Secret.zero(suite), new Secret(new byte[0]), groupContext); +// TreeSize size = TreeSize.forLeaves(1); +// return joinerSecret.complete(size, groupContext); + return KeyScheduleEpoch.joiner(suite, joinerSecret.joinerSecret.value(), new ArrayList(), groupContext); + } + + public static KeyScheduleEpoch forCreator(MlsCipherSuite suite) + throws IOException, IllegalAccessException + { + SecureRandom rng = new SecureRandom(); + return forCreator(suite, rng); + } + + public static KeyScheduleEpoch forCreator(MlsCipherSuite suite, SecureRandom rng) + throws IOException, IllegalAccessException + { + byte[] epochSecret = new byte[suite.getKDF().getHashLength()]; + rng.nextBytes(epochSecret); + TreeSize treeSize = TreeSize.forLeaves(1); + return new KeyScheduleEpoch(suite, treeSize, new Secret(epochSecret)); + } + + public static KeyScheduleEpoch forExternalJoiner(MlsCipherSuite suite, TreeSize treeSize, ExternalInitParams externalInitParams, Secret commitSecret, List psks, byte[] context) + throws IOException, IllegalAccessException + { + return JoinSecrets.forMember(suite, externalInitParams.initSecret, commitSecret, JoinSecrets.pskSecret(suite, psks), context).complete(treeSize, context); + } + + public JoinSecrets startCommit(Secret commitSecret, List psks, byte[] context) + throws IOException + { + return JoinSecrets.forMember(suite, initSecret, commitSecret, JoinSecrets.pskSecret(suite, psks), context); + } + + /* + epoch_secret + | + | + +--> DeriveSecret(.,