Skip to content

Commit

Permalink
Remove TssId. (#979)
Browse files Browse the repository at this point in the history
Co-authored-by: Metadata Update Bot <[email protected]>
  • Loading branch information
dvc94ch and Metadata Update Bot authored Jul 4, 2024
1 parent 7300ad9 commit 30b53f6
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 127 deletions.
4 changes: 2 additions & 2 deletions chronicle/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use time_primitives::{
sp_core, AccountId, Balance, BlockHash, BlockNumber, ChainName, ChainNetwork, Commitment,
Function, MemberStatus, NetworkId, Payload, PeerId, ProofOfKnowledge, PublicKey, Runtime,
ShardId, ShardStatus, TaskDescriptor, TaskExecution, TaskId, TaskPhase, TaskResult, TssHash,
TssId, TssSignature, TssSigningRequest,
TssSignature, TssSigningRequest,
};
use tokio::time::Duration;
use tss::{sum_commitments, VerifiableSecretSharingCommitment, VerifyingKey};
Expand Down Expand Up @@ -207,7 +207,7 @@ impl Mock {
if let Some(mut tss) = self.tss.clone() {
let (tx, rx) = oneshot::channel();
tss.send(TssSigningRequest {
request_id: TssId::new(task_id, task_phase),
request_id: TaskExecution::new(task_id, task_phase),
shard_id,
block_number,
data: payload.to_vec(),
Expand Down
4 changes: 2 additions & 2 deletions chronicle/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use std::ops::Deref;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use time_primitives::{BlockNumber, ShardId, TssId};
use time_primitives::{BlockNumber, ShardId, TaskExecution};

mod protocol;

pub use time_primitives::PeerId;

pub type TssMessage = tss::TssMessage<TssId>;
pub type TssMessage = tss::TssMessage<TaskExecution>;

pub const PROTOCOL_NAME: &str = "/analog-labs/chronicle/1";

Expand Down
49 changes: 25 additions & 24 deletions chronicle/src/shards/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ use std::{
task::Poll,
};
use time_primitives::{
BlockHash, BlockNumber, Runtime, ShardId, ShardStatus, TssId, TssSignature, TssSigningRequest,
BlockHash, BlockNumber, Runtime, ShardId, ShardStatus, TaskExecution, TssSignature,
TssSigningRequest,
};
use tokio::time::{sleep, Duration};
use tracing::{event, span, Level, Span};
Expand All @@ -41,8 +42,8 @@ pub struct TimeWorker<S, T, Tx, Rx> {
tss_states: HashMap<ShardId, Tss>,
executor_states: HashMap<ShardId, T>,
messages: BTreeMap<BlockNumber, Vec<(ShardId, PeerId, TssMessage)>>,
requests: BTreeMap<BlockNumber, Vec<(ShardId, TssId, Vec<u8>)>>,
channels: HashMap<TssId, oneshot::Sender<([u8; 32], TssSignature)>>,
requests: BTreeMap<BlockNumber, Vec<(ShardId, TaskExecution, Vec<u8>)>>,
channels: HashMap<TaskExecution, oneshot::Sender<([u8; 32], TssSignature)>>,
#[allow(clippy::type_complexity)]
outgoing_requests: FuturesUnordered<
Pin<Box<dyn Future<Output = (ShardId, PeerId, Result<()>)> + Send + 'static>>,
Expand Down Expand Up @@ -193,27 +194,6 @@ where
self.poll_actions(&span, shard_id, block_number).await;
}
}
while let Some(n) = self.messages.keys().copied().next() {
if n > block_number {
break;
}
for (shard_id, peer_id, msg) in self.messages.remove(&n).unwrap() {
let Some(tss) = self.tss_states.get_mut(&shard_id) else {
event!(
target: TW_LOG,
parent: &span,
Level::INFO,
shard_id,
"dropping message {} from {:?}",
msg,
peer_id,
);
continue;
};
tss.on_message(peer_id, msg)?;
self.poll_actions(&span, shard_id, n).await;
}
}
for shard_id in shards {
if self.substrate.get_shard_status(block, shard_id).await? != ShardStatus::Online {
continue;
Expand Down Expand Up @@ -253,6 +233,27 @@ where
for session in start_sessions {
tss.on_start(session);
}
while let Some(n) = self.messages.keys().copied().next() {
if n > block_number {
break;
}
for (shard_id, peer_id, msg) in self.messages.remove(&n).unwrap() {
let Some(tss) = self.tss_states.get_mut(&shard_id) else {
event!(
target: TW_LOG,
parent: &span,
Level::INFO,
shard_id,
"dropping message {} from {:?}",
msg,
peer_id,
);
continue;
};
tss.on_message(peer_id, msg)?;
self.poll_actions(&span, shard_id, n).await;
}
}
}
Ok(())
}
Expand Down
14 changes: 7 additions & 7 deletions chronicle/src/shards/tss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ use anyhow::Result;
use sha3::{Digest, Sha3_256};
use std::collections::BTreeSet;
use std::path::{Path, PathBuf};
pub use time_primitives::TssId;
pub use time_primitives::TaskExecution;
pub use tss::{
ProofOfKnowledge, Signature, SigningKey, VerifiableSecretSharingCommitment, VerifyingKey,
};

pub type TssMessage = tss::TssMessage<TssId>;
pub type TssMessage = tss::TssMessage<TaskExecution>;

#[derive(Clone)]
pub enum TssAction {
Send(Vec<(PeerId, TssMessage)>),
Commit(VerifiableSecretSharingCommitment, ProofOfKnowledge),
PublicKey(VerifyingKey),
Signature(TssId, [u8; 32], Signature),
Signature(TaskExecution, [u8; 32], Signature),
}

pub enum Tss {
Enabled(tss::Tss<TssId, TssPeerId>),
Enabled(tss::Tss<TaskExecution, TssPeerId>),
Disabled(SigningKey, Option<TssAction>, bool),
}

Expand Down Expand Up @@ -114,14 +114,14 @@ impl Tss {
}
}

pub fn on_start(&mut self, request_id: TssId) {
pub fn on_start(&mut self, request_id: TaskExecution) {
match self {
Self::Enabled(tss) => tss.on_start(request_id),
Self::Disabled(_, _, _) => {},
}
}

pub fn on_sign(&mut self, request_id: TssId, data: Vec<u8>) {
pub fn on_sign(&mut self, request_id: TaskExecution, data: Vec<u8>) {
match self {
Self::Enabled(tss) => tss.on_sign(request_id, data),
Self::Disabled(key, actions, _) => {
Expand All @@ -131,7 +131,7 @@ impl Tss {
}
}

pub fn on_complete(&mut self, request_id: TssId) {
pub fn on_complete(&mut self, request_id: TaskExecution) {
match self {
Self::Enabled(tss) => tss.on_complete(request_id),
Self::Disabled(_, _, _) => {},
Expand Down
52 changes: 15 additions & 37 deletions chronicle/src/tasks/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use futures::Stream;
use std::{collections::BTreeMap, pin::Pin};
use time_primitives::{
BlockHash, BlockNumber, Function, GmpParams, Message, NetworkId, Runtime, ShardId,
TaskExecution, TaskPhase, TssId,
TaskExecution, TaskPhase,
};
use tokio::task::JoinHandle;
use tracing::{event, span, Level, Span};
use tracing::{event, span, Level};

/// Set of properties we need to run our gadget
#[derive(Clone)]
Expand Down Expand Up @@ -61,23 +61,9 @@ where
block_number: BlockNumber,
shard_id: ShardId,
target_block_height: u64,
) -> Result<(Vec<TssId>, Vec<TssId>)> {
let span = span!(
target: TW_LOG,
Level::DEBUG,
"process_tasks",
block = block_hash.to_string(),
block_number,
);
TaskExecutor::process_tasks(
self,
&span,
block_hash,
block_number,
shard_id,
target_block_height,
)
.await
) -> Result<(Vec<TaskExecution>, Vec<TaskExecution>)> {
TaskExecutor::process_tasks(self, block_hash, block_number, shard_id, target_block_height)
.await
}
}

Expand Down Expand Up @@ -105,15 +91,13 @@ where
/// preprocesses the task before sending it for execution in task_spawner.rs
pub async fn process_tasks(
&mut self,
span: &Span,
block_hash: BlockHash,
block_number: BlockNumber,
shard_id: ShardId,
target_block_height: u64,
) -> Result<(Vec<TssId>, Vec<TssId>)> {
) -> Result<(Vec<TaskExecution>, Vec<TaskExecution>)> {
let span = span!(
target: TW_LOG,
parent: span,
Level::DEBUG,
"process_tasks",
block = block_hash.to_string(),
Expand All @@ -123,7 +107,7 @@ where
let mut start_sessions = vec![];
let tasks = self.substrate.get_shard_tasks(block_hash, shard_id).await?;
tracing::debug!("debug_latency Current Tasks Under processing: {:?}", tasks);
for executable_task in tasks.iter().clone() {
for executable_task in tasks.iter().copied() {
let task_id = executable_task.task_id;
event!(
target: TW_LOG,
Expand All @@ -132,7 +116,7 @@ where
task_id,
"task in execution",
);
if self.running_tasks.contains_key(executable_task) {
if self.running_tasks.contains_key(&executable_task) {
continue;
}
// gets task details
Expand Down Expand Up @@ -323,7 +307,7 @@ where
// Metrics: Increase number of running tasks
self.task_counter_metric.inc(&phase, &function_metric_clone);
let counter = self.task_counter_metric.clone();
start_sessions.push(TssId::new(task_id, phase));
start_sessions.push(executable_task);

let handle = tokio::task::spawn(async move {
match task.await {
Expand Down Expand Up @@ -353,25 +337,25 @@ where
// Metrics: Decrease number of running tasks
counter.dec(&phase, &function_metric_clone);
});
self.running_tasks.insert(executable_task.clone(), handle);
self.running_tasks.insert(executable_task, handle);
}
let mut completed_sessions = Vec::with_capacity(self.running_tasks.len());
// remove from running task if task is completed or we dont receive anymore from pallet
self.running_tasks.retain(|x, handle| {
if tasks.contains(x) {
self.running_tasks.retain(|executable_task, handle| {
if tasks.contains(executable_task) {
true
} else {
if !handle.is_finished() {
event!(
target: TW_LOG,
parent: &span,
Level::DEBUG,
x.task_id,
executable_task.task_id,
"task aborted",
);
handle.abort();
}
completed_sessions.push(TssId::new(x.task_id, x.phase));
completed_sessions.push(*executable_task);
false
}
});
Expand Down Expand Up @@ -435,14 +419,8 @@ mod tests {
while let Some((block_hash, block_number)) =
mock.finality_notification_stream().next().await
{
let span = span!(
Level::DEBUG,
"task_executor_smoke",
block = block_hash.to_string(),
block_number,
);
task_executor
.process_tasks(&span, block_hash, block_number, shard, target_block_height)
.process_tasks(block_hash, block_number, shard, target_block_height)
.await
.unwrap();
tracing::info!("Watching for result");
Expand Down
6 changes: 4 additions & 2 deletions chronicle/src/tasks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use anyhow::Result;
use async_trait::async_trait;
use futures::{Future, Stream};
use std::pin::Pin;
use time_primitives::{BlockHash, BlockNumber, Function, NetworkId, ShardId, TaskId, TssId};
use time_primitives::{
BlockHash, BlockNumber, Function, NetworkId, ShardId, TaskExecution, TaskId,
};

pub mod executor;
pub mod spawner;
Expand Down Expand Up @@ -46,5 +48,5 @@ pub trait TaskExecutor {
block_number: BlockNumber,
shard_id: ShardId,
target_block_height: u64,
) -> Result<(Vec<TssId>, Vec<TssId>)>;
) -> Result<(Vec<TaskExecution>, Vec<TaskExecution>)>;
}
6 changes: 3 additions & 3 deletions chronicle/src/tasks/spawner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use std::{
task::{Context, Poll},
};
use time_primitives::{
BlockNumber, Function, NetworkId, Payload, Runtime, ShardId, TaskId, TaskPhase, TaskResult,
TssHash, TssId, TssSignature, TssSigningRequest,
BlockNumber, Function, NetworkId, Payload, Runtime, ShardId, TaskExecution, TaskId, TaskPhase,
TaskResult, TssHash, TssSignature, TssSigningRequest,
};
use time_primitives::{IGateway, Msg};
use tokio::sync::Mutex;
Expand Down Expand Up @@ -201,7 +201,7 @@ where
self.tss
.clone()
.send(TssSigningRequest {
request_id: TssId::new(task_id, task_phase),
request_id: TaskExecution::new(task_id, task_phase),
shard_id,
block_number,
data: payload.to_vec(),
Expand Down
1 change: 1 addition & 0 deletions node/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ pub fn new_partial(
}

/// Result of [`new_full_base`].
#[allow(dead_code)]
pub struct NewFullBase {
/// The task manager of the node.
pub task_manager: TaskManager,
Expand Down
24 changes: 2 additions & 22 deletions primitives/src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures::channel::oneshot;
#[cfg(feature = "std")]
use serde::{Deserialize, Serialize};

use crate::{TaskId, TaskPhase};
use crate::TaskExecution;
use scale_codec::{Decode, Encode};
use scale_info::prelude::string::String;
use scale_info::TypeInfo;
Expand All @@ -20,26 +20,6 @@ pub type ShardId = u64;
pub type ProofOfKnowledge = [u8; 65];
pub type Commitment = Vec<TssPublicKey>;

#[cfg_attr(feature = "std", derive(Serialize, Deserialize))]
#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct TssId {
task_id: TaskId,
task_phase: TaskPhase,
}

impl TssId {
pub fn new(task_id: TaskId, task_phase: TaskPhase) -> Self {
Self { task_id, task_phase }
}
}

#[cfg(feature = "std")]
impl std::fmt::Display for TssId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}-{}", self.task_id, self.task_phase)
}
}

#[derive(Debug, Clone, Eq, PartialEq, Encode, Decode, TypeInfo)]
pub enum MemberStatus {
Added,
Expand Down Expand Up @@ -87,7 +67,7 @@ impl Default for ShardStatus {

#[cfg(feature = "std")]
pub struct TssSigningRequest {
pub request_id: TssId,
pub request_id: TaskExecution,
pub shard_id: ShardId,
pub block_number: BlockNumber,
pub data: Vec<u8>,
Expand Down
Loading

0 comments on commit 30b53f6

Please sign in to comment.