diff --git a/contract/src/lib.rs b/contract/src/lib.rs index f933decba..771444465 100644 --- a/contract/src/lib.rs +++ b/contract/src/lib.rs @@ -1,65 +1,38 @@ +pub mod primitives; + use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; use near_sdk::serde::{Deserialize, Serialize}; use near_sdk::{env, near_bindgen, AccountId, PanicOnDefault, PublicKey}; +use primitives::{CandidateInfo, Candidates, ParticipantInfo, Participants, PkVotes, Votes}; use std::collections::{BTreeMap, HashSet}; -type ParticipantId = u32; - -pub mod hpke { - pub type PublicKey = [u8; 32]; -} - -#[derive( - Serialize, - Deserialize, - BorshDeserialize, - BorshSerialize, - Clone, - Hash, - PartialEq, - Eq, - PartialOrd, - Ord, - Debug, -)] -pub struct ParticipantInfo { - pub id: ParticipantId, - pub account_id: AccountId, - pub url: String, - /// The public key used for encrypting messages. - pub cipher_pk: hpke::PublicKey, - /// The public key used for verifying messages. - pub sign_pk: PublicKey, -} - #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] pub struct InitializingContractState { - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, - pub pk_votes: BTreeMap>, + pub pk_votes: PkVotes, } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] pub struct RunningContractState { pub epoch: u64, - // TODO: why is this account id for participants instead of participant id? - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub public_key: PublicKey, - pub candidates: BTreeMap, - pub join_votes: BTreeMap>, - pub leave_votes: BTreeMap>, + pub candidates: Candidates, + pub join_votes: Votes, + pub leave_votes: Votes, } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] pub struct ResharingContractState { pub old_epoch: u64, - pub old_participants: BTreeMap, + pub old_participants: Participants, // TODO: only store diff to save on storage - pub new_participants: BTreeMap, + pub new_participants: Participants, pub threshold: usize, pub public_key: PublicKey, - pub finished_votes: HashSet, + pub finished_votes: HashSet, } #[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] @@ -82,9 +55,9 @@ impl MpcContract { pub fn init(threshold: usize, participants: BTreeMap) -> Self { MpcContract { protocol_state: ProtocolContractState::Initializing(InitializingContractState { - participants, + participants: Participants { participants }, threshold, - pk_votes: BTreeMap::new(), + pk_votes: PkVotes::new(), }), } } @@ -95,9 +68,8 @@ impl MpcContract { pub fn join( &mut self, - participant_id: ParticipantId, url: String, - cipher_pk: hpke::PublicKey, + cipher_pk: primitives::hpke::PublicKey, sign_pk: PublicKey, ) { match &mut self.protocol_state { @@ -106,15 +78,14 @@ impl MpcContract { candidates, .. }) => { - let account_id = env::signer_account_id(); - if participants.contains_key(&account_id) { + let signer_account_id = env::signer_account_id(); + if participants.contains_key(&signer_account_id) { env::panic_str("this participant is already in the participant set"); } candidates.insert( - participant_id, - ParticipantInfo { - id: participant_id, - account_id, + signer_account_id.clone(), + CandidateInfo { + account_id: signer_account_id, url, cipher_pk, sign_pk, @@ -125,7 +96,7 @@ impl MpcContract { } } - pub fn vote_join(&mut self, participant: ParticipantId) -> bool { + pub fn vote_join(&mut self, candidate_account_id: AccountId) -> bool { match &mut self.protocol_state { ProtocolContractState::Running(RunningContractState { epoch, @@ -136,19 +107,19 @@ impl MpcContract { join_votes, .. }) => { - let voting_participant = participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the participant set") - }); - let candidate = candidates - .get(&participant) + let signer_account_id = env::signer_account_id(); + if !participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the participant set"); + } + let candidate_info = candidates + .get(&candidate_account_id) .unwrap_or_else(|| env::panic_str("candidate is not registered")); - let voted = join_votes.entry(participant).or_default(); - voted.insert(voting_participant.id); + let voted = join_votes.entry(candidate_account_id.clone()); + voted.insert(signer_account_id); if voted.len() >= *threshold { let mut new_participants = participants.clone(); - new_participants.insert(candidate.account_id.clone(), candidate.clone()); + new_participants + .insert(candidate_account_id.clone(), candidate_info.clone().into()); self.protocol_state = ProtocolContractState::Resharing(ResharingContractState { old_epoch: *epoch, @@ -167,30 +138,28 @@ impl MpcContract { } } - pub fn vote_leave(&mut self, participant: ParticipantId) -> bool { + pub fn vote_leave(&mut self, acc_id_to_leave: AccountId) -> bool { match &mut self.protocol_state { ProtocolContractState::Running(RunningContractState { epoch, participants, threshold, public_key, - candidates, leave_votes, .. }) => { - let voting_participant = participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the participant set") - }); - let candidate = candidates - .get(&participant) - .unwrap_or_else(|| env::panic_str("candidate is not registered")); - let voted = leave_votes.entry(participant).or_default(); - voted.insert(voting_participant.id); + let signer_account_id = env::signer_account_id(); + if !participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the participant set"); + } + if !participants.contains_key(&acc_id_to_leave) { + env::panic_str("account to leave is not in the participant set"); + } + let voted = leave_votes.entry(acc_id_to_leave.clone()); + voted.insert(signer_account_id); if voted.len() >= *threshold { let mut new_participants = participants.clone(); - new_participants.remove(&candidate.account_id); + new_participants.remove(&acc_id_to_leave); self.protocol_state = ProtocolContractState::Resharing(ResharingContractState { old_epoch: *epoch, @@ -216,22 +185,21 @@ impl MpcContract { threshold, pk_votes, }) => { - let voting_participant = participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the participant set") - }); - let voted = pk_votes.entry(public_key.clone()).or_default(); - voted.insert(voting_participant.id); + let signer_account_id = env::signer_account_id(); + if !participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the participant set"); + } + let voted = pk_votes.entry(public_key.clone()); + voted.insert(signer_account_id); if voted.len() >= *threshold { self.protocol_state = ProtocolContractState::Running(RunningContractState { epoch: 0, participants: participants.clone(), threshold: *threshold, public_key, - candidates: BTreeMap::new(), - join_votes: BTreeMap::new(), - leave_votes: BTreeMap::new(), + candidates: Candidates::new(), + join_votes: Votes::new(), + leave_votes: Votes::new(), }); true } else { @@ -257,21 +225,20 @@ impl MpcContract { if *old_epoch + 1 != epoch { env::panic_str("mismatched epochs"); } - let voting_participant = old_participants - .get(&env::signer_account_id()) - .unwrap_or_else(|| { - env::panic_str("calling account is not in the old participant set") - }); - finished_votes.insert(voting_participant.id); + let signer_account_id = env::signer_account_id(); + if !old_participants.contains_key(&signer_account_id) { + env::panic_str("calling account is not in the old participant set"); + } + finished_votes.insert(signer_account_id); if finished_votes.len() >= *threshold { self.protocol_state = ProtocolContractState::Running(RunningContractState { epoch: *old_epoch + 1, participants: new_participants.clone(), threshold: *threshold, public_key: public_key.clone(), - candidates: BTreeMap::new(), - join_votes: BTreeMap::new(), - leave_votes: BTreeMap::new(), + candidates: Candidates::new(), + join_votes: Votes::new(), + leave_votes: Votes::new(), }); true } else { diff --git a/contract/src/primitives.rs b/contract/src/primitives.rs new file mode 100644 index 000000000..595b08127 --- /dev/null +++ b/contract/src/primitives.rs @@ -0,0 +1,199 @@ +use near_sdk::borsh::{self, BorshDeserialize, BorshSerialize}; +use near_sdk::serde::{Deserialize, Serialize}; +use near_sdk::{AccountId, PublicKey}; +use std::collections::{BTreeMap, HashSet}; + +pub mod hpke { + pub type PublicKey = [u8; 32]; +} + +#[derive( + Serialize, + Deserialize, + BorshDeserialize, + BorshSerialize, + Clone, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + Debug, +)] +pub struct ParticipantInfo { + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: PublicKey, +} + +impl From for ParticipantInfo { + fn from(candidate_info: CandidateInfo) -> Self { + ParticipantInfo { + account_id: candidate_info.account_id, + url: candidate_info.url, + cipher_pk: candidate_info.cipher_pk, + sign_pk: candidate_info.sign_pk, + } + } +} + +#[derive( + Serialize, + Deserialize, + BorshDeserialize, + BorshSerialize, + Clone, + Hash, + PartialEq, + Eq, + PartialOrd, + Ord, + Debug, +)] +pub struct CandidateInfo { + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: PublicKey, +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)] +pub struct Participants { + pub participants: BTreeMap, +} + +impl Default for Participants { + fn default() -> Self { + Self::new() + } +} + +impl Participants { + pub fn new() -> Self { + Participants { + participants: BTreeMap::new(), + } + } + + pub fn contains_key(&self, account_id: &AccountId) -> bool { + self.participants.contains_key(account_id) + } + + pub fn insert(&mut self, account_id: AccountId, participant_info: ParticipantInfo) { + self.participants.insert(account_id, participant_info); + } + + pub fn remove(&mut self, account_id: &AccountId) { + self.participants.remove(account_id); + } + + pub fn get(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { + self.participants.get(account_id) + } + + pub fn iter(&self) -> impl Iterator { + self.participants.iter() + } + + pub fn keys(&self) -> impl Iterator { + self.participants.keys() + } + + pub fn len(&self) -> usize { + self.participants.len() + } + + pub fn is_empty(&self) -> bool { + self.participants.is_empty() + } +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug, Clone)] +pub struct Candidates { + pub candidates: BTreeMap, +} + +impl Default for Candidates { + fn default() -> Self { + Self::new() + } +} + +impl Candidates { + pub fn new() -> Self { + Candidates { + candidates: BTreeMap::new(), + } + } + + pub fn contains_key(&self, account_id: &AccountId) -> bool { + self.candidates.contains_key(account_id) + } + + pub fn insert(&mut self, account_id: AccountId, candidate: CandidateInfo) { + self.candidates.insert(account_id, candidate); + } + + pub fn remove(&mut self, account_id: &AccountId) { + self.candidates.remove(account_id); + } + + pub fn get(&self, account_id: &AccountId) -> Option<&CandidateInfo> { + self.candidates.get(account_id) + } + + pub fn iter(&self) -> impl Iterator { + self.candidates.iter() + } +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] +pub struct Votes { + pub votes: BTreeMap>, +} + +impl Default for Votes { + fn default() -> Self { + Self::new() + } +} + +impl Votes { + pub fn new() -> Self { + Votes { + votes: BTreeMap::new(), + } + } + + pub fn entry(&mut self, account_id: AccountId) -> &mut HashSet { + self.votes.entry(account_id).or_default() + } +} + +#[derive(BorshDeserialize, BorshSerialize, Serialize, Deserialize, Debug)] +pub struct PkVotes { + pub votes: BTreeMap>, +} + +impl Default for PkVotes { + fn default() -> Self { + Self::new() + } +} + +impl PkVotes { + pub fn new() -> Self { + PkVotes { + votes: BTreeMap::new(), + } + } + + pub fn entry(&mut self, public_key: PublicKey) -> &mut HashSet { + self.votes.entry(public_key).or_default() + } +} diff --git a/integration-tests/src/multichain/containers.rs b/integration-tests/src/multichain/containers.rs index 1c440fa15..6040bb57a 100644 --- a/integration-tests/src/multichain/containers.rs +++ b/integration-tests/src/multichain/containers.rs @@ -32,17 +32,15 @@ impl<'a> Node<'a> { pub async fn run( ctx: &super::Context<'a>, - node_id: u32, - account: &AccountId, + account_id: &AccountId, account_sk: &near_workspaces::types::SecretKey, ) -> anyhow::Result> { - tracing::info!(node_id, "running node container"); + tracing::info!("running node container, account_id={}", account_id); let (cipher_sk, cipher_pk) = hpke::generate(); let args = mpc_recovery_node::cli::Cli::Start { - node_id: node_id.into(), near_rpc: ctx.lake_indexer.rpc_host_address.clone(), mpc_contract_id: ctx.mpc_contract.id().clone(), - account: account.clone(), + account_id: account_id.clone(), account_sk: account_sk.to_string().parse()?, web_port: Self::CONTAINER_PORT, cipher_pk: hex::encode(cipher_pk.to_bytes()), @@ -80,7 +78,11 @@ impl<'a> Node<'a> { }); let full_address = format!("http://{ip_address}:{}", Self::CONTAINER_PORT); - tracing::info!(node_id, full_address, "node container is running"); + tracing::info!( + full_address, + "node container is running, account_id={}", + account_id + ); Ok(Node { container, address: full_address, diff --git a/integration-tests/src/multichain/local.rs b/integration-tests/src/multichain/local.rs index eed87114a..e5f1249c9 100644 --- a/integration-tests/src/multichain/local.rs +++ b/integration-tests/src/multichain/local.rs @@ -6,8 +6,7 @@ use near_workspaces::AccountId; #[allow(dead_code)] pub struct Node { pub address: String, - node_id: usize, - account: AccountId, + account_id: AccountId, pub account_sk: near_workspaces::types::SecretKey, pub cipher_pk: hpke::PublicKey, cipher_sk: hpke::SecretKey, @@ -20,17 +19,15 @@ pub struct Node { impl Node { pub async fn run( ctx: &super::Context<'_>, - node_id: u32, - account: &AccountId, + account_id: &AccountId, account_sk: &near_workspaces::types::SecretKey, ) -> anyhow::Result { let web_port = util::pick_unused_port().await?; let (cipher_sk, cipher_pk) = hpke::generate(); let cli = mpc_recovery_node::cli::Cli::Start { - node_id: node_id.into(), near_rpc: ctx.lake_indexer.rpc_host_address.clone(), mpc_contract_id: ctx.mpc_contract.id().clone(), - account: account.clone(), + account_id: account_id.clone(), account_sk: account_sk.to_string().parse()?, web_port, cipher_pk: hex::encode(cipher_pk.to_bytes()), @@ -48,17 +45,16 @@ impl Node { }, }; - let mpc_node_id = format!("multichain/{node_id}"); + let mpc_node_id = format!("multichain/{account_id}", account_id = account_id); let process = mpc::spawn_multichain(ctx.release, &mpc_node_id, cli)?; let address = format!("http://127.0.0.1:{web_port}"); tracing::info!("node is starting at {}", address); util::ping_until_ok(&address, 60).await?; - tracing::info!("node started [node_id={node_id}, {address}]"); + tracing::info!("node started [node_account_id={account_id}, {address}]"); Ok(Self { address, - node_id: node_id as usize, - account: account.clone(), + account_id: account_id.clone(), account_sk: account_sk.clone(), cipher_pk, cipher_sk, diff --git a/integration-tests/src/multichain/mod.rs b/integration-tests/src/multichain/mod.rs index 49d4b89bb..c91dab7d8 100644 --- a/integration-tests/src/multichain/mod.rs +++ b/integration-tests/src/multichain/mod.rs @@ -3,7 +3,7 @@ pub mod local; use crate::env::containers::DockerClient; use crate::{initialize_lake_indexer, LakeIndexerCtx}; -use mpc_contract::ParticipantInfo; +use mpc_contract::primitives::ParticipantInfo; use near_workspaces::network::Sandbox; use near_workspaces::{AccountId, Contract, Worker}; use serde_json::json; @@ -50,17 +50,16 @@ impl Nodes<'_> { pub async fn add_node( &mut self, - node_id: u32, account: &AccountId, account_sk: &near_workspaces::types::SecretKey, ) -> anyhow::Result<()> { tracing::info!(%account, "adding one more node"); match self { Nodes::Local { ctx, nodes } => { - nodes.push(local::Node::run(ctx, node_id, account, account_sk).await?) + nodes.push(local::Node::run(ctx, account, account_sk).await?) } Nodes::Docker { ctx, nodes } => { - nodes.push(containers::Node::run(ctx, node_id, account, account_sk).await?) + nodes.push(containers::Node::run(ctx, account, account_sk).await?) } } @@ -124,8 +123,8 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul .into_iter() .collect::, _>>()?; let mut node_futures = Vec::new(); - for (i, account) in accounts.iter().enumerate() { - let node = containers::Node::run(&ctx, i as u32, account.id(), account.secret_key()); + for account in &accounts { + let node = containers::Node::run(&ctx, account.id(), account.secret_key()); node_futures.push(node); } let nodes = futures::future::join_all(node_futures) @@ -135,13 +134,11 @@ pub async fn docker(nodes: usize, docker_client: &DockerClient) -> anyhow::Resul let participants: HashMap = accounts .iter() .cloned() - .enumerate() .zip(&nodes) - .map(|((i, account), node)| { + .map(|(account, node)| { ( account.id().clone(), ParticipantInfo { - id: i as u32, account_id: account.id().to_string().parse().unwrap(), url: node.address.clone(), cipher_pk: node.cipher_pk.to_bytes(), @@ -171,13 +168,8 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result< .into_iter() .collect::, _>>()?; let mut node_futures = Vec::with_capacity(nodes); - for (i, account) in accounts.iter().enumerate().take(nodes) { - node_futures.push(local::Node::run( - &ctx, - i as u32, - account.id(), - account.secret_key(), - )); + for account in accounts.iter().take(nodes) { + node_futures.push(local::Node::run(&ctx, account.id(), account.secret_key())); } let nodes = futures::future::join_all(node_futures) .await @@ -186,13 +178,11 @@ pub async fn host(nodes: usize, docker_client: &DockerClient) -> anyhow::Result< let participants: HashMap = accounts .iter() .cloned() - .enumerate() .zip(&nodes) - .map(|((i, account), node)| { + .map(|(account, node)| { ( account.id().clone(), ParticipantInfo { - id: i as u32, account_id: account.id().to_string().parse().unwrap(), url: node.address.clone(), cipher_pk: node.cipher_pk.to_bytes(), diff --git a/integration-tests/tests/multichain/mod.rs b/integration-tests/tests/multichain/mod.rs index 12a098631..d6d7c99ed 100644 --- a/integration-tests/tests/multichain/mod.rs +++ b/integration-tests/tests/multichain/mod.rs @@ -19,7 +19,7 @@ async fn test_multichain_reshare() -> anyhow::Result<()> { let account = ctx.nodes.ctx().worker.dev_create_account().await?; ctx.nodes - .add_node(3, account.id(), account.secret_key()) + .add_node(account.id(), account.secret_key()) .await?; // Wait for network to complete key reshare diff --git a/node/src/cli.rs b/node/src/cli.rs index 9eb7884a6..1fe8da51e 100644 --- a/node/src/cli.rs +++ b/node/src/cli.rs @@ -1,6 +1,5 @@ use crate::protocol::{MpcSignProtocol, SignQueue}; use crate::{indexer, storage, web}; -use cait_sith::protocol::Participant; use clap::Parser; use local_ip_address::local_ip; use near_crypto::{InMemorySigner, SecretKey}; @@ -16,9 +15,6 @@ use mpc_keys::hpke; #[derive(Parser, Debug)] pub enum Cli { Start { - /// Node ID - #[arg(long, value_parser = parse_participant, env("MPC_RECOVERY_NODE_ID"))] - node_id: Participant, /// NEAR RPC address #[arg( long, @@ -30,8 +26,8 @@ pub enum Cli { #[arg(long, env("MPC_RECOVERY_CONTRACT_ID"))] mpc_contract_id: AccountId, /// This node's account id - #[arg(long, env("MPC_RECOVERY_ACCOUNT"))] - account: AccountId, + #[arg(long, env("MPC_RECOVERY_ACCOUNT_ID"))] + account_id: AccountId, /// This node's account ed25519 secret key #[arg(long, env("MPC_RECOVERY_ACCOUNT_SK"))] account_sk: SecretKey, @@ -57,19 +53,13 @@ pub enum Cli { }, } -fn parse_participant(arg: &str) -> Result { - let participant_id: u32 = arg.parse()?; - Ok(participant_id.into()) -} - impl Cli { pub fn into_str_args(self) -> Vec { match self { Cli::Start { - node_id, near_rpc, + account_id, mpc_contract_id, - account, account_sk, web_port, cipher_pk, @@ -80,14 +70,12 @@ impl Cli { } => { let mut args = vec![ "start".to_string(), - "--node-id".to_string(), - u32::from(node_id).to_string(), "--near-rpc".to_string(), near_rpc, "--mpc-contract-id".to_string(), mpc_contract_id.to_string(), - "--account".to_string(), - account.to_string(), + "--account-id".to_string(), + account_id.to_string(), "--account-sk".to_string(), account_sk.to_string(), "--web-port".to_string(), @@ -123,11 +111,10 @@ pub fn run(cmd: Cli) -> anyhow::Result<()> { match cmd { Cli::Start { - node_id, near_rpc, web_port, mpc_contract_id, - account, + account_id, account_sk, cipher_pk, cipher_sk, @@ -156,11 +143,11 @@ pub fn run(cmd: Cli) -> anyhow::Result<()> { tracing::info!(%my_address, "address detected"); let rpc_client = near_fetch::Client::new(&near_rpc); tracing::debug!(rpc_addr = rpc_client.rpc_addr(), "rpc client initialized"); - let signer = InMemorySigner::from_secret_key(account, account_sk); + let signer = InMemorySigner::from_secret_key(account_id.clone(), account_sk); let (protocol, protocol_state) = MpcSignProtocol::init( - node_id, my_address, mpc_contract_id.clone(), + account_id, rpc_client.clone(), signer.clone(), receiver, diff --git a/node/src/http_client.rs b/node/src/http_client.rs index 10574a362..4d4e91a6c 100644 --- a/node/src/http_client.rs +++ b/node/src/http_client.rs @@ -1,8 +1,9 @@ +use crate::protocol::contract::primitives::ParticipantInfo; use crate::protocol::message::SignedMessage; use crate::protocol::MpcMessage; -use crate::protocol::ParticipantInfo; use cait_sith::protocol::Participant; use mpc_keys::hpke; +use near_primitives::types::AccountId; use reqwest::{Client, IntoUrl}; use std::collections::VecDeque; use std::str::Utf8Error; @@ -73,8 +74,12 @@ async fn send_encrypted( Retry::spawn(retry_strategy, action).await } -pub async fn join(client: &Client, url: U, me: &Participant) -> Result<(), SendError> { - let _span = tracing::info_span!("join_request", ?me); +pub async fn join( + client: &Client, + url: U, + account_id: &AccountId, +) -> Result<(), SendError> { + let _span = tracing::info_span!("join_request", ?account_id); let mut url = url.into_url()?; url.set_path("join"); tracing::debug!(%url, "making http request"); @@ -82,7 +87,7 @@ pub async fn join(client: &Client, url: U, me: &Participant) -> Resu let response = client .post(url.clone()) .header("content-type", "application/json") - .json(&me) + .json(&account_id) .send() .await .map_err(SendError::ReqwestClientError)?; diff --git a/node/src/protocol/consensus.rs b/node/src/protocol/consensus.rs index c0ffeca7e..1834b6a9a 100644 --- a/node/src/protocol/consensus.rs +++ b/node/src/protocol/consensus.rs @@ -25,7 +25,7 @@ use tokio::sync::RwLock; use url::Url; pub trait ConsensusCtx { - fn me(&self) -> Participant; + fn my_account_id(&self) -> &AccountId; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; fn signer(&self) -> &InMemorySigner; @@ -99,48 +99,52 @@ impl ConsensusProtocol for StartedState { } Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { - if contract_state.participants.contains_key(&ctx.me()) { - tracing::info!( - "contract state is running and we are already a participant" - ); - let participants_vec: Vec = - contract_state.participants.keys().cloned().collect(); - Ok(NodeState::Running(RunningState { - epoch, - participants: contract_state.participants, - threshold: contract_state.threshold, - private_share, - public_key, - sign_queue: ctx.sign_queue(), - triple_manager: Arc::new(RwLock::new(TripleManager::new( - participants_vec.clone(), - ctx.me(), - contract_state.threshold, + match contract_state + .participants + .find_participant(ctx.my_account_id()) + { + Some(me) => { + tracing::info!( + "contract state is running and we are already a participant" + ); + let participants_vec: Vec = + contract_state.participants.keys().cloned().collect(); + Ok(NodeState::Running(RunningState { epoch, - ))), - presignature_manager: Arc::new(RwLock::new( - PresignatureManager::new( + participants: contract_state.participants, + threshold: contract_state.threshold, + private_share, + public_key, + sign_queue: ctx.sign_queue(), + triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - ctx.me(), + me, contract_state.threshold, epoch, - ), - )), - signature_manager: Arc::new(RwLock::new( - SignatureManager::new( - participants_vec, - ctx.me(), - contract_state.public_key, - epoch, - ), - )), - messages: Default::default(), - })) - } else { - Ok(NodeState::Joining(JoiningState { + ))), + presignature_manager: Arc::new(RwLock::new( + PresignatureManager::new( + participants_vec.clone(), + me, + contract_state.threshold, + epoch, + ), + )), + signature_manager: Arc::new(RwLock::new( + SignatureManager::new( + participants_vec, + me, + contract_state.public_key, + epoch, + ), + )), + messages: Default::default(), + })) + } + None => Ok(NodeState::Joining(JoiningState { participants: contract_state.participants, public_key, - })) + })), } } } @@ -166,30 +170,38 @@ impl ConsensusProtocol for StartedState { tracing::info!( "contract state is resharing with us, joining as a participant" ); - start_resharing(Some(private_share), ctx, contract_state) + start_resharing(Some(private_share), ctx, contract_state).await } } } }, None => match contract_state { ProtocolState::Initializing(contract_state) => { - if contract_state.participants.contains_key(&ctx.me()) { - tracing::info!("starting key generation as a part of the participant set"); - let participants = contract_state.participants; - let protocol = cait_sith::keygen::( - &participants.keys().cloned().collect::>(), - ctx.me(), - contract_state.threshold, - )?; - Ok(NodeState::Generating(GeneratingState { - participants, - threshold: contract_state.threshold, - protocol: Arc::new(RwLock::new(protocol)), - messages: Default::default(), - })) - } else { - tracing::info!("we are not a part of the initial participant set, waiting for key generation to complete"); - Ok(NodeState::Started(self)) + match contract_state + .participants + .find_participant(ctx.my_account_id()) + { + Some(me) => { + tracing::info!( + "starting key generation as a part of the participant set" + ); + let participants = contract_state.participants; + let protocol = cait_sith::keygen::( + &participants.keys().cloned().collect::>(), + me, + contract_state.threshold, + )?; + Ok(NodeState::Generating(GeneratingState { + participants, + threshold: contract_state.threshold, + protocol: Arc::new(RwLock::new(protocol)), + messages: Default::default(), + })) + } + None => { + tracing::info!("we are not a part of the initial participant set, waiting for key generation to complete"); + Ok(NodeState::Started(self)) + } } } ProtocolState::Running(contract_state) => Ok(NodeState::Joining(JoiningState { @@ -269,7 +281,7 @@ impl ConsensusProtocol for WaitingForConsensusState { let has_voted = contract_state .pk_votes .get(&public_key) - .map(|ps| ps.contains(&ctx.me())) + .map(|ps| ps.contains(ctx.my_account_id())) .unwrap_or_default(); if !has_voted { tracing::info!("we haven't voted yet, voting for the generated public key"); @@ -310,6 +322,12 @@ impl ConsensusProtocol for WaitingForConsensusState { } let participants_vec: Vec = self.participants.keys().cloned().collect(); + + let me = contract_state + .participants + .find_participant(ctx.my_account_id()) + .unwrap(); + Ok(NodeState::Running(RunningState { epoch: self.epoch, participants: self.participants, @@ -319,19 +337,19 @@ impl ConsensusProtocol for WaitingForConsensusState { sign_queue: ctx.sign_queue(), triple_manager: Arc::new(RwLock::new(TripleManager::new( participants_vec.clone(), - ctx.me(), + me, self.threshold, self.epoch, ))), presignature_manager: Arc::new(RwLock::new(PresignatureManager::new( participants_vec.clone(), - ctx.me(), + me, self.threshold, self.epoch, ))), signature_manager: Arc::new(RwLock::new(SignatureManager::new( participants_vec, - ctx.me(), + me, self.public_key, self.epoch, ))), @@ -352,7 +370,7 @@ impl ConsensusProtocol for WaitingForConsensusState { if contract_state.public_key != self.public_key { return Err(ConsensusError::MismatchedPublicKey); } - start_resharing(Some(self.private_share), ctx, contract_state) + start_resharing(Some(self.private_share), ctx, contract_state).await } Ordering::Greater => { tracing::warn!( @@ -370,20 +388,35 @@ impl ConsensusProtocol for WaitingForConsensusState { tracing::debug!( "waiting for resharing consensus, contract state has not been finalized yet" ); - let has_voted = contract_state.finished_votes.contains(&ctx.me()); - if !has_voted && contract_state.old_participants.contains_key(&ctx.me()) { - tracing::info!( - epoch = self.epoch, - "we haven't voted yet, voting for resharing to complete" - ); - rpc_client::vote_reshared( - ctx.rpc_client(), - ctx.signer(), - ctx.mpc_contract_id(), - self.epoch, - ) - .await - .unwrap(); + let has_voted = contract_state.finished_votes.contains(ctx.my_account_id()); + match contract_state + .old_participants + .find_participant(ctx.my_account_id()) + { + Some(_) => { + if !has_voted { + tracing::info!( + epoch = self.epoch, + "we haven't voted yet, voting for resharing to complete" + ); + rpc_client::vote_reshared( + ctx.rpc_client(), + ctx.signer(), + ctx.mpc_contract_id(), + self.epoch, + ) + .await + .unwrap(); + } else { + tracing::info!( + epoch = self.epoch, + "we have voted for resharing to complete" + ); + } + } + None => { + tracing::info!("we are not a part of the old participant set"); + } } Ok(NodeState::WaitingForConsensus(self)) } @@ -445,15 +478,19 @@ impl ConsensusProtocol for RunningState { Ordering::Less => Err(ConsensusError::EpochRollback), Ordering::Equal => { tracing::info!("contract is resharing"); - if !contract_state.old_participants.contains_key(&ctx.me()) - || !contract_state.new_participants.contains_key(&ctx.me()) - { + let is_in_old_participant_set = contract_state + .old_participants + .contains_account_id(ctx.my_account_id()); + let is_in_new_participant_set = contract_state + .new_participants + .contains_account_id(ctx.my_account_id()); + if !is_in_old_participant_set || !is_in_new_participant_set { return Err(ConsensusError::HasBeenKicked); } if contract_state.public_key != self.public_key { return Err(ConsensusError::MismatchedPublicKey); } - start_resharing(Some(self.private_share), ctx, contract_state) + start_resharing(Some(self.private_share), ctx, contract_state).await } } } @@ -545,54 +582,66 @@ impl ConsensusProtocol for JoiningState { match contract_state { ProtocolState::Initializing(_) => Err(ConsensusError::ContractStateRollback), ProtocolState::Running(contract_state) => { - if contract_state.candidates.contains_key(&ctx.me()) { - let voted = contract_state - .join_votes - .get(&ctx.me()) - .cloned() - .unwrap_or_default(); - tracing::info!( - already_voted = voted.len(), - votes_to_go = contract_state.threshold - voted.len(), - "trying to get participants to vote for us" - ); - for (p, info) in contract_state.participants { - if voted.contains(&p) { - continue; - } - http_client::join(ctx.http_client(), info.url, &ctx.me()) + match contract_state + .candidates + .find_candidate(ctx.my_account_id()) + { + Some(candidate_info) => { + let voted = contract_state + .join_votes + .get(ctx.my_account_id()) + .cloned() + .unwrap_or_default(); + tracing::info!( + already_voted = voted.len(), + votes_to_go = contract_state.threshold - voted.len(), + "trying to get participants to vote for us" + ); + for (_, info) in contract_state.participants { + if voted.contains(&info.account_id) { + continue; + } + http_client::join( + ctx.http_client(), + info.url, + &candidate_info.account_id, + ) .await .unwrap() + } + Ok(NodeState::Joining(self)) + } + None => { + tracing::info!("sending a transaction to join the participant set"); + let args = serde_json::json!({ + "url": ctx.my_address(), + "cipher_pk": ctx.cipher_pk().to_bytes(), + "sign_pk": ctx.sign_pk(), + }); + ctx.rpc_client() + .send_tx( + ctx.signer(), + ctx.mpc_contract_id(), + vec![Action::FunctionCall(FunctionCallAction { + method_name: "join".to_string(), + args: args.to_string().into_bytes(), + gas: 300_000_000_000_000, + deposit: 0, + })], + ) + .await + .unwrap(); + Ok(NodeState::Joining(self)) } - Ok(NodeState::Joining(self)) - } else { - tracing::info!("sending a transaction to join the participant set"); - let args = serde_json::json!({ - "participant_id": ctx.me(), - "url": ctx.my_address(), - "cipher_pk": ctx.cipher_pk().to_bytes(), - "sign_pk": ctx.sign_pk(), - }); - ctx.rpc_client() - .send_tx( - ctx.signer(), - ctx.mpc_contract_id(), - vec![Action::FunctionCall(FunctionCallAction { - method_name: "join".to_string(), - args: args.to_string().into_bytes(), - gas: 300_000_000_000_000, - deposit: 0, - })], - ) - .await - .unwrap(); - Ok(NodeState::Joining(self)) } } ProtocolState::Resharing(contract_state) => { - if contract_state.new_participants.contains_key(&ctx.me()) { + if contract_state + .new_participants + .contains_account_id(ctx.my_account_id()) + { tracing::info!("joining as a new participant"); - start_resharing(None, ctx, contract_state) + start_resharing(None, ctx, contract_state).await } else { tracing::debug!("network is resharing without us, waiting for them to finish"); Ok(NodeState::Joining(self)) @@ -624,11 +673,15 @@ impl ConsensusProtocol for NodeState { } } -fn start_resharing( +async fn start_resharing( private_share: Option, ctx: C, contract_state: ResharingContractState, ) -> Result { + let me = contract_state + .new_participants + .find_participant(ctx.my_account_id()) + .unwrap(); let protocol = cait_sith::reshare::( &contract_state .old_participants @@ -642,7 +695,7 @@ fn start_resharing( .cloned() .collect::>(), contract_state.threshold, - ctx.me(), + me, private_share, contract_state.public_key, )?; diff --git a/node/src/protocol/contract.rs b/node/src/protocol/contract.rs deleted file mode 100644 index 08e606450..000000000 --- a/node/src/protocol/contract.rs +++ /dev/null @@ -1,200 +0,0 @@ -use crate::types::PublicKey; -use crate::util::NearPublicKeyExt; -use cait_sith::protocol::Participant; -use mpc_contract::ProtocolContractState; -use mpc_keys::hpke; -use near_primitives::borsh::BorshDeserialize; -use near_sdk::AccountId; -use serde::{Deserialize, Serialize}; -use std::collections::{BTreeMap, HashSet}; - -type ParticipantId = u32; - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] -pub struct ParticipantInfo { - pub id: ParticipantId, - pub account_id: AccountId, - pub url: String, - /// The public key used for encrypting messages. - pub cipher_pk: hpke::PublicKey, - /// The public key used for verifying messages. - pub sign_pk: near_crypto::PublicKey, -} - -impl From for ParticipantInfo { - fn from(value: mpc_contract::ParticipantInfo) -> Self { - ParticipantInfo { - id: value.id, - account_id: value.account_id, - url: value.url, - cipher_pk: hpke::PublicKey::from_bytes(&value.cipher_pk), - sign_pk: BorshDeserialize::try_from_slice(value.sign_pk.as_bytes()).unwrap(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct InitializingContractState { - pub participants: BTreeMap, - pub threshold: usize, - pub pk_votes: BTreeMap>, -} - -impl From for InitializingContractState { - fn from(value: mpc_contract::InitializingContractState) -> Self { - InitializingContractState { - participants: contract_participants_into_cait_participants(value.participants), - threshold: value.threshold, - pk_votes: value - .pk_votes - .into_iter() - .map(|(pk, participants)| { - ( - near_crypto::PublicKey::SECP256K1( - near_crypto::Secp256K1PublicKey::try_from(&pk.as_bytes()[1..]).unwrap(), - ), - participants - .into_iter() - .map(Participant::from) - .collect::>(), - ) - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct RunningContractState { - pub epoch: u64, - pub participants: BTreeMap, - pub threshold: usize, - pub public_key: PublicKey, - pub candidates: BTreeMap, - pub join_votes: BTreeMap>, - pub leave_votes: BTreeMap>, -} - -impl From for RunningContractState { - fn from(value: mpc_contract::RunningContractState) -> Self { - RunningContractState { - epoch: value.epoch, - participants: contract_participants_into_cait_participants(value.participants), - threshold: value.threshold, - public_key: value.public_key.into_affine_point(), - candidates: value - .candidates - .into_iter() - .map(|(p, p_info)| (Participant::from(p), p_info.into())) - .collect(), - join_votes: value - .join_votes - .into_iter() - .map(|(p, ps)| { - ( - Participant::from(p), - ps.into_iter().map(Participant::from).collect(), - ) - }) - .collect(), - leave_votes: value - .leave_votes - .into_iter() - .map(|(p, ps)| { - ( - Participant::from(p), - ps.into_iter().map(Participant::from).collect(), - ) - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ResharingContractState { - pub old_epoch: u64, - pub old_participants: BTreeMap, - pub new_participants: BTreeMap, - pub threshold: usize, - pub public_key: PublicKey, - pub finished_votes: HashSet, -} - -impl From for ResharingContractState { - fn from(value: mpc_contract::ResharingContractState) -> Self { - ResharingContractState { - old_epoch: value.old_epoch, - old_participants: contract_participants_into_cait_participants(value.old_participants), - new_participants: contract_participants_into_cait_participants(value.new_participants), - threshold: value.threshold, - public_key: value.public_key.into_affine_point(), - finished_votes: value - .finished_votes - .into_iter() - .map(Participant::from) - .collect(), - } - } -} - -#[derive(Debug)] -pub enum ProtocolState { - Initializing(InitializingContractState), - Running(RunningContractState), - Resharing(ResharingContractState), -} - -impl ProtocolState { - pub fn participants(&self) -> &BTreeMap { - match self { - ProtocolState::Initializing(InitializingContractState { participants, .. }) => { - participants - } - ProtocolState::Running(RunningContractState { participants, .. }) => participants, - ProtocolState::Resharing(ResharingContractState { - old_participants, .. - }) => old_participants, - } - } - - pub fn public_key(&self) -> Option<&PublicKey> { - match self { - ProtocolState::Initializing { .. } => None, - ProtocolState::Running(RunningContractState { public_key, .. }) => Some(public_key), - ProtocolState::Resharing(ResharingContractState { public_key, .. }) => Some(public_key), - } - } - - pub fn threshold(&self) -> usize { - match self { - ProtocolState::Initializing(InitializingContractState { threshold, .. }) => *threshold, - ProtocolState::Running(RunningContractState { threshold, .. }) => *threshold, - ProtocolState::Resharing(ResharingContractState { threshold, .. }) => *threshold, - } - } -} - -impl TryFrom for ProtocolState { - type Error = (); - - fn try_from(value: ProtocolContractState) -> Result { - match value { - ProtocolContractState::NotInitialized => Err(()), - ProtocolContractState::Initializing(state) => { - Ok(ProtocolState::Initializing(state.into())) - } - ProtocolContractState::Running(state) => Ok(ProtocolState::Running(state.into())), - ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())), - } - } -} - -fn contract_participants_into_cait_participants( - participants: BTreeMap, -) -> BTreeMap { - participants - .into_values() - .map(|p| (Participant::from(p.id), p.into())) - .collect() -} diff --git a/node/src/protocol/contract/mod.rs b/node/src/protocol/contract/mod.rs new file mode 100644 index 000000000..56441439a --- /dev/null +++ b/node/src/protocol/contract/mod.rs @@ -0,0 +1,131 @@ +pub mod primitives; + +use crate::types::PublicKey; +use crate::util::NearPublicKeyExt; +use mpc_contract::ProtocolContractState; +use near_primitives::types::AccountId; +use serde::{Deserialize, Serialize}; +use std::{collections::HashSet, str::FromStr}; + +use self::primitives::{Candidates, Participants, PkVotes, Votes}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct InitializingContractState { + pub participants: Participants, + pub threshold: usize, + pub pk_votes: PkVotes, +} + +impl From for InitializingContractState { + fn from(value: mpc_contract::InitializingContractState) -> Self { + InitializingContractState { + participants: value.participants.into(), + threshold: value.threshold, + pk_votes: value.pk_votes.into(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct RunningContractState { + pub epoch: u64, + pub participants: Participants, + pub threshold: usize, + pub public_key: PublicKey, + pub candidates: Candidates, + pub join_votes: Votes, + pub leave_votes: Votes, +} + +impl From for RunningContractState { + fn from(value: mpc_contract::RunningContractState) -> Self { + RunningContractState { + epoch: value.epoch, + participants: value.participants.into(), + threshold: value.threshold, + public_key: value.public_key.into_affine_point(), + candidates: value.candidates.into(), + join_votes: value.join_votes.into(), + leave_votes: value.leave_votes.into(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ResharingContractState { + pub old_epoch: u64, + pub old_participants: Participants, + pub new_participants: Participants, + pub threshold: usize, + pub public_key: PublicKey, + pub finished_votes: HashSet, +} + +impl From for ResharingContractState { + fn from(contract_state: mpc_contract::ResharingContractState) -> Self { + ResharingContractState { + old_epoch: contract_state.old_epoch, + old_participants: contract_state.old_participants.into(), + new_participants: contract_state.new_participants.into(), + threshold: contract_state.threshold, + public_key: contract_state.public_key.into_affine_point(), + finished_votes: contract_state + .finished_votes + .into_iter() + .map(|acc_id| AccountId::from_str(acc_id.as_ref()).unwrap()) + .collect(), + } + } +} + +#[derive(Debug)] +pub enum ProtocolState { + Initializing(InitializingContractState), + Running(RunningContractState), + Resharing(ResharingContractState), +} + +impl ProtocolState { + pub fn participants(&self) -> &Participants { + match self { + ProtocolState::Initializing(InitializingContractState { participants, .. }) => { + participants + } + ProtocolState::Running(RunningContractState { participants, .. }) => participants, + ProtocolState::Resharing(ResharingContractState { + old_participants, .. + }) => old_participants, + } + } + + pub fn public_key(&self) -> Option<&PublicKey> { + match self { + ProtocolState::Initializing { .. } => None, + ProtocolState::Running(RunningContractState { public_key, .. }) => Some(public_key), + ProtocolState::Resharing(ResharingContractState { public_key, .. }) => Some(public_key), + } + } + + pub fn threshold(&self) -> usize { + match self { + ProtocolState::Initializing(InitializingContractState { threshold, .. }) => *threshold, + ProtocolState::Running(RunningContractState { threshold, .. }) => *threshold, + ProtocolState::Resharing(ResharingContractState { threshold, .. }) => *threshold, + } + } +} + +impl TryFrom for ProtocolState { + type Error = (); + + fn try_from(value: ProtocolContractState) -> Result { + match value { + ProtocolContractState::Initializing(state) => { + Ok(ProtocolState::Initializing(state.into())) + } + ProtocolContractState::Running(state) => Ok(ProtocolState::Running(state.into())), + ProtocolContractState::Resharing(state) => Ok(ProtocolState::Resharing(state.into())), + ProtocolContractState::NotInitialized => Err(()), + } + } +} diff --git a/node/src/protocol/contract/primitives.rs b/node/src/protocol/contract/primitives.rs new file mode 100644 index 000000000..7c99cbd8d --- /dev/null +++ b/node/src/protocol/contract/primitives.rs @@ -0,0 +1,244 @@ +use cait_sith::protocol::Participant; +use mpc_keys::hpke; +use near_primitives::{borsh::BorshDeserialize, types::AccountId}; +use serde::{Deserialize, Serialize}; +use std::{ + collections::{BTreeMap, HashSet}, + str::FromStr, +}; + +type ParticipantId = u32; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct ParticipantInfo { + pub id: ParticipantId, + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: near_crypto::PublicKey, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] +pub struct Participants { + pub participants: BTreeMap, +} + +impl From for Participants { + fn from(contract_participants: mpc_contract::primitives::Participants) -> Self { + Participants { + // take position of participant in contract_participants as id for participants + participants: contract_participants + .participants + .into_iter() + .enumerate() + .map(|(participant_id, participant)| { + let contract_participant_info = participant.1; + ( + Participant::from(participant_id as ParticipantId), + ParticipantInfo { + id: participant_id as ParticipantId, + account_id: AccountId::from_str( + contract_participant_info.account_id.as_ref(), + ) + .unwrap(), + url: contract_participant_info.url, + cipher_pk: hpke::PublicKey::from_bytes( + &contract_participant_info.cipher_pk, + ), + sign_pk: BorshDeserialize::try_from_slice( + contract_participant_info.sign_pk.as_bytes(), + ) + .unwrap(), + }, + ) + }) + .collect(), + } + } +} + +impl IntoIterator for Participants { + type Item = (Participant, ParticipantInfo); + type IntoIter = std::collections::btree_map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.participants.into_iter() + } +} + +impl Participants { + pub fn get(&self, id: &Participant) -> Option<&ParticipantInfo> { + self.participants.get(id) + } + + pub fn contains_key(&self, id: &Participant) -> bool { + self.participants.contains_key(id) + } + + pub fn keys(&self) -> impl Iterator { + self.participants.keys() + } + + pub fn iter(&self) -> impl Iterator { + self.participants.iter() + } + + pub fn find_participant(&self, account_id: &AccountId) -> Option { + self.participants + .iter() + .find(|(_, participant_info)| participant_info.account_id == *account_id) + .map(|(participant, _)| *participant) + } + + pub fn find_participant_info(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { + self.participants + .values() + .find(|participant_info| participant_info.account_id == *account_id) + } + + pub fn contains_account_id(&self, account_id: &AccountId) -> bool { + self.participants + .values() + .any(|participant_info| participant_info.account_id == *account_id) + } + + pub fn account_ids(&self) -> Vec { + self.participants + .values() + .map(|participant_info| participant_info.account_id.clone()) + .collect() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct CandidateInfo { + pub account_id: AccountId, + pub url: String, + /// The public key used for encrypting messages. + pub cipher_pk: hpke::PublicKey, + /// The public key used for verifying messages. + pub sign_pk: near_crypto::PublicKey, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct Candidates { + pub candidates: BTreeMap, +} + +impl Candidates { + pub fn get(&self, id: &AccountId) -> Option<&CandidateInfo> { + self.candidates.get(id) + } + + pub fn contains_key(&self, id: &AccountId) -> bool { + self.candidates.contains_key(id) + } + + pub fn keys(&self) -> impl Iterator { + self.candidates.keys() + } + + pub fn iter(&self) -> impl Iterator { + self.candidates.iter() + } + + pub fn find_candidate(&self, account_id: &AccountId) -> Option<&CandidateInfo> { + self.candidates.get(account_id) + } +} + +impl From for Candidates { + fn from(contract_candidates: mpc_contract::primitives::Candidates) -> Self { + Candidates { + candidates: contract_candidates + .candidates + .into_iter() + .map(|(account_id, candidate_info)| { + ( + AccountId::from_str(account_id.as_ref()).unwrap(), + CandidateInfo { + account_id: AccountId::from_str(candidate_info.account_id.as_ref()) + .unwrap(), + url: candidate_info.url, + cipher_pk: hpke::PublicKey::from_bytes(&candidate_info.cipher_pk), + sign_pk: BorshDeserialize::try_from_slice( + candidate_info.sign_pk.as_bytes(), + ) + .unwrap(), + }, + ) + }) + .collect(), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct PkVotes { + pub pk_votes: BTreeMap>, +} + +impl PkVotes { + pub fn get(&self, id: &near_crypto::PublicKey) -> Option<&HashSet> { + self.pk_votes.get(id) + } +} + +impl From for PkVotes { + fn from(contract_votes: mpc_contract::primitives::PkVotes) -> Self { + PkVotes { + pk_votes: contract_votes + .votes + .into_iter() + .map(|(pk, participants)| { + ( + near_crypto::PublicKey::SECP256K1( + near_crypto::Secp256K1PublicKey::try_from(&pk.as_bytes()[1..]).unwrap(), + ), + participants + .into_iter() + .map(|acc_id: near_sdk::AccountId| { + AccountId::from_str(acc_id.as_ref()).unwrap() + }) + .collect(), + ) + }) + .collect(), + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Votes { + pub votes: BTreeMap>, +} + +impl Votes { + pub fn get(&self, id: &AccountId) -> Option<&HashSet> { + self.votes.get(id) + } +} + +impl From for Votes { + fn from(contract_votes: mpc_contract::primitives::Votes) -> Self { + Votes { + votes: contract_votes + .votes + .into_iter() + .map(|(account_id, participants)| { + ( + AccountId::from_str(account_id.as_ref()).unwrap(), + participants + .into_iter() + .map(|acc_id: near_sdk::AccountId| { + AccountId::from_str(acc_id.as_ref()).unwrap() + }) + .collect(), + ) + }) + .collect(), + } + } +} diff --git a/node/src/protocol/cryptography.rs b/node/src/protocol/cryptography.rs index c212bb665..5ba3b2235 100644 --- a/node/src/protocol/cryptography.rs +++ b/node/src/protocol/cryptography.rs @@ -13,8 +13,9 @@ use mpc_keys::hpke; use near_crypto::InMemorySigner; use near_primitives::types::AccountId; +#[async_trait::async_trait] pub trait CryptographicCtx { - fn me(&self) -> Participant; + async fn me(&self) -> Participant; fn http_client(&self) -> &reqwest::Client; fn rpc_client(&self) -> &near_fetch::Client; fn signer(&self) -> &InMemorySigner; @@ -81,7 +82,7 @@ impl CryptographicProtocol for GeneratingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message"); @@ -92,15 +93,15 @@ impl CryptographicProtocol for GeneratingState { Action::SendMany(m) => { tracing::debug!("sending a message to many participants"); let mut messages = self.messages.write().await; - for (p, info) in &self.participants { - if p == &ctx.me() { + for (p, info) in self.participants.iter() { + if p == &ctx.me().await { // Skip yourself, cait-sith never sends messages to oneself continue; } messages.push( info.clone(), MpcMessage::Generating(GeneratingMessage { - from: ctx.me(), + from: ctx.me().await, data: m.clone(), }), ); @@ -112,7 +113,7 @@ impl CryptographicProtocol for GeneratingState { self.messages.write().await.push( info.clone(), MpcMessage::Generating(GeneratingMessage { - from: ctx.me(), + from: ctx.me().await, data: m.clone(), }), ); @@ -134,7 +135,7 @@ impl CryptographicProtocol for GeneratingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "generating: failed to send encrypted message"); @@ -163,7 +164,7 @@ impl CryptographicProtocol for WaitingForConsensusState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "waiting: failed to send encrypted message"); @@ -192,7 +193,7 @@ impl CryptographicProtocol for ResharingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(wait): failed to send encrypted message"); @@ -203,8 +204,8 @@ impl CryptographicProtocol for ResharingState { Action::SendMany(m) => { tracing::debug!("sending a message to all participants"); let mut messages = self.messages.write().await; - for (p, info) in &self.new_participants { - if p == &ctx.me() { + for (p, info) in self.new_participants.clone() { + if p == ctx.me().await { // Skip yourself, cait-sith never sends messages to oneself continue; } @@ -213,7 +214,7 @@ impl CryptographicProtocol for ResharingState { info.clone(), MpcMessage::Resharing(ResharingMessage { epoch: self.old_epoch, - from: ctx.me(), + from: ctx.me().await, data: m.clone(), }), ) @@ -226,7 +227,7 @@ impl CryptographicProtocol for ResharingState { info.clone(), MpcMessage::Resharing(ResharingMessage { epoch: self.old_epoch, - from: ctx.me(), + from: ctx.me().await, data: m.clone(), }), ), @@ -241,7 +242,7 @@ impl CryptographicProtocol for ResharingState { .messages .write() .await - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, new = ?self.new_participants, old = ?self.old_participants, "resharing(return): failed to send encrypted message"); @@ -270,7 +271,7 @@ impl CryptographicProtocol for RunningState { let mut messages = self.messages.write().await; // Try sending any leftover messages donated to RunningState. if let Err(err) = messages - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "running(pre): failed to send encrypted message"); @@ -309,8 +310,8 @@ impl CryptographicProtocol for RunningState { let mut sign_queue = self.sign_queue.write().await; let mut signature_manager = self.signature_manager.write().await; - sign_queue.organize(&self, ctx.me()); - let my_requests = sign_queue.my_requests(ctx.me()); + sign_queue.organize(&self, ctx.me().await); + let my_requests = sign_queue.my_requests(ctx.me().await); while presignature_manager.my_len() > 0 { let Some((receipt_id, _)) = my_requests.iter().next() else { break; @@ -340,7 +341,7 @@ impl CryptographicProtocol for RunningState { .await?; drop(signature_manager); if let Err(err) = messages - .send_encrypted(ctx.me(), ctx.sign_sk(), ctx.http_client()) + .send_encrypted(ctx.me().await, ctx.sign_sk(), ctx.http_client()) .await { tracing::warn!(?err, participants = ?self.participants, "running(post): failed to send encrypted message"); diff --git a/node/src/protocol/message.rs b/node/src/protocol/message.rs index 5dd853874..7f6dab3e0 100644 --- a/node/src/protocol/message.rs +++ b/node/src/protocol/message.rs @@ -15,8 +15,9 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use tokio::sync::RwLock; +#[async_trait::async_trait] pub trait MessageCtx { - fn me(&self) -> Participant; + async fn me(&self) -> Participant; } #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)] @@ -193,7 +194,6 @@ impl MessageHandler for ResharingState { let q = queue.resharing_bins.entry(self.old_epoch).or_default(); let mut protocol = self.protocol.write().await; while let Some(msg) = q.pop_front() { - tracing::debug!("handling new resharing message"); protocol.message(msg.from, msg.data); } Ok(()) diff --git a/node/src/protocol/mod.rs b/node/src/protocol/mod.rs index e61558de8..15ee4396f 100644 --- a/node/src/protocol/mod.rs +++ b/node/src/protocol/mod.rs @@ -1,4 +1,4 @@ -mod contract; +pub mod contract; mod cryptography; mod presignature; mod signature; @@ -9,7 +9,8 @@ pub mod message; pub mod state; pub use consensus::ConsensusError; -pub use contract::{ParticipantInfo, ProtocolState}; +pub use contract::primitives::ParticipantInfo; +pub use contract::ProtocolState; pub use cryptography::CryptographicError; pub use message::MpcMessage; pub use signature::SignQueue; @@ -36,8 +37,8 @@ use url::Url; use mpc_keys::hpke; struct Ctx { - me: Participant, my_address: Url, + account_id: AccountId, mpc_contract_id: AccountId, signer: InMemorySigner, rpc_client: near_fetch::Client, @@ -48,89 +49,91 @@ struct Ctx { secret_storage: SecretNodeStorageBox, } -impl ConsensusCtx for &Ctx { - fn me(&self) -> Participant { - self.me +impl ConsensusCtx for &MpcSignProtocol { + fn my_account_id(&self) -> &AccountId { + &self.ctx.account_id } fn http_client(&self) -> &reqwest::Client { - &self.http_client + &self.ctx.http_client } fn rpc_client(&self) -> &near_fetch::Client { - &self.rpc_client + &self.ctx.rpc_client } fn signer(&self) -> &InMemorySigner { - &self.signer + &self.ctx.signer } fn mpc_contract_id(&self) -> &AccountId { - &self.mpc_contract_id + &self.ctx.mpc_contract_id } fn my_address(&self) -> &Url { - &self.my_address + &self.ctx.my_address } fn sign_queue(&self) -> Arc> { - self.sign_queue.clone() + self.ctx.sign_queue.clone() } fn cipher_pk(&self) -> &hpke::PublicKey { - &self.cipher_pk + &self.ctx.cipher_pk } fn sign_pk(&self) -> near_crypto::PublicKey { - self.sign_sk.public_key() + self.ctx.sign_sk.public_key() } fn sign_sk(&self) -> &near_crypto::SecretKey { - &self.sign_sk + &self.ctx.sign_sk } fn secret_storage(&self) -> &SecretNodeStorageBox { - &self.secret_storage + &self.ctx.secret_storage } } -impl CryptographicCtx for &mut Ctx { - fn me(&self) -> Participant { - self.me +#[async_trait::async_trait] +impl CryptographicCtx for &mut MpcSignProtocol { + async fn me(&self) -> Participant { + get_my_participant(self).await } fn http_client(&self) -> &reqwest::Client { - &self.http_client + &self.ctx.http_client } fn rpc_client(&self) -> &near_fetch::Client { - &self.rpc_client + &self.ctx.rpc_client } fn signer(&self) -> &InMemorySigner { - &self.signer + &self.ctx.signer } fn mpc_contract_id(&self) -> &AccountId { - &self.mpc_contract_id + &self.ctx.mpc_contract_id } fn cipher_pk(&self) -> &hpke::PublicKey { - &self.cipher_pk + &self.ctx.cipher_pk } fn sign_sk(&self) -> &near_crypto::SecretKey { - &self.sign_sk + &self.ctx.sign_sk } fn secret_storage(&mut self) -> &mut SecretNodeStorageBox { - &mut self.secret_storage + &mut self.ctx.secret_storage } } -impl MessageCtx for &Ctx { - fn me(&self) -> Participant { - self.me +#[async_trait::async_trait] +impl MessageCtx for &MpcSignProtocol { + async fn me(&self) -> Participant { + get_my_participant(self).await } } @@ -143,9 +146,9 @@ pub struct MpcSignProtocol { impl MpcSignProtocol { #![allow(clippy::too_many_arguments)] pub fn init( - me: Participant, my_address: U, mpc_contract_id: AccountId, + account_id: AccountId, rpc_client: near_fetch::Client, signer: InMemorySigner, receiver: mpsc::Receiver, @@ -155,8 +158,8 @@ impl MpcSignProtocol { ) -> (Self, Arc>) { let state = Arc::new(RwLock::new(NodeState::Starting)); let ctx = Ctx { - me, my_address: my_address.into_url().unwrap(), + account_id, mpc_contract_id, rpc_client, http_client: reqwest::Client::new(), @@ -175,7 +178,7 @@ impl MpcSignProtocol { } pub async fn run(mut self) -> anyhow::Result<()> { - let _span = tracing::info_span!("running", me = u32::from(self.ctx.me)); + let _span = tracing::info_span!("running", my_account_id = self.ctx.account_id.to_string()); let mut queue = MpcMessageQueue::default(); loop { tracing::debug!("trying to advance mpc recovery protocol"); @@ -215,21 +218,21 @@ impl MpcSignProtocol { let guard = self.state.read().await; guard.clone() }; - let state = match state.progress(&mut self.ctx).await { + let state = match state.progress(&mut self).await { Ok(state) => state, Err(err) => { tracing::info!("protocol unable to progress: {err:?}"); continue; } }; - let mut state = match state.advance(&self.ctx, contract_state).await { + let mut state = match state.advance(&self, contract_state).await { Ok(state) => state, Err(err) => { tracing::info!("protocol unable to advance: {err:?}"); continue; } }; - if let Err(err) = state.handle(&self.ctx, &mut queue).await { + if let Err(err) = state.handle(&self, &mut queue).await { tracing::info!("protocol unable to handle messages: {err:?}"); continue; } @@ -242,3 +245,15 @@ impl MpcSignProtocol { } } } + +async fn get_my_participant(protocol: &MpcSignProtocol) -> Participant { + let my_near_acc_id = protocol.ctx.account_id.clone(); + let state = protocol.state.read().await; + let participant_info = state + .find_participant_info(&my_near_acc_id) + .unwrap_or_else(|| { + tracing::error!("could not find participant info for {my_near_acc_id}"); + panic!("could not find participant info for {my_near_acc_id}"); + }); + participant_info.id.into() +} diff --git a/node/src/protocol/state.rs b/node/src/protocol/state.rs index da69ed2de..80cae1245 100644 --- a/node/src/protocol/state.rs +++ b/node/src/protocol/state.rs @@ -1,14 +1,14 @@ +use super::contract::primitives::{ParticipantInfo, Participants}; use super::cryptography::CryptographicError; use super::presignature::PresignatureManager; use super::signature::SignatureManager; use super::triple::TripleManager; use super::SignQueue; use crate::http_client::MessageQueue; -use crate::protocol::ParticipantInfo; use crate::types::{KeygenProtocol, PublicKey, ReshareProtocol, SecretKeyShare}; use cait_sith::protocol::Participant; +use near_primitives::types::AccountId; use serde::{Deserialize, Serialize}; -use std::collections::BTreeMap; use std::sync::Arc; use tokio::sync::RwLock; @@ -24,7 +24,7 @@ pub struct StartedState(pub Option); #[derive(Clone)] pub struct GeneratingState { - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub protocol: KeygenProtocol, pub messages: Arc>, @@ -42,7 +42,7 @@ impl GeneratingState { #[derive(Clone)] pub struct WaitingForConsensusState { pub epoch: u64, - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub private_share: SecretKeyShare, pub public_key: PublicKey, @@ -61,7 +61,7 @@ impl WaitingForConsensusState { #[derive(Clone)] pub struct RunningState { pub epoch: u64, - pub participants: BTreeMap, + pub participants: Participants, pub threshold: usize, pub private_share: SecretKeyShare, pub public_key: PublicKey, @@ -84,8 +84,8 @@ impl RunningState { #[derive(Clone)] pub struct ResharingState { pub old_epoch: u64, - pub old_participants: BTreeMap, - pub new_participants: BTreeMap, + pub old_participants: Participants, + pub new_participants: Participants, pub threshold: usize, pub public_key: PublicKey, pub protocol: ReshareProtocol, @@ -104,7 +104,7 @@ impl ResharingState { #[derive(Clone)] pub struct JoiningState { - pub participants: BTreeMap, + pub participants: Participants, pub public_key: PublicKey, } @@ -144,11 +144,28 @@ impl NodeState { _ => Err(CryptographicError::UnknownParticipant(*p)), } } + + pub fn find_participant_info(&self, account_id: &AccountId) -> Option<&ParticipantInfo> { + match self { + NodeState::Starting => None, + NodeState::Started(_) => None, + NodeState::Generating(state) => state.participants.find_participant_info(account_id), + NodeState::WaitingForConsensus(state) => { + state.participants.find_participant_info(account_id) + } + NodeState::Running(state) => state.participants.find_participant_info(account_id), + NodeState::Resharing(state) => state + .new_participants + .find_participant_info(account_id) + .or_else(|| state.old_participants.find_participant_info(account_id)), + NodeState::Joining(state) => state.participants.find_participant_info(account_id), + } + } } fn fetch_participant<'a>( p: &Participant, - participants: &'a BTreeMap, + participants: &'a Participants, ) -> Result<&'a ParticipantInfo, CryptographicError> { participants .get(p) diff --git a/node/src/web/mod.rs b/node/src/web/mod.rs index aca70012c..022d2f113 100644 --- a/node/src/web/mod.rs +++ b/node/src/web/mod.rs @@ -100,13 +100,13 @@ async fn msg( #[tracing::instrument(level = "debug", skip_all)] async fn join( Extension(state): Extension>, - WithRejection(Json(participant), _): WithRejection, Error>, + WithRejection(Json(account_id), _): WithRejection, Error>, ) -> Result<()> { let protocol_state = state.protocol_state.read().await; match &*protocol_state { NodeState::Running { .. } => { let args = serde_json::json!({ - "participant": participant + "candidate_account_id": account_id }); match state .rpc_client @@ -123,7 +123,7 @@ async fn join( .await { Ok(_) => { - tracing::info!(?participant, "successfully voted for a node to join"); + tracing::info!(?account_id, "successfully voted for a node to join"); Ok(()) } Err(e) => { @@ -133,7 +133,7 @@ async fn join( } } _ => { - tracing::debug!(?participant, "not ready to accept join requests yet"); + tracing::debug!(?account_id, "not ready to accept join requests yet"); Err(Error::NotRunning) } }