From 07e1e7dce397579bfa2e273430f6c47ef05536f5 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Fri, 1 Dec 2023 14:40:11 +0200 Subject: [PATCH 01/21] save --- contract/src/lib.rs | 154 ++++++-------- contract/src/primitives.rs | 177 ++++++++++++++++ .../src/multichain/containers.rs | 6 +- integration-tests/src/multichain/local.rs | 6 +- integration-tests/src/multichain/mod.rs | 14 +- integration-tests/tests/multichain/mod.rs | 2 +- node/src/cli.rs | 23 +- node/src/protocol/consensus.rs | 44 +++- node/src/protocol/contract.rs | 199 ------------------ node/src/protocol/contract/mod.rs | 138 ++++++++++++ node/src/protocol/contract/primitives.rs | 136 ++++++++++++ node/src/protocol/cryptography.rs | 10 +- node/src/protocol/mod.rs | 20 +- node/src/protocol/state.rs | 22 +- 14 files changed, 588 insertions(+), 363 deletions(-) create mode 100644 contract/src/primitives.rs delete mode 100644 node/src/protocol/contract.rs create mode 100644 node/src/protocol/contract/mod.rs create mode 100644 node/src/protocol/contract/primitives.rs diff --git a/contract/src/lib.rs b/contract/src/lib.rs index 4434b116f..2b2737839 100644 --- a/contract/src/lib.rs +++ b/contract/src/lib.rs @@ -1,65 +1,38 @@ +pub mod primitives; + use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; use near_sdk::serde::{Deserialize, Serialize}; use near_sdk::{env, near_bindgen, AccountId, PanicOnDefault, PublicKey}; -use std::collections::{BTreeMap, HashSet}; - -type ParticipantId = u32; - -pub mod hpke { - pub type PublicKey = [u8; 32]; -} - -#[derive( - Serialize, - Deserialize, - BorshDeserialize, - BorshSerialize, - Clone, - Hash, - PartialEq, - Eq, - PartialOrd, - Ord, - Debug, -)] -pub struct ParticipantInfo { - pub id: ParticipantId, - pub account_id: AccountId, - pub url: String, - /// The public key used for encrypting messages. - pub cipher_pk: hpke::PublicKey, - /// The public key used for verifying messages. - pub sign_pk: PublicKey, -} +use primitives::{Votes, PkVotes, Participants, CandidateInfo, Candidates}; +use std::collections::HashSet; #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] pub struct InitializingContractState { - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, - pub pk_votes: BTreeMap>, + pub pk_votes: PkVotes, } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] pub struct RunningContractState { pub epoch: u64, - // TODO: why is this account id for participants instead of participant id? - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub public_key: PublicKey, - pub candidates: BTreeMap, - pub join_votes: BTreeMap>, - pub leave_votes: BTreeMap>, + pub candidates: Candidates, + pub join_votes: Votes, + pub leave_votes: Votes, } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] pub struct ResharingContractState { pub old_epoch: u64, - pub old_participants: BTreeMap, + pub old_participants: Participants, // TODO: only store diff to save on storage - pub new_participants: BTreeMap, + pub new_participants: Participants, pub threshold: usize, pub public_key: PublicKey, - pub finished_votes: HashSet, + pub finished_votes: HashSet, } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] @@ -78,12 +51,12 @@ pub struct MpcContract { #[near_bindgen] impl MpcContract { #[init] - pub fn init(threshold: usize, participants: BTreeMap) -> Self { + pub fn init(threshold: usize, participants: Participants) -> Self { MpcContract { protocol_state: ProtocolContractState::Initializing(InitializingContractState { participants, threshold, - pk_votes: BTreeMap::new(), + pk_votes: PkVotes::new(), }), } } @@ -94,9 +67,8 @@ impl MpcContract { pub fn join( &mut self, - participant_id: ParticipantId, url: String, - cipher_pk: hpke::PublicKey, + cipher_pk: primitives::hpke::PublicKey, sign_pk: PublicKey, ) { match &mut self.protocol_state { @@ -105,15 +77,14 @@ impl MpcContract { candidates, .. }) => { - let account_id = env::signer_account_id(); - if participants.contains_key(&account_id) { + let signer_account_id = env::signer_account_id(); + if participants.contains_key(&signer_account_id) { env::panic_str("this participant is already in the participant set"); } candidates.insert( - participant_id, - ParticipantInfo { - id: participant_id, - account_id, + signer_account_id.clone(), + CandidateInfo { + account_id: signer_account_id, url, cipher_pk, sign_pk, @@ -124,7 +95,7 @@ impl MpcContract { } } - pub fn vote_join(&mut self, participant: ParticipantId) -> bool { + pub fn vote_join(&mut self, candidate_account_id: AccountId) -> bool { match &mut self.protocol_state { ProtocolContractState::Running(RunningContractState { epoch, @@ -135,19 +106,19 @@ impl MpcContract { join_votes, .. }) => { - let voting_participant = participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the participant set") - }); - let candidate = candidates - .get(&participant) + let signer_account_id = env::signer_account_id(); + if !participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the participant set"); + } + let candidate_info = candidates + .get(&candidate_account_id) .unwrap_or_else(|| env::panic_str("candidate is not registered")); - let voted = join_votes.entry(participant).or_default(); - voted.insert(voting_participant.id); + let voted = join_votes.entry(candidate_account_id.clone()); + voted.insert(signer_account_id); if voted.len() >= *threshold { let mut new_participants = participants.clone(); - new_participants.insert(candidate.account_id.clone(), candidate.clone()); + new_participants + .insert(candidate_account_id.clone(), candidate_info.clone().into()); self.protocol_state = ProtocolContractState::Resharing(ResharingContractState { old_epoch: *epoch, @@ -166,30 +137,29 @@ impl MpcContract { } } - pub fn vote_leave(&mut self, participant: ParticipantId) -> bool { + pub fn vote_leave(&mut self, acc_id_to_leave: AccountId) -> bool { match &mut self.protocol_state { ProtocolContractState::Running(RunningContractState { epoch, participants, threshold, public_key, - candidates, + candidates: _, leave_votes, .. }) => { - let voting_participant = participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the participant set") - }); - let candidate = candidates - .get(&participant) - .unwrap_or_else(|| env::panic_str("candidate is not registered")); - let voted = leave_votes.entry(participant).or_default(); - voted.insert(voting_participant.id); + let signer_account_id = env::signer_account_id(); + if !participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the participant set"); + } + if !participants.contains_key(&acc_id_to_leave) { + env::panic_str("account to leave is not in the participant set"); + } + let voted = leave_votes.entry(acc_id_to_leave.clone()); + voted.insert(signer_account_id); if voted.len() >= *threshold { let mut new_participants = participants.clone(); - new_participants.remove(&candidate.account_id); + new_participants.remove(&acc_id_to_leave); self.protocol_state = ProtocolContractState::Resharing(ResharingContractState { old_epoch: *epoch, @@ -215,22 +185,21 @@ impl MpcContract { threshold, pk_votes, }) => { - let voting_participant = participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the participant set") - }); - let voted = pk_votes.entry(public_key.clone()).or_default(); - voted.insert(voting_participant.id); + let signer_account_id = env::signer_account_id(); + if !participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the participant set"); + } + let voted = pk_votes.entry(public_key.clone()); + voted.insert(signer_account_id); if voted.len() >= *threshold { self.protocol_state = ProtocolContractState::Running(RunningContractState { epoch: 0, participants: participants.clone(), threshold: *threshold, public_key, - candidates: BTreeMap::new(), - join_votes: BTreeMap::new(), - leave_votes: BTreeMap::new(), + candidates: Candidates::new(), + join_votes: Votes::new(), + leave_votes: Votes::new(), }); true } else { @@ -256,21 +225,20 @@ impl MpcContract { if *old_epoch + 1 != epoch { env::panic_str("mismatched epochs"); } - let voting_participant = old_participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the old participant set") - }); - finished_votes.insert(voting_participant.id); + let signer_account_id = env::signer_account_id(); + if !old_participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the old participant set"); + } + finished_votes.insert(signer_account_id); if finished_votes.len() >= *threshold { self.protocol_state = ProtocolContractState::Running(RunningContractState { epoch: *old_epoch + 1, participants: new_participants.clone(), threshold: *threshold, public_key: public_key.clone(), - candidates: BTreeMap::new(), - join_votes: BTreeMap::new(), - leave_votes: BTreeMap::new(), + candidates: Candidates::new(), + join_votes: Votes::new(), + leave_votes: Votes::new(), }); true } else { diff --git a/contract/src/primitives.rs b/contract/src/primitives.rs new file mode 100644 index 000000000..719887b1f --- /dev/null +++ b/contract/src/primitives.rs @@ -0,0 +1,177 @@ +use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; +use near_sdk::serde::{Deserialize, Serialize}; +use near_sdk::{AccountId, PublicKey}; +use std::collections::{BTreeMap, HashSet}; + +pub mod hpke { + pub type PublicKey = [u8; 32]; +} + +#[derive( + Serialize, + Deserialize, + BorshDeserialize, + BorshSerialize, + Clone, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + Debug, +)] +pub struct ParticipantInfo { + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: PublicKey, +} + +impl From for ParticipantInfo { + fn from(candidate_info: CandidateInfo) -> Self { + ParticipantInfo { + account_id: candidate_info.account_id, + url: candidate_info.url, + cipher_pk: candidate_info.cipher_pk, + sign_pk: candidate_info.sign_pk, + } + } +} + +#[derive( + Serialize, + Deserialize, + BorshDeserialize, + BorshSerialize, + Clone, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + Debug, +)] +pub struct CandidateInfo { + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: PublicKey, +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)] +pub struct Participants { + pub participants: BTreeMap, +} + +impl Participants { + pub fn new() -> Self { + Participants { + participants: BTreeMap::new(), + } + } + + pub fn contains_key(&self, account_id: &AccountId) -> bool { + self.participants.contains_key(account_id) + } + + pub fn insert(&mut self, account_id: AccountId, participant_info: ParticipantInfo) { + self.participants.insert(account_id, participant_info); + } + + pub fn remove(&mut self, account_id: &AccountId) { + self.participants.remove(account_id); + } + + pub fn get(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { + self.participants.get(account_id) + } + pub fn iter(&self) -> impl Iterator { + self.participants.iter() + } + + pub fn into_iter(self) -> impl Iterator { + self.participants.into_iter() + } +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)] +pub struct Candidates { + pub candidates: BTreeMap, +} + +impl Candidates { + pub fn new() -> Self { + Candidates { + candidates: BTreeMap::new(), + } + } + + pub fn contains_key(&self, account_id: &AccountId) -> bool { + self.candidates.contains_key(account_id) + } + + pub fn insert(&mut self, account_id: AccountId, candidate: CandidateInfo) { + self.candidates.insert(account_id, candidate); + } + + pub fn remove(&mut self, account_id: &AccountId) { + self.candidates.remove(account_id); + } + + pub fn get(&self, account_id: &AccountId) -> Option<&CandidateInfo> { + self.candidates.get(account_id) + } + pub fn iter(&self) -> impl Iterator { + self.candidates.iter() + } + + pub fn into_iter(self) -> impl Iterator { + self.candidates.into_iter() + } +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] +pub struct Votes { + pub votes: BTreeMap>, +} + +impl Votes { + pub fn new() -> Self { + Votes { + votes: BTreeMap::new(), + } + } + + pub fn entry(&mut self, account_id: AccountId) -> &mut HashSet { + self.votes.entry(account_id).or_default() + } + + pub fn into_iter(self) -> impl Iterator)> { + self.votes.into_iter() + } +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] +pub struct PkVotes { + pub votes: BTreeMap>, +} + +impl PkVotes { + pub fn new() -> Self { + PkVotes { + votes: BTreeMap::new(), + } + } + + pub fn entry(&mut self, public_key: PublicKey) -> &mut HashSet { + self.votes.entry(public_key).or_default() + } + + pub fn into_iter(self) -> impl Iterator)> { + self.votes.into_iter() + } +} diff --git a/integration-tests/src/multichain/containers.rs b/integration-tests/src/multichain/containers.rs index 8b2670e1c..072726fff 100644 --- a/integration-tests/src/multichain/containers.rs +++ b/integration-tests/src/multichain/containers.rs @@ -32,17 +32,15 @@ impl<'a> Node<'a> { pub async fn run( ctx: &super::Context<'a>, - node_id: u32, - account: &AccountId, + account_id: &AccountId, account_sk: &near_workspaces::types::SecretKey, ) -> anyhow::Result> { tracing::info!(node_id, "running node container"); let (cipher_sk, cipher_pk) = hpke::generate(); let args = mpc_recovery_node::cli::Cli::Start { - node_id: node_id.into(), near_rpc: ctx.lake_indexer.rpc_host_address.clone(), mpc_contract_id: ctx.mpc_contract.id().clone(), - account: account.clone(), + account_id: account_id.clone(), account_sk: account_sk.to_string().parse()?, web_port: Self::CONTAINER_PORT, cipher_pk: hex::encode(cipher_pk.to_bytes()), diff --git a/integration-tests/src/multichain/local.rs b/integration-tests/src/multichain/local.rs index 4bb3451d1..c77fe3c86 100644 --- a/integration-tests/src/multichain/local.rs +++ b/integration-tests/src/multichain/local.rs @@ -20,17 +20,15 @@ pub struct Node { impl Node { pub async fn run( ctx: &super::Context<'_>, - node_id: u32, - account: &AccountId, + account_id: &AccountId, account_sk: &near_workspaces::types::SecretKey, ) -> anyhow::Result { let web_port = util::pick_unused_port().await?; let (cipher_sk, cipher_pk) = hpke::generate(); let cli = mpc_recovery_node::cli::Cli::Start { - node_id: node_id.into(), near_rpc: ctx.lake_indexer.rpc_host_address.clone(), mpc_contract_id: ctx.mpc_contract.id().clone(), - account: account.clone(), + account_id: account_id.clone(), account_sk: account_sk.to_string().parse()?, web_port, cipher_pk: hex::encode(cipher_pk.to_bytes()), diff --git a/integration-tests/src/multichain/mod.rs b/integration-tests/src/multichain/mod.rs index 49d4b89bb..4e9ed6c89 100644 --- a/integration-tests/src/multichain/mod.rs +++ b/integration-tests/src/multichain/mod.rs @@ -50,17 +50,16 @@ impl Nodes<'_> { pub async fn add_node( &mut self, - node_id: u32, account: &AccountId, account_sk: &near_workspaces::types::SecretKey, ) -> anyhow::Result<()> { tracing::info!(%account, "adding one more node"); match self { Nodes::Local { ctx, nodes } => { - nodes.push(local::Node::run(ctx, node_id, account, account_sk).await?) + nodes.push(local::Node::run(ctx, account, account_sk).await?) } Nodes::Docker { ctx, nodes } => { - nodes.push(containers::Node::run(ctx, node_id, account, account_sk).await?) + nodes.push(containers::Node::run(ctx, account, account_sk).await?) } } @@ -125,7 +124,7 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul .collect::, _>>()?; let mut node_futures = Vec::new(); for (i, account) in accounts.iter().enumerate() { - let node = containers::Node::run(&ctx, i as u32, account.id(), account.secret_key()); + let node = containers::Node::run(&ctx, account.id(), account.secret_key()); node_futures.push(node); } let nodes = futures::future::join_all(node_futures) @@ -172,12 +171,7 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result< .collect::, _>>()?; let mut node_futures = Vec::with_capacity(nodes); for (i, account) in accounts.iter().enumerate().take(nodes) { - node_futures.push(local::Node::run( - &ctx, - i as u32, - account.id(), - account.secret_key(), - )); + node_futures.push(local::Node::run(&ctx, account.id(), account.secret_key())); } let nodes = futures::future::join_all(node_futures) .await diff --git a/integration-tests/tests/multichain/mod.rs b/integration-tests/tests/multichain/mod.rs index eb85a0aea..492d95419 100644 --- a/integration-tests/tests/multichain/mod.rs +++ b/integration-tests/tests/multichain/mod.rs @@ -18,7 +18,7 @@ async fn test_multichain_reshare() -> anyhow::Result<()> { let account = ctx.nodes.ctx().worker.dev_create_account().await?; ctx.nodes - .add_node(3, account.id(), account.secret_key()) + .add_node(account.id(), account.secret_key()) .await?; // Wait for network to complete key reshare diff --git a/node/src/cli.rs b/node/src/cli.rs index a65cd7d78..ad8572e8a 100644 --- a/node/src/cli.rs +++ b/node/src/cli.rs @@ -16,9 +16,6 @@ use mpc_keys::hpke; #[derive(Parser, Debug)] pub enum Cli { Start { - /// Node ID - #[arg(long, value_parser = parse_participant, env("MPC_RECOVERY_NODE_ID"))] - node_id: Participant, /// NEAR RPC address #[arg( long, @@ -30,8 +27,8 @@ pub enum Cli { #[arg(long, env("MPC_RECOVERY_CONTRACT_ID"))] mpc_contract_id: AccountId, /// This node's account id - #[arg(long, env("MPC_RECOVERY_ACCOUNT"))] - account: AccountId, + #[arg(long, env("MPC_RECOVERY_ACCOUNT_ID"))] + account_id: AccountId, /// This node's account ed25519 secret key #[arg(long, env("MPC_RECOVERY_ACCOUNT_SK"))] account_sk: SecretKey, @@ -60,10 +57,9 @@ impl Cli { pub fn into_str_args(self) -> Vec { match self { Cli::Start { - node_id, near_rpc, + account_id, mpc_contract_id, - account, account_sk, web_port, cipher_pk, @@ -72,14 +68,12 @@ impl Cli { } => { let mut args = vec![ "start".to_string(), - "--node-id".to_string(), - u32::from(node_id).to_string(), "--near-rpc".to_string(), near_rpc, "--mpc-contract-id".to_string(), mpc_contract_id.to_string(), - "--account".to_string(), - account.to_string(), + "--account-id".to_string(), + account_id.to_string(), "--account-sk".to_string(), account_sk.to_string(), "--web-port".to_string(), @@ -111,11 +105,10 @@ pub fn run(cmd: Cli) -> anyhow::Result<()> { match cmd { Cli::Start { - node_id, near_rpc, web_port, mpc_contract_id, - account, + account_id, account_sk, cipher_pk, cipher_sk, @@ -140,11 +133,11 @@ pub fn run(cmd: Cli) -> anyhow::Result<()> { tracing::info!(%my_address, "address detected"); let rpc_client = near_fetch::Client::new(&near_rpc); tracing::debug!(rpc_addr = rpc_client.rpc_addr(), "rpc client initialized"); - let signer = InMemorySigner::from_secret_key(account, account_sk); + let signer = InMemorySigner::from_secret_key(account_id, account_sk); let (protocol, protocol_state) = MpcSignProtocol::init( - node_id, my_address, mpc_contract_id.clone(), + account_id, rpc_client.clone(), signer.clone(), receiver, diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index ae1b5adb9..badd3061c 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -17,14 +17,14 @@ use k256::Secp256k1; use mpc_keys::hpke; use near_crypto::InMemorySigner; use near_primitives::transaction::{Action, FunctionCallAction}; -use near_primitives::types::AccountId; +use near_sdk::AccountId; use std::cmp::Ordering; use std::sync::Arc; use tokio::sync::RwLock; use url::Url; pub trait ConsensusCtx { - fn me(&self) -> Participant; + fn me(&self) -> &AccountId; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; fn signer(&self) -> &InMemorySigner; @@ -91,11 +91,14 @@ impl ConsensusProtocol for StartedState { } Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { - if contract_state.participants.contains_key(&ctx.me()) { + if contract_state + .participants + .contains_key(&ctx.me()) + { tracing::info!( "contract state is running and we are already a participant" ); - let participants_vec: Vec = + let participants_vec: Vec = contract_state.participants.keys().cloned().collect(); Ok(NodeState::Running(RunningState { epoch, @@ -158,7 +161,10 @@ impl ConsensusProtocol for StartedState { }, None => match contract_state { ProtocolState::Initializing(contract_state) => { - if contract_state.participants.contains_key(&ctx.me()) { + if contract_state + .participants + .contains_key(&ctx.me()) + { tracing::info!("starting key generation as a part of the participant set"); let participants = contract_state.participants; let protocol = cait_sith::keygen::( @@ -347,8 +353,14 @@ impl ConsensusProtocol for WaitingForConsensusState { tracing::debug!( "waiting for resharing consensus, contract state has not been finalized yet" ); - let has_voted = contract_state.finished_votes.contains(&ctx.me()); - if !has_voted && contract_state.old_participants.contains_key(&ctx.me()) { + let has_voted = contract_state + .finished_votes + .contains(&ctx.me()); + if !has_voted + && contract_state + .old_participants + .contains_key(&ctx.me()) + { tracing::info!( epoch = self.epoch, "we haven't voted yet, voting for resharing to complete" @@ -420,8 +432,12 @@ impl ConsensusProtocol for RunningState { Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { tracing::info!("contract is resharing"); - if !contract_state.old_participants.contains_key(&ctx.me()) - || !contract_state.new_participants.contains_key(&ctx.me()) + if !contract_state + .old_participants + .contains_key(&ctx.me()) + || !contract_state + .new_participants + .contains_key(&ctx.me()) { return Err(ConsensusError::HasBeenKicked); } @@ -518,7 +534,10 @@ impl ConsensusProtocol for JoiningState { match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => { - if contract_state.candidates.contains_key(&ctx.me()) { + if contract_state + .candidates + .contains_key(&ctx.me()) + { let voted = contract_state .join_votes .get(&ctx.me()) @@ -563,7 +582,10 @@ impl ConsensusProtocol for JoiningState { } } ProtocolState::Resharing(contract_state) => { - if contract_state.new_participants.contains_key(&ctx.me()) { + if contract_state + .new_participants + .contains_key(&ctx.me()) + { tracing::info!("joining as a new participant"); start_resharing(None, ctx, contract_state) } else { diff --git a/node/src/protocol/contract.rs b/node/src/protocol/contract.rs deleted file mode 100644 index 0cb284f9a..000000000 --- a/node/src/protocol/contract.rs +++ /dev/null @@ -1,199 +0,0 @@ -use crate::types::PublicKey; -use crate::util::NearPublicKeyExt; -use cait_sith::protocol::Participant; -use mpc_contract::ProtocolContractState; -use mpc_keys::hpke; -use near_primitives::borsh::BorshDeserialize; -use near_sdk::AccountId; -use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, HashSet}; - -type ParticipantId = u32; - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] -pub struct ParticipantInfo { - pub id: ParticipantId, - pub account_id: AccountId, - pub url: String, - /// The public key used for encrypting messages. - pub cipher_pk: hpke::PublicKey, - /// The public key used for verifying messages. - pub sign_pk: near_crypto::PublicKey, -} - -impl From for ParticipantInfo { - fn from(value: mpc_contract::ParticipantInfo) -> Self { - ParticipantInfo { - id: value.id, - account_id: value.account_id, - url: value.url, - cipher_pk: hpke::PublicKey::from_bytes(&value.cipher_pk), - sign_pk: BorshDeserialize::try_from_slice(value.sign_pk.as_bytes()).unwrap(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct InitializingContractState { - pub participants: BTreeMap, - pub threshold: usize, - pub pk_votes: BTreeMap>, -} - -impl From for InitializingContractState { - fn from(value: mpc_contract::InitializingContractState) -> Self { - InitializingContractState { - participants: contract_participants_into_cait_participants(value.participants), - threshold: value.threshold, - pk_votes: value - .pk_votes - .into_iter() - .map(|(pk, participants)| { - ( - near_crypto::PublicKey::SECP256K1( - near_crypto::Secp256K1PublicKey::try_from(&pk.as_bytes()[1..]).unwrap(), - ), - participants - .into_iter() - .map(Participant::from) - .collect::>(), - ) - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct RunningContractState { - pub epoch: u64, - pub participants: BTreeMap, - pub threshold: usize, - pub public_key: PublicKey, - pub candidates: BTreeMap, - pub join_votes: BTreeMap>, - pub leave_votes: BTreeMap>, -} - -impl From for RunningContractState { - fn from(value: mpc_contract::RunningContractState) -> Self { - RunningContractState { - epoch: value.epoch, - participants: contract_participants_into_cait_participants(value.participants), - threshold: value.threshold, - public_key: value.public_key.into_affine_point(), - candidates: value - .candidates - .into_iter() - .map(|(p, p_info)| (Participant::from(p), p_info.into())) - .collect(), - join_votes: value - .join_votes - .into_iter() - .map(|(p, ps)| { - ( - Participant::from(p), - ps.into_iter().map(Participant::from).collect(), - ) - }) - .collect(), - leave_votes: value - .leave_votes - .into_iter() - .map(|(p, ps)| { - ( - Participant::from(p), - ps.into_iter().map(Participant::from).collect(), - ) - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ResharingContractState { - pub old_epoch: u64, - pub old_participants: BTreeMap, - pub new_participants: BTreeMap, - pub threshold: usize, - pub public_key: PublicKey, - pub finished_votes: HashSet, -} - -impl From for ResharingContractState { - fn from(value: mpc_contract::ResharingContractState) -> Self { - ResharingContractState { - old_epoch: value.old_epoch, - old_participants: contract_participants_into_cait_participants(value.old_participants), - new_participants: contract_participants_into_cait_participants(value.new_participants), - threshold: value.threshold, - public_key: value.public_key.into_affine_point(), - finished_votes: value - .finished_votes - .into_iter() - .map(Participant::from) - .collect(), - } - } -} - -#[derive(Debug)] -pub enum ProtocolState { - Initializing(InitializingContractState), - Running(RunningContractState), - Resharing(ResharingContractState), -} - -impl ProtocolState { - pub fn participants(&self) -> &BTreeMap { - match self { - ProtocolState::Initializing(InitializingContractState { participants, .. }) => { - participants - } - ProtocolState::Running(RunningContractState { participants, .. }) => participants, - ProtocolState::Resharing(ResharingContractState { - old_participants, .. - }) => old_participants, - } - } - - pub fn public_key(&self) -> Option<&PublicKey> { - match self { - ProtocolState::Initializing { .. } => None, - ProtocolState::Running(RunningContractState { public_key, .. }) => Some(public_key), - ProtocolState::Resharing(ResharingContractState { public_key, .. }) => Some(public_key), - } - } - - pub fn threshold(&self) -> usize { - match self { - ProtocolState::Initializing(InitializingContractState { threshold, .. }) => *threshold, - ProtocolState::Running(RunningContractState { threshold, .. }) => *threshold, - ProtocolState::Resharing(ResharingContractState { threshold, .. }) => *threshold, - } - } -} - -impl TryFrom for ProtocolState { - type Error = (); - - fn try_from(value: ProtocolContractState) -> Result { - match value { - ProtocolContractState::Initializing(state) => { - Ok(ProtocolState::Initializing(state.into())) - } - ProtocolContractState::Running(state) => Ok(ProtocolState::Running(state.into())), - ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())), - } - } -} - -fn contract_participants_into_cait_participants( - participants: BTreeMap, -) -> BTreeMap { - participants - .into_values() - .map(|p| (Participant::from(p.id), p.into())) - .collect() -} diff --git a/node/src/protocol/contract/mod.rs b/node/src/protocol/contract/mod.rs new file mode 100644 index 000000000..8d095d24b --- /dev/null +++ b/node/src/protocol/contract/mod.rs @@ -0,0 +1,138 @@ +pub mod primitives; + +use crate::types::PublicKey; +use crate::util::NearPublicKeyExt; +use mpc_contract::ProtocolContractState; +use near_primitives::types::AccountId; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +use self::primitives::{Participants, PkVotes, Candidates, Votes}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct InitializingContractState { + pub participants: Participants, + pub threshold: usize, + pub pk_votes: PkVotes, +} + +impl From for InitializingContractState { + fn from(value: mpc_contract::InitializingContractState) -> Self { + InitializingContractState { + participants: value.participants.into(), + threshold: value.threshold, + pk_votes: value.pk_votes.into(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct RunningContractState { + pub epoch: u64, + pub participants: Participants, + pub threshold: usize, + pub public_key: PublicKey, + pub candidates: Candidates, + pub join_votes: Votes, + pub leave_votes: Votes, +} + +impl From for RunningContractState { + fn from(value: mpc_contract::RunningContractState) -> Self { + RunningContractState { + epoch: value.epoch, + participants: value.participants.into(), + threshold: value.threshold, + public_key: value.public_key.into_affine_point(), + candidates: value.candidates.into(), + join_votes: value.join_votes.into(), + leave_votes: value.leave_votes.into(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ResharingContractState { + pub old_epoch: u64, + pub old_participants: Participants, + pub new_participants: Participants, + pub threshold: usize, + pub public_key: PublicKey, + pub finished_votes: HashSet, +} + +impl From for ResharingContractState { + fn from(contract_state: mpc_contract::ResharingContractState) -> Self { + ResharingContractState { + old_epoch: contract_state.old_epoch, + old_participants: contract_state.old_participants.into(), + new_participants: contract_state.new_participants.into(), + threshold: contract_state.threshold, + public_key: contract_state.public_key.into_affine_point(), + finished_votes: contract_state.finished_votes.clone(), + } + } +} + +#[derive(Debug)] +pub enum ProtocolState { + Initializing(InitializingContractState), + Running(RunningContractState), + Resharing(ResharingContractState), +} + +impl ProtocolState { + pub fn participants(&self) -> &Participants { + match self { + ProtocolState::Initializing(InitializingContractState { participants, .. }) => { + participants + } + ProtocolState::Running(RunningContractState { participants, .. }) => participants, + ProtocolState::Resharing(ResharingContractState { + old_participants, .. + }) => old_participants, + } + } + + pub fn public_key(&self) -> Option<&PublicKey> { + match self { + ProtocolState::Initializing { .. } => None, + ProtocolState::Running(RunningContractState { public_key, .. }) => Some(public_key), + ProtocolState::Resharing(ResharingContractState { public_key, .. }) => Some(public_key), + } + } + + pub fn threshold(&self) -> usize { + match self { + ProtocolState::Initializing(InitializingContractState { threshold, .. }) => *threshold, + ProtocolState::Running(RunningContractState { threshold, .. }) => *threshold, + ProtocolState::Resharing(ResharingContractState { threshold, .. }) => *threshold, + } + } +} + +impl TryFrom for ProtocolState { + type Error = (); + + fn try_from(value: ProtocolContractState) -> Result { + match value { + ProtocolContractState::Initializing(state) => { + Ok(ProtocolState::Initializing(state.into())) + } + ProtocolContractState::Running(state) => Ok(ProtocolState::Running(state.into())), + ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())), + } + } +} + +impl From for Participants { + fn from(value: mpc_contract::primitives::Participants) -> Self { + Participants { + participants: value + .participants + .into_iter() + .map(|(account_id, participant_info)| (account_id, participant_info.into())) + .collect(), + } + } +} \ No newline at end of file diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs new file mode 100644 index 000000000..87e767309 --- /dev/null +++ b/node/src/protocol/contract/primitives.rs @@ -0,0 +1,136 @@ +use mpc_keys::hpke; +use near_primitives::borsh::BorshDeserialize; +use near_primitives::types::AccountId; +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashSet}; + +type ParticipantId = u32; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct ParticipantInfo { + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: near_crypto::PublicKey, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Participants { + pub participants: BTreeMap, +} + +impl Participants { + pub fn get(&self, id: &AccountId) -> Option<&ParticipantInfo> { + self.participants.get(id) + } + + pub fn contains_key(&self, id: &AccountId) -> bool { + self.participants.contains_key(id) + } + + pub fn keys(&self) -> impl Iterator { + self.participants.keys() + } + + pub fn iter(&self) -> impl Iterator { + self.participants.iter() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct CandidateInfo { + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: near_crypto::PublicKey, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Candidates { + pub candidates: BTreeMap, +} + +impl Candidates { + pub fn get(&self, id: &AccountId) -> Option<&CandidateInfo> { + self.candidates.get(id) + } + + pub fn contains_key(&self, id: &AccountId) -> bool { + self.candidates.contains_key(id) + } + + pub fn keys(&self) -> impl Iterator { + self.candidates.keys() + } + + pub fn iter(&self) -> impl Iterator { + self.candidates.iter() + } +} + +impl From for CandidateInfo { + fn from(contract_info: mpc_contract::primitives::CandidateInfo) -> Self { + CandidateInfo { + account_id: AccountId::from(contract_info.account_id), + url: contract_info.url, + cipher_pk: hpke::PublicKey::from_bytes(&contract_info.cipher_pk), + sign_pk: BorshDeserialize::try_from_slice(contract_info.sign_pk.as_bytes()).unwrap(), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PkVotes { + pub pk_votes: BTreeMap>, +} + +impl From for PkVotes { + fn from(contract_votes: mpc_contract::primitives::PkVotes) -> Self { + PkVotes { + pk_votes: contract_votes + .votes + .into_iter() + .map(|(pk, participants)| { + ( + near_crypto::PublicKey::SECP256K1( + near_crypto::Secp256K1PublicKey::try_from(&pk.as_bytes()[1..]).unwrap(), + ), + participants + .into_iter() + .map(AccountId::from) + .collect::>(), + ) + }) + .collect(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Votes { + pub votes: BTreeMap>, +} + +impl From for Votes { + fn from(contract_votes: mpc_contract::Votes) -> Self { + Votes { + votes: contract_votes + .votes + .into_iter() + .map(|(account_id, votes)| { + ( + account_id, + votes + .into_iter() + .map(|account_id| account_id) + .collect::>(), + ) + }) + .collect(), + } + } +} \ No newline at end of file diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index b77735347..9ba0e50b1 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -6,13 +6,13 @@ use crate::protocol::message::{GeneratingMessage, ResharingMessage}; use crate::protocol::state::WaitingForConsensusState; use crate::protocol::MpcMessage; use async_trait::async_trait; -use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; +use cait_sith::protocol::{Action, InitializationError, ProtocolError, Participant}; use k256::elliptic_curve::group::GroupEncoding; use near_crypto::InMemorySigner; use near_primitives::types::AccountId; pub trait CryptographicCtx { - fn me(&self) -> Participant; + fn my_near_acc_id(&self) -> &AccountId; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; fn signer(&self) -> &InMemorySigner; @@ -25,7 +25,7 @@ pub enum CryptographicError { #[error("failed to send a message: {0}")] SendError(#[from] SendError), #[error("unknown participant: {0:?}")] - UnknownParticipant(Participant), + UnknownParticipant(near_sdk::AccountId), #[error("rpc error: {0}")] RpcError(#[from] near_fetch::Error), #[error("cait-sith initialization error: {0}")] @@ -75,8 +75,8 @@ impl CryptographicProtocol for GeneratingState { } Action::SendMany(m) => { tracing::debug!("sending a message to many participants"); - for (p, info) in &self.participants { - if p == &ctx.me() { + for (p, info) in self.participants.iter() { + if p == &self.participants.find(ctx.my_near_acc_id()) { // Skip yourself, cait-sith never sends messages to oneself continue; } diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index 0122418f1..e6cd3ab0f 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -8,6 +8,7 @@ mod triple; pub mod message; pub mod state; +use cait_sith::protocol::Participant; pub use contract::{ParticipantInfo, ProtocolState}; pub use message::MpcMessage; pub use signature::SignQueue; @@ -21,7 +22,6 @@ use crate::protocol::consensus::ConsensusProtocol; use crate::protocol::cryptography::CryptographicProtocol; use crate::protocol::message::{MessageHandler, MpcMessageQueue}; use crate::rpc_client::{self}; -use cait_sith::protocol::Participant; use near_crypto::InMemorySigner; use near_primitives::types::AccountId; use reqwest::IntoUrl; @@ -33,8 +33,8 @@ use url::Url; use mpc_keys::hpke; struct Ctx { - me: Participant, my_address: Url, + account_id: AccountId, mpc_contract_id: AccountId, signer: InMemorySigner, rpc_client: near_fetch::Client, @@ -46,7 +46,7 @@ struct Ctx { impl ConsensusCtx for &Ctx { fn me(&self) -> Participant { - self.me + &self.account_id } fn http_client(&self) -> &reqwest::Client { @@ -83,8 +83,8 @@ impl ConsensusCtx for &Ctx { } impl CryptographicCtx for &Ctx { - fn me(&self) -> Participant { - self.me + fn me(&self) -> &AccountId { + &self.account_id } fn http_client(&self) -> &reqwest::Client { @@ -109,8 +109,8 @@ impl CryptographicCtx for &Ctx { } impl MessageCtx for &Ctx { - fn me(&self) -> Participant { - self.me + fn my_near_acc_id(&self) -> AccountId { + self.account_id } } @@ -123,9 +123,9 @@ pub struct MpcSignProtocol { impl MpcSignProtocol { #![allow(clippy::too_many_arguments)] pub fn init( - me: Participant, my_address: U, mpc_contract_id: AccountId, + account_id: AccountId, rpc_client: near_fetch::Client, signer: InMemorySigner, receiver: mpsc::Receiver, @@ -134,8 +134,8 @@ impl MpcSignProtocol { ) -> (Self, Arc>) { let state = Arc::new(RwLock::new(NodeState::Starting)); let ctx = Ctx { - me, my_address: my_address.into_url().unwrap(), + account_id, mpc_contract_id, rpc_client, http_client: reqwest::Client::new(), @@ -153,7 +153,7 @@ impl MpcSignProtocol { } pub async fn run(mut self) -> anyhow::Result<()> { - let _span = tracing::info_span!("running", me = u32::from(self.ctx.me)); + let _span = tracing::info_span!("running", me = u32::from(self.ctx)); let mut queue = MpcMessageQueue::default(); loop { tracing::debug!("trying to advance mpc recovery protocol"); diff --git a/node/src/protocol/state.rs b/node/src/protocol/state.rs index 66af8d340..7120ee818 100644 --- a/node/src/protocol/state.rs +++ b/node/src/protocol/state.rs @@ -1,11 +1,11 @@ +use super::contract::Participants; use super::presignature::PresignatureManager; use super::signature::SignatureManager; use super::triple::TripleManager; use super::SignQueue; use crate::protocol::ParticipantInfo; use crate::types::{KeygenProtocol, PrivateKeyShare, PublicKey, ReshareProtocol}; -use cait_sith::protocol::Participant; -use std::collections::BTreeMap; +use near_primitives::types::AccountId; use std::sync::Arc; use tokio::sync::RwLock; @@ -21,7 +21,7 @@ pub struct StartedState(pub Option); #[derive(Clone)] pub struct GeneratingState { - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub protocol: KeygenProtocol, } @@ -29,7 +29,7 @@ pub struct GeneratingState { #[derive(Clone)] pub struct WaitingForConsensusState { pub epoch: u64, - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub private_share: PrivateKeyShare, pub public_key: PublicKey, @@ -38,7 +38,7 @@ pub struct WaitingForConsensusState { #[derive(Clone)] pub struct RunningState { pub epoch: u64, - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub private_share: PrivateKeyShare, pub public_key: PublicKey, @@ -51,8 +51,8 @@ pub struct RunningState { #[derive(Clone)] pub struct ResharingState { pub old_epoch: u64, - pub old_participants: BTreeMap, - pub new_participants: BTreeMap, + pub old_participants: Participants, + pub new_participants: Participants, pub threshold: usize, pub public_key: PublicKey, pub protocol: ReshareProtocol, @@ -77,15 +77,15 @@ pub enum NodeState { } impl NodeState { - pub fn fetch_participant(&self, p: Participant) -> Option { + pub fn fetch_participant(&self, account_id: AccountId) -> Option { let participants = match self { NodeState::Running(state) => &state.participants, NodeState::Generating(state) => &state.participants, NodeState::WaitingForConsensus(state) => &state.participants, NodeState::Resharing(state) => { - if let Some(info) = state.new_participants.get(&p) { + if let Some(info) = state.new_participants.get(&account_id) { return Some(info.clone()); - } else if let Some(info) = state.old_participants.get(&p) { + } else if let Some(info) = state.old_participants.get(&account_id) { return Some(info.clone()); } else { return None; @@ -94,6 +94,6 @@ impl NodeState { _ => return None, }; - participants.get(&p).cloned() + participants.get(&account_id).cloned() } } From 4d9e22737af1e03fa1ce92b1c5ba116b209621aa Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 13 Dec 2023 16:58:24 +0200 Subject: [PATCH 02/21] contract state maping fixed --- contract/src/lib.rs | 2 +- node/src/protocol/contract/mod.rs | 18 +--- node/src/protocol/contract/primitives.rs | 113 ++++++++++++++++++----- 3 files changed, 94 insertions(+), 39 deletions(-) diff --git a/contract/src/lib.rs b/contract/src/lib.rs index 811e803cb..89ed1f979 100644 --- a/contract/src/lib.rs +++ b/contract/src/lib.rs @@ -3,7 +3,7 @@ pub mod primitives; use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; use near_sdk::serde::{Deserialize, Serialize}; use near_sdk::{env, near_bindgen, AccountId, PanicOnDefault, PublicKey}; -use primitives::{Votes, PkVotes, Participants, CandidateInfo, Candidates}; +use primitives::{CandidateInfo, Candidates, Participants, PkVotes, Votes}; use std::collections::HashSet; #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] diff --git a/node/src/protocol/contract/mod.rs b/node/src/protocol/contract/mod.rs index 8d095d24b..9b6171f85 100644 --- a/node/src/protocol/contract/mod.rs +++ b/node/src/protocol/contract/mod.rs @@ -5,7 +5,7 @@ use crate::util::NearPublicKeyExt; use mpc_contract::ProtocolContractState; use near_primitives::types::AccountId; use serde::{Deserialize, Serialize}; -use std::collections::HashSet; +use std::{collections::HashSet, str::FromStr}; use self::primitives::{Participants, PkVotes, Candidates, Votes}; @@ -69,7 +69,9 @@ impl From for ResharingContractState { new_participants: contract_state.new_participants.into(), threshold: contract_state.threshold, public_key: contract_state.public_key.into_affine_point(), - finished_votes: contract_state.finished_votes.clone(), + finished_votes: contract_state.finished_votes.into_iter().map(|acc_id| { + AccountId::from_str(&acc_id.to_string()).unwrap() // TODO: code duplication + }).collect(), } } } @@ -123,16 +125,4 @@ impl TryFrom for ProtocolState { ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())), } } -} - -impl From for Participants { - fn from(value: mpc_contract::primitives::Participants) -> Self { - Participants { - participants: value - .participants - .into_iter() - .map(|(account_id, participant_info)| (account_id, participant_info.into())) - .collect(), - } - } } \ No newline at end of file diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 87e767309..b9d2ce25e 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -1,13 +1,17 @@ +use cait_sith::protocol::Participant; use mpc_keys::hpke; -use near_primitives::borsh::BorshDeserialize; use near_primitives::types::AccountId; use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, HashSet}; +use std::{ + collections::{BTreeMap, HashSet}, + str::FromStr, +}; type ParticipantId = u32; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct ParticipantInfo { + pub id: ParticipantId, // TODO: do we need this parameter? pub account_id: AccountId, pub url: String, /// The public key used for encrypting messages. @@ -18,23 +22,59 @@ pub struct ParticipantInfo { #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Participants { - pub participants: BTreeMap, + pub participants: BTreeMap, +} + +impl From for Participants { + fn from(contract_participants: mpc_contract::primitives::Participants) -> Self { + Participants { + // take position of participant in contract_participants as id for participants + participants: contract_participants + .participants + .into_iter() + .enumerate() + .map(|(participant_id, participant)| { + let contract_participant_info = participant.1; + ( + Participant::from(participant_id as ParticipantId), + ParticipantInfo { + id: participant_id as ParticipantId, + account_id: AccountId::from_str( + &contract_participant_info.account_id.to_string(), + ) + .unwrap(), // TODO: remove unwrap + url: contract_participant_info.url, + cipher_pk: hpke::PublicKey::from_bytes( + &contract_participant_info.cipher_pk, + ), + sign_pk: near_crypto::PublicKey::SECP256K1( + near_crypto::Secp256K1PublicKey::try_from( + &contract_participant_info.sign_pk.as_bytes()[1..], + ) + .unwrap(), + ), + }, + ) + }) + .collect(), + } + } } impl Participants { - pub fn get(&self, id: &AccountId) -> Option<&ParticipantInfo> { + pub fn get(&self, id: &Participant) -> Option<&ParticipantInfo> { self.participants.get(id) } - pub fn contains_key(&self, id: &AccountId) -> bool { + pub fn contains_key(&self, id: &Participant) -> bool { self.participants.contains_key(id) } - pub fn keys(&self) -> impl Iterator { + pub fn keys(&self) -> impl Iterator { self.participants.keys() } - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.participants.iter() } } @@ -72,13 +112,32 @@ impl Candidates { } } -impl From for CandidateInfo { - fn from(contract_info: mpc_contract::primitives::CandidateInfo) -> Self { - CandidateInfo { - account_id: AccountId::from(contract_info.account_id), - url: contract_info.url, - cipher_pk: hpke::PublicKey::from_bytes(&contract_info.cipher_pk), - sign_pk: BorshDeserialize::try_from_slice(contract_info.sign_pk.as_bytes()).unwrap(), +impl From for Candidates { + fn from(contract_candidates: mpc_contract::primitives::Candidates) -> Self { + Candidates { + candidates: contract_candidates + .candidates + .into_iter() + .map(|(account_id, candidate_info)| { + ( + AccountId::from_str(&account_id.to_string()).unwrap(), // TODO: fix unwrap + CandidateInfo { + account_id: AccountId::from_str( + &candidate_info.account_id.to_string(), + ) + .unwrap(), // TODO: fix unwrap + url: candidate_info.url, + cipher_pk: hpke::PublicKey::from_bytes(&candidate_info.cipher_pk), + sign_pk: near_crypto::PublicKey::SECP256K1( + near_crypto::Secp256K1PublicKey::try_from( + &candidate_info.sign_pk.as_bytes()[1..], + ) + .unwrap(), + ), + }, + ) + }) + .collect(), } } } @@ -101,8 +160,11 @@ impl From for PkVotes { ), participants .into_iter() - .map(AccountId::from) - .collect::>(), + .map(|acc_id: near_sdk::AccountId| { + AccountId::from_str(&acc_id.to_string()).unwrap() + // TODO: fix unwrap + }) + .collect(), ) }) .collect(), @@ -115,22 +177,25 @@ pub struct Votes { pub votes: BTreeMap>, } -impl From for Votes { - fn from(contract_votes: mpc_contract::Votes) -> Self { +impl From for Votes { + fn from(contract_votes: mpc_contract::primitives::Votes) -> Self { Votes { votes: contract_votes .votes .into_iter() - .map(|(account_id, votes)| { + .map(|(accountId, participants)| { ( - account_id, - votes + AccountId::from_str(&accountId.to_string()).unwrap(), // TODO: fix unwrap + participants .into_iter() - .map(|account_id| account_id) - .collect::>(), + .map(|acc_id: near_sdk::AccountId| { + AccountId::from_str(&acc_id.to_string()).unwrap() + // TODO: fix unwrap + }) + .collect(), // TODO: remove code duplication ) }) .collect(), } } -} \ No newline at end of file +} From 5b45c500ebf84b409d79d09ffbe23f37794d7bfd Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 13 Dec 2023 20:16:26 +0200 Subject: [PATCH 03/21] min changes --- node/src/protocol/state.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/node/src/protocol/state.rs b/node/src/protocol/state.rs index fa0e5ae1d..6a5453d08 100644 --- a/node/src/protocol/state.rs +++ b/node/src/protocol/state.rs @@ -1,12 +1,10 @@ -use super::contract::Participants; +use super::contract::primitives::{Participants, ParticipantInfo}; use super::cryptography::CryptographicError; use super::presignature::PresignatureManager; use super::signature::SignatureManager; use super::triple::TripleManager; use super::SignQueue; -use crate::protocol::ParticipantInfo; use crate::types::{KeygenProtocol, PrivateKeyShare, PublicKey, ReshareProtocol}; -use near_primitives::types::AccountId; use std::sync::Arc; use tokio::sync::RwLock; From fb3ec6d1683d1a1a4c6205f93c7abf4d24bd3721 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 10 Jan 2024 19:19:03 +0200 Subject: [PATCH 04/21] node id refactoring --- node/src/http_client.rs | 2 +- node/src/protocol/consensus.rs | 46 ++---- node/src/protocol/contract.rs | 200 ----------------------- node/src/protocol/contract/mod.rs | 15 +- node/src/protocol/contract/primitives.rs | 6 +- node/src/protocol/cryptography.rs | 4 +- node/src/protocol/message.rs | 2 +- node/src/protocol/mod.rs | 12 +- node/src/protocol/state.rs | 7 +- 9 files changed, 35 insertions(+), 259 deletions(-) delete mode 100644 node/src/protocol/contract.rs diff --git a/node/src/http_client.rs b/node/src/http_client.rs index 10574a362..c0a3792be 100644 --- a/node/src/http_client.rs +++ b/node/src/http_client.rs @@ -1,6 +1,6 @@ +use crate::protocol::contract::primitives::ParticipantInfo; use crate::protocol::message::SignedMessage; use crate::protocol::MpcMessage; -use crate::protocol::ParticipantInfo; use cait_sith::protocol::Participant; use mpc_keys::hpke; use reqwest::{Client, IntoUrl}; diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index 13e7e6278..7391e6f72 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -1,9 +1,8 @@ -use super::contract::{ProtocolState, ResharingContractState}; use super::state::{ JoiningState, NodeState, PersistentNodeData, RunningState, StartedState, WaitingForConsensusState, }; -use super::SignQueue; +use super::{ProtocolState, SignQueue}; use crate::protocol::presignature::PresignatureManager; use crate::protocol::signature::SignatureManager; use crate::protocol::state::{GeneratingState, ResharingState}; @@ -15,17 +14,18 @@ use crate::{http_client, rpc_client}; use async_trait::async_trait; use cait_sith::protocol::{InitializationError, Participant}; use k256::Secp256k1; +use mpc_contract::ResharingContractState; use mpc_keys::hpke; use near_crypto::InMemorySigner; use near_primitives::transaction::{Action, FunctionCallAction}; -use near_sdk::AccountId; +use near_primitives::types::AccountId; use std::cmp::Ordering; use std::sync::Arc; use tokio::sync::RwLock; use url::Url; pub trait ConsensusCtx { - fn me(&self) -> &AccountId; + fn my_account_id(&self) -> &AccountId; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; fn signer(&self) -> &InMemorySigner; @@ -99,10 +99,7 @@ impl ConsensusProtocol for StartedState { } Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { - if contract_state - .participants - .contains_key(&ctx.me()) - { + if contract_state.participants.contains_key(&ctx.me()) { tracing::info!( "contract state is running and we are already a participant" ); @@ -176,10 +173,7 @@ impl ConsensusProtocol for StartedState { }, None => match contract_state { ProtocolState::Initializing(contract_state) => { - if contract_state - .participants - .contains_key(&ctx.me()) - { + if contract_state.participants.contains_key(&ctx.me()) { tracing::info!("starting key generation as a part of the participant set"); let participants = contract_state.participants; let protocol = cait_sith::keygen::( @@ -376,14 +370,8 @@ impl ConsensusProtocol for WaitingForConsensusState { tracing::debug!( "waiting for resharing consensus, contract state has not been finalized yet" ); - let has_voted = contract_state - .finished_votes - .contains(&ctx.me()); - if !has_voted - && contract_state - .old_participants - .contains_key(&ctx.me()) - { + let has_voted = contract_state.finished_votes.contains(&ctx.me()); + if !has_voted && contract_state.old_participants.contains_key(&ctx.me()) { tracing::info!( epoch = self.epoch, "we haven't voted yet, voting for resharing to complete" @@ -457,12 +445,8 @@ impl ConsensusProtocol for RunningState { Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { tracing::info!("contract is resharing"); - if !contract_state - .old_participants - .contains_key(&ctx.me()) - || !contract_state - .new_participants - .contains_key(&ctx.me()) + if !contract_state.old_participants.contains_key(&ctx.me()) + || !contract_state.new_participants.contains_key(&ctx.me()) { return Err(ConsensusError::HasBeenKicked); } @@ -561,10 +545,7 @@ impl ConsensusProtocol for JoiningState { match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => { - if contract_state - .candidates - .contains_key(&ctx.me()) - { + if contract_state.candidates.contains_key(&ctx.me()) { let voted = contract_state .join_votes .get(&ctx.me()) @@ -609,10 +590,7 @@ impl ConsensusProtocol for JoiningState { } } ProtocolState::Resharing(contract_state) => { - if contract_state - .new_participants - .contains_key(&ctx.me()) - { + if contract_state.new_participants.contains_key(&ctx.me()) { tracing::info!("joining as a new participant"); start_resharing(None, ctx, contract_state) } else { diff --git a/node/src/protocol/contract.rs b/node/src/protocol/contract.rs deleted file mode 100644 index 08e606450..000000000 --- a/node/src/protocol/contract.rs +++ /dev/null @@ -1,200 +0,0 @@ -use crate::types::PublicKey; -use crate::util::NearPublicKeyExt; -use cait_sith::protocol::Participant; -use mpc_contract::ProtocolContractState; -use mpc_keys::hpke; -use near_primitives::borsh::BorshDeserialize; -use near_sdk::AccountId; -use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, HashSet}; - -type ParticipantId = u32; - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] -pub struct ParticipantInfo { - pub id: ParticipantId, - pub account_id: AccountId, - pub url: String, - /// The public key used for encrypting messages. - pub cipher_pk: hpke::PublicKey, - /// The public key used for verifying messages. - pub sign_pk: near_crypto::PublicKey, -} - -impl From for ParticipantInfo { - fn from(value: mpc_contract::ParticipantInfo) -> Self { - ParticipantInfo { - id: value.id, - account_id: value.account_id, - url: value.url, - cipher_pk: hpke::PublicKey::from_bytes(&value.cipher_pk), - sign_pk: BorshDeserialize::try_from_slice(value.sign_pk.as_bytes()).unwrap(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct InitializingContractState { - pub participants: BTreeMap, - pub threshold: usize, - pub pk_votes: BTreeMap>, -} - -impl From for InitializingContractState { - fn from(value: mpc_contract::InitializingContractState) -> Self { - InitializingContractState { - participants: contract_participants_into_cait_participants(value.participants), - threshold: value.threshold, - pk_votes: value - .pk_votes - .into_iter() - .map(|(pk, participants)| { - ( - near_crypto::PublicKey::SECP256K1( - near_crypto::Secp256K1PublicKey::try_from(&pk.as_bytes()[1..]).unwrap(), - ), - participants - .into_iter() - .map(Participant::from) - .collect::>(), - ) - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct RunningContractState { - pub epoch: u64, - pub participants: BTreeMap, - pub threshold: usize, - pub public_key: PublicKey, - pub candidates: BTreeMap, - pub join_votes: BTreeMap>, - pub leave_votes: BTreeMap>, -} - -impl From for RunningContractState { - fn from(value: mpc_contract::RunningContractState) -> Self { - RunningContractState { - epoch: value.epoch, - participants: contract_participants_into_cait_participants(value.participants), - threshold: value.threshold, - public_key: value.public_key.into_affine_point(), - candidates: value - .candidates - .into_iter() - .map(|(p, p_info)| (Participant::from(p), p_info.into())) - .collect(), - join_votes: value - .join_votes - .into_iter() - .map(|(p, ps)| { - ( - Participant::from(p), - ps.into_iter().map(Participant::from).collect(), - ) - }) - .collect(), - leave_votes: value - .leave_votes - .into_iter() - .map(|(p, ps)| { - ( - Participant::from(p), - ps.into_iter().map(Participant::from).collect(), - ) - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ResharingContractState { - pub old_epoch: u64, - pub old_participants: BTreeMap, - pub new_participants: BTreeMap, - pub threshold: usize, - pub public_key: PublicKey, - pub finished_votes: HashSet, -} - -impl From for ResharingContractState { - fn from(value: mpc_contract::ResharingContractState) -> Self { - ResharingContractState { - old_epoch: value.old_epoch, - old_participants: contract_participants_into_cait_participants(value.old_participants), - new_participants: contract_participants_into_cait_participants(value.new_participants), - threshold: value.threshold, - public_key: value.public_key.into_affine_point(), - finished_votes: value - .finished_votes - .into_iter() - .map(Participant::from) - .collect(), - } - } -} - -#[derive(Debug)] -pub enum ProtocolState { - Initializing(InitializingContractState), - Running(RunningContractState), - Resharing(ResharingContractState), -} - -impl ProtocolState { - pub fn participants(&self) -> &BTreeMap { - match self { - ProtocolState::Initializing(InitializingContractState { participants, .. }) => { - participants - } - ProtocolState::Running(RunningContractState { participants, .. }) => participants, - ProtocolState::Resharing(ResharingContractState { - old_participants, .. - }) => old_participants, - } - } - - pub fn public_key(&self) -> Option<&PublicKey> { - match self { - ProtocolState::Initializing { .. } => None, - ProtocolState::Running(RunningContractState { public_key, .. }) => Some(public_key), - ProtocolState::Resharing(ResharingContractState { public_key, .. }) => Some(public_key), - } - } - - pub fn threshold(&self) -> usize { - match self { - ProtocolState::Initializing(InitializingContractState { threshold, .. }) => *threshold, - ProtocolState::Running(RunningContractState { threshold, .. }) => *threshold, - ProtocolState::Resharing(ResharingContractState { threshold, .. }) => *threshold, - } - } -} - -impl TryFrom for ProtocolState { - type Error = (); - - fn try_from(value: ProtocolContractState) -> Result { - match value { - ProtocolContractState::NotInitialized => Err(()), - ProtocolContractState::Initializing(state) => { - Ok(ProtocolState::Initializing(state.into())) - } - ProtocolContractState::Running(state) => Ok(ProtocolState::Running(state.into())), - ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())), - } - } -} - -fn contract_participants_into_cait_participants( - participants: BTreeMap, -) -> BTreeMap { - participants - .into_values() - .map(|p| (Participant::from(p.id), p.into())) - .collect() -} diff --git a/node/src/protocol/contract/mod.rs b/node/src/protocol/contract/mod.rs index 9b6171f85..e24b153b5 100644 --- a/node/src/protocol/contract/mod.rs +++ b/node/src/protocol/contract/mod.rs @@ -7,7 +7,7 @@ use near_primitives::types::AccountId; use serde::{Deserialize, Serialize}; use std::{collections::HashSet, str::FromStr}; -use self::primitives::{Participants, PkVotes, Candidates, Votes}; +use self::primitives::{Candidates, Participants, PkVotes, Votes}; #[derive(Serialize, Deserialize, Debug)] pub struct InitializingContractState { @@ -69,9 +69,13 @@ impl From for ResharingContractState { new_participants: contract_state.new_participants.into(), threshold: contract_state.threshold, public_key: contract_state.public_key.into_affine_point(), - finished_votes: contract_state.finished_votes.into_iter().map(|acc_id| { - AccountId::from_str(&acc_id.to_string()).unwrap() // TODO: code duplication - }).collect(), + finished_votes: contract_state + .finished_votes + .into_iter() + .map(|acc_id| { + AccountId::from_str(&acc_id.to_string()).unwrap() // TODO: code duplication + }) + .collect(), } } } @@ -123,6 +127,7 @@ impl TryFrom for ProtocolState { } ProtocolContractState::Running(state) => Ok(ProtocolState::Running(state.into())), ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())), + ProtocolContractState::NotInitialized => Err(()), } } -} \ No newline at end of file +} diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index b9d2ce25e..afec0d744 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -122,10 +122,8 @@ impl From for Candidates { ( AccountId::from_str(&account_id.to_string()).unwrap(), // TODO: fix unwrap CandidateInfo { - account_id: AccountId::from_str( - &candidate_info.account_id.to_string(), - ) - .unwrap(), // TODO: fix unwrap + account_id: AccountId::from_str(&candidate_info.account_id.to_string()) + .unwrap(), // TODO: fix unwrap url: candidate_info.url, cipher_pk: hpke::PublicKey::from_bytes(&candidate_info.cipher_pk), sign_pk: near_crypto::PublicKey::SECP256K1( diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index c1170a3f8..58daf5868 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -7,7 +7,7 @@ use crate::protocol::state::{PersistentNodeData, WaitingForConsensusState}; use crate::protocol::MpcMessage; use crate::storage::{SecretNodeStorageBox, SecretStorageError}; use async_trait::async_trait; -use cait_sith::protocol::{Action, InitializationError, ProtocolError}; +use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; use k256::elliptic_curve::group::GroupEncoding; use mpc_keys::hpke; use near_crypto::InMemorySigner; @@ -29,7 +29,7 @@ pub enum CryptographicError { #[error("failed to send a message: {0}")] SendError(#[from] SendError), #[error("unknown participant: {0:?}")] - UnknownParticipant(near_sdk::AccountId), + UnknownParticipant(Participant), #[error("rpc error: {0}")] RpcError(#[from] near_fetch::Error), #[error("cait-sith initialization error: {0}")] diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index a2925eca4..b3df08fd5 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -10,7 +10,7 @@ use k256::Scalar; use mpc_keys::hpke::{self, Ciphered}; use near_crypto::Signature; use near_primitives::hash::CryptoHash; -use near_sdk::AccountId; +use near_primitives::types::AccountId; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, VecDeque}; use std::sync::Arc; diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index db21d994c..11625eadc 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -1,4 +1,4 @@ -mod contract; +pub mod contract; mod cryptography; mod presignature; mod signature; @@ -8,9 +8,8 @@ pub mod consensus; pub mod message; pub mod state; -use cait_sith::protocol::Participant; pub use consensus::ConsensusError; -pub use contract::{ParticipantInfo, ProtocolState}; +pub use contract::ProtocolState; pub use cryptography::CryptographicError; pub use message::MpcMessage; pub use signature::SignQueue; @@ -25,7 +24,6 @@ use crate::protocol::cryptography::CryptographicProtocol; use crate::protocol::message::{MessageHandler, MpcMessageQueue}; use crate::rpc_client::{self}; use crate::storage::SecretNodeStorageBox; -use cait_sith::protocol::Participant; use near_crypto::InMemorySigner; use near_primitives::types::AccountId; use reqwest::IntoUrl; @@ -50,7 +48,7 @@ struct Ctx { } impl ConsensusCtx for &Ctx { - fn me(&self) -> Participant { + fn my_account_id(&self) -> &AccountId { &self.account_id } @@ -176,7 +174,7 @@ impl MpcSignProtocol { } pub async fn run(mut self) -> anyhow::Result<()> { - let _span = tracing::info_span!("running", me = u32::from(self.ctx)); + let _span = tracing::info_span!("running", me = self.ctx.account_id.to_string()); let mut queue = MpcMessageQueue::default(); loop { tracing::debug!("trying to advance mpc recovery protocol"); @@ -216,7 +214,7 @@ impl MpcSignProtocol { let guard = self.state.read().await; guard.clone() }; - let state = match state.progress(&mut self.ctx).await { + let state = match state.progress(&self.ctx).await { Ok(state) => state, Err(err) => { tracing::info!("protocol unable to progress: {err:?}"); diff --git a/node/src/protocol/state.rs b/node/src/protocol/state.rs index f295d6e1c..8b5637a0a 100644 --- a/node/src/protocol/state.rs +++ b/node/src/protocol/state.rs @@ -5,12 +5,9 @@ use super::signature::SignatureManager; use super::triple::TripleManager; use super::SignQueue; use crate::http_client::MessageQueue; -use crate::protocol::ParticipantInfo; -use crate::types::{KeygenProtocol, PrivateKeyShare, PublicKey, ReshareProtocol}; use crate::types::{KeygenProtocol, PublicKey, ReshareProtocol, SecretKeyShare}; use cait_sith::protocol::Participant; use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; use std::sync::Arc; use tokio::sync::RwLock; @@ -106,7 +103,7 @@ impl ResharingState { #[derive(Clone)] pub struct JoiningState { - pub participants: BTreeMap, + pub participants: Participants, pub public_key: PublicKey, } @@ -150,7 +147,7 @@ impl NodeState { fn fetch_participant<'a>( p: &Participant, - participants: &'a BTreeMap, + participants: &'a Participants, ) -> Result<&'a ParticipantInfo, CryptographicError> { participants .get(p) From 660b142a564c1def281f10fd8f307b05b6f0d8f2 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Wed, 10 Jan 2024 22:58:45 +0200 Subject: [PATCH 05/21] id refactoring --- contract/src/primitives.rs | 4 ++ node/src/protocol/consensus.rs | 64 ++++++++++++++++-------- node/src/protocol/contract/primitives.rs | 19 +++++++ 3 files changed, 66 insertions(+), 21 deletions(-) diff --git a/contract/src/primitives.rs b/contract/src/primitives.rs index 719887b1f..9d2f4f765 100644 --- a/contract/src/primitives.rs +++ b/contract/src/primitives.rs @@ -96,6 +96,10 @@ impl Participants { pub fn into_iter(self) -> impl Iterator { self.participants.into_iter() } + + pub fn keys(&self) -> impl Iterator { + self.participants.keys() + } } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)] diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index 7391e6f72..7efce0d9d 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -1,3 +1,4 @@ +use super::contract::ResharingContractState; use super::state::{ JoiningState, NodeState, PersistentNodeData, RunningState, StartedState, WaitingForConsensusState, @@ -14,7 +15,6 @@ use crate::{http_client, rpc_client}; use async_trait::async_trait; use cait_sith::protocol::{InitializationError, Participant}; use k256::Secp256k1; -use mpc_contract::ResharingContractState; use mpc_keys::hpke; use near_crypto::InMemorySigner; use near_primitives::transaction::{Action, FunctionCallAction}; @@ -74,6 +74,10 @@ impl ConsensusProtocol for StartedState { ctx: C, contract_state: ProtocolState, ) -> Result { + let me = contract_state + .participants() + .find_participant(ctx.my_account_id()) + .unwrap(); // TODO: remove unwrap match self.0 { Some(PersistentNodeData { epoch, @@ -99,12 +103,16 @@ impl ConsensusProtocol for StartedState { } Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { - if contract_state.participants.contains_key(&ctx.me()) { + if contract_state + .participants + .contains_account_id(&ctx.my_account_id()) + { tracing::info!( "contract state is running and we are already a participant" ); - let participants_vec: Vec = + let participants_vec: Vec = contract_state.participants.keys().cloned().collect(); + Ok(NodeState::Running(RunningState { epoch, participants: contract_state.participants, @@ -114,14 +122,14 @@ impl ConsensusProtocol for StartedState { sign_queue: ctx.sign_queue(), triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - ctx.me(), + me, contract_state.threshold, epoch, ))), presignature_manager: Arc::new(RwLock::new( PresignatureManager::new( participants_vec.clone(), - ctx.me(), + me, contract_state.threshold, epoch, ), @@ -129,7 +137,7 @@ impl ConsensusProtocol for StartedState { signature_manager: Arc::new(RwLock::new( SignatureManager::new( participants_vec, - ctx.me(), + me, contract_state.public_key, epoch, ), @@ -173,12 +181,15 @@ impl ConsensusProtocol for StartedState { }, None => match contract_state { ProtocolState::Initializing(contract_state) => { - if contract_state.participants.contains_key(&ctx.me()) { + if contract_state + .participants + .contains_account_id(&ctx.my_account_id()) + { tracing::info!("starting key generation as a part of the participant set"); let participants = contract_state.participants; let protocol = cait_sith::keygen::( &participants.keys().cloned().collect::>(), - ctx.me(), + me, contract_state.threshold, )?; Ok(NodeState::Generating(GeneratingState { @@ -262,6 +273,10 @@ impl ConsensusProtocol for WaitingForConsensusState { ctx: C, contract_state: ProtocolState, ) -> Result { + let me = self + .participants + .find_participant(ctx.my_account_id()) + .unwrap(); // TODO: remove unwrap match contract_state { ProtocolState::Initializing(contract_state) => { tracing::debug!("waiting for consensus, contract state has not been finalized yet"); @@ -269,7 +284,7 @@ impl ConsensusProtocol for WaitingForConsensusState { let has_voted = contract_state .pk_votes .get(&public_key) - .map(|ps| ps.contains(&ctx.me())) + .map(|ps| ps.contains(ctx.my_account_id())) .unwrap_or_default(); if !has_voted { tracing::info!("we haven't voted yet, voting for the generated public key"); @@ -319,19 +334,19 @@ impl ConsensusProtocol for WaitingForConsensusState { sign_queue: ctx.sign_queue(), triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - ctx.me(), + me, self.threshold, self.epoch, ))), presignature_manager: Arc::new(RwLock::new(PresignatureManager::new( participants_vec.clone(), - ctx.me(), + me, self.threshold, self.epoch, ))), signature_manager: Arc::new(RwLock::new(SignatureManager::new( participants_vec, - ctx.me(), + me, self.public_key, self.epoch, ))), @@ -370,8 +385,8 @@ impl ConsensusProtocol for WaitingForConsensusState { tracing::debug!( "waiting for resharing consensus, contract state has not been finalized yet" ); - let has_voted = contract_state.finished_votes.contains(&ctx.me()); - if !has_voted && contract_state.old_participants.contains_key(&ctx.me()) { + let has_voted = contract_state.finished_votes.contains(ctx.my_account_id()); + if !has_voted && contract_state.old_participants.contains_key(&me) { tracing::info!( epoch = self.epoch, "we haven't voted yet, voting for resharing to complete" @@ -400,6 +415,10 @@ impl ConsensusProtocol for RunningState { ctx: C, contract_state: ProtocolState, ) -> Result { + let me = contract_state + .participants() + .find_participant(ctx.my_account_id()) + .unwrap(); // TODO: remove unwrap match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => match contract_state.epoch.cmp(&self.epoch) { @@ -445,8 +464,8 @@ impl ConsensusProtocol for RunningState { Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { tracing::info!("contract is resharing"); - if !contract_state.old_participants.contains_key(&ctx.me()) - || !contract_state.new_participants.contains_key(&ctx.me()) + if !contract_state.old_participants.contains_key(&me) + || !contract_state.new_participants.contains_key(&me) { return Err(ConsensusError::HasBeenKicked); } @@ -548,7 +567,7 @@ impl ConsensusProtocol for JoiningState { if contract_state.candidates.contains_key(&ctx.me()) { let voted = contract_state .join_votes - .get(&ctx.me()) + .get(&ctx.my_account_id()) .cloned() .unwrap_or_default(); tracing::info!( @@ -642,16 +661,19 @@ fn start_resharing( .cloned() .collect::>(), contract_state.threshold, - ctx.me(), + contract_state + .old_participants + .find_participant(ctx.my_account_id()) + .unwrap(), // TODO: remove unwrap private_share, contract_state.public_key, )?; Ok(NodeState::Resharing(ResharingState { old_epoch: contract_state.old_epoch, - old_participants: contract_state.old_participants, - new_participants: contract_state.new_participants, + old_participants: contract_state.old_participants.into(), + new_participants: contract_state.new_participants.into(), threshold: contract_state.threshold, - public_key: contract_state.public_key, + public_key: contract_state.public_key.into(), protocol: Arc::new(RwLock::new(protocol)), messages: Default::default(), })) diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index afec0d744..740f27d2a 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -77,6 +77,19 @@ impl Participants { pub fn iter(&self) -> impl Iterator { self.participants.iter() } + + pub fn find_participant(&self, account_id: &AccountId) -> Option { + self.participants + .iter() + .find(|(_, participant_info)| participant_info.account_id == *account_id) + .map(|(participant, _)| *participant) + } + + pub fn contains_account_id(&self, account_id: &AccountId) -> bool { + self.participants + .iter() + .any(|(_, participant_info)| participant_info.account_id == *account_id) + } } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] @@ -145,6 +158,12 @@ pub struct PkVotes { pub pk_votes: BTreeMap>, } +impl PkVotes { + pub fn get(&self, id: &near_crypto::PublicKey) -> Option<&HashSet> { + self.pk_votes.get(id) + } +} + impl From for PkVotes { fn from(contract_votes: mpc_contract::primitives::PkVotes) -> Self { PkVotes { From 685a6aa59675d05a95c87431e658d0d759089f9b Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 11 Jan 2024 20:44:12 +0200 Subject: [PATCH 06/21] id refactoring --- node/src/cli.rs | 2 +- node/src/http_client.rs | 11 ++- node/src/protocol/consensus.rs | 14 +-- node/src/protocol/contract/primitives.rs | 31 ++++++- node/src/protocol/cryptography.rs | 39 ++++---- node/src/protocol/message.rs | 108 +++++++++++++++++------ node/src/protocol/mod.rs | 6 +- node/src/protocol/presignature.rs | 7 +- node/src/protocol/signature.rs | 7 +- node/src/protocol/state.rs | 18 ++++ node/src/protocol/triple.rs | 9 +- 11 files changed, 188 insertions(+), 64 deletions(-) diff --git a/node/src/cli.rs b/node/src/cli.rs index a4ca786a1..79b9f0db6 100644 --- a/node/src/cli.rs +++ b/node/src/cli.rs @@ -149,7 +149,7 @@ pub fn run(cmd: Cli) -> anyhow::Result<()> { tracing::info!(%my_address, "address detected"); let rpc_client = near_fetch::Client::new(&near_rpc); tracing::debug!(rpc_addr = rpc_client.rpc_addr(), "rpc client initialized"); - let signer = InMemorySigner::from_secret_key(account_id, account_sk); + let signer = InMemorySigner::from_secret_key(account_id.clone(), account_sk); let (protocol, protocol_state) = MpcSignProtocol::init( my_address, mpc_contract_id.clone(), diff --git a/node/src/http_client.rs b/node/src/http_client.rs index c0a3792be..78ad15153 100644 --- a/node/src/http_client.rs +++ b/node/src/http_client.rs @@ -8,6 +8,7 @@ use std::collections::VecDeque; use std::str::Utf8Error; use tokio_retry::strategy::{jitter, ExponentialBackoff}; use tokio_retry::Retry; +use near_primitives::types::AccountId; #[derive(Debug, thiserror::Error)] pub enum SendError { @@ -26,7 +27,7 @@ pub enum SendError { } async fn send_encrypted( - from: Participant, + from: &AccountId, cipher_pk: &hpke::PublicKey, sign_sk: &near_crypto::SecretKey, client: &Client, @@ -127,7 +128,7 @@ impl MessageQueue { pub async fn send_encrypted( &mut self, - from: Participant, + from: &AccountId, sign_sk: &near_crypto::SecretKey, client: &Client, ) -> Result<(), SendError> { @@ -151,6 +152,10 @@ impl MessageQueue { #[cfg(test)] mod tests { + use std::str::FromStr; + + use near_lake_primitives::AccountId; + use crate::protocol::message::GeneratingMessage; use crate::protocol::MpcMessage; @@ -159,7 +164,7 @@ mod tests { let associated_data = b""; let (sk, pk) = mpc_keys::hpke::generate(); let starting_message = MpcMessage::Generating(GeneratingMessage { - from: cait_sith::protocol::Participant::from(0), + from: AccountId::from_str("alice.near").unwrap(), data: vec![], }); diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index 7efce0d9d..3daef57ea 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -564,7 +564,7 @@ impl ConsensusProtocol for JoiningState { match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => { - if contract_state.candidates.contains_key(&ctx.me()) { + if contract_state.candidates.contains_key(&ctx.my_account_id()) { let voted = contract_state .join_votes .get(&ctx.my_account_id()) @@ -575,11 +575,14 @@ impl ConsensusProtocol for JoiningState { votes_to_go = contract_state.threshold - voted.len(), "trying to get participants to vote for us" ); - for (p, info) in contract_state.participants { - if voted.contains(&p) { + for account_id in contract_state.participants.account_ids() { + if voted.contains(&account_id) { continue; } - http_client::join(ctx.http_client(), info.url, &ctx.me()) + let info = contract_state.participants.find_participant_info(&account_id).unwrap(); + let my_participant_info = contract_state.participants.find_participant_info(ctx.my_account_id()).unwrap(); + let participant: Participant = my_participant_info.id.clone().into(); + http_client::join(ctx.http_client(), info.url.clone(), &participant) // TODO: should this be account_id? .await .unwrap() } @@ -587,7 +590,6 @@ impl ConsensusProtocol for JoiningState { } else { tracing::info!("sending a transaction to join the participant set"); let args = serde_json::json!({ - "participant_id": ctx.me(), "url": ctx.my_address(), "cipher_pk": ctx.cipher_pk().to_bytes(), "sign_pk": ctx.sign_pk(), @@ -609,7 +611,7 @@ impl ConsensusProtocol for JoiningState { } } ProtocolState::Resharing(contract_state) => { - if contract_state.new_participants.contains_key(&ctx.me()) { + if contract_state.new_participants.contains_account_id(&ctx.my_account_id()) { tracing::info!("joining as a new participant"); start_resharing(None, ctx, contract_state) } else { diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 740f27d2a..b0cd2b52b 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -20,7 +20,7 @@ pub struct ParticipantInfo { pub sign_pk: near_crypto::PublicKey, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct Participants { pub participants: BTreeMap, } @@ -61,6 +61,15 @@ impl From for Participants { } } +impl IntoIterator for Participants { + type Item = (Participant, ParticipantInfo); + type IntoIter = std::collections::btree_map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.participants.into_iter() + } +} + impl Participants { pub fn get(&self, id: &Participant) -> Option<&ParticipantInfo> { self.participants.get(id) @@ -85,11 +94,25 @@ impl Participants { .map(|(participant, _)| *participant) } + pub fn find_participant_info(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { + self.participants + .iter() + .find(|(_, participant_info)| participant_info.account_id == *account_id) + .map(|(_, participant_info)| participant_info) + } + pub fn contains_account_id(&self, account_id: &AccountId) -> bool { self.participants .iter() .any(|(_, participant_info)| participant_info.account_id == *account_id) } + + pub fn account_ids(&self) -> Vec { + self.participants + .iter() + .map(|(_, participant_info)| participant_info.account_id.clone()) + .collect() + } } #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] @@ -194,6 +217,12 @@ pub struct Votes { pub votes: BTreeMap>, } +impl Votes { + pub fn get(&self, id: &AccountId) -> Option<&HashSet> { + self.votes.get(id) + } +} + impl From for Votes { fn from(contract_votes: mpc_contract::primitives::Votes) -> Self { Votes { diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index 58daf5868..0b3e8c0a0 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -81,10 +81,10 @@ impl CryptographicProtocol for GeneratingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) .await { - tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message"); + tracing::warn!(?err, "generating: failed to send encrypted message"); } return Ok(NodeState::Generating(self)); @@ -92,15 +92,15 @@ impl CryptographicProtocol for GeneratingState { Action::SendMany(m) => { tracing::debug!("sending a message to many participants"); let mut messages = self.messages.write().await; - for (p, info) in &self.participants { - if p == &self.participants.find(ctx.my_near_acc_id()) { + for (p, info) in self.participants.clone() { + if p == self.participants.find_participant(ctx.my_near_acc_id()).unwrap() { // Skip yourself, cait-sith never sends messages to oneself continue; } messages.push( info.clone(), MpcMessage::Generating(GeneratingMessage { - from: ctx.me(), + from: ctx.my_near_acc_id().clone(), data: m.clone(), }), ); @@ -112,7 +112,7 @@ impl CryptographicProtocol for GeneratingState { self.messages.write().await.push( info.clone(), MpcMessage::Generating(GeneratingMessage { - from: ctx.me(), + from: ctx.my_near_acc_id().clone(), data: m.clone(), }), ); @@ -134,7 +134,7 @@ impl CryptographicProtocol for GeneratingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message"); @@ -163,7 +163,7 @@ impl CryptographicProtocol for WaitingForConsensusState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "waiting: failed to send encrypted message"); @@ -192,10 +192,10 @@ impl CryptographicProtocol for ResharingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) .await { - tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(wait): failed to send encrypted message"); + tracing::warn!(?err, "resharing(wait): failed to send encrypted message"); } return Ok(NodeState::Resharing(self)); @@ -203,8 +203,8 @@ impl CryptographicProtocol for ResharingState { Action::SendMany(m) => { tracing::debug!("sending a message to all participants"); let mut messages = self.messages.write().await; - for (p, info) in &self.new_participants { - if p == &ctx.me() { + for (_, info) in self.new_participants.clone() { + if &info.account_id == ctx.my_near_acc_id() { // Skip yourself, cait-sith never sends messages to oneself continue; } @@ -213,7 +213,7 @@ impl CryptographicProtocol for ResharingState { info.clone(), MpcMessage::Resharing(ResharingMessage { epoch: self.old_epoch, - from: ctx.me(), + from: ctx.my_near_acc_id().clone(), data: m.clone(), }), ) @@ -226,7 +226,7 @@ impl CryptographicProtocol for ResharingState { info.clone(), MpcMessage::Resharing(ResharingMessage { epoch: self.old_epoch, - from: ctx.me(), + from: ctx.my_near_acc_id().clone(), data: m.clone(), }), ), @@ -241,7 +241,7 @@ impl CryptographicProtocol for ResharingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(return): failed to send encrypted message"); @@ -267,10 +267,11 @@ impl CryptographicProtocol for RunningState { mut self, ctx: C, ) -> Result { + let me = self.participants.find_participant(ctx.my_near_acc_id()).unwrap(); let mut messages = self.messages.write().await; // Try sending any leftover messages donated to RunningState. if let Err(err) = messages - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "running(pre): failed to send encrypted message"); @@ -309,8 +310,8 @@ impl CryptographicProtocol for RunningState { let mut sign_queue = self.sign_queue.write().await; let mut signature_manager = self.signature_manager.write().await; - sign_queue.organize(&self, ctx.me()); - let my_requests = sign_queue.my_requests(ctx.me()); + sign_queue.organize(&self, me); + let my_requests = sign_queue.my_requests(me); while presignature_manager.my_len() > 0 { let Some((receipt_id, _)) = my_requests.iter().next() else { break; @@ -340,7 +341,7 @@ impl CryptographicProtocol for RunningState { .await?; drop(signature_manager); if let Err(err) = messages - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "running(post): failed to send encrypted message"); diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index b3df08fd5..b23aa2f02 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -17,19 +17,19 @@ use std::sync::Arc; use tokio::sync::RwLock; pub trait MessageCtx { - fn my_near_acc_id(&self) -> AccountId; + fn my_near_acc_id(&self) -> &AccountId; } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct GeneratingMessage { - pub from: Participant, + pub from: AccountId, pub data: MessageData, } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct ResharingMessage { pub epoch: u64, - pub from: Participant, + pub from: AccountId, pub data: MessageData, } @@ -37,7 +37,7 @@ pub struct ResharingMessage { pub struct TripleMessage { pub id: u64, pub epoch: u64, - pub from: Participant, + pub from: AccountId, pub data: MessageData, } @@ -47,20 +47,20 @@ pub struct PresignatureMessage { pub triple0: TripleId, pub triple1: TripleId, pub epoch: u64, - pub from: Participant, + pub from: AccountId, pub data: MessageData, } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct SignatureMessage { pub receipt_id: CryptoHash, - pub proposer: Participant, + pub proposer: Participant, // TODO: should it be node AccountId? pub presignature_id: PresignatureId, pub msg_hash: [u8; 32], pub epsilon: Scalar, pub delta: Scalar, pub epoch: u64, - pub from: Participant, + pub from: AccountId, pub data: MessageData, } @@ -172,13 +172,25 @@ pub trait MessageHandler { impl MessageHandler for GeneratingState { async fn handle( &mut self, - _ctx: C, + ctx: C, queue: &mut MpcMessageQueue, ) -> Result<(), MessageHandleError> { let mut protocol = self.protocol.write().await; while let Some(msg) = queue.generating.pop_front() { tracing::debug!("handling new generating message"); - protocol.message(msg.from, msg.data); + let participant = self + .participants + .find_participant(&msg.from) + .map(|participant| { + tracing::debug!(from = %msg.from, "handling message from"); + protocol.message(participant, msg.data); + }) + .unwrap_or_else(|| { + tracing::warn!( + participant = %msg.from, + "received message from unknown participant" + ); + }); } Ok(()) } @@ -194,8 +206,19 @@ impl MessageHandler for ResharingState { let q = queue.resharing_bins.entry(self.old_epoch).or_default(); let mut protocol = self.protocol.write().await; while let Some(msg) = q.pop_front() { - tracing::debug!("handling new resharing message"); - protocol.message(msg.from, msg.data); + let participant = self + .old_participants + .find_participant(&msg.from) + .map(|participant| { + tracing::debug!(from = %msg.from, "handling resharing message from"); + protocol.message(participant, msg.data); + }) + .unwrap_or_else(|| { + tracing::warn!( + participant = %msg.from, + "received message from unknown participant in resharing state" + ); + }); } Ok(()) } @@ -215,7 +238,19 @@ impl MessageHandler for RunningState { .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; while let Some(message) = queue.pop_front() { - protocol.message(message.from, message.data); + let participant = self + .participants + .find_participant(&message.from) + .map(|participant| { + tracing::debug!(from = %message.from, "running state, triple message, handling message from"); + protocol.message(participant, message.data); + }) + .unwrap_or_else(|| { + tracing::warn!( + participant = %message.from, + "received message from unknown participant in running state" + ); + }); } } } @@ -236,7 +271,19 @@ impl MessageHandler for RunningState { let mut protocol = protocol .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; - protocol.message(message.from, message.data) + let participant = self + .participants + .find_participant(&message.from) + .map(|participant| { + tracing::debug!(from = %message.from, "running state, presignature message, handling message from"); + protocol.message(participant, message.data); + }) + .unwrap_or_else(|| { + tracing::warn!( + participant = %message.from, + "received message from unknown participant in running state" + ); + }); } Err(presignature::GenerationError::AlreadyGenerated) => { tracing::info!(id, "presignature already generated, nothing left to do") @@ -290,7 +337,17 @@ impl MessageHandler for RunningState { let mut protocol = protocol .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; - protocol.message(message.from, message.data) + let participant = match self.participants.find_participant(&message.from) { + Some(participant) => participant, + None => { + tracing::warn!( + participant = %message.from, + "received message from unknown participant in running state" + ); + continue; + } + }; + protocol.message(participant, message.data) } None => { // Store the message until we are ready to process it @@ -338,7 +395,7 @@ pub struct SignedMessage { /// The signature used to verify the authenticity of the encrypted message. pub sig: Signature, /// From which particpant the message was sent. - pub from: Participant, + pub from: AccountId, } impl SignedMessage { @@ -351,13 +408,13 @@ where { pub fn encrypt( msg: T, - from: Participant, + from: &AccountId, sign_sk: &near_crypto::SecretKey, cipher_pk: &hpke::PublicKey, ) -> Result { let msg = serde_json::to_vec(&msg)?; let sig = sign_sk.sign(&msg); - let msg = SignedMessage { msg, sig, from }; + let msg = SignedMessage { msg, sig, from: from.clone() }; let msg = serde_json::to_vec(&msg)?; let ciphered = cipher_pk .encrypt(&msg, SignedMessage::::ASSOCIATED_DATA) @@ -379,14 +436,15 @@ where .decrypt(&encrypted, SignedMessage::::ASSOCIATED_DATA) .map_err(|err| CryptographicError::Encryption(err.to_string()))?; let SignedMessage::> { msg, sig, from } = serde_json::from_slice(&message)?; - if !sig.verify( - &msg, - &protocol_state - .read() - .await - .fetch_participant(&from)? - .sign_pk, - ) { + let sign_pk = match protocol_state.read().await.find_participant_info(&from) { + Some(info) => info.sign_pk.clone(), + None => { + return Err(CryptographicError::Encryption( + "unknown participant".to_string(), + )); + } + }; + if !sig.verify(&msg, &sign_pk) { return Err(CryptographicError::Encryption( "invalid signature while verifying authenticity of encrypted ".to_string(), )); diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index 11625eadc..f7e07ffae 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -128,8 +128,8 @@ impl CryptographicCtx for &Ctx { } impl MessageCtx for &Ctx { - fn my_near_acc_id(&self) -> AccountId { - self.account_id + fn my_near_acc_id(&self) -> &AccountId { + &self.account_id } } @@ -214,7 +214,7 @@ impl MpcSignProtocol { let guard = self.state.read().await; guard.clone() }; - let state = match state.progress(&self.ctx).await { + let state = match state.progress(&mut self.ctx).await { Ok(state) => state, Err(err) => { tracing::info!("protocol unable to progress: {err:?}"); diff --git a/node/src/protocol/presignature.rs b/node/src/protocol/presignature.rs index d20ed74d3..30ae8624b 100644 --- a/node/src/protocol/presignature.rs +++ b/node/src/protocol/presignature.rs @@ -7,7 +7,9 @@ use cait_sith::{KeygenOutput, PresignArguments, PresignOutput}; use k256::Secp256k1; use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; +use std::str::FromStr; use std::sync::Arc; +use near_primitives::types::AccountId; /// Unique number used to identify a specific ongoing presignature generation protocol. /// Without `PresignatureId` it would be unclear where to route incoming cait-sith presignature @@ -230,6 +232,7 @@ impl PresignatureManager { break false; } }; + let my_account_id: AccountId = AccountId::from_str("acc.near").unwrap(); // TODO: account id is not available in this context match action { Action::Wait => { tracing::debug!("waiting"); @@ -245,7 +248,7 @@ impl PresignatureManager { triple0: generator.triple0, triple1: generator.triple1, epoch: self.epoch, - from: self.me, + from: my_account_id.clone(), data: data.clone(), }, )) @@ -258,7 +261,7 @@ impl PresignatureManager { triple0: generator.triple0, triple1: generator.triple1, epoch: self.epoch, - from: self.me, + from: my_account_id, data: data.clone(), }, )), diff --git a/node/src/protocol/signature.rs b/node/src/protocol/signature.rs index b6933e6e8..6d0435bc4 100644 --- a/node/src/protocol/signature.rs +++ b/node/src/protocol/signature.rs @@ -254,6 +254,9 @@ impl SignatureManager { break false; } }; + + let my_near_account_id: AccountId = "acc.near".parse().unwrap(); // TODO: account id is not available in this context + match action { Action::Wait => { tracing::debug!("waiting"); @@ -272,7 +275,7 @@ impl SignatureManager { epsilon: generator.epsilon, delta: generator.delta, epoch: self.epoch, - from: self.me, + from: my_near_account_id.clone(), data: data.clone(), }, )) @@ -288,7 +291,7 @@ impl SignatureManager { epsilon: generator.epsilon, delta: generator.delta, epoch: self.epoch, - from: self.me, + from: my_near_account_id, data: data.clone(), }, )), diff --git a/node/src/protocol/state.rs b/node/src/protocol/state.rs index 8b5637a0a..80cae1245 100644 --- a/node/src/protocol/state.rs +++ b/node/src/protocol/state.rs @@ -7,6 +7,7 @@ use super::SignQueue; use crate::http_client::MessageQueue; use crate::types::{KeygenProtocol, PublicKey, ReshareProtocol, SecretKeyShare}; use cait_sith::protocol::Participant; +use near_primitives::types::AccountId; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tokio::sync::RwLock; @@ -143,6 +144,23 @@ impl NodeState { _ => Err(CryptographicError::UnknownParticipant(*p)), } } + + pub fn find_participant_info(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { + match self { + NodeState::Starting => None, + NodeState::Started(_) => None, + NodeState::Generating(state) => state.participants.find_participant_info(account_id), + NodeState::WaitingForConsensus(state) => { + state.participants.find_participant_info(account_id) + } + NodeState::Running(state) => state.participants.find_participant_info(account_id), + NodeState::Resharing(state) => state + .new_participants + .find_participant_info(account_id) + .or_else(|| state.old_participants.find_participant_info(account_id)), + NodeState::Joining(state) => state.participants.find_participant_info(account_id), + } + } } fn fetch_participant<'a>( diff --git a/node/src/protocol/triple.rs b/node/src/protocol/triple.rs index 7357f28ab..bbe9d2600 100644 --- a/node/src/protocol/triple.rs +++ b/node/src/protocol/triple.rs @@ -5,8 +5,10 @@ use crate::util::AffinePointExt; use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; use cait_sith::triples::{TriplePub, TripleShare}; use k256::Secp256k1; +use near_lake_primitives::AccountId; use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; +use std::str::FromStr; use std::sync::Arc; /// Unique number used to identify a specific ongoing triple generation protocol. @@ -196,6 +198,9 @@ impl TripleManager { } }; + // TODO: in this context we don't have access to the account id of the node. + let my_account_id: AccountId = AccountId::from_str("acc.near").unwrap(); + match action { Action::Wait => { tracing::debug!("waiting"); @@ -209,7 +214,7 @@ impl TripleManager { TripleMessage { id: *id, epoch: self.epoch, - from: self.me, + from: my_account_id.clone(), data: data.clone(), }, )) @@ -220,7 +225,7 @@ impl TripleManager { TripleMessage { id: *id, epoch: self.epoch, - from: self.me, + from: my_account_id, data: data.clone(), }, )), From fa2e745cbc518f8f8bdb0ccdad586a32d5fb2cd2 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 11 Jan 2024 21:00:18 +0200 Subject: [PATCH 07/21] id refactoring --- integration-tests/src/multichain/containers.rs | 8 ++++++-- integration-tests/src/multichain/local.rs | 10 ++++------ integration-tests/src/multichain/mod.rs | 12 +++++------- node/src/http_client.rs | 2 +- node/src/protocol/consensus.rs | 15 ++++++++++++--- node/src/protocol/cryptography.rs | 11 +++++++++-- node/src/protocol/message.rs | 6 +++++- node/src/protocol/mod.rs | 2 +- node/src/protocol/presignature.rs | 2 +- node/src/protocol/triple.rs | 2 +- 10 files changed, 45 insertions(+), 25 deletions(-) diff --git a/integration-tests/src/multichain/containers.rs b/integration-tests/src/multichain/containers.rs index 36f65b375..6040bb57a 100644 --- a/integration-tests/src/multichain/containers.rs +++ b/integration-tests/src/multichain/containers.rs @@ -35,7 +35,7 @@ impl<'a> Node<'a> { account_id: &AccountId, account_sk: &near_workspaces::types::SecretKey, ) -> anyhow::Result> { - tracing::info!(node_id, "running node container"); + tracing::info!("running node container, account_id={}", account_id); let (cipher_sk, cipher_pk) = hpke::generate(); let args = mpc_recovery_node::cli::Cli::Start { near_rpc: ctx.lake_indexer.rpc_host_address.clone(), @@ -78,7 +78,11 @@ impl<'a> Node<'a> { }); let full_address = format!("http://{ip_address}:{}", Self::CONTAINER_PORT); - tracing::info!(node_id, full_address, "node container is running"); + tracing::info!( + full_address, + "node container is running, account_id={}", + account_id + ); Ok(Node { container, address: full_address, diff --git a/integration-tests/src/multichain/local.rs b/integration-tests/src/multichain/local.rs index d835b5265..e5f1249c9 100644 --- a/integration-tests/src/multichain/local.rs +++ b/integration-tests/src/multichain/local.rs @@ -6,8 +6,7 @@ use near_workspaces::AccountId; #[allow(dead_code)] pub struct Node { pub address: String, - node_id: usize, - account: AccountId, + account_id: AccountId, pub account_sk: near_workspaces::types::SecretKey, pub cipher_pk: hpke::PublicKey, cipher_sk: hpke::SecretKey, @@ -46,17 +45,16 @@ impl Node { }, }; - let mpc_node_id = format!("multichain/{node_id}"); + let mpc_node_id = format!("multichain/{account_id}", account_id = account_id); let process = mpc::spawn_multichain(ctx.release, &mpc_node_id, cli)?; let address = format!("http://127.0.0.1:{web_port}"); tracing::info!("node is starting at {}", address); util::ping_until_ok(&address, 60).await?; - tracing::info!("node started [node_id={node_id}, {address}]"); + tracing::info!("node started [node_account_id={account_id}, {address}]"); Ok(Self { address, - node_id: node_id as usize, - account: account.clone(), + account_id: account_id.clone(), account_sk: account_sk.clone(), cipher_pk, cipher_sk, diff --git a/integration-tests/src/multichain/mod.rs b/integration-tests/src/multichain/mod.rs index 4e9ed6c89..cd241be35 100644 --- a/integration-tests/src/multichain/mod.rs +++ b/integration-tests/src/multichain/mod.rs @@ -3,7 +3,7 @@ pub mod local; use crate::env::containers::DockerClient; use crate::{initialize_lake_indexer, LakeIndexerCtx}; -use mpc_contract::ParticipantInfo; +use mpc_contract::primitives::ParticipantInfo; use near_workspaces::network::Sandbox; use near_workspaces::{AccountId, Contract, Worker}; use serde_json::json; @@ -123,7 +123,7 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul .into_iter() .collect::, _>>()?; let mut node_futures = Vec::new(); - for (i, account) in accounts.iter().enumerate() { + for (_, account) in accounts.iter().enumerate() { let node = containers::Node::run(&ctx, account.id(), account.secret_key()); node_futures.push(node); } @@ -136,11 +136,10 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul .cloned() .enumerate() .zip(&nodes) - .map(|((i, account), node)| { + .map(|((_, account), node)| { ( account.id().clone(), ParticipantInfo { - id: i as u32, account_id: account.id().to_string().parse().unwrap(), url: node.address.clone(), cipher_pk: node.cipher_pk.to_bytes(), @@ -170,7 +169,7 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result< .into_iter() .collect::, _>>()?; let mut node_futures = Vec::with_capacity(nodes); - for (i, account) in accounts.iter().enumerate().take(nodes) { + for (_, account) in accounts.iter().enumerate().take(nodes) { node_futures.push(local::Node::run(&ctx, account.id(), account.secret_key())); } let nodes = futures::future::join_all(node_futures) @@ -182,11 +181,10 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result< .cloned() .enumerate() .zip(&nodes) - .map(|((i, account), node)| { + .map(|((_, account), node)| { ( account.id().clone(), ParticipantInfo { - id: i as u32, account_id: account.id().to_string().parse().unwrap(), url: node.address.clone(), cipher_pk: node.cipher_pk.to_bytes(), diff --git a/node/src/http_client.rs b/node/src/http_client.rs index 78ad15153..b10684a83 100644 --- a/node/src/http_client.rs +++ b/node/src/http_client.rs @@ -3,12 +3,12 @@ use crate::protocol::message::SignedMessage; use crate::protocol::MpcMessage; use cait_sith::protocol::Participant; use mpc_keys::hpke; +use near_primitives::types::AccountId; use reqwest::{Client, IntoUrl}; use std::collections::VecDeque; use std::str::Utf8Error; use tokio_retry::strategy::{jitter, ExponentialBackoff}; use tokio_retry::Retry; -use near_primitives::types::AccountId; #[derive(Debug, thiserror::Error)] pub enum SendError { diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index 3daef57ea..34b6fae49 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -579,8 +579,14 @@ impl ConsensusProtocol for JoiningState { if voted.contains(&account_id) { continue; } - let info = contract_state.participants.find_participant_info(&account_id).unwrap(); - let my_participant_info = contract_state.participants.find_participant_info(ctx.my_account_id()).unwrap(); + let info = contract_state + .participants + .find_participant_info(&account_id) + .unwrap(); + let my_participant_info = contract_state + .participants + .find_participant_info(ctx.my_account_id()) + .unwrap(); let participant: Participant = my_participant_info.id.clone().into(); http_client::join(ctx.http_client(), info.url.clone(), &participant) // TODO: should this be account_id? .await @@ -611,7 +617,10 @@ impl ConsensusProtocol for JoiningState { } } ProtocolState::Resharing(contract_state) => { - if contract_state.new_participants.contains_account_id(&ctx.my_account_id()) { + if contract_state + .new_participants + .contains_account_id(&ctx.my_account_id()) + { tracing::info!("joining as a new participant"); start_resharing(None, ctx, contract_state) } else { diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index 0b3e8c0a0..e956301a4 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -93,7 +93,11 @@ impl CryptographicProtocol for GeneratingState { tracing::debug!("sending a message to many participants"); let mut messages = self.messages.write().await; for (p, info) in self.participants.clone() { - if p == self.participants.find_participant(ctx.my_near_acc_id()).unwrap() { + if p == self + .participants + .find_participant(ctx.my_near_acc_id()) + .unwrap() + { // Skip yourself, cait-sith never sends messages to oneself continue; } @@ -267,7 +271,10 @@ impl CryptographicProtocol for RunningState { mut self, ctx: C, ) -> Result { - let me = self.participants.find_participant(ctx.my_near_acc_id()).unwrap(); + let me = self + .participants + .find_participant(ctx.my_near_acc_id()) + .unwrap(); let mut messages = self.messages.write().await; // Try sending any leftover messages donated to RunningState. if let Err(err) = messages diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index b23aa2f02..237daf42d 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -414,7 +414,11 @@ where ) -> Result { let msg = serde_json::to_vec(&msg)?; let sig = sign_sk.sign(&msg); - let msg = SignedMessage { msg, sig, from: from.clone() }; + let msg = SignedMessage { + msg, + sig, + from: from.clone(), + }; let msg = serde_json::to_vec(&msg)?; let ciphered = cipher_pk .encrypt(&msg, SignedMessage::::ASSOCIATED_DATA) diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index f7e07ffae..5aed63dd8 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -93,7 +93,7 @@ impl ConsensusCtx for &Ctx { } } -impl CryptographicCtx for &Ctx { +impl CryptographicCtx for &mut Ctx { fn my_near_acc_id(&self) -> &AccountId { &self.account_id } diff --git a/node/src/protocol/presignature.rs b/node/src/protocol/presignature.rs index 30ae8624b..8ada14998 100644 --- a/node/src/protocol/presignature.rs +++ b/node/src/protocol/presignature.rs @@ -5,11 +5,11 @@ use crate::util::AffinePointExt; use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; use cait_sith::{KeygenOutput, PresignArguments, PresignOutput}; use k256::Secp256k1; +use near_primitives::types::AccountId; use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; use std::str::FromStr; use std::sync::Arc; -use near_primitives::types::AccountId; /// Unique number used to identify a specific ongoing presignature generation protocol. /// Without `PresignatureId` it would be unclear where to route incoming cait-sith presignature diff --git a/node/src/protocol/triple.rs b/node/src/protocol/triple.rs index bbe9d2600..1fa15b9b2 100644 --- a/node/src/protocol/triple.rs +++ b/node/src/protocol/triple.rs @@ -199,7 +199,7 @@ impl TripleManager { }; // TODO: in this context we don't have access to the account id of the node. - let my_account_id: AccountId = AccountId::from_str("acc.near").unwrap(); + let my_account_id: AccountId = AccountId::from_str("acc.near").unwrap(); match action { Action::Wait => { From 3a15ebb86cb784f1f723116b852461ad408fa1b9 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Thu, 11 Jan 2024 21:31:28 +0200 Subject: [PATCH 08/21] can build roject and run tests (failing) --- contract/src/primitives.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/contract/src/primitives.rs b/contract/src/primitives.rs index 9d2f4f765..d19caad1d 100644 --- a/contract/src/primitives.rs +++ b/contract/src/primitives.rs @@ -100,6 +100,10 @@ impl Participants { pub fn keys(&self) -> impl Iterator { self.participants.keys() } + + pub fn len(&self) -> usize { + self.participants.len() + } } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)] From 8ee6d637481484708b10873cfc29de16413f3769 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Fri, 12 Jan 2024 00:10:05 +0200 Subject: [PATCH 09/21] warnings fixed --- node/src/cli.rs | 6 ------ node/src/protocol/contract/primitives.rs | 4 ++-- node/src/protocol/message.rs | 12 +++++------- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/node/src/cli.rs b/node/src/cli.rs index 79b9f0db6..1fe8da51e 100644 --- a/node/src/cli.rs +++ b/node/src/cli.rs @@ -1,6 +1,5 @@ use crate::protocol::{MpcSignProtocol, SignQueue}; use crate::{indexer, storage, web}; -use cait_sith::protocol::Participant; use clap::Parser; use local_ip_address::local_ip; use near_crypto::{InMemorySigner, SecretKey}; @@ -54,11 +53,6 @@ pub enum Cli { }, } -fn parse_participant(arg: &str) -> Result { - let participant_id: u32 = arg.parse()?; - Ok(participant_id.into()) -} - impl Cli { pub fn into_str_args(self) -> Vec { match self { diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index b0cd2b52b..6e8f3a161 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -229,9 +229,9 @@ impl From for Votes { votes: contract_votes .votes .into_iter() - .map(|(accountId, participants)| { + .map(|(account_id, participants)| { ( - AccountId::from_str(&accountId.to_string()).unwrap(), // TODO: fix unwrap + AccountId::from_str(&account_id.to_string()).unwrap(), // TODO: fix unwrap participants .into_iter() .map(|acc_id: near_sdk::AccountId| { diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index 237daf42d..72b0dbefc 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -172,14 +172,13 @@ pub trait MessageHandler { impl MessageHandler for GeneratingState { async fn handle( &mut self, - ctx: C, + _ctx: C, queue: &mut MpcMessageQueue, ) -> Result<(), MessageHandleError> { let mut protocol = self.protocol.write().await; while let Some(msg) = queue.generating.pop_front() { tracing::debug!("handling new generating message"); - let participant = self - .participants + self.participants .find_participant(&msg.from) .map(|participant| { tracing::debug!(from = %msg.from, "handling message from"); @@ -206,8 +205,7 @@ impl MessageHandler for ResharingState { let q = queue.resharing_bins.entry(self.old_epoch).or_default(); let mut protocol = self.protocol.write().await; while let Some(msg) = q.pop_front() { - let participant = self - .old_participants + self.old_participants .find_participant(&msg.from) .map(|participant| { tracing::debug!(from = %msg.from, "handling resharing message from"); @@ -238,7 +236,7 @@ impl MessageHandler for RunningState { .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; while let Some(message) = queue.pop_front() { - let participant = self + self .participants .find_participant(&message.from) .map(|participant| { @@ -271,7 +269,7 @@ impl MessageHandler for RunningState { let mut protocol = protocol .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; - let participant = self + self .participants .find_participant(&message.from) .map(|participant| { From 5eca1b40c665c2eaf0cdf74a60aa1bfe368cfd4c Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 14:08:21 +0200 Subject: [PATCH 10/21] async me --- node/src/http_client.rs | 11 +-- node/src/protocol/consensus.rs | 96 +++++++++++--------------- node/src/protocol/cryptography.rs | 47 ++++++------- node/src/protocol/message.rs | 109 +++++++----------------------- node/src/protocol/mod.rs | 79 ++++++++++++++-------- node/src/protocol/presignature.rs | 7 +- node/src/protocol/signature.rs | 6 +- node/src/protocol/triple.rs | 9 +-- 8 files changed, 141 insertions(+), 223 deletions(-) diff --git a/node/src/http_client.rs b/node/src/http_client.rs index b10684a83..c0a3792be 100644 --- a/node/src/http_client.rs +++ b/node/src/http_client.rs @@ -3,7 +3,6 @@ use crate::protocol::message::SignedMessage; use crate::protocol::MpcMessage; use cait_sith::protocol::Participant; use mpc_keys::hpke; -use near_primitives::types::AccountId; use reqwest::{Client, IntoUrl}; use std::collections::VecDeque; use std::str::Utf8Error; @@ -27,7 +26,7 @@ pub enum SendError { } async fn send_encrypted( - from: &AccountId, + from: Participant, cipher_pk: &hpke::PublicKey, sign_sk: &near_crypto::SecretKey, client: &Client, @@ -128,7 +127,7 @@ impl MessageQueue { pub async fn send_encrypted( &mut self, - from: &AccountId, + from: Participant, sign_sk: &near_crypto::SecretKey, client: &Client, ) -> Result<(), SendError> { @@ -152,10 +151,6 @@ impl MessageQueue { #[cfg(test)] mod tests { - use std::str::FromStr; - - use near_lake_primitives::AccountId; - use crate::protocol::message::GeneratingMessage; use crate::protocol::MpcMessage; @@ -164,7 +159,7 @@ mod tests { let associated_data = b""; let (sk, pk) = mpc_keys::hpke::generate(); let starting_message = MpcMessage::Generating(GeneratingMessage { - from: AccountId::from_str("alice.near").unwrap(), + from: cait_sith::protocol::Participant::from(0), data: vec![], }); diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index 34b6fae49..51cc9e5df 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -1,9 +1,9 @@ -use super::contract::ResharingContractState; +use super::contract::{ProtocolState, ResharingContractState}; use super::state::{ JoiningState, NodeState, PersistentNodeData, RunningState, StartedState, WaitingForConsensusState, }; -use super::{ProtocolState, SignQueue}; +use super::SignQueue; use crate::protocol::presignature::PresignatureManager; use crate::protocol::signature::SignatureManager; use crate::protocol::state::{GeneratingState, ResharingState}; @@ -24,7 +24,9 @@ use std::sync::Arc; use tokio::sync::RwLock; use url::Url; +#[async_trait::async_trait] pub trait ConsensusCtx { + async fn me(&self) -> Participant; fn my_account_id(&self) -> &AccountId; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; @@ -74,10 +76,6 @@ impl ConsensusProtocol for StartedState { ctx: C, contract_state: ProtocolState, ) -> Result { - let me = contract_state - .participants() - .find_participant(ctx.my_account_id()) - .unwrap(); // TODO: remove unwrap match self.0 { Some(PersistentNodeData { epoch, @@ -103,16 +101,12 @@ impl ConsensusProtocol for StartedState { } Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { - if contract_state - .participants - .contains_account_id(&ctx.my_account_id()) - { + if contract_state.participants.contains_key(&ctx.me().await) { tracing::info!( "contract state is running and we are already a participant" ); let participants_vec: Vec = contract_state.participants.keys().cloned().collect(); - Ok(NodeState::Running(RunningState { epoch, participants: contract_state.participants, @@ -122,14 +116,14 @@ impl ConsensusProtocol for StartedState { sign_queue: ctx.sign_queue(), triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - me, + ctx.me().await, contract_state.threshold, epoch, ))), presignature_manager: Arc::new(RwLock::new( PresignatureManager::new( participants_vec.clone(), - me, + ctx.me().await, contract_state.threshold, epoch, ), @@ -137,7 +131,7 @@ impl ConsensusProtocol for StartedState { signature_manager: Arc::new(RwLock::new( SignatureManager::new( participants_vec, - me, + ctx.me().await, contract_state.public_key, epoch, ), @@ -174,22 +168,19 @@ impl ConsensusProtocol for StartedState { tracing::info!( "contract state is resharing with us, joining as a participant" ); - start_resharing(Some(private_share), ctx, contract_state) + start_resharing(Some(private_share), ctx, contract_state).await } } } }, None => match contract_state { ProtocolState::Initializing(contract_state) => { - if contract_state - .participants - .contains_account_id(&ctx.my_account_id()) - { + if contract_state.participants.contains_key(&ctx.me().await) { tracing::info!("starting key generation as a part of the participant set"); let participants = contract_state.participants; let protocol = cait_sith::keygen::( &participants.keys().cloned().collect::>(), - me, + ctx.me().await, contract_state.threshold, )?; Ok(NodeState::Generating(GeneratingState { @@ -273,10 +264,6 @@ impl ConsensusProtocol for WaitingForConsensusState { ctx: C, contract_state: ProtocolState, ) -> Result { - let me = self - .participants - .find_participant(ctx.my_account_id()) - .unwrap(); // TODO: remove unwrap match contract_state { ProtocolState::Initializing(contract_state) => { tracing::debug!("waiting for consensus, contract state has not been finalized yet"); @@ -334,19 +321,19 @@ impl ConsensusProtocol for WaitingForConsensusState { sign_queue: ctx.sign_queue(), triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - me, + ctx.me().await, self.threshold, self.epoch, ))), presignature_manager: Arc::new(RwLock::new(PresignatureManager::new( participants_vec.clone(), - me, + ctx.me().await, self.threshold, self.epoch, ))), signature_manager: Arc::new(RwLock::new(SignatureManager::new( participants_vec, - me, + ctx.me().await, self.public_key, self.epoch, ))), @@ -367,7 +354,7 @@ impl ConsensusProtocol for WaitingForConsensusState { if contract_state.public_key != self.public_key { return Err(ConsensusError::MismatchedPublicKey); } - start_resharing(Some(self.private_share), ctx, contract_state) + start_resharing(Some(self.private_share), ctx, contract_state).await } Ordering::Greater => { tracing::warn!( @@ -386,7 +373,11 @@ impl ConsensusProtocol for WaitingForConsensusState { "waiting for resharing consensus, contract state has not been finalized yet" ); let has_voted = contract_state.finished_votes.contains(ctx.my_account_id()); - if !has_voted && contract_state.old_participants.contains_key(&me) { + if !has_voted + && contract_state + .old_participants + .contains_key(&ctx.me().await) + { tracing::info!( epoch = self.epoch, "we haven't voted yet, voting for resharing to complete" @@ -415,10 +406,6 @@ impl ConsensusProtocol for RunningState { ctx: C, contract_state: ProtocolState, ) -> Result { - let me = contract_state - .participants() - .find_participant(ctx.my_account_id()) - .unwrap(); // TODO: remove unwrap match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => match contract_state.epoch.cmp(&self.epoch) { @@ -464,15 +451,19 @@ impl ConsensusProtocol for RunningState { Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { tracing::info!("contract is resharing"); - if !contract_state.old_participants.contains_key(&me) - || !contract_state.new_participants.contains_key(&me) + if !contract_state + .old_participants + .contains_key(&ctx.me().await) + || !contract_state + .new_participants + .contains_key(&ctx.me().await) { return Err(ConsensusError::HasBeenKicked); } if contract_state.public_key != self.public_key { return Err(ConsensusError::MismatchedPublicKey); } - start_resharing(Some(self.private_share), ctx, contract_state) + start_resharing(Some(self.private_share), ctx, contract_state).await } } } @@ -575,20 +566,11 @@ impl ConsensusProtocol for JoiningState { votes_to_go = contract_state.threshold - voted.len(), "trying to get participants to vote for us" ); - for account_id in contract_state.participants.account_ids() { - if voted.contains(&account_id) { + for (_, info) in contract_state.participants { + if voted.contains(&info.account_id) { continue; } - let info = contract_state - .participants - .find_participant_info(&account_id) - .unwrap(); - let my_participant_info = contract_state - .participants - .find_participant_info(ctx.my_account_id()) - .unwrap(); - let participant: Participant = my_participant_info.id.clone().into(); - http_client::join(ctx.http_client(), info.url.clone(), &participant) // TODO: should this be account_id? + http_client::join(ctx.http_client(), info.url, &ctx.me().await) .await .unwrap() } @@ -596,6 +578,7 @@ impl ConsensusProtocol for JoiningState { } else { tracing::info!("sending a transaction to join the participant set"); let args = serde_json::json!({ + "participant_id": ctx.me().await, "url": ctx.my_address(), "cipher_pk": ctx.cipher_pk().to_bytes(), "sign_pk": ctx.sign_pk(), @@ -619,10 +602,10 @@ impl ConsensusProtocol for JoiningState { ProtocolState::Resharing(contract_state) => { if contract_state .new_participants - .contains_account_id(&ctx.my_account_id()) + .contains_key(&ctx.me().await) { tracing::info!("joining as a new participant"); - start_resharing(None, ctx, contract_state) + start_resharing(None, ctx, contract_state).await } else { tracing::debug!("network is resharing without us, waiting for them to finish"); Ok(NodeState::Joining(self)) @@ -654,7 +637,7 @@ impl ConsensusProtocol for NodeState { } } -fn start_resharing( +async fn start_resharing( private_share: Option, ctx: C, contract_state: ResharingContractState, @@ -672,19 +655,16 @@ fn start_resharing( .cloned() .collect::>(), contract_state.threshold, - contract_state - .old_participants - .find_participant(ctx.my_account_id()) - .unwrap(), // TODO: remove unwrap + ctx.me().await, private_share, contract_state.public_key, )?; Ok(NodeState::Resharing(ResharingState { old_epoch: contract_state.old_epoch, - old_participants: contract_state.old_participants.into(), - new_participants: contract_state.new_participants.into(), + old_participants: contract_state.old_participants, + new_participants: contract_state.new_participants, threshold: contract_state.threshold, - public_key: contract_state.public_key.into(), + public_key: contract_state.public_key, protocol: Arc::new(RwLock::new(protocol)), messages: Default::default(), })) diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index e956301a4..96615fc9e 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -13,8 +13,9 @@ use mpc_keys::hpke; use near_crypto::InMemorySigner; use near_primitives::types::AccountId; +#[async_trait::async_trait] pub trait CryptographicCtx { - fn my_near_acc_id(&self) -> &AccountId; + async fn me(&self) -> Participant; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; fn signer(&self) -> &InMemorySigner; @@ -81,10 +82,10 @@ impl CryptographicProtocol for GeneratingState { .messages .write() .await - .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { - tracing::warn!(?err, "generating: failed to send encrypted message"); + tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message"); } return Ok(NodeState::Generating(self)); @@ -93,18 +94,14 @@ impl CryptographicProtocol for GeneratingState { tracing::debug!("sending a message to many participants"); let mut messages = self.messages.write().await; for (p, info) in self.participants.clone() { - if p == self - .participants - .find_participant(ctx.my_near_acc_id()) - .unwrap() - { + if p == ctx.me().await { // Skip yourself, cait-sith never sends messages to oneself continue; } messages.push( info.clone(), MpcMessage::Generating(GeneratingMessage { - from: ctx.my_near_acc_id().clone(), + from: ctx.me().await, data: m.clone(), }), ); @@ -116,7 +113,7 @@ impl CryptographicProtocol for GeneratingState { self.messages.write().await.push( info.clone(), MpcMessage::Generating(GeneratingMessage { - from: ctx.my_near_acc_id().clone(), + from: ctx.me().await, data: m.clone(), }), ); @@ -138,7 +135,7 @@ impl CryptographicProtocol for GeneratingState { .messages .write() .await - .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message"); @@ -167,7 +164,7 @@ impl CryptographicProtocol for WaitingForConsensusState { .messages .write() .await - .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "waiting: failed to send encrypted message"); @@ -196,10 +193,10 @@ impl CryptographicProtocol for ResharingState { .messages .write() .await - .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { - tracing::warn!(?err, "resharing(wait): failed to send encrypted message"); + tracing::warn!(?err, new = ?self.new_participants.clone(), old = ?self.old_participants.clone(), "resharing(wait): failed to send encrypted message"); } return Ok(NodeState::Resharing(self)); @@ -207,8 +204,8 @@ impl CryptographicProtocol for ResharingState { Action::SendMany(m) => { tracing::debug!("sending a message to all participants"); let mut messages = self.messages.write().await; - for (_, info) in self.new_participants.clone() { - if &info.account_id == ctx.my_near_acc_id() { + for (p, info) in self.new_participants.clone() { + if p == ctx.me().await { // Skip yourself, cait-sith never sends messages to oneself continue; } @@ -217,7 +214,7 @@ impl CryptographicProtocol for ResharingState { info.clone(), MpcMessage::Resharing(ResharingMessage { epoch: self.old_epoch, - from: ctx.my_near_acc_id().clone(), + from: ctx.me().await, data: m.clone(), }), ) @@ -230,7 +227,7 @@ impl CryptographicProtocol for ResharingState { info.clone(), MpcMessage::Resharing(ResharingMessage { epoch: self.old_epoch, - from: ctx.my_near_acc_id().clone(), + from: ctx.me().await, data: m.clone(), }), ), @@ -245,7 +242,7 @@ impl CryptographicProtocol for ResharingState { .messages .write() .await - .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(return): failed to send encrypted message"); @@ -271,14 +268,10 @@ impl CryptographicProtocol for RunningState { mut self, ctx: C, ) -> Result { - let me = self - .participants - .find_participant(ctx.my_near_acc_id()) - .unwrap(); let mut messages = self.messages.write().await; // Try sending any leftover messages donated to RunningState. if let Err(err) = messages - .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "running(pre): failed to send encrypted message"); @@ -317,8 +310,8 @@ impl CryptographicProtocol for RunningState { let mut sign_queue = self.sign_queue.write().await; let mut signature_manager = self.signature_manager.write().await; - sign_queue.organize(&self, me); - let my_requests = sign_queue.my_requests(me); + sign_queue.organize(&self, ctx.me().await); + let my_requests = sign_queue.my_requests(ctx.me().await); while presignature_manager.my_len() > 0 { let Some((receipt_id, _)) = my_requests.iter().next() else { break; @@ -348,7 +341,7 @@ impl CryptographicProtocol for RunningState { .await?; drop(signature_manager); if let Err(err) = messages - .send_encrypted(ctx.my_near_acc_id(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "running(post): failed to send encrypted message"); diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index 72b0dbefc..b3d387185 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -10,26 +10,26 @@ use k256::Scalar; use mpc_keys::hpke::{self, Ciphered}; use near_crypto::Signature; use near_primitives::hash::CryptoHash; -use near_primitives::types::AccountId; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use tokio::sync::RwLock; +#[async_trait::async_trait] pub trait MessageCtx { - fn my_near_acc_id(&self) -> &AccountId; + async fn me(&self) -> Participant; } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct GeneratingMessage { - pub from: AccountId, + pub from: Participant, pub data: MessageData, } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct ResharingMessage { pub epoch: u64, - pub from: AccountId, + pub from: Participant, pub data: MessageData, } @@ -37,7 +37,7 @@ pub struct ResharingMessage { pub struct TripleMessage { pub id: u64, pub epoch: u64, - pub from: AccountId, + pub from: Participant, pub data: MessageData, } @@ -47,20 +47,20 @@ pub struct PresignatureMessage { pub triple0: TripleId, pub triple1: TripleId, pub epoch: u64, - pub from: AccountId, + pub from: Participant, pub data: MessageData, } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] pub struct SignatureMessage { pub receipt_id: CryptoHash, - pub proposer: Participant, // TODO: should it be node AccountId? + pub proposer: Participant, pub presignature_id: PresignatureId, pub msg_hash: [u8; 32], pub epsilon: Scalar, pub delta: Scalar, pub epoch: u64, - pub from: AccountId, + pub from: Participant, pub data: MessageData, } @@ -178,18 +178,7 @@ impl MessageHandler for GeneratingState { let mut protocol = self.protocol.write().await; while let Some(msg) = queue.generating.pop_front() { tracing::debug!("handling new generating message"); - self.participants - .find_participant(&msg.from) - .map(|participant| { - tracing::debug!(from = %msg.from, "handling message from"); - protocol.message(participant, msg.data); - }) - .unwrap_or_else(|| { - tracing::warn!( - participant = %msg.from, - "received message from unknown participant" - ); - }); + protocol.message(msg.from, msg.data); } Ok(()) } @@ -205,18 +194,7 @@ impl MessageHandler for ResharingState { let q = queue.resharing_bins.entry(self.old_epoch).or_default(); let mut protocol = self.protocol.write().await; while let Some(msg) = q.pop_front() { - self.old_participants - .find_participant(&msg.from) - .map(|participant| { - tracing::debug!(from = %msg.from, "handling resharing message from"); - protocol.message(participant, msg.data); - }) - .unwrap_or_else(|| { - tracing::warn!( - participant = %msg.from, - "received message from unknown participant in resharing state" - ); - }); + protocol.message(msg.from, msg.data); } Ok(()) } @@ -236,19 +214,7 @@ impl MessageHandler for RunningState { .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; while let Some(message) = queue.pop_front() { - self - .participants - .find_participant(&message.from) - .map(|participant| { - tracing::debug!(from = %message.from, "running state, triple message, handling message from"); - protocol.message(participant, message.data); - }) - .unwrap_or_else(|| { - tracing::warn!( - participant = %message.from, - "received message from unknown participant in running state" - ); - }); + protocol.message(message.from, message.data); } } } @@ -269,19 +235,7 @@ impl MessageHandler for RunningState { let mut protocol = protocol .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; - self - .participants - .find_participant(&message.from) - .map(|participant| { - tracing::debug!(from = %message.from, "running state, presignature message, handling message from"); - protocol.message(participant, message.data); - }) - .unwrap_or_else(|| { - tracing::warn!( - participant = %message.from, - "received message from unknown participant in running state" - ); - }); + protocol.message(message.from, message.data); } Err(presignature::GenerationError::AlreadyGenerated) => { tracing::info!(id, "presignature already generated, nothing left to do") @@ -335,17 +289,7 @@ impl MessageHandler for RunningState { let mut protocol = protocol .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; - let participant = match self.participants.find_participant(&message.from) { - Some(participant) => participant, - None => { - tracing::warn!( - participant = %message.from, - "received message from unknown participant in running state" - ); - continue; - } - }; - protocol.message(participant, message.data) + protocol.message(message.from, message.data); } None => { // Store the message until we are ready to process it @@ -393,7 +337,7 @@ pub struct SignedMessage { /// The signature used to verify the authenticity of the encrypted message. pub sig: Signature, /// From which particpant the message was sent. - pub from: AccountId, + pub from: Participant, } impl SignedMessage { @@ -406,17 +350,13 @@ where { pub fn encrypt( msg: T, - from: &AccountId, + from: Participant, sign_sk: &near_crypto::SecretKey, cipher_pk: &hpke::PublicKey, ) -> Result { let msg = serde_json::to_vec(&msg)?; let sig = sign_sk.sign(&msg); - let msg = SignedMessage { - msg, - sig, - from: from.clone(), - }; + let msg = SignedMessage { msg, sig, from }; let msg = serde_json::to_vec(&msg)?; let ciphered = cipher_pk .encrypt(&msg, SignedMessage::::ASSOCIATED_DATA) @@ -438,15 +378,14 @@ where .decrypt(&encrypted, SignedMessage::::ASSOCIATED_DATA) .map_err(|err| CryptographicError::Encryption(err.to_string()))?; let SignedMessage::> { msg, sig, from } = serde_json::from_slice(&message)?; - let sign_pk = match protocol_state.read().await.find_participant_info(&from) { - Some(info) => info.sign_pk.clone(), - None => { - return Err(CryptographicError::Encryption( - "unknown participant".to_string(), - )); - } - }; - if !sig.verify(&msg, &sign_pk) { + if !sig.verify( + &msg, + &protocol_state + .read() + .await + .fetch_participant(&from)? + .sign_pk, + ) { return Err(CryptographicError::Encryption( "invalid signature while verifying authenticity of encrypted ".to_string(), )); diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index 5aed63dd8..c7ea3a617 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -9,6 +9,7 @@ pub mod message; pub mod state; pub use consensus::ConsensusError; +pub use contract::primitives::ParticipantInfo; pub use contract::ProtocolState; pub use cryptography::CryptographicError; pub use message::MpcMessage; @@ -24,6 +25,7 @@ use crate::protocol::cryptography::CryptographicProtocol; use crate::protocol::message::{MessageHandler, MpcMessageQueue}; use crate::rpc_client::{self}; use crate::storage::SecretNodeStorageBox; +use cait_sith::protocol::Participant; use near_crypto::InMemorySigner; use near_primitives::types::AccountId; use reqwest::IntoUrl; @@ -47,89 +49,96 @@ struct Ctx { secret_storage: SecretNodeStorageBox, } -impl ConsensusCtx for &Ctx { +#[async_trait::async_trait] +impl ConsensusCtx for &MpcSignProtocol { + async fn me(&self) -> Participant { + get_my_participant(self).await + } + fn my_account_id(&self) -> &AccountId { - &self.account_id + &self.ctx.account_id } fn http_client(&self) -> &reqwest::Client { - &self.http_client + &self.ctx.http_client } fn rpc_client(&self) -> &near_fetch::Client { - &self.rpc_client + &self.ctx.rpc_client } fn signer(&self) -> &InMemorySigner { - &self.signer + &self.ctx.signer } fn mpc_contract_id(&self) -> &AccountId { - &self.mpc_contract_id + &self.ctx.mpc_contract_id } fn my_address(&self) -> &Url { - &self.my_address + &self.ctx.my_address } fn sign_queue(&self) -> Arc> { - self.sign_queue.clone() + self.ctx.sign_queue.clone() } fn cipher_pk(&self) -> &hpke::PublicKey { - &self.cipher_pk + &self.ctx.cipher_pk } fn sign_pk(&self) -> near_crypto::PublicKey { - self.sign_sk.public_key() + self.ctx.sign_sk.public_key() } fn sign_sk(&self) -> &near_crypto::SecretKey { - &self.sign_sk + &self.ctx.sign_sk } fn secret_storage(&self) -> &SecretNodeStorageBox { - &self.secret_storage + &self.ctx.secret_storage } } -impl CryptographicCtx for &mut Ctx { - fn my_near_acc_id(&self) -> &AccountId { - &self.account_id +#[async_trait::async_trait] +impl CryptographicCtx for &mut MpcSignProtocol { + async fn me(&self) -> Participant { + get_my_participant(self).await } fn http_client(&self) -> &reqwest::Client { - &self.http_client + &self.ctx.http_client } fn rpc_client(&self) -> &near_fetch::Client { - &self.rpc_client + &self.ctx.rpc_client } fn signer(&self) -> &InMemorySigner { - &self.signer + &self.ctx.signer } fn mpc_contract_id(&self) -> &AccountId { - &self.mpc_contract_id + &self.ctx.mpc_contract_id } fn cipher_pk(&self) -> &hpke::PublicKey { - &self.cipher_pk + &self.ctx.cipher_pk } fn sign_sk(&self) -> &near_crypto::SecretKey { - &self.sign_sk + &self.ctx.sign_sk } fn secret_storage(&mut self) -> &mut SecretNodeStorageBox { - &mut self.secret_storage + &mut self.ctx.secret_storage } } -impl MessageCtx for &Ctx { - fn my_near_acc_id(&self) -> &AccountId { - &self.account_id +#[async_trait::async_trait] +impl MessageCtx for &MpcSignProtocol { + async fn me(&self) -> Participant { + get_my_participant(self).await } } @@ -174,7 +183,7 @@ impl MpcSignProtocol { } pub async fn run(mut self) -> anyhow::Result<()> { - let _span = tracing::info_span!("running", me = self.ctx.account_id.to_string()); + let _span = tracing::info_span!("running", my_account_id = self.ctx.account_id.to_string()); let mut queue = MpcMessageQueue::default(); loop { tracing::debug!("trying to advance mpc recovery protocol"); @@ -214,21 +223,21 @@ impl MpcSignProtocol { let guard = self.state.read().await; guard.clone() }; - let state = match state.progress(&mut self.ctx).await { + let state = match state.progress(&mut self).await { Ok(state) => state, Err(err) => { tracing::info!("protocol unable to progress: {err:?}"); continue; } }; - let mut state = match state.advance(&self.ctx, contract_state).await { + let mut state = match state.advance(&self, contract_state).await { Ok(state) => state, Err(err) => { tracing::info!("protocol unable to advance: {err:?}"); continue; } }; - if let Err(err) = state.handle(&self.ctx, &mut queue).await { + if let Err(err) = state.handle(&self, &mut queue).await { tracing::info!("protocol unable to handle messages: {err:?}"); continue; } @@ -241,3 +250,15 @@ impl MpcSignProtocol { } } } + +async fn get_my_participant(protocol: &MpcSignProtocol) -> Participant { + let my_near_acc_id = protocol.ctx.account_id.clone(); + let state = protocol.state.read().await; + let participant_info = state + .find_participant_info(&my_near_acc_id) + .unwrap_or_else(|| { + tracing::error!("could not find participant info for {my_near_acc_id}"); + panic!("could not find participant info for {my_near_acc_id}"); // TOOD: probably we should not panic here + }); + participant_info.id.into() +} diff --git a/node/src/protocol/presignature.rs b/node/src/protocol/presignature.rs index 8ada14998..d20ed74d3 100644 --- a/node/src/protocol/presignature.rs +++ b/node/src/protocol/presignature.rs @@ -5,10 +5,8 @@ use crate::util::AffinePointExt; use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; use cait_sith::{KeygenOutput, PresignArguments, PresignOutput}; use k256::Secp256k1; -use near_primitives::types::AccountId; use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; -use std::str::FromStr; use std::sync::Arc; /// Unique number used to identify a specific ongoing presignature generation protocol. @@ -232,7 +230,6 @@ impl PresignatureManager { break false; } }; - let my_account_id: AccountId = AccountId::from_str("acc.near").unwrap(); // TODO: account id is not available in this context match action { Action::Wait => { tracing::debug!("waiting"); @@ -248,7 +245,7 @@ impl PresignatureManager { triple0: generator.triple0, triple1: generator.triple1, epoch: self.epoch, - from: my_account_id.clone(), + from: self.me, data: data.clone(), }, )) @@ -261,7 +258,7 @@ impl PresignatureManager { triple0: generator.triple0, triple1: generator.triple1, epoch: self.epoch, - from: my_account_id, + from: self.me, data: data.clone(), }, )), diff --git a/node/src/protocol/signature.rs b/node/src/protocol/signature.rs index 6d0435bc4..e44338cb3 100644 --- a/node/src/protocol/signature.rs +++ b/node/src/protocol/signature.rs @@ -255,8 +255,6 @@ impl SignatureManager { } }; - let my_near_account_id: AccountId = "acc.near".parse().unwrap(); // TODO: account id is not available in this context - match action { Action::Wait => { tracing::debug!("waiting"); @@ -275,7 +273,7 @@ impl SignatureManager { epsilon: generator.epsilon, delta: generator.delta, epoch: self.epoch, - from: my_near_account_id.clone(), + from: self.me, data: data.clone(), }, )) @@ -291,7 +289,7 @@ impl SignatureManager { epsilon: generator.epsilon, delta: generator.delta, epoch: self.epoch, - from: my_near_account_id, + from: self.me, data: data.clone(), }, )), diff --git a/node/src/protocol/triple.rs b/node/src/protocol/triple.rs index 1fa15b9b2..7357f28ab 100644 --- a/node/src/protocol/triple.rs +++ b/node/src/protocol/triple.rs @@ -5,10 +5,8 @@ use crate::util::AffinePointExt; use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError}; use cait_sith::triples::{TriplePub, TripleShare}; use k256::Secp256k1; -use near_lake_primitives::AccountId; use std::collections::hash_map::Entry; use std::collections::{HashMap, VecDeque}; -use std::str::FromStr; use std::sync::Arc; /// Unique number used to identify a specific ongoing triple generation protocol. @@ -198,9 +196,6 @@ impl TripleManager { } }; - // TODO: in this context we don't have access to the account id of the node. - let my_account_id: AccountId = AccountId::from_str("acc.near").unwrap(); - match action { Action::Wait => { tracing::debug!("waiting"); @@ -214,7 +209,7 @@ impl TripleManager { TripleMessage { id: *id, epoch: self.epoch, - from: my_account_id.clone(), + from: self.me, data: data.clone(), }, )) @@ -225,7 +220,7 @@ impl TripleManager { TripleMessage { id: *id, epoch: self.epoch, - from: my_account_id, + from: self.me, data: data.clone(), }, )), From 654527f4337d9a2f21ad2c4d2140cd86e060cf1e Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 14:27:39 +0200 Subject: [PATCH 11/21] clippy --- contract/src/primitives.rs | 44 ++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/contract/src/primitives.rs b/contract/src/primitives.rs index d19caad1d..b1706b9c2 100644 --- a/contract/src/primitives.rs +++ b/contract/src/primitives.rs @@ -67,6 +67,12 @@ pub struct Participants { pub participants: BTreeMap, } +impl Default for Participants { + fn default() -> Self { + Self::new() + } +} + impl Participants { pub fn new() -> Self { Participants { @@ -93,10 +99,6 @@ impl Participants { self.participants.iter() } - pub fn into_iter(self) -> impl Iterator { - self.participants.into_iter() - } - pub fn keys(&self) -> impl Iterator { self.participants.keys() } @@ -104,6 +106,10 @@ impl Participants { pub fn len(&self) -> usize { self.participants.len() } + + pub fn is_empty(&self) -> bool { + self.participants.is_empty() + } } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)] @@ -111,6 +117,12 @@ pub struct Candidates { pub candidates: BTreeMap, } +impl Default for Candidates { + fn default() -> Self { + Self::new() + } +} + impl Candidates { pub fn new() -> Self { Candidates { @@ -136,10 +148,6 @@ impl Candidates { pub fn iter(&self) -> impl Iterator { self.candidates.iter() } - - pub fn into_iter(self) -> impl Iterator { - self.candidates.into_iter() - } } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] @@ -147,6 +155,12 @@ pub struct Votes { pub votes: BTreeMap>, } +impl Default for Votes { + fn default() -> Self { + Self::new() + } +} + impl Votes { pub fn new() -> Self { Votes { @@ -157,10 +171,6 @@ impl Votes { pub fn entry(&mut self, account_id: AccountId) -> &mut HashSet { self.votes.entry(account_id).or_default() } - - pub fn into_iter(self) -> impl Iterator)> { - self.votes.into_iter() - } } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] @@ -168,6 +178,12 @@ pub struct PkVotes { pub votes: BTreeMap>, } +impl Default for PkVotes { + fn default() -> Self { + Self::new() + } +} + impl PkVotes { pub fn new() -> Self { PkVotes { @@ -178,8 +194,4 @@ impl PkVotes { pub fn entry(&mut self, public_key: PublicKey) -> &mut HashSet { self.votes.entry(public_key).or_default() } - - pub fn into_iter(self) -> impl Iterator)> { - self.votes.into_iter() - } } From 99a9203d725c1835f031c1b169f8257203a6d195 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 15:09:08 +0200 Subject: [PATCH 12/21] init fn inteface fixed --- contract/src/lib.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/contract/src/lib.rs b/contract/src/lib.rs index 8175cb46b..0dc44c656 100644 --- a/contract/src/lib.rs +++ b/contract/src/lib.rs @@ -3,8 +3,8 @@ pub mod primitives; use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; use near_sdk::serde::{Deserialize, Serialize}; use near_sdk::{env, near_bindgen, AccountId, PanicOnDefault, PublicKey}; -use primitives::{CandidateInfo, Candidates, Participants, PkVotes, Votes}; -use std::collections::HashSet; +use primitives::{CandidateInfo, Candidates, ParticipantInfo, Participants, PkVotes, Votes}; +use std::collections::{BTreeMap, HashSet}; #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] pub struct InitializingContractState { @@ -52,10 +52,10 @@ pub struct MpcContract { #[near_bindgen] impl MpcContract { #[init(ignore_state)] - pub fn init(threshold: usize, participants: Participants) -> Self { + pub fn init(threshold: usize, participants: BTreeMap) -> Self { MpcContract { protocol_state: ProtocolContractState::Initializing(InitializingContractState { - participants, + participants: Participants { participants }, threshold, pk_votes: PkVotes::new(), }), From 962e75e0a1d6a1ffb235cd27f7cb114f6b1ba95a Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 16:30:51 +0200 Subject: [PATCH 13/21] pk conversion fixed --- node/src/protocol/contract/primitives.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 6e8f3a161..d691de7ee 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -1,6 +1,6 @@ use cait_sith::protocol::Participant; use mpc_keys::hpke; -use near_primitives::types::AccountId; +use near_primitives::{borsh::BorshDeserialize, types::AccountId}; use serde::{Deserialize, Serialize}; use std::{ collections::{BTreeMap, HashSet}, @@ -42,17 +42,15 @@ impl From for Participants { account_id: AccountId::from_str( &contract_participant_info.account_id.to_string(), ) - .unwrap(), // TODO: remove unwrap + .unwrap(), url: contract_participant_info.url, cipher_pk: hpke::PublicKey::from_bytes( &contract_participant_info.cipher_pk, ), - sign_pk: near_crypto::PublicKey::SECP256K1( - near_crypto::Secp256K1PublicKey::try_from( - &contract_participant_info.sign_pk.as_bytes()[1..], - ) - .unwrap(), - ), + sign_pk: BorshDeserialize::try_from_slice( + contract_participant_info.sign_pk.as_bytes(), + ) + .unwrap(), }, ) }) From e4093e08553d58653364af830fef002d4d0f3bdb Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 19:44:58 +0200 Subject: [PATCH 14/21] me removed from consensus context --- node/src/protocol/consensus.rs | 202 +++++++++++++++++++-------------- node/src/protocol/mod.rs | 4 - 2 files changed, 119 insertions(+), 87 deletions(-) diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index 51cc9e5df..de0ed23ae 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -26,7 +26,6 @@ use url::Url; #[async_trait::async_trait] pub trait ConsensusCtx { - async fn me(&self) -> Participant; fn my_account_id(&self) -> &AccountId; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; @@ -101,48 +100,52 @@ impl ConsensusProtocol for StartedState { } Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { - if contract_state.participants.contains_key(&ctx.me().await) { - tracing::info!( - "contract state is running and we are already a participant" - ); - let participants_vec: Vec = - contract_state.participants.keys().cloned().collect(); - Ok(NodeState::Running(RunningState { - epoch, - participants: contract_state.participants, - threshold: contract_state.threshold, - private_share, - public_key, - sign_queue: ctx.sign_queue(), - triple_manager: Arc::new(RwLock::new(TripleManager::new( - participants_vec.clone(), - ctx.me().await, - contract_state.threshold, + match contract_state + .participants + .find_participant(&ctx.my_account_id()) + { + Some(me) => { + tracing::info!( + "contract state is running and we are already a participant" + ); + let participants_vec: Vec = + contract_state.participants.keys().cloned().collect(); + Ok(NodeState::Running(RunningState { epoch, - ))), - presignature_manager: Arc::new(RwLock::new( - PresignatureManager::new( + participants: contract_state.participants, + threshold: contract_state.threshold, + private_share, + public_key, + sign_queue: ctx.sign_queue(), + triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - ctx.me().await, + me, contract_state.threshold, epoch, - ), - )), - signature_manager: Arc::new(RwLock::new( - SignatureManager::new( - participants_vec, - ctx.me().await, - contract_state.public_key, - epoch, - ), - )), - messages: Default::default(), - })) - } else { - Ok(NodeState::Joining(JoiningState { + ))), + presignature_manager: Arc::new(RwLock::new( + PresignatureManager::new( + participants_vec.clone(), + me, + contract_state.threshold, + epoch, + ), + )), + signature_manager: Arc::new(RwLock::new( + SignatureManager::new( + participants_vec, + me, + contract_state.public_key, + epoch, + ), + )), + messages: Default::default(), + })) + } + None => Ok(NodeState::Joining(JoiningState { participants: contract_state.participants, public_key, - })) + })), } } } @@ -175,23 +178,31 @@ impl ConsensusProtocol for StartedState { }, None => match contract_state { ProtocolState::Initializing(contract_state) => { - if contract_state.participants.contains_key(&ctx.me().await) { - tracing::info!("starting key generation as a part of the participant set"); - let participants = contract_state.participants; - let protocol = cait_sith::keygen::( - &participants.keys().cloned().collect::>(), - ctx.me().await, - contract_state.threshold, - )?; - Ok(NodeState::Generating(GeneratingState { - participants, - threshold: contract_state.threshold, - protocol: Arc::new(RwLock::new(protocol)), - messages: Default::default(), - })) - } else { - tracing::info!("we are not a part of the initial participant set, waiting for key generation to complete"); - Ok(NodeState::Started(self)) + match contract_state + .participants + .find_participant(&ctx.my_account_id()) + { + Some(me) => { + tracing::info!( + "starting key generation as a part of the participant set" + ); + let participants = contract_state.participants; + let protocol = cait_sith::keygen::( + &participants.keys().cloned().collect::>(), + me, + contract_state.threshold, + )?; + Ok(NodeState::Generating(GeneratingState { + participants, + threshold: contract_state.threshold, + protocol: Arc::new(RwLock::new(protocol)), + messages: Default::default(), + })) + } + None => { + tracing::info!("we are not a part of the initial participant set, waiting for key generation to complete"); + Ok(NodeState::Started(self)) + } } } ProtocolState::Running(contract_state) => Ok(NodeState::Joining(JoiningState { @@ -312,6 +323,12 @@ impl ConsensusProtocol for WaitingForConsensusState { } let participants_vec: Vec = self.participants.keys().cloned().collect(); + + let me = contract_state + .participants + .find_participant(&ctx.my_account_id()) + .unwrap(); + Ok(NodeState::Running(RunningState { epoch: self.epoch, participants: self.participants, @@ -321,19 +338,19 @@ impl ConsensusProtocol for WaitingForConsensusState { sign_queue: ctx.sign_queue(), triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - ctx.me().await, + me, self.threshold, self.epoch, ))), presignature_manager: Arc::new(RwLock::new(PresignatureManager::new( participants_vec.clone(), - ctx.me().await, + me, self.threshold, self.epoch, ))), signature_manager: Arc::new(RwLock::new(SignatureManager::new( participants_vec, - ctx.me().await, + me, self.public_key, self.epoch, ))), @@ -373,23 +390,34 @@ impl ConsensusProtocol for WaitingForConsensusState { "waiting for resharing consensus, contract state has not been finalized yet" ); let has_voted = contract_state.finished_votes.contains(ctx.my_account_id()); - if !has_voted - && contract_state - .old_participants - .contains_key(&ctx.me().await) + match contract_state + .old_participants + .find_participant(&ctx.my_account_id()) { - tracing::info!( - epoch = self.epoch, - "we haven't voted yet, voting for resharing to complete" - ); - rpc_client::vote_reshared( - ctx.rpc_client(), - ctx.signer(), - ctx.mpc_contract_id(), - self.epoch, - ) - .await - .unwrap(); + Some(_) => { + if !has_voted { + tracing::info!( + epoch = self.epoch, + "we haven't voted yet, voting for resharing to complete" + ); + rpc_client::vote_reshared( + ctx.rpc_client(), + ctx.signer(), + ctx.mpc_contract_id(), + self.epoch, + ) + .await + .unwrap(); + } else { + tracing::info!( + epoch = self.epoch, + "we have voted for resharing to complete" + ); + } + } + None => { + tracing::info!("we are not a part of the old participant set"); + } } Ok(NodeState::WaitingForConsensus(self)) } @@ -451,13 +479,13 @@ impl ConsensusProtocol for RunningState { Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { tracing::info!("contract is resharing"); - if !contract_state + let is_in_old_participant_set = contract_state .old_participants - .contains_key(&ctx.me().await) - || !contract_state - .new_participants - .contains_key(&ctx.me().await) - { + .contains_account_id(ctx.my_account_id()); + let is_in_new_participant_set = contract_state + .new_participants + .contains_account_id(ctx.my_account_id()); + if !is_in_old_participant_set || !is_in_new_participant_set { return Err(ConsensusError::HasBeenKicked); } if contract_state.public_key != self.public_key { @@ -555,6 +583,10 @@ impl ConsensusProtocol for JoiningState { match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => { + let me = contract_state + .participants + .find_participant(&ctx.my_account_id()) + .unwrap(); if contract_state.candidates.contains_key(&ctx.my_account_id()) { let voted = contract_state .join_votes @@ -570,7 +602,7 @@ impl ConsensusProtocol for JoiningState { if voted.contains(&info.account_id) { continue; } - http_client::join(ctx.http_client(), info.url, &ctx.me().await) + http_client::join(ctx.http_client(), info.url, &me) .await .unwrap() } @@ -578,7 +610,7 @@ impl ConsensusProtocol for JoiningState { } else { tracing::info!("sending a transaction to join the participant set"); let args = serde_json::json!({ - "participant_id": ctx.me().await, + "participant_id": me, "url": ctx.my_address(), "cipher_pk": ctx.cipher_pk().to_bytes(), "sign_pk": ctx.sign_pk(), @@ -602,7 +634,7 @@ impl ConsensusProtocol for JoiningState { ProtocolState::Resharing(contract_state) => { if contract_state .new_participants - .contains_key(&ctx.me().await) + .contains_account_id(&ctx.my_account_id()) { tracing::info!("joining as a new participant"); start_resharing(None, ctx, contract_state).await @@ -642,6 +674,10 @@ async fn start_resharing( ctx: C, contract_state: ResharingContractState, ) -> Result { + let me = contract_state + .new_participants + .find_participant(&ctx.my_account_id()) + .unwrap(); let protocol = cait_sith::reshare::( &contract_state .old_participants @@ -655,7 +691,7 @@ async fn start_resharing( .cloned() .collect::>(), contract_state.threshold, - ctx.me().await, + me, private_share, contract_state.public_key, )?; diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index c7ea3a617..759c90bdd 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -51,10 +51,6 @@ struct Ctx { #[async_trait::async_trait] impl ConsensusCtx for &MpcSignProtocol { - async fn me(&self) -> Participant { - get_my_participant(self).await - } - fn my_account_id(&self) -> &AccountId { &self.ctx.account_id } From 3c3a9472199d468a53ca0e288614375b8bc3009f Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 19:59:18 +0200 Subject: [PATCH 15/21] formatting, todos --- node/src/protocol/contract/mod.rs | 4 +--- node/src/protocol/contract/primitives.rs | 10 ++++------ node/src/protocol/signature.rs | 1 - 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/node/src/protocol/contract/mod.rs b/node/src/protocol/contract/mod.rs index e24b153b5..867d8a450 100644 --- a/node/src/protocol/contract/mod.rs +++ b/node/src/protocol/contract/mod.rs @@ -72,9 +72,7 @@ impl From for ResharingContractState { finished_votes: contract_state .finished_votes .into_iter() - .map(|acc_id| { - AccountId::from_str(&acc_id.to_string()).unwrap() // TODO: code duplication - }) + .map(|acc_id| AccountId::from_str(&acc_id.to_string()).unwrap()) .collect(), } } diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index d691de7ee..92e1e895a 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -154,10 +154,10 @@ impl From for Candidates { .into_iter() .map(|(account_id, candidate_info)| { ( - AccountId::from_str(&account_id.to_string()).unwrap(), // TODO: fix unwrap + AccountId::from_str(&account_id.to_string()).unwrap(), CandidateInfo { account_id: AccountId::from_str(&candidate_info.account_id.to_string()) - .unwrap(), // TODO: fix unwrap + .unwrap(), url: candidate_info.url, cipher_pk: hpke::PublicKey::from_bytes(&candidate_info.cipher_pk), sign_pk: near_crypto::PublicKey::SECP256K1( @@ -200,7 +200,6 @@ impl From for PkVotes { .into_iter() .map(|acc_id: near_sdk::AccountId| { AccountId::from_str(&acc_id.to_string()).unwrap() - // TODO: fix unwrap }) .collect(), ) @@ -229,14 +228,13 @@ impl From for Votes { .into_iter() .map(|(account_id, participants)| { ( - AccountId::from_str(&account_id.to_string()).unwrap(), // TODO: fix unwrap + AccountId::from_str(&account_id.to_string()).unwrap(), participants .into_iter() .map(|acc_id: near_sdk::AccountId| { AccountId::from_str(&acc_id.to_string()).unwrap() - // TODO: fix unwrap }) - .collect(), // TODO: remove code duplication + .collect(), ) }) .collect(), diff --git a/node/src/protocol/signature.rs b/node/src/protocol/signature.rs index e44338cb3..b6933e6e8 100644 --- a/node/src/protocol/signature.rs +++ b/node/src/protocol/signature.rs @@ -254,7 +254,6 @@ impl SignatureManager { break false; } }; - match action { Action::Wait => { tracing::debug!("waiting"); From 9c2cfed23215aa3f43347515d47d0dac00e9216c Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 23:45:39 +0200 Subject: [PATCH 16/21] clippy, participants fn refactored --- node/src/protocol/consensus.rs | 18 +++++++++--------- node/src/protocol/contract/mod.rs | 2 +- node/src/protocol/contract/primitives.rs | 20 ++++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index de0ed23ae..39b79ffa3 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -102,7 +102,7 @@ impl ConsensusProtocol for StartedState { Ordering::Equal => { match contract_state .participants - .find_participant(&ctx.my_account_id()) + .find_participant(ctx.my_account_id()) { Some(me) => { tracing::info!( @@ -180,7 +180,7 @@ impl ConsensusProtocol for StartedState { ProtocolState::Initializing(contract_state) => { match contract_state .participants - .find_participant(&ctx.my_account_id()) + .find_participant(ctx.my_account_id()) { Some(me) => { tracing::info!( @@ -326,7 +326,7 @@ impl ConsensusProtocol for WaitingForConsensusState { let me = contract_state .participants - .find_participant(&ctx.my_account_id()) + .find_participant(ctx.my_account_id()) .unwrap(); Ok(NodeState::Running(RunningState { @@ -392,7 +392,7 @@ impl ConsensusProtocol for WaitingForConsensusState { let has_voted = contract_state.finished_votes.contains(ctx.my_account_id()); match contract_state .old_participants - .find_participant(&ctx.my_account_id()) + .find_participant(ctx.my_account_id()) { Some(_) => { if !has_voted { @@ -585,12 +585,12 @@ impl ConsensusProtocol for JoiningState { ProtocolState::Running(contract_state) => { let me = contract_state .participants - .find_participant(&ctx.my_account_id()) + .find_participant(ctx.my_account_id()) .unwrap(); - if contract_state.candidates.contains_key(&ctx.my_account_id()) { + if contract_state.candidates.contains_key(ctx.my_account_id()) { let voted = contract_state .join_votes - .get(&ctx.my_account_id()) + .get(ctx.my_account_id()) .cloned() .unwrap_or_default(); tracing::info!( @@ -634,7 +634,7 @@ impl ConsensusProtocol for JoiningState { ProtocolState::Resharing(contract_state) => { if contract_state .new_participants - .contains_account_id(&ctx.my_account_id()) + .contains_account_id(ctx.my_account_id()) { tracing::info!("joining as a new participant"); start_resharing(None, ctx, contract_state).await @@ -676,7 +676,7 @@ async fn start_resharing( ) -> Result { let me = contract_state .new_participants - .find_participant(&ctx.my_account_id()) + .find_participant(ctx.my_account_id()) .unwrap(); let protocol = cait_sith::reshare::( &contract_state diff --git a/node/src/protocol/contract/mod.rs b/node/src/protocol/contract/mod.rs index 867d8a450..56441439a 100644 --- a/node/src/protocol/contract/mod.rs +++ b/node/src/protocol/contract/mod.rs @@ -72,7 +72,7 @@ impl From for ResharingContractState { finished_votes: contract_state .finished_votes .into_iter() - .map(|acc_id| AccountId::from_str(&acc_id.to_string()).unwrap()) + .map(|acc_id| AccountId::from_str(acc_id.as_ref()).unwrap()) .collect(), } } diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 92e1e895a..786a78e19 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -40,7 +40,7 @@ impl From for Participants { ParticipantInfo { id: participant_id as ParticipantId, account_id: AccountId::from_str( - &contract_participant_info.account_id.to_string(), + contract_participant_info.account_id.as_ref(), ) .unwrap(), url: contract_participant_info.url, @@ -94,21 +94,21 @@ impl Participants { pub fn find_participant_info(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { self.participants - .iter() - .find(|(_, participant_info)| participant_info.account_id == *account_id) - .map(|(_, participant_info)| participant_info) + .values() + .find(|participant_info| participant_info.account_id == *account_id) + .map(|participant_info| participant_info) } pub fn contains_account_id(&self, account_id: &AccountId) -> bool { self.participants - .iter() - .any(|(_, participant_info)| participant_info.account_id == *account_id) + .values() + .any(|participant_info| participant_info.account_id == *account_id) } pub fn account_ids(&self) -> Vec { self.participants - .iter() - .map(|(_, participant_info)| participant_info.account_id.clone()) + .values() + .map(|participant_info| participant_info.account_id.clone()) .collect() } } @@ -199,7 +199,7 @@ impl From for PkVotes { participants .into_iter() .map(|acc_id: near_sdk::AccountId| { - AccountId::from_str(&acc_id.to_string()).unwrap() + AccountId::from_str(acc_id.as_ref()).unwrap() }) .collect(), ) @@ -232,7 +232,7 @@ impl From for Votes { participants .into_iter() .map(|acc_id: near_sdk::AccountId| { - AccountId::from_str(&acc_id.to_string()).unwrap() + AccountId::from_str(acc_id.as_ref()).unwrap() }) .collect(), ) From 8c5f1cf7360fda480f1422312d8e65b319a2ca70 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Mon, 15 Jan 2024 23:54:55 +0200 Subject: [PATCH 17/21] clippy, participants fn refactored --- node/src/protocol/contract/primitives.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 786a78e19..86aa27fc6 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -96,7 +96,6 @@ impl Participants { self.participants .values() .find(|participant_info| participant_info.account_id == *account_id) - .map(|participant_info| participant_info) } pub fn contains_account_id(&self, account_id: &AccountId) -> bool { @@ -154,9 +153,9 @@ impl From for Candidates { .into_iter() .map(|(account_id, candidate_info)| { ( - AccountId::from_str(&account_id.to_string()).unwrap(), + AccountId::from_str(account_id.as_ref()).unwrap(), CandidateInfo { - account_id: AccountId::from_str(&candidate_info.account_id.to_string()) + account_id: AccountId::from_str(candidate_info.account_id.as_ref()) .unwrap(), url: candidate_info.url, cipher_pk: hpke::PublicKey::from_bytes(&candidate_info.cipher_pk), @@ -228,7 +227,7 @@ impl From for Votes { .into_iter() .map(|(account_id, participants)| { ( - AccountId::from_str(&account_id.to_string()).unwrap(), + AccountId::from_str(account_id.as_ref()).unwrap(), participants .into_iter() .map(|acc_id: near_sdk::AccountId| { From c653dc5d22ef4437f36e187147a0c1f0dbba7d80 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Tue, 16 Jan 2024 00:12:32 +0200 Subject: [PATCH 18/21] clippy --- integration-tests/src/multichain/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-tests/src/multichain/mod.rs b/integration-tests/src/multichain/mod.rs index cd241be35..4f6e642ab 100644 --- a/integration-tests/src/multichain/mod.rs +++ b/integration-tests/src/multichain/mod.rs @@ -123,7 +123,7 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul .into_iter() .collect::, _>>()?; let mut node_futures = Vec::new(); - for (_, account) in accounts.iter().enumerate() { + for account in accounts.iter() { let node = containers::Node::run(&ctx, account.id(), account.secret_key()); node_futures.push(node); } From a7fdbf11aef411b4556423e11648a287718e0687 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Tue, 16 Jan 2024 00:57:25 +0200 Subject: [PATCH 19/21] vote_join accepts account id --- node/src/http_client.rs | 11 ++- node/src/protocol/consensus.rs | 91 +++++++++++++----------- node/src/protocol/contract/primitives.rs | 4 ++ node/src/web/mod.rs | 8 +-- 4 files changed, 64 insertions(+), 50 deletions(-) diff --git a/node/src/http_client.rs b/node/src/http_client.rs index c0a3792be..4d4e91a6c 100644 --- a/node/src/http_client.rs +++ b/node/src/http_client.rs @@ -3,6 +3,7 @@ use crate::protocol::message::SignedMessage; use crate::protocol::MpcMessage; use cait_sith::protocol::Participant; use mpc_keys::hpke; +use near_primitives::types::AccountId; use reqwest::{Client, IntoUrl}; use std::collections::VecDeque; use std::str::Utf8Error; @@ -73,8 +74,12 @@ async fn send_encrypted( Retry::spawn(retry_strategy, action).await } -pub async fn join(client: &Client, url: U, me: &Participant) -> Result<(), SendError> { - let _span = tracing::info_span!("join_request", ?me); +pub async fn join( + client: &Client, + url: U, + account_id: &AccountId, +) -> Result<(), SendError> { + let _span = tracing::info_span!("join_request", ?account_id); let mut url = url.into_url()?; url.set_path("join"); tracing::debug!(%url, "making http request"); @@ -82,7 +87,7 @@ pub async fn join(client: &Client, url: U, me: &Participant) -> Resu let response = client .post(url.clone()) .header("content-type", "application/json") - .json(&me) + .json(&account_id) .send() .await .map_err(SendError::ReqwestClientError)?; diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index 39b79ffa3..d30f92fc6 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -583,52 +583,57 @@ impl ConsensusProtocol for JoiningState { match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => { - let me = contract_state - .participants - .find_participant(ctx.my_account_id()) - .unwrap(); - if contract_state.candidates.contains_key(ctx.my_account_id()) { - let voted = contract_state - .join_votes - .get(ctx.my_account_id()) - .cloned() - .unwrap_or_default(); - tracing::info!( - already_voted = voted.len(), - votes_to_go = contract_state.threshold - voted.len(), - "trying to get participants to vote for us" - ); - for (_, info) in contract_state.participants { - if voted.contains(&info.account_id) { - continue; - } - http_client::join(ctx.http_client(), info.url, &me) + match contract_state + .candidates + .find_candidate(ctx.my_account_id()) + { + Some(candidate_info) => { + let voted = contract_state + .join_votes + .get(ctx.my_account_id()) + .cloned() + .unwrap_or_default(); + tracing::info!( + already_voted = voted.len(), + votes_to_go = contract_state.threshold - voted.len(), + "trying to get participants to vote for us" + ); + for (_, info) in contract_state.participants { + if voted.contains(&info.account_id) { + continue; + } + http_client::join( + ctx.http_client(), + info.url, + &candidate_info.account_id, + ) .await .unwrap() + } + Ok(NodeState::Joining(self)) + } + None => { + tracing::info!("sending a transaction to join the participant set"); + let args = serde_json::json!({ + "url": ctx.my_address(), + "cipher_pk": ctx.cipher_pk().to_bytes(), + "sign_pk": ctx.sign_pk(), + }); + ctx.rpc_client() + .send_tx( + ctx.signer(), + ctx.mpc_contract_id(), + vec![Action::FunctionCall(FunctionCallAction { + method_name: "join".to_string(), + args: args.to_string().into_bytes(), + gas: 300_000_000_000_000, + deposit: 0, + })], + ) + .await + .unwrap(); + Ok(NodeState::Joining(self)) } - Ok(NodeState::Joining(self)) - } else { - tracing::info!("sending a transaction to join the participant set"); - let args = serde_json::json!({ - "participant_id": me, - "url": ctx.my_address(), - "cipher_pk": ctx.cipher_pk().to_bytes(), - "sign_pk": ctx.sign_pk(), - }); - ctx.rpc_client() - .send_tx( - ctx.signer(), - ctx.mpc_contract_id(), - vec![Action::FunctionCall(FunctionCallAction { - method_name: "join".to_string(), - args: args.to_string().into_bytes(), - gas: 300_000_000_000_000, - deposit: 0, - })], - ) - .await - .unwrap(); - Ok(NodeState::Joining(self)) } } ProtocolState::Resharing(contract_state) => { diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 86aa27fc6..8a8462f1b 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -143,6 +143,10 @@ impl Candidates { pub fn iter(&self) -> impl Iterator { self.candidates.iter() } + + pub fn find_candidate(&self, account_id: &AccountId) -> Option<&CandidateInfo> { + self.candidates.get(account_id) + } } impl From for Candidates { diff --git a/node/src/web/mod.rs b/node/src/web/mod.rs index aca70012c..022d2f113 100644 --- a/node/src/web/mod.rs +++ b/node/src/web/mod.rs @@ -100,13 +100,13 @@ async fn msg( #[tracing::instrument(level = "debug", skip_all)] async fn join( Extension(state): Extension>, - WithRejection(Json(participant), _): WithRejection, Error>, + WithRejection(Json(account_id), _): WithRejection, Error>, ) -> Result<()> { let protocol_state = state.protocol_state.read().await; match &*protocol_state { NodeState::Running { .. } => { let args = serde_json::json!({ - "participant": participant + "candidate_account_id": account_id }); match state .rpc_client @@ -123,7 +123,7 @@ async fn join( .await { Ok(_) => { - tracing::info!(?participant, "successfully voted for a node to join"); + tracing::info!(?account_id, "successfully voted for a node to join"); Ok(()) } Err(e) => { @@ -133,7 +133,7 @@ async fn join( } } _ => { - tracing::debug!(?participant, "not ready to accept join requests yet"); + tracing::debug!(?account_id, "not ready to accept join requests yet"); Err(Error::NotRunning) } } From bea5bff61246fc226db48be1aef859866489d7c6 Mon Sep 17 00:00:00 2001 From: Serhii Volovyk Date: Tue, 16 Jan 2024 01:06:39 +0200 Subject: [PATCH 20/21] sign_pk conversion fix --- node/src/protocol/contract/primitives.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 8a8462f1b..0d75a2bcb 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -163,12 +163,10 @@ impl From for Candidates { .unwrap(), url: candidate_info.url, cipher_pk: hpke::PublicKey::from_bytes(&candidate_info.cipher_pk), - sign_pk: near_crypto::PublicKey::SECP256K1( - near_crypto::Secp256K1PublicKey::try_from( - &candidate_info.sign_pk.as_bytes()[1..], - ) - .unwrap(), - ), + sign_pk: BorshDeserialize::try_from_slice( + candidate_info.sign_pk.as_bytes(), + ) + .unwrap(), }, ) }) From b176b213c00f8e3ed11013417ecb1ec71132ffa1 Mon Sep 17 00:00:00 2001 From: Daniyar Itegulov Date: Tue, 16 Jan 2024 21:26:33 +1100 Subject: [PATCH 21/21] small refactorings --- contract/src/lib.rs | 1 - contract/src/primitives.rs | 2 ++ integration-tests/src/multichain/mod.rs | 10 ++++------ node/src/protocol/consensus.rs | 1 - node/src/protocol/contract/primitives.rs | 2 +- node/src/protocol/cryptography.rs | 6 +++--- node/src/protocol/message.rs | 4 ++-- node/src/protocol/mod.rs | 3 +-- 8 files changed, 13 insertions(+), 16 deletions(-) diff --git a/contract/src/lib.rs b/contract/src/lib.rs index 0dc44c656..771444465 100644 --- a/contract/src/lib.rs +++ b/contract/src/lib.rs @@ -145,7 +145,6 @@ impl MpcContract { participants, threshold, public_key, - candidates: _, leave_votes, .. }) => { diff --git a/contract/src/primitives.rs b/contract/src/primitives.rs index b1706b9c2..595b08127 100644 --- a/contract/src/primitives.rs +++ b/contract/src/primitives.rs @@ -95,6 +95,7 @@ impl Participants { pub fn get(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { self.participants.get(account_id) } + pub fn iter(&self) -> impl Iterator { self.participants.iter() } @@ -145,6 +146,7 @@ impl Candidates { pub fn get(&self, account_id: &AccountId) -> Option<&CandidateInfo> { self.candidates.get(account_id) } + pub fn iter(&self) -> impl Iterator { self.candidates.iter() } diff --git a/integration-tests/src/multichain/mod.rs b/integration-tests/src/multichain/mod.rs index 4f6e642ab..c91dab7d8 100644 --- a/integration-tests/src/multichain/mod.rs +++ b/integration-tests/src/multichain/mod.rs @@ -123,7 +123,7 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul .into_iter() .collect::, _>>()?; let mut node_futures = Vec::new(); - for account in accounts.iter() { + for account in &accounts { let node = containers::Node::run(&ctx, account.id(), account.secret_key()); node_futures.push(node); } @@ -134,9 +134,8 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul let participants: HashMap = accounts .iter() .cloned() - .enumerate() .zip(&nodes) - .map(|((_, account), node)| { + .map(|(account, node)| { ( account.id().clone(), ParticipantInfo { @@ -169,7 +168,7 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result< .into_iter() .collect::, _>>()?; let mut node_futures = Vec::with_capacity(nodes); - for (_, account) in accounts.iter().enumerate().take(nodes) { + for account in accounts.iter().take(nodes) { node_futures.push(local::Node::run(&ctx, account.id(), account.secret_key())); } let nodes = futures::future::join_all(node_futures) @@ -179,9 +178,8 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result< let participants: HashMap = accounts .iter() .cloned() - .enumerate() .zip(&nodes) - .map(|((_, account), node)| { + .map(|(account, node)| { ( account.id().clone(), ParticipantInfo { diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index d30f92fc6..1834b6a9a 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -24,7 +24,6 @@ use std::sync::Arc; use tokio::sync::RwLock; use url::Url; -#[async_trait::async_trait] pub trait ConsensusCtx { fn my_account_id(&self) -> &AccountId; fn http_client(&self) -> &reqwest::Client; diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs index 0d75a2bcb..7c99cbd8d 100644 --- a/node/src/protocol/contract/primitives.rs +++ b/node/src/protocol/contract/primitives.rs @@ -11,7 +11,7 @@ type ParticipantId = u32; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct ParticipantInfo { - pub id: ParticipantId, // TODO: do we need this parameter? + pub id: ParticipantId, pub account_id: AccountId, pub url: String, /// The public key used for encrypting messages. diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index 96615fc9e..5ba3b2235 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -93,8 +93,8 @@ impl CryptographicProtocol for GeneratingState { Action::SendMany(m) => { tracing::debug!("sending a message to many participants"); let mut messages = self.messages.write().await; - for (p, info) in self.participants.clone() { - if p == ctx.me().await { + for (p, info) in self.participants.iter() { + if p == &ctx.me().await { // Skip yourself, cait-sith never sends messages to oneself continue; } @@ -196,7 +196,7 @@ impl CryptographicProtocol for ResharingState { .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { - tracing::warn!(?err, new = ?self.new_participants.clone(), old = ?self.old_participants.clone(), "resharing(wait): failed to send encrypted message"); + tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(wait): failed to send encrypted message"); } return Ok(NodeState::Resharing(self)); diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index b3d387185..7f6dab3e0 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -235,7 +235,7 @@ impl MessageHandler for RunningState { let mut protocol = protocol .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; - protocol.message(message.from, message.data); + protocol.message(message.from, message.data) } Err(presignature::GenerationError::AlreadyGenerated) => { tracing::info!(id, "presignature already generated, nothing left to do") @@ -289,7 +289,7 @@ impl MessageHandler for RunningState { let mut protocol = protocol .write() .map_err(|err| MessageHandleError::SyncError(err.to_string()))?; - protocol.message(message.from, message.data); + protocol.message(message.from, message.data) } None => { // Store the message until we are ready to process it diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index 759c90bdd..15ee4396f 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -49,7 +49,6 @@ struct Ctx { secret_storage: SecretNodeStorageBox, } -#[async_trait::async_trait] impl ConsensusCtx for &MpcSignProtocol { fn my_account_id(&self) -> &AccountId { &self.ctx.account_id @@ -254,7 +253,7 @@ async fn get_my_participant(protocol: &MpcSignProtocol) -> Participant { .find_participant_info(&my_near_acc_id) .unwrap_or_else(|| { tracing::error!("could not find participant info for {my_near_acc_id}"); - panic!("could not find participant info for {my_near_acc_id}"); // TOOD: probably we should not panic here + panic!("could not find participant info for {my_near_acc_id}"); }); participant_info.id.into() }