Skip to content

Commit

Permalink
feat: complete v3 Aggregation APIs (#424)
Browse files Browse the repository at this point in the history
* feat(host): complete the aggregation API

* feat(tasks): complete aggregation task APIs

Write in-memory implementation
Stub out sqlite implementation

* feat(host): update v3 Aggregation Cancel API

All of theses changes were commentted and discussed at #387

1. Remove checking of empty proof of Aggregation Cancel request.

  For this case, querying task status will return None, and we will
response with NotFound error directly. See also #387 (comment)

2. Remove metrics recording to maintain consistency with v1 and v2

  See also #387 (comment)

3. Check task status before signal cancel to prover

---------

Co-authored-by: Petar Vujovic <[email protected]>
Co-authored-by: smtmfft <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 32bc6a9 commit 5dade7a
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 5 deletions.
108 changes: 108 additions & 0 deletions host/src/server/api/v3/proof/aggregate/cancel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use crate::{
interfaces::{HostError, HostResult},
server::api::v2::CancelStatus,
Message, ProverState,
};
use axum::{debug_handler, extract::State, routing::post, Json, Router};
use raiko_core::interfaces::AggregationOnlyRequest;
use raiko_tasks::{TaskManager, TaskStatus};
use utoipa::OpenApi;

#[utoipa::path(post, path = "/proof/aggregate/cancel",
tag = "Proving",
request_body = AggregationOnlyRequest,
responses (
(status = 200, description = "Successfully cancelled proof aggregation task", body = CancelStatus)
)
)]
#[debug_handler(state = ProverState)]
/// Cancel a proof aggregation task with requested config.
///
/// Accepts a proof aggregation request and cancels a proving task with the specified guest prover.
/// The guest provers currently available are:
/// - native - constructs a block and checks for equality
/// - sgx - uses the sgx environment to construct a block and produce proof of execution
/// - sp1 - uses the sp1 prover
/// - risc0 - uses the risc0 prover
async fn cancel_handler(
State(prover_state): State<ProverState>,
Json(mut aggregation_request): Json<AggregationOnlyRequest>,
) -> HostResult<CancelStatus> {
// Override the existing proof request config from the config file and command line
// options with the request from the client.
aggregation_request.merge(&prover_state.request_config())?;

let status = prover_state
.task_manager()
.get_aggregation_task_proving_status(&aggregation_request)
.await?;

let Some((latest_status, ..)) = status.0.last() else {
return Err(HostError::Io(std::io::ErrorKind::NotFound.into()));
};

let mut should_signal_cancel = false;
let returning_cancel_status = match latest_status {
/* Task is already cancelled, so we don't need further action */
TaskStatus::Cancelled
| TaskStatus::Cancelled_Aborted
| TaskStatus::Cancelled_NeverStarted
| TaskStatus::CancellationInProgress => CancelStatus::Ok,

/* Task is not completed, so we need to signal the prover to cancel */
TaskStatus::Registered | TaskStatus::WorkInProgress => {
should_signal_cancel = true;
CancelStatus::Ok
}

/* Task is completed with failure, so we don't need further action, but in case of
* retry we safe to signal the prover to cancel */
TaskStatus::ProofFailure_Generic
| TaskStatus::ProofFailure_OutOfMemory
| TaskStatus::NetworkFailure(_)
| TaskStatus::IoFailure(_)
| TaskStatus::AnyhowError(_)
| TaskStatus::GuestProverFailure(_)
| TaskStatus::InvalidOrUnsupportedBlock
| TaskStatus::UnspecifiedFailureReason
| TaskStatus::TaskDbCorruption(_) => {
should_signal_cancel = true;
CancelStatus::Error {
error: "Task already completed".to_string(),
message: format!("Task already completed, status: {:?}", latest_status),
}
}

/* Task is completed with success, so we return an error */
TaskStatus::Success => CancelStatus::Error {
error: "Task already completed".to_string(),
message: format!("Task already completed, status: {:?}", latest_status),
},
};

if should_signal_cancel {
prover_state
.task_channel
.try_send(Message::CancelAggregate(aggregation_request.clone()))?;

let mut manager = prover_state.task_manager();

manager
.update_aggregation_task_progress(&aggregation_request, TaskStatus::Cancelled, None)
.await?;
}

Ok(returning_cancel_status)
}

#[derive(OpenApi)]
#[openapi(paths(cancel_handler))]
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", post(cancel_handler))
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ use crate::{
Message, ProverState,
};

pub mod cancel;
pub mod prune;
pub mod report;

#[utoipa::path(post, path = "/proof/aggregate",
tag = "Proving",
request_body = AggregationRequest,
Expand Down Expand Up @@ -109,9 +113,22 @@ async fn aggregation_handler(
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
[
cancel::create_docs(),
report::create_docs(),
prune::create_docs(),
]
.into_iter()
.fold(Docs::openapi(), |mut docs, curr| {
docs.merge(curr);
docs
})
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", post(aggregation_handler))
Router::new()
.route("/", post(aggregation_handler))
.nest("/cancel", cancel::create_router())
.nest("/prune", prune::create_router())
.nest("/report", report::create_router())
}
33 changes: 33 additions & 0 deletions host/src/server/api/v3/proof/aggregate/prune.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use axum::{debug_handler, extract::State, routing::post, Router};
use raiko_tasks::TaskManager;
use utoipa::OpenApi;

use crate::{interfaces::HostResult, server::api::v2::PruneStatus, ProverState};

#[utoipa::path(post, path = "/proof/aggregate/prune",
tag = "Proving",
responses (
(status = 200, description = "Successfully pruned all aggregation tasks", body = PruneStatus)
)
)]
#[debug_handler(state = ProverState)]
/// Prune all aggregation tasks.
async fn prune_handler(State(prover_state): State<ProverState>) -> HostResult<PruneStatus> {
let mut manager = prover_state.task_manager();

manager.prune_aggregation_db().await?;

Ok(PruneStatus::Ok)
}

#[derive(OpenApi)]
#[openapi(paths(prune_handler))]
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", post(prune_handler))
}
37 changes: 37 additions & 0 deletions host/src/server/api/v3/proof/aggregate/report.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use axum::{debug_handler, extract::State, routing::get, Json, Router};
use raiko_tasks::{AggregationTaskReport, TaskManager};
use utoipa::OpenApi;

use crate::{interfaces::HostResult, ProverState};

#[utoipa::path(post, path = "/proof/aggregate/report",
tag = "Proving",
responses (
(status = 200, description = "Successfully retrieved a report of all aggregation tasks", body = AggregationTaskReport)
)
)]
#[debug_handler(state = ProverState)]
/// List all aggregation tasks.
///
/// Retrieve a list of aggregation task reports.
async fn report_handler(
State(prover_state): State<ProverState>,
) -> HostResult<Json<Vec<AggregationTaskReport>>> {
let mut manager = prover_state.task_manager();

let task_report = manager.list_all_aggregation_tasks().await?;

Ok(Json(task_report))
}

#[derive(OpenApi)]
#[openapi(paths(report_handler))]
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", get(report_handler))
}
19 changes: 19 additions & 0 deletions taskdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ pub enum TaskDescriptor {

pub type TaskReport = (TaskDescriptor, TaskStatus);

pub type AggregationTaskReport = (AggregationOnlyRequest, TaskStatus);

#[derive(Debug, Clone, Default)]
pub struct TaskManagerOpts {
pub max_db_size: usize,
Expand Down Expand Up @@ -283,6 +285,13 @@ pub trait TaskManager: IdStore + IdWrite + Send + Sync {
&mut self,
request: &AggregationOnlyRequest,
) -> TaskManagerResult<Vec<u8>>;

/// Prune old tasks.
async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()>;

/// List all tasks in the db.
async fn list_all_aggregation_tasks(&mut self)
-> TaskManagerResult<Vec<AggregationTaskReport>>;
}

pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> {
Expand Down Expand Up @@ -397,6 +406,16 @@ impl<T: TaskManager> TaskManager for TaskManagerWrapper<T> {
) -> TaskManagerResult<Vec<u8>> {
self.manager.get_aggregation_task_proof(request).await
}

async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> {
self.manager.prune_aggregation_db().await
}

async fn list_all_aggregation_tasks(
&mut self,
) -> TaskManagerResult<Vec<AggregationTaskReport>> {
self.manager.list_all_aggregation_tasks().await
}
}

#[cfg(feature = "in-memory")]
Expand Down
36 changes: 33 additions & 3 deletions taskdb/src/mem_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use tokio::sync::Mutex;
use tracing::{info, warn};

use crate::{
ensure, AggregationTaskDescriptor, ProofTaskDescriptor, TaskDescriptor, TaskManager,
TaskManagerError, TaskManagerOpts, TaskManagerResult, TaskProvingStatusRecords, TaskReport,
TaskStatus,
ensure, AggregationTaskDescriptor, AggregationTaskReport, ProofTaskDescriptor, TaskDescriptor,
TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, TaskProvingStatusRecords,
TaskReport, TaskStatus,
};

#[derive(Debug)]
Expand Down Expand Up @@ -254,6 +254,24 @@ impl InMemoryTaskDb {
hex::decode(proof)
.map_err(|_| TaskManagerError::Anyhow("couldn't decode from hex".to_owned()))
}

fn prune_aggregation(&mut self) -> TaskManagerResult<()> {
self.aggregation_tasks_queue.clear();
Ok(())
}

fn list_all_aggregation_tasks(&mut self) -> TaskManagerResult<Vec<AggregationTaskReport>> {
Ok(self
.aggregation_tasks_queue
.iter()
.flat_map(|(request, statuses)| {
statuses
.0
.last()
.map(|status| (request.clone(), status.0.clone()))
})
.collect())
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -391,6 +409,18 @@ impl TaskManager for InMemoryTaskManager {
let mut db = self.db.lock().await;
db.get_aggregation_task_proof(request)
}

async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> {
let mut db = self.db.lock().await;
db.prune_aggregation()
}

async fn list_all_aggregation_tasks(
&mut self,
) -> TaskManagerResult<Vec<AggregationTaskReport>> {
let mut db = self.db.lock().await;
db.list_all_aggregation_tasks()
}
}

#[cfg(test)]
Expand Down

0 comments on commit 5dade7a

Please sign in to comment.