Skip to content

Commit

Permalink
chore: Make the TLS epoch an enum type
Browse files Browse the repository at this point in the history
  • Loading branch information
larseggert committed Jan 8, 2025
1 parent 5c36c79 commit 12e892a
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 44 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions neqo-crypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ workspace = true

[dependencies]
# Checked against https://searchfox.org/mozilla-central/source/Cargo.lock 2024-11-11
enum-map = { workspace = true }
log = { workspace = true }
neqo-common = { path = "../neqo-common" }

Expand Down
2 changes: 1 addition & 1 deletion neqo-crypto/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
assert_initialized,
auth::AuthenticationStatus,
constants::{
Alert, Cipher, Epoch, Extension, Group, SignatureScheme, Version, TLS_VERSION_1_3,
Alert, Cipher, Extension, Group, SignatureScheme, Epoch, Version, TLS_VERSION_1_3,
},
ech,
err::{is_blocked, secstatus_to_res, Error, PRErrorCode, Res},
Expand Down
2 changes: 1 addition & 1 deletion neqo-crypto/src/agentio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl RecordList {
let records = arg.cast::<Self>().as_mut().unwrap();

let slice = null_safe_slice(data, len);
records.append(epoch, ContentType::try_from(ct).unwrap(), slice);
records.append(epoch.into(), ContentType::try_from(ct).unwrap(), slice);
ssl::SECSuccess
}

Expand Down
55 changes: 47 additions & 8 deletions neqo-crypto/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,60 @@

#![allow(dead_code)]

use enum_map::Enum;

use crate::ssl;

// Ideally all of these would be enums, but size matters and we need to allow
// for values outside of those that are defined here.

pub type Alert = u8;

pub type Epoch = u16;
// TLS doesn't really have an "initial" concept that maps to QUIC so directly,
// but this should be clear enough.
pub const TLS_EPOCH_INITIAL: Epoch = 0_u16;
pub const TLS_EPOCH_ZERO_RTT: Epoch = 1_u16;
pub const TLS_EPOCH_HANDSHAKE: Epoch = 2_u16;
// Also, we don't use TLS epochs > 3.
pub const TLS_EPOCH_APPLICATION_DATA: Epoch = 3_u16;
#[derive(Default, Debug, Enum)]
pub enum Epoch {
// TLS doesn't really have an "initial" concept that maps to QUIC so directly,
// but this should be clear enough.
#[default]
Initial = 0,
ZeroRtt,
Handshake,
ApplicationData,
// Also, we don't use TLS epochs > 3.
}

impl From<u16> for Epoch {
fn from(e: u16) -> Self {
match e {
0 => Self::Initial,
1 => Self::ZeroRtt,
2 => Self::Handshake,
3 => Self::ApplicationData,
_ => unreachable!(),
}
}
}

impl From<Epoch> for usize {
fn from(e: Epoch) -> Self {
match e {
Epoch::Initial => 0,
Epoch::ZeroRtt => 1,
Epoch::Handshake => 2,
Epoch::ApplicationData => 3,
}
}
}

impl std::fmt::Display for Epoch {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Initial => write!(f, "Initial"),
Self::ZeroRtt => write!(f, "ZeroRtt"),
Self::Handshake => write!(f, "Handshake"),
Self::ApplicationData => write!(f, "ApplicationData"),
}
}
}

/// Rather than defining a type alias and a bunch of constants, which leads to a ton of repetition,
/// use this macro.
Expand Down
23 changes: 11 additions & 12 deletions neqo-crypto/src/secrets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

use std::{os::raw::c_void, pin::Pin};

use enum_map::EnumMap;
use neqo_common::qdebug;

use crate::{
Expand Down Expand Up @@ -43,23 +44,16 @@ impl From<SSLSecretDirection::Type> for SecretDirection {
#[allow(clippy::module_name_repetitions)]
pub struct DirectionalSecrets {
// We only need to maintain 3 secrets for the epochs used during the handshake.
secrets: [Option<SymKey>; 3],
secrets: EnumMap<Epoch, Option<SymKey>>,
}

impl DirectionalSecrets {
fn put(&mut self, epoch: Epoch, key: SymKey) {
assert!(epoch > 0);
let i = (epoch - 1) as usize;
assert!(i < self.secrets.len());
// assert!(self.secrets[i].is_none());
self.secrets[i] = Some(key);
self.secrets[epoch] = Some(key);
}

pub fn take(&mut self, epoch: Epoch) -> Option<SymKey> {
assert!(epoch > 0);
let i = (epoch - 1) as usize;
assert!(i < self.secrets.len());
self.secrets[i].take()
self.secrets[epoch].take()
}
}

Expand All @@ -78,10 +72,15 @@ impl Secrets {
arg: *mut c_void,
) {
let secrets = arg.cast::<Self>().as_mut().unwrap();
secrets.put_raw(epoch, dir, secret);
secrets.put_raw(epoch.into(), dir, secret);
}

fn put_raw(&mut self, epoch: Epoch, dir: SSLSecretDirection::Type, key_ptr: *mut PK11SymKey) {
fn put_raw(
&mut self,
epoch: Epoch,
dir: SSLSecretDirection::Type,
key_ptr: *mut PK11SymKey,
) {
let key_ptr = unsafe { PK11_ReferenceSymKey(key_ptr) };
let key = SymKey::from_ptr(key_ptr).expect("NSS shouldn't be passing out NULL secrets");
self.put(SecretDirection::from(dir), epoch, key);
Expand Down
37 changes: 18 additions & 19 deletions neqo-transport/src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@ use std::{

use neqo_common::{hex, hex_snip_middle, qdebug, qinfo, qtrace, Encoder, Role};
use neqo_crypto::{
hkdf, hp::HpKey, Aead, Agent, AntiReplay, Cipher, Epoch, Error as CryptoError, HandshakeState,
PrivateKey, PublicKey, Record, RecordList, ResumptionToken, SymKey, ZeroRttChecker,
hkdf, hp::HpKey, Aead, Agent, AntiReplay, Cipher, Error as CryptoError, HandshakeState,
PrivateKey, PublicKey, Record, RecordList, ResumptionToken, SymKey, Epoch, ZeroRttChecker,
TLS_AES_128_GCM_SHA256, TLS_AES_256_GCM_SHA384, TLS_CHACHA20_POLY1305_SHA256, TLS_CT_HANDSHAKE,
TLS_EPOCH_APPLICATION_DATA, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL, TLS_EPOCH_ZERO_RTT,
TLS_GRP_EC_SECP256R1, TLS_GRP_EC_SECP384R1, TLS_GRP_EC_SECP521R1, TLS_GRP_EC_X25519,
TLS_GRP_KEM_MLKEM768X25519, TLS_VERSION_1_3,
};
Expand Down Expand Up @@ -192,10 +191,10 @@ impl Crypto {
let input = data.map(|d| {
qtrace!("Handshake record received {:0x?} ", d);
let epoch = match space {
PacketNumberSpace::Initial => TLS_EPOCH_INITIAL,
PacketNumberSpace::Handshake => TLS_EPOCH_HANDSHAKE,
PacketNumberSpace::Initial => Epoch::Initial,
PacketNumberSpace::Handshake => Epoch::Handshake,
// Our epoch progresses forward, but the TLS epoch is fixed to 3.
PacketNumberSpace::ApplicationData => TLS_EPOCH_APPLICATION_DATA,
PacketNumberSpace::ApplicationData => Epoch::ApplicationData,
};
Record {
ct: TLS_CT_HANDSHAKE,
Expand Down Expand Up @@ -232,11 +231,11 @@ impl Crypto {
let (dir, secret) = match role {
Role::Client => (
CryptoDxDirection::Write,
self.tls.write_secret(TLS_EPOCH_ZERO_RTT),
self.tls.write_secret(Epoch::ZeroRtt),
),
Role::Server => (
CryptoDxDirection::Read,
self.tls.read_secret(TLS_EPOCH_ZERO_RTT),
self.tls.read_secret(Epoch::ZeroRtt),
),
};
let secret = secret.ok_or(Error::InternalError)?;
Expand Down Expand Up @@ -266,13 +265,13 @@ impl Crypto {

fn install_handshake_keys(&mut self) -> Res<bool> {
qtrace!([self], "Attempt to install handshake keys");
let Some(write_secret) = self.tls.write_secret(TLS_EPOCH_HANDSHAKE) else {
let Some(write_secret) = self.tls.write_secret(Epoch::Handshake) else {
// No keys is fine.
return Ok(false);
};
let read_secret = self
.tls
.read_secret(TLS_EPOCH_HANDSHAKE)
.read_secret(Epoch::Handshake)
.ok_or(Error::InternalError)?;
let cipher = match self.tls.info() {
None => self.tls.preinfo()?.cipher_suite(),
Expand All @@ -287,7 +286,7 @@ impl Crypto {

fn maybe_install_application_write_key(&mut self, version: Version) -> Res<()> {
qtrace!([self], "Attempt to install application write key");
if let Some(secret) = self.tls.write_secret(TLS_EPOCH_APPLICATION_DATA) {
if let Some(secret) = self.tls.write_secret(Epoch::ApplicationData) {
self.states.set_application_write_key(version, &secret)?;
qdebug!([self], "Application write key installed");
}
Expand All @@ -301,7 +300,7 @@ impl Crypto {
debug_assert!(self.states.app_write.is_some());
let read_secret = self
.tls
.read_secret(TLS_EPOCH_APPLICATION_DATA)
.read_secret(Epoch::ApplicationData)
.ok_or(Error::InternalError)?;
self.states
.set_application_read_key(version, &read_secret, expire_0rtt)?;
Expand Down Expand Up @@ -487,7 +486,7 @@ impl CryptoDxState {

let secret = hkdf::expand_label(TLS_VERSION_1_3, cipher, &initial_secret, &[], label)?;

Self::new(version, direction, TLS_EPOCH_INITIAL, &secret, cipher)
Self::new(version, direction, Epoch::Initial, &secret, cipher)
}

/// Determine the confidentiality and integrity limits for the cipher.
Expand Down Expand Up @@ -620,13 +619,13 @@ impl CryptoDxState {
// Only initiate a key update if we have processed exactly one packet
// and we are in an epoch greater than 3.
self.used_pn.start + 1 == self.used_pn.end
&& self.epoch > usize::from(TLS_EPOCH_APPLICATION_DATA)
&& self.epoch > usize::from(Epoch::ApplicationData)
}

#[must_use]
pub fn can_update(&self, largest_acknowledged: Option<PacketNumber>) -> bool {
largest_acknowledged.map_or_else(
|| self.epoch == usize::from(TLS_EPOCH_APPLICATION_DATA),
|| self.epoch == usize::from(Epoch::ApplicationData),
|la| self.used_pn.contains(&la),
)
}
Expand Down Expand Up @@ -765,7 +764,7 @@ impl CryptoDxAppData {
cipher: Cipher,
) -> Res<Self> {
Ok(Self {
dx: CryptoDxState::new(version, dir, TLS_EPOCH_APPLICATION_DATA, secret, cipher)?,
dx: CryptoDxState::new(version, dir, Epoch::ApplicationData, secret, cipher)?,
cipher,
next_secret: Self::update_secret(cipher, secret)?,
})
Expand Down Expand Up @@ -1028,7 +1027,7 @@ impl CryptoStates {
self.zero_rtt = Some(CryptoDxState::new(
version,
dir,
TLS_EPOCH_ZERO_RTT,
Epoch::ZeroRtt,
secret,
cipher,
)?);
Expand Down Expand Up @@ -1069,14 +1068,14 @@ impl CryptoStates {
tx: CryptoDxState::new(
version,
CryptoDxDirection::Write,
TLS_EPOCH_HANDSHAKE,
Epoch::Handshake,
write_secret,
cipher,
)?,
rx: CryptoDxState::new(
version,
CryptoDxDirection::Read,
TLS_EPOCH_HANDSHAKE,
Epoch::Handshake,
read_secret,
cipher,
)?,
Expand Down
6 changes: 3 additions & 3 deletions neqo-transport/src/tracking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{

use enum_map::{enum_map, Enum, EnumMap};
use neqo_common::{qdebug, qinfo, qtrace, qwarn, IpTosEcn};
use neqo_crypto::{Epoch, TLS_EPOCH_HANDSHAKE, TLS_EPOCH_INITIAL};
use neqo_crypto::Epoch;

use crate::{
ecn::EcnCount,
Expand Down Expand Up @@ -47,8 +47,8 @@ impl PacketNumberSpace {
impl From<Epoch> for PacketNumberSpace {
fn from(epoch: Epoch) -> Self {
match epoch {
TLS_EPOCH_INITIAL => Self::Initial,
TLS_EPOCH_HANDSHAKE => Self::Handshake,
Epoch::Initial => Self::Initial,
Epoch::Handshake => Self::Handshake,
_ => Self::ApplicationData,
}
}
Expand Down

0 comments on commit 12e892a

Please sign in to comment.