Skip to content

Commit

Permalink
Cleaner approach to reverting on error when changing fields
Browse files Browse the repository at this point in the history
  • Loading branch information
AgeManning committed Oct 28, 2024
1 parent 6eeb694 commit bdfd7a9
Showing 1 changed file with 64 additions and 88 deletions.
152 changes: 64 additions & 88 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,123 +739,103 @@ impl<K: EnrKey> Enr<K> {

/// Helper function for `set_tcp_socket()` and `set_udp_socket`.
fn set_socket(&mut self, socket: SocketAddr, key: &K, is_tcp: bool) -> Result<(), Error> {
let enr_backup = self.clone();
// We work on this new version, allowing us to not mutate self on error.
let mut new_enr = self.clone();
let (port_string, port_v6_string): (Key, Key) = if is_tcp {
(TCP_ENR_KEY.into(), TCP6_ENR_KEY.into())
} else {
(UDP_ENR_KEY.into(), UDP6_ENR_KEY.into())
};

let (prev_ip, prev_port) = match socket.ip() {
match socket.ip() {
IpAddr::V4(addr) => {
let mut ip = BytesMut::new();
addr.encode(&mut ip);
let mut port = BytesMut::new();
socket.port().encode(&mut port);
(
self.content.insert(IP_ENR_KEY.into(), ip.freeze()),
self.content.insert(port_string.clone(), port.freeze()),
)
new_enr.content.insert(IP_ENR_KEY.into(), ip.freeze());
new_enr.content.insert(port_string.clone(), port.freeze());
}
IpAddr::V6(addr) => {
let mut ip6 = BytesMut::new();
addr.encode(&mut ip6);
let mut port = BytesMut::new();
socket.port().encode(&mut port);
(
self.content.insert(IP6_ENR_KEY.into(), ip6.freeze()),
self.content.insert(port_v6_string.clone(), port.freeze()),
)
new_enr.content.insert(IP6_ENR_KEY.into(), ip6.freeze());
new_enr
.content
.insert(port_v6_string.clone(), port.freeze());
}
};

let public_key = key.public();
let mut pubkey = BytesMut::new();
public_key.encode().as_ref().encode(&mut pubkey);
let previous_key = self.content.insert(public_key.enr_key(), pubkey.freeze());
new_enr
.content
.insert(public_key.enr_key(), pubkey.freeze());

// check the size and revert on failure
if self.size() > MAX_ENR_SIZE {
// if the size of the record is too large, revert and error
// revert the public key
if let Some(key) = previous_key {
self.content.insert(public_key.enr_key(), key);
} else {
self.content.remove(&public_key.enr_key());
}
// revert the content
match socket.ip() {
IpAddr::V4(_) => {
if let Some(ip) = prev_ip {
self.content.insert(IP_ENR_KEY.into(), ip);
} else {
self.content.remove(IP_ENR_KEY);
}
if let Some(udp) = prev_port {
self.content.insert(port_string, udp);
} else {
self.content.remove(&port_string);
}
}
IpAddr::V6(_) => {
if let Some(ip) = prev_ip {
self.content.insert(IP6_ENR_KEY.into(), ip);
} else {
self.content.remove(IP6_ENR_KEY);
}
if let Some(udp) = prev_port {
self.content.insert(port_v6_string, udp);
} else {
self.content.remove(&port_v6_string);
}
}
}
// check the size
if new_enr.size() > MAX_ENR_SIZE {
return Err(Error::ExceedsMaxSize);
}

// increment the sequence number
match self.seq.checked_add(1) {
Some(seq_no) => self.seq = seq_no,
None => {
*self = enr_backup;
return Err(Error::SequenceNumberTooHigh);
}
}
new_enr.seq = new_enr
.seq
.checked_add(1)
.ok_or(Error::SequenceNumberTooHigh)?;

// sign the record
self.sign(key)?;
new_enr.sign(key)?;

// update the node id
self.node_id = NodeId::from(key.public());
new_enr.node_id = NodeId::from(key.public());

if new_enr.size() > MAX_ENR_SIZE {
// in case the signature size changes, inform the user the size has exceeded the
// maximum
return Err(Error::ExceedsMaxSize);
}
// Everything passed update the record
*self = new_enr;

Ok(())
}

/// Removes a key from the ENR.
pub fn remove_key(&mut self, content_key: impl AsRef<[u8]>, enr_key: &K) -> Result<(), Error> {
let enr_backup = self.clone();
self.content.remove(content_key.as_ref());
// We work on this new version, allowing us to not mutate self on error.
let mut new_enr = self.clone();
new_enr.content.remove(content_key.as_ref());

// add the new public key.
let public_key = enr_key.public();
let mut pubkey = BytesMut::new();
public_key.encode().as_ref().encode(&mut pubkey);
self.content.insert(public_key.enr_key(), pubkey.freeze());
new_enr
.content
.insert(public_key.enr_key(), pubkey.freeze());

// increment the sequence number
match self.seq.checked_add(1) {
Some(seq_no) => self.seq = seq_no,
None => {
*self = enr_backup;
return Err(Error::SequenceNumberTooHigh);
}
}
new_enr.seq = new_enr
.seq
.checked_add(1)
.ok_or(Error::SequenceNumberTooHigh)?;

// sign the record
self.sign(enr_key)?;
new_enr.sign(enr_key)?;

// update the node id
self.node_id = NodeId::from(enr_key.public());
new_enr.node_id = NodeId::from(enr_key.public());

if new_enr.size() > MAX_ENR_SIZE {
// in case the signature size changes, inform the user the size has exceeded the
// maximum
return Err(Error::ExceedsMaxSize);
}

*self = new_enr;
Ok(())
}

Expand All @@ -871,61 +851,57 @@ impl<K: EnrKey> Enr<K> {
insert_key_values: impl Iterator<Item = (impl AsRef<[u8]>, &'a [u8])>,
enr_key: &K,
) -> Result<(PreviousRlpEncodedValues, PreviousRlpEncodedValues), Error> {
let enr_backup = self.clone();
// We work on this new version, allowing us to not mutate self on error.
let mut new_enr = self.clone();

let mut removed = Vec::new();
for key in remove_keys {
removed.push(self.content.remove(key.as_ref()));
removed.push(new_enr.content.remove(key.as_ref()));
}

// add the new public key
let public_key = enr_key.public();
let mut pubkey = BytesMut::new();
public_key.encode().as_ref().encode(&mut pubkey);
self.content.insert(public_key.enr_key(), pubkey.freeze());
new_enr
.content
.insert(public_key.enr_key(), pubkey.freeze());

let mut inserted = Vec::new();
for (key, value) in insert_key_values {
// currently only support "v4" identity schemes
if key.as_ref() == ID_ENR_KEY && value != ENR_VERSION {
*self = enr_backup;
return Err(Error::UnsupportedIdentityScheme);
}
let mut out = BytesMut::new();
value.encode(&mut out);
let value = out.freeze();
// Prevent inserting invalid RLP integers
if is_keyof_u16(key.as_ref()) {
if let Err(err) = u16::decode(&mut value.as_ref()) {
*self = enr_backup;
return Err(err.into());
}
u16::decode(&mut value.as_ref())?;
}

inserted.push(self.content.insert(key.as_ref().to_vec(), value));
inserted.push(new_enr.content.insert(key.as_ref().to_vec(), value));
}

// increment the sequence number
match self.seq.checked_add(1) {
Some(seq_no) => self.seq = seq_no,
None => {
*self = enr_backup;
return Err(Error::SequenceNumberTooHigh);
}
}
new_enr.seq = new_enr
.seq
.checked_add(1)
.ok_or(Error::SequenceNumberTooHigh)?;

// sign the record
self.sign(enr_key)?;
new_enr.sign(enr_key)?;

// update the node id
self.node_id = NodeId::from(enr_key.public());
new_enr.node_id = NodeId::from(enr_key.public());

if self.size() > MAX_ENR_SIZE {
if new_enr.size() > MAX_ENR_SIZE {
// in case the signature size changes, inform the user the size has exceeded the
// maximum
*self = enr_backup;
return Err(Error::ExceedsMaxSize);
}
*self = new_enr;

Ok((removed, inserted))
}
Expand Down

0 comments on commit bdfd7a9

Please sign in to comment.