Skip to content

Commit

Permalink
the fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wangweij committed Jan 30, 2025
1 parent 3a564ed commit 0e5796f
Show file tree
Hide file tree
Showing 18 changed files with 1,197 additions and 124 deletions.
11 changes: 6 additions & 5 deletions src/java.base/share/classes/com/sun/crypto/provider/ML_KEM.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -217,7 +217,7 @@ protected Object checkPrivateKey(byte[] sk) throws InvalidKeyException {
/*
Main internal algorithms from Section 6 of specification
*/
protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d, byte[] kem_z) {
protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d_z) {
MessageDigest mlKemH;
try {
mlKemH = MessageDigest.getInstance(HASH_H_NAME);
Expand All @@ -227,7 +227,7 @@ protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d, byte[] kem_z) {
}

//Generate K-PKE keys
var kPkeKeyPair = generateK_PkeKeyPair(kem_d);
var kPkeKeyPair = generateK_PkeKeyPair(kem_d_z);
//encaps key = kPke encryption key
byte[] encapsKey = kPkeKeyPair.publicKey.keyBytes;

Expand All @@ -246,7 +246,7 @@ protected ML_KEM_KeyPair generateKemKeyPair(byte[] kem_d, byte[] kem_z) {
// This should never happen.
throw new RuntimeException(e);
}
System.arraycopy(kem_z, 0, decapsKey,
System.arraycopy(kem_d_z, 32, decapsKey,
kPkePrivateKey.length + encapsKey.length + 32, 32);

return new ML_KEM_KeyPair(
Expand Down Expand Up @@ -367,10 +367,11 @@ private K_PKE_KeyPair generateK_PkeKeyPair(byte[] seed) {
throw new RuntimeException(e);
}

mlKemG.update(seed);
mlKemG.update(seed, 0, 32);
mlKemG.update((byte)mlKem_k);

var rhoSigma = mlKemG.digest();
mlKemG.reset();
var rho = Arrays.copyOfRange(rhoSigma, 0, 32);
var sigma = Arrays.copyOfRange(rhoSigma, 32, 64);
Arrays.fill(rhoSigma, (byte)0);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand Down Expand Up @@ -37,6 +37,12 @@

public final class ML_KEM_Impls {

public static byte[] seedToExpandedPrivate(String pname, byte[] seed) {
return new ML_KEM(pname).generateKemKeyPair(seed)
.decapsulationKey()
.keyBytes();
}

public sealed static class KPG
extends NamedKeyPairGenerator permits KPG2, KPG3, KPG5 {

Expand All @@ -51,22 +57,16 @@ protected KPG(String pname) {

@Override
protected byte[][] implGenerateKeyPair(String name, SecureRandom random) {
byte[] seed = new byte[32];
byte[] seedAndZ = new byte[64];
var r = random != null ? random : JCAUtil.getDefSecureRandom();
r.nextBytes(seed);
byte[] z = new byte[32];
r.nextBytes(z);
r.nextBytes(seedAndZ);

ML_KEM mlKem = new ML_KEM(name);
ML_KEM.ML_KEM_KeyPair kp;
try {
kp = mlKem.generateKemKeyPair(seed, z);
} finally {
Arrays.fill(seed, (byte)0);
Arrays.fill(z, (byte)0);
}
return new byte[][] {
kp = mlKem.generateKemKeyPair(seedAndZ);
return new byte[][]{
kp.encapsulationKey().keyBytes(),
seedAndZ,
kp.decapsulationKey().keyBytes()
};
}
Expand Down Expand Up @@ -97,6 +97,15 @@ public KF() {
public KF(String name) {
super("ML-KEM", name);
}

@Override
protected byte[] implGenAlt(String name, byte[] key) {
if (key.length == 64) {
return seedToExpandedPrivate(name, key);
} else {
return null;
}
}
}

public final static class KF2 extends KF {
Expand Down Expand Up @@ -183,11 +192,11 @@ protected Object implCheckPrivateKey(String name, byte[] sk)
}

public K() {
super("ML-KEM", "ML-KEM-512", "ML-KEM-768", "ML-KEM-1024");
super("ML-KEM", new KF(), "ML-KEM-512", "ML-KEM-768", "ML-KEM-1024");
}

public K(String name) {
super("ML-KEM", name);
super("ML-KEM", new KF(name), name);
}
}

Expand Down
92 changes: 66 additions & 26 deletions src/java.base/share/classes/sun/security/pkcs/NamedPKCS8Key.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -25,7 +25,6 @@

package sun.security.pkcs;

import sun.security.util.DerInputStream;
import sun.security.util.DerValue;
import sun.security.x509.AlgorithmId;

Expand All @@ -39,6 +38,7 @@
import java.security.ProviderException;
import java.security.spec.NamedParameterSpec;
import java.util.Arrays;
import java.util.function.BiFunction;

/// Represents a private key from an algorithm family that is specialized
/// with a named parameter set.
Expand All @@ -50,49 +50,71 @@
/// identifier in the PKCS #8 encoding of the key is always a single OID derived
/// from the parameter set name.
///
/// Besides the [PKCS8Key#key] field, this class might contain an optional
/// alternative key stored in [#alt].
///
/// 1. If there is only `key`, there is only one private key encoding.
/// 2. If both `key` and `alt` exist. `key` is used in encoding,
/// and `alt` is used in calculation.
///
/// This allows ML-KEM or ML-DSA to encode the seed used in key pair
/// generation as the private key. In this case, `alt` will be the
/// expanded key as described in the FIPS documents. If the seed is
/// lost, `key` will be the expanded key and `alt` will be null.
///
/// For algorithms that do not have this "alternative" key format,
/// only `key` will be included and `alt` must be `null`.
///
/// @see sun.security.provider.NamedKeyPairGenerator
public final class NamedPKCS8Key extends PKCS8Key {
@Serial
private static final long serialVersionUID = 1L;

private final String fname;
private final transient NamedParameterSpec paramSpec;
private final byte[] rawBytes;
private final byte[] alt;

private transient boolean destroyed = false;

/// Ctor from family name, parameter set name, raw key bytes.
/// Key bytes won't be cloned, caller must relinquish ownership
public NamedPKCS8Key(String fname, String pname, byte[] rawBytes) {
/// Ctor from raw key bytes.
///
/// `rawBytes` and `alt` won't be cloned, caller
/// must relinquish ownership.
///
/// @param fname family name
/// @param pname parameter set name
/// @param rawBytes raw key bytes
/// @param alt alternative key format, can be `null`.
public NamedPKCS8Key(String fname, String pname, byte[] rawBytes, byte[] alt) {
this.fname = fname;
this.paramSpec = new NamedParameterSpec(pname);
this.alt = alt;
try {
this.algid = AlgorithmId.get(pname);
} catch (NoSuchAlgorithmException e) {
throw new ProviderException(e);
}
this.rawBytes = rawBytes;

DerValue val = new DerValue(DerValue.tag_OctetString, rawBytes);
try {
this.key = val.toByteArray();
} finally {
val.clear();
}
this.key = rawBytes;
}

/// Ctor from family name, and PKCS #8 bytes
public NamedPKCS8Key(String fname, byte[] encoded) throws InvalidKeyException {
/// Ctor from family name and PKCS #8 encoding
///
/// @param fname family name
/// @param encoded PKCS #8 encoding. It is copied so caller can modify
/// it after the method call.
/// @param genAlt a function that is able to calculate the alternative
/// key from raw key inside `encoded`. In the case of seed/expanded,
/// the function will calculate expanded from seed. If it recognizes
/// the input being already the expanded key, it must return `null`.
/// If there is no alternative key format, `getAlt` must be `null`.
public NamedPKCS8Key(String fname, byte[] encoded,
BiFunction<String, byte[], byte[]> genAlt) throws InvalidKeyException {
super(encoded);
this.fname = fname;
try {
paramSpec = new NamedParameterSpec(algid.getName());
if (algid.getEncodedParams() != null) {
throw new InvalidKeyException("algorithm identifier has params");
}
rawBytes = new DerInputStream(key).getOctetString();
} catch (IOException e) {
throw new InvalidKeyException("Cannot parse input", e);
this.alt = genAlt == null ? null : genAlt.apply(algid.getName(), this.key);
paramSpec = new NamedParameterSpec(algid.getName());
if (algid.getEncodedParams() != null) {
throw new InvalidKeyException("algorithm identifier has params");
}
}

Expand All @@ -106,7 +128,23 @@ public String toString() {
/// Returns the reference to the internal key. Caller must not modify
/// the content or keep a reference.
public byte[] getRawBytes() {
return rawBytes;
return key;
}

/// Returns the reference to the key that will be used in computations
/// inside `NamedKEM` or `NamedSignature` between `alt` (if exists)
/// and `key`.
///
/// This method currently simply chooses the longer one, where it is the
/// expanded format. If the key used in computations is not the longer
/// one for an algorithm, consider adding overridable methods to
/// `NamedKEM` and `NamedSignature` to extract it.
public byte[] getExpanded() {
if (alt == null) {
return key;
} else {
return alt.length > key.length ? alt : key;
}
}

@Override
Expand All @@ -128,8 +166,10 @@ private void readObject(ObjectInputStream stream)

@Override
public void destroy() throws DestroyFailedException {
Arrays.fill(rawBytes, (byte)0);
Arrays.fill(key, (byte)0);
if (alt != null) {
Arrays.fill(alt, (byte)0);
}
if (encodedKey != null) {
Arrays.fill(encodedKey, (byte)0);
}
Expand Down
36 changes: 23 additions & 13 deletions src/java.base/share/classes/sun/security/provider/ML_DSA_Impls.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2024, 2025, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
Expand All @@ -26,12 +26,17 @@
package sun.security.provider;

import sun.security.jca.JCAUtil;

import java.security.*;
import java.security.SecureRandom;
import java.util.Arrays;

public class ML_DSA_Impls {

public static byte[] seedToExpandedPrivate(String pname, byte[] seed) {
var impl = new ML_DSA(name2int(pname));
return impl.skEncode(impl.generateKeyPairInternal(seed).privateKey());
}

public enum Version {
DRAFT, FINAL
}
Expand Down Expand Up @@ -75,15 +80,11 @@ protected byte[][] implGenerateKeyPair(String name, SecureRandom sr) {
r.nextBytes(seed);
ML_DSA mlDsa = new ML_DSA(name2int(name));
ML_DSA.ML_DSA_KeyPair kp = mlDsa.generateKeyPairInternal(seed);
try {
return new byte[][]{
mlDsa.pkEncode(kp.publicKey()),
mlDsa.skEncode(kp.privateKey())
};
} finally {
kp.privateKey().destroy();
Arrays.fill(seed, (byte)0);
}
return new byte[][]{
mlDsa.pkEncode(kp.publicKey()),
seed,
mlDsa.skEncode(kp.privateKey())
};
}
}

Expand Down Expand Up @@ -112,6 +113,15 @@ public KF() {
public KF(String name) {
super("ML-DSA", name);
}

@Override
protected byte[] implGenAlt(String name, byte[] key) {
if (key.length == 32) {
return seedToExpandedPrivate(name, key);
} else {
return null;
}
}
}

public final static class KF2 extends KF {
Expand All @@ -134,10 +144,10 @@ public KF5() {

public sealed static class SIG extends NamedSignature permits SIG2, SIG3, SIG5 {
public SIG() {
super("ML-DSA", "ML-DSA-44", "ML-DSA-65", "ML-DSA-87");
super("ML-DSA", new KF(), "ML-DSA-44", "ML-DSA-65", "ML-DSA-87");
}
public SIG(String name) {
super("ML-DSA", name);
super("ML-DSA", new KF(name), name);
}

@Override
Expand Down
Loading

0 comments on commit 0e5796f

Please sign in to comment.