Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: complete v3 Aggregation APIs #424

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions host/src/server/api/v3/proof/aggregate/cancel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use crate::{
interfaces::{HostError, HostResult},
server::api::v2::CancelStatus,
Message, ProverState,
};
use axum::{debug_handler, extract::State, routing::post, Json, Router};
use raiko_core::interfaces::AggregationOnlyRequest;
use raiko_tasks::{TaskManager, TaskStatus};
use utoipa::OpenApi;

#[utoipa::path(post, path = "/proof/aggregate/cancel",
tag = "Proving",
request_body = AggregationOnlyRequest,
responses (
(status = 200, description = "Successfully cancelled proof aggregation task", body = CancelStatus)
)
)]
#[debug_handler(state = ProverState)]
/// Cancel a proof aggregation task with requested config.
///
/// Accepts a proof aggregation request and cancels a proving task with the specified guest prover.
/// The guest provers currently available are:
/// - native - constructs a block and checks for equality
/// - sgx - uses the sgx environment to construct a block and produce proof of execution
/// - sp1 - uses the sp1 prover
/// - risc0 - uses the risc0 prover
async fn cancel_handler(
State(prover_state): State<ProverState>,
Json(mut aggregation_request): Json<AggregationOnlyRequest>,
) -> HostResult<CancelStatus> {
// Override the existing proof request config from the config file and command line
// options with the request from the client.
aggregation_request.merge(&prover_state.request_config())?;

let status = prover_state
.task_manager()
.get_aggregation_task_proving_status(&aggregation_request)
.await?;

let Some((latest_status, ..)) = status.0.last() else {
return Err(HostError::Io(std::io::ErrorKind::NotFound.into()));
};

let mut should_signal_cancel = false;
let returning_cancel_status = match latest_status {
/* Task is already cancelled, so we don't need further action */
TaskStatus::Cancelled
| TaskStatus::Cancelled_Aborted
| TaskStatus::Cancelled_NeverStarted
| TaskStatus::CancellationInProgress => CancelStatus::Ok,

/* Task is not completed, so we need to signal the prover to cancel */
TaskStatus::Registered | TaskStatus::WorkInProgress => {
should_signal_cancel = true;
CancelStatus::Ok
}

/* Task is completed with failure, so we don't need further action, but in case of
* retry we safe to signal the prover to cancel */
TaskStatus::ProofFailure_Generic
| TaskStatus::ProofFailure_OutOfMemory
| TaskStatus::NetworkFailure(_)
| TaskStatus::IoFailure(_)
| TaskStatus::AnyhowError(_)
| TaskStatus::GuestProverFailure(_)
| TaskStatus::InvalidOrUnsupportedBlock
| TaskStatus::UnspecifiedFailureReason
| TaskStatus::TaskDbCorruption(_) => {
should_signal_cancel = true;
CancelStatus::Error {
error: "Task already completed".to_string(),
message: format!("Task already completed, status: {:?}", latest_status),
}
}

/* Task is completed with success, so we return an error */
TaskStatus::Success => CancelStatus::Error {
error: "Task already completed".to_string(),
message: format!("Task already completed, status: {:?}", latest_status),
},
};

if should_signal_cancel {
prover_state
.task_channel
.try_send(Message::CancelAggregate(aggregation_request.clone()))?;

let mut manager = prover_state.task_manager();

manager
.update_aggregation_task_progress(&aggregation_request, TaskStatus::Cancelled, None)
.await?;
}

Ok(returning_cancel_status)
}

#[derive(OpenApi)]
#[openapi(paths(cancel_handler))]
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", post(cancel_handler))
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ use crate::{
Message, ProverState,
};

pub mod cancel;
pub mod prune;
pub mod report;

#[utoipa::path(post, path = "/proof/aggregate",
tag = "Proving",
request_body = AggregationRequest,
Expand Down Expand Up @@ -109,9 +113,22 @@ async fn aggregation_handler(
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
[
cancel::create_docs(),
report::create_docs(),
prune::create_docs(),
]
.into_iter()
.fold(Docs::openapi(), |mut docs, curr| {
docs.merge(curr);
docs
})
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", post(aggregation_handler))
Router::new()
.route("/", post(aggregation_handler))
.nest("/cancel", cancel::create_router())
.nest("/prune", prune::create_router())
.nest("/report", report::create_router())
}
33 changes: 33 additions & 0 deletions host/src/server/api/v3/proof/aggregate/prune.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use axum::{debug_handler, extract::State, routing::post, Router};
use raiko_tasks::TaskManager;
use utoipa::OpenApi;

use crate::{interfaces::HostResult, server::api::v2::PruneStatus, ProverState};

#[utoipa::path(post, path = "/proof/aggregate/prune",
tag = "Proving",
responses (
(status = 200, description = "Successfully pruned all aggregation tasks", body = PruneStatus)
)
)]
#[debug_handler(state = ProverState)]
/// Prune all aggregation tasks.
async fn prune_handler(State(prover_state): State<ProverState>) -> HostResult<PruneStatus> {
let mut manager = prover_state.task_manager();

manager.prune_aggregation_db().await?;

Ok(PruneStatus::Ok)
}

#[derive(OpenApi)]
#[openapi(paths(prune_handler))]
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", post(prune_handler))
}
37 changes: 37 additions & 0 deletions host/src/server/api/v3/proof/aggregate/report.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use axum::{debug_handler, extract::State, routing::get, Json, Router};
use raiko_tasks::{AggregationTaskReport, TaskManager};
use utoipa::OpenApi;

use crate::{interfaces::HostResult, ProverState};

#[utoipa::path(post, path = "/proof/aggregate/report",
tag = "Proving",
responses (
(status = 200, description = "Successfully retrieved a report of all aggregation tasks", body = AggregationTaskReport)
)
)]
#[debug_handler(state = ProverState)]
/// List all aggregation tasks.
///
/// Retrieve a list of aggregation task reports.
async fn report_handler(
State(prover_state): State<ProverState>,
) -> HostResult<Json<Vec<AggregationTaskReport>>> {
let mut manager = prover_state.task_manager();

let task_report = manager.list_all_aggregation_tasks().await?;

Ok(Json(task_report))
}

#[derive(OpenApi)]
#[openapi(paths(report_handler))]
struct Docs;

pub fn create_docs() -> utoipa::openapi::OpenApi {
Docs::openapi()
}

pub fn create_router() -> Router<ProverState> {
Router::new().route("/", get(report_handler))
}
19 changes: 19 additions & 0 deletions taskdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ pub enum TaskDescriptor {

pub type TaskReport = (TaskDescriptor, TaskStatus);

pub type AggregationTaskReport = (AggregationOnlyRequest, TaskStatus);

#[derive(Debug, Clone, Default)]
pub struct TaskManagerOpts {
pub max_db_size: usize,
Expand Down Expand Up @@ -283,6 +285,13 @@ pub trait TaskManager: IdStore + IdWrite + Send + Sync {
&mut self,
request: &AggregationOnlyRequest,
) -> TaskManagerResult<Vec<u8>>;

/// Prune old tasks.
async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()>;

/// List all tasks in the db.
async fn list_all_aggregation_tasks(&mut self)
-> TaskManagerResult<Vec<AggregationTaskReport>>;
}

pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> {
Expand Down Expand Up @@ -397,6 +406,16 @@ impl<T: TaskManager> TaskManager for TaskManagerWrapper<T> {
) -> TaskManagerResult<Vec<u8>> {
self.manager.get_aggregation_task_proof(request).await
}

async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> {
self.manager.prune_aggregation_db().await
}

async fn list_all_aggregation_tasks(
&mut self,
) -> TaskManagerResult<Vec<AggregationTaskReport>> {
self.manager.list_all_aggregation_tasks().await
}
}

#[cfg(feature = "in-memory")]
Expand Down
36 changes: 33 additions & 3 deletions taskdb/src/mem_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use tokio::sync::Mutex;
use tracing::{info, warn};

use crate::{
ensure, AggregationTaskDescriptor, ProofTaskDescriptor, TaskDescriptor, TaskManager,
TaskManagerError, TaskManagerOpts, TaskManagerResult, TaskProvingStatusRecords, TaskReport,
TaskStatus,
ensure, AggregationTaskDescriptor, AggregationTaskReport, ProofTaskDescriptor, TaskDescriptor,
TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, TaskProvingStatusRecords,
TaskReport, TaskStatus,
};

#[derive(Debug)]
Expand Down Expand Up @@ -254,6 +254,24 @@ impl InMemoryTaskDb {
hex::decode(proof)
.map_err(|_| TaskManagerError::Anyhow("couldn't decode from hex".to_owned()))
}

fn prune_aggregation(&mut self) -> TaskManagerResult<()> {
self.aggregation_tasks_queue.clear();
Ok(())
}

fn list_all_aggregation_tasks(&mut self) -> TaskManagerResult<Vec<AggregationTaskReport>> {
Ok(self
.aggregation_tasks_queue
.iter()
.flat_map(|(request, statuses)| {
statuses
.0
.last()
.map(|status| (request.clone(), status.0.clone()))
})
.collect())
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -391,6 +409,18 @@ impl TaskManager for InMemoryTaskManager {
let mut db = self.db.lock().await;
db.get_aggregation_task_proof(request)
}

async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> {
let mut db = self.db.lock().await;
db.prune_aggregation()
}

async fn list_all_aggregation_tasks(
&mut self,
) -> TaskManagerResult<Vec<AggregationTaskReport>> {
let mut db = self.db.lock().await;
db.list_all_aggregation_tasks()
}
}

#[cfg(test)]
Expand Down
Loading