diff --git a/core/src/interfaces.rs b/core/src/interfaces.rs index 63ad41140..ca22af647 100644 --- a/core/src/interfaces.rs +++ b/core/src/interfaces.rs @@ -154,6 +154,17 @@ impl TryFrom for ProofType { } } +impl From for u8 { + fn from(val: ProofType) -> Self { + match val { + ProofType::Native => 0, + ProofType::Sp1 => 1, + ProofType::Sgx => 2, + ProofType::Risc0 => 3, + } + } +} + impl From for VerifierType { fn from(val: ProofType) -> Self { match val { diff --git a/host/src/server/api/v3/proof/aggregate/cancel.rs b/host/src/server/api/v3/proof/aggregate/cancel.rs new file mode 100644 index 000000000..018f922ca --- /dev/null +++ b/host/src/server/api/v3/proof/aggregate/cancel.rs @@ -0,0 +1,75 @@ +use std::str::FromStr; + +use axum::{debug_handler, extract::State, routing::post, Json, Router}; +use raiko_core::interfaces::{AggregationOnlyRequest, ProofType}; +use raiko_tasks::{TaskManager, TaskStatus}; +use utoipa::OpenApi; + +use crate::{ + interfaces::HostResult, + metrics::{inc_guest_req_count, inc_host_req_count}, + server::api::v2::CancelStatus, + Message, ProverState, +}; + +#[utoipa::path(post, path = "/proof/aggregate/cancel", + tag = "Proving", + request_body = AggregationOnlyRequest, + responses ( + (status = 200, description = "Successfully cancelled proof aggregation task", body = CancelStatus) + ) +)] +#[debug_handler(state = ProverState)] +/// Cancel a proof aggregation task with requested config. +/// +/// Accepts a proof aggregation request and cancels a proving task with the specified guest prover. +/// The guest provers currently available are: +/// - native - constructs a block and checks for equality +/// - sgx - uses the sgx environment to construct a block and produce proof of execution +/// - sp1 - uses the sp1 prover +/// - risc0 - uses the risc0 prover +async fn cancel_handler( + State(prover_state): State, + Json(mut aggregation_request): Json, +) -> HostResult { + // Override the existing proof request config from the config file and command line + // options with the request from the client. + aggregation_request.merge(&prover_state.request_config())?; + + let proof_type = ProofType::from_str( + aggregation_request + .proof_type + .as_deref() + .unwrap_or_default(), + )?; + inc_host_req_count(0); + inc_guest_req_count(&proof_type, 0); + + if aggregation_request.proofs.is_empty() { + return Err(anyhow::anyhow!("No proofs provided").into()); + } + + prover_state + .task_channel + .try_send(Message::CancelAggregate(aggregation_request.clone()))?; + + let mut manager = prover_state.task_manager(); + + manager + .update_aggregation_task_progress(&aggregation_request, TaskStatus::Cancelled, None) + .await?; + + Ok(CancelStatus::Ok) +} + +#[derive(OpenApi)] +#[openapi(paths(cancel_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", post(cancel_handler)) +} diff --git a/host/src/server/api/v3/proof/aggregate.rs b/host/src/server/api/v3/proof/aggregate/mod.rs similarity index 88% rename from host/src/server/api/v3/proof/aggregate.rs rename to host/src/server/api/v3/proof/aggregate/mod.rs index 3bbffa00f..5f9fceb5e 100644 --- a/host/src/server/api/v3/proof/aggregate.rs +++ b/host/src/server/api/v3/proof/aggregate/mod.rs @@ -12,6 +12,10 @@ use crate::{ Message, ProverState, }; +pub mod cancel; +pub mod prune; +pub mod report; + #[utoipa::path(post, path = "/proof/aggregate", tag = "Proving", request_body = AggregationRequest, @@ -106,9 +110,22 @@ async fn aggregation_handler( struct Docs; pub fn create_docs() -> utoipa::openapi::OpenApi { - Docs::openapi() + [ + cancel::create_docs(), + report::create_docs(), + prune::create_docs(), + ] + .into_iter() + .fold(Docs::openapi(), |mut docs, curr| { + docs.merge(curr); + docs + }) } pub fn create_router() -> Router { - Router::new().route("/", post(aggregation_handler)) + Router::new() + .route("/", post(aggregation_handler)) + .nest("/cancel", cancel::create_router()) + .nest("/prune", prune::create_router()) + .nest("/report", report::create_router()) } diff --git a/host/src/server/api/v3/proof/aggregate/prune.rs b/host/src/server/api/v3/proof/aggregate/prune.rs new file mode 100644 index 000000000..f0e6dd956 --- /dev/null +++ b/host/src/server/api/v3/proof/aggregate/prune.rs @@ -0,0 +1,33 @@ +use axum::{debug_handler, extract::State, routing::post, Router}; +use raiko_tasks::TaskManager; +use utoipa::OpenApi; + +use crate::{interfaces::HostResult, server::api::v2::PruneStatus, ProverState}; + +#[utoipa::path(post, path = "/proof/aggregate/prune", + tag = "Proving", + responses ( + (status = 200, description = "Successfully pruned all aggregation tasks", body = PruneStatus) + ) +)] +#[debug_handler(state = ProverState)] +/// Prune all aggregation tasks. +async fn prune_handler(State(prover_state): State) -> HostResult { + let mut manager = prover_state.task_manager(); + + manager.prune_aggregation_db().await?; + + Ok(PruneStatus::Ok) +} + +#[derive(OpenApi)] +#[openapi(paths(prune_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", post(prune_handler)) +} diff --git a/host/src/server/api/v3/proof/aggregate/report.rs b/host/src/server/api/v3/proof/aggregate/report.rs new file mode 100644 index 000000000..64d0d18bb --- /dev/null +++ b/host/src/server/api/v3/proof/aggregate/report.rs @@ -0,0 +1,37 @@ +use axum::{debug_handler, extract::State, routing::get, Json, Router}; +use raiko_tasks::{AggregationTaskReport, TaskManager}; +use utoipa::OpenApi; + +use crate::{interfaces::HostResult, ProverState}; + +#[utoipa::path(post, path = "/proof/aggregate/report", + tag = "Proving", + responses ( + (status = 200, description = "Successfully retrieved a report of all aggregation tasks", body = AggregationTaskReport) + ) +)] +#[debug_handler(state = ProverState)] +/// List all aggregation tasks. +/// +/// Retrieve a list of aggregation task reports. +async fn report_handler( + State(prover_state): State, +) -> HostResult>> { + let mut manager = prover_state.task_manager(); + + let task_report = manager.list_all_aggregation_tasks().await?; + + Ok(Json(task_report)) +} + +#[derive(OpenApi)] +#[openapi(paths(report_handler))] +struct Docs; + +pub fn create_docs() -> utoipa::openapi::OpenApi { + Docs::openapi() +} + +pub fn create_router() -> Router { + Router::new().route("/", get(report_handler)) +} diff --git a/tasks/src/adv_sqlite.rs b/tasks/src/adv_sqlite.rs index 96f9e4bb3..5fcd0be30 100644 --- a/tasks/src/adv_sqlite.rs +++ b/tasks/src/adv_sqlite.rs @@ -155,11 +155,12 @@ use std::{ fs::File, path::Path, + str::FromStr, sync::{Arc, Once}, }; use chrono::{DateTime, Utc}; -use raiko_core::interfaces::AggregationOnlyRequest; +use raiko_core::interfaces::{AggregationOnlyRequest, ProofType, ProverSpecificOpts}; use raiko_lib::{ primitives::B256, prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult}, @@ -170,8 +171,8 @@ use rusqlite::{ use tokio::sync::Mutex; use crate::{ - TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, - TaskProvingStatus, TaskProvingStatusRecords, TaskReport, TaskStatus, + AggregationTaskReport, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, + TaskManagerResult, TaskProvingStatus, TaskProvingStatusRecords, TaskReport, TaskStatus, }; // Types @@ -244,6 +245,14 @@ impl TaskDb { UNIQUE (chain_id, blockhash, proofsys_id) ); + CREATE TABLE aggregation_store( + proofs TEXT NOT NULL, + proofsys_id INTEGER NOT NULL, + id TEXT NOT NULL, + FOREIGN KEY(proofsys_id) REFERENCES proofsys(id), + UNIQUE (proofs, proofsys_id) + ); + -- Metadata and mappings ----------------------------------------------- CREATE TABLE metadata( @@ -308,6 +317,14 @@ impl TaskDb { FOREIGN KEY(proofsys_id) REFERENCES proofsys(id), UNIQUE (chain_id, blockhash, proofsys_id) ); + + CREATE TABLE aggregation_tasks( + id INTEGER UNIQUE NOT NULL PRIMARY KEY, + proofs TEXT NOT NULL, + proofsys_id INTEGER NOT NULL, + FOREIGN KEY(proofsys_id) REFERENCES proofsys(id), + UNIQUE (proofs, proofsys_id) + ); -- Proofs might also be large, so we isolate them in a dedicated table CREATE TABLE task_proofs( @@ -315,6 +332,12 @@ impl TaskDb { proof TEXT, FOREIGN KEY(task_id) REFERENCES tasks(id) ); + + CREATE TABLE aggregation_task_proofs( + task_id INTEGER UNIQUE NOT NULL PRIMARY KEY, + proof TEXT, + FOREIGN KEY(task_id) REFERENCES aggregation_tasks(id) + ); CREATE TABLE task_status( task_id INTEGER NOT NULL, @@ -324,6 +347,15 @@ impl TaskDb { FOREIGN KEY(status_id) REFERENCES status_codes(id), UNIQUE (task_id, timestamp) ); + + CREATE TABLE aggregation_task_status( + task_id INTEGER NOT NULL, + status_id INTEGER NOT NULL, + timestamp TIMESTAMP DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) NOT NULL, + FOREIGN KEY(task_id) REFERENCES aggregation_tasks(id), + FOREIGN KEY(status_id) REFERENCES status_codes(id), + UNIQUE (task_id, timestamp) + ); "#, )?; @@ -358,6 +390,27 @@ impl TaskDb { tasks t LEFT JOIN task_status ts on ts.task_id = t.id LEFT JOIN task_proofs tpf on tpf.task_id = t.id; + + CREATE VIEW enqueue_aggregation_task AS + SELECT + t.id, + t.proofs, + t.proofsys_id + FROM + aggregation_tasks t + LEFT JOIN aggregation_task_status ts on ts.task_id = t.id; + + CREATE VIEW update_aggregation_task_progress AS + SELECT + t.id, + t.proofs, + t.proofsys_id, + ts.status_id, + tpf.proof + FROM + aggregation_tasks t + LEFT JOIN aggregation_task_status ts on ts.task_id = t.id + LEFT JOIN aggregation_task_proofs tpf on tpf.task_id = t.id; "#, )?; @@ -408,6 +461,7 @@ impl TaskDb { r#" -- PRAGMA temp_store = 'MEMORY'; CREATE TEMPORARY TABLE IF NOT EXISTS temp.current_task(task_id INTEGER); + CREATE TEMPORARY TABLE IF NOT EXISTS temp.current_aggregation_task(task_id INTEGER); CREATE TEMPORARY TRIGGER IF NOT EXISTS enqueue_task_insert_trigger INSTEAD OF INSERT @@ -447,6 +501,43 @@ impl TaskDb { DELETE FROM current_task; END; + + CREATE TEMPORARY TRIGGER IF NOT EXISTS enqueue_aggregation_task_insert_trigger INSTEAD OF + INSERT + ON enqueue_aggregation_task + BEGIN + INSERT INTO + aggregation_tasks(proofs, proofsys_id) + VALUES + ( + new.proofs, + new.proofsys_id + ); + + INSERT INTO + current_aggregation_task + SELECT + id + FROM + aggregation_tasks + WHERE + rowid = last_insert_rowid() + LIMIT + 1; + + -- Tasks are initialized at status 1000 - registered + -- timestamp is auto-filled with datetime('now'), see its field definition + INSERT INTO + aggregation_task_status(task_id, status_id) + SELECT + tmp.task_id, + 1000 + FROM + current_aggregation_task tmp; + + DELETE FROM + current_aggregation_task; + END; CREATE TEMPORARY TRIGGER IF NOT EXISTS update_task_progress_trigger INSTEAD OF INSERT @@ -491,6 +582,49 @@ impl TaskDb { DELETE FROM current_task; END; + + CREATE TEMPORARY TRIGGER IF NOT EXISTS update_aggregation_task_progress_trigger INSTEAD OF + INSERT + ON update_aggregation_task_progress + BEGIN + INSERT INTO + current_aggregation_task + SELECT + id + FROM + aggregation_tasks + WHERE + proofs = new.proofs + AND proofsys_id = new.proofsys_id + LIMIT + 1; + + -- timestamp is auto-filled with datetime('now'), see its field definition + INSERT INTO + aggregation_task_status(task_id, status_id) + SELECT + tmp.task_id, + new.status_id + FROM + current_aggregation_task tmp + LIMIT + 1; + + INSERT + OR REPLACE INTO aggregation_task_proofs + SELECT + task_id, + new.proof + FROM + current_aggregation_task + WHERE + new.proof IS NOT NULL + LIMIT + 1; + + DELETE FROM + current_aggregation_task; + END; "#, )?; @@ -844,6 +978,210 @@ impl TaskDb { Ok(query) } + + fn enqueue_aggregation_task(&self, request: &AggregationOnlyRequest) -> TaskManagerResult<()> { + let mut statement = self.conn.prepare_cached( + r#" + INSERT INTO + enqueue_aggregation_task( + proofs, + proofsys_id + ) + VALUES + ( + :proofs + :proofsys_id + ); + "#, + )?; + let proof_type = + ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned())) + .map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?; + statement.execute(named_params! { + ":proofs": serde_json::to_string(&request.proofs)?, + ":proofsys_id": u8::from(proof_type), + })?; + + Ok(()) + } + + fn get_aggregation_task_proving_status( + &self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + ts.status_id, + tp.proof, + timestamp + FROM + aggregation_task_status ts + LEFT JOIN aggregation_tasks t ON ts.task_id = t.id + LEFT JOIN aggregation_task_proofs tp ON tp.task_id = t.id + WHERE + t.proofs = :proofs + AND t.proofsys_id = :proofsys_id + ORDER BY + ts.timestamp; + "#, + )?; + let proof_type = + ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned())) + .map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?; + let query = statement.query_map( + named_params! { + ":proofs": serde_json::to_string(&request.proofs)?, + ":proofsys_id": u8::from(proof_type), + }, + |row| { + Ok(( + TaskStatus::from(row.get::<_, i32>(0)?), + row.get::<_, Option>(1)?, + row.get::<_, DateTime>(2)?, + )) + }, + )?; + + Ok(query.collect::, _>>()?) + } + + fn update_aggregation_task_progress( + &self, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, + ) -> TaskManagerResult<()> { + let mut statement = self.conn.prepare_cached( + r#" + INSERT INTO + update_aggregation_task_progress( + proofs, + proofsys_id, + status_id, + proof + ) + VALUES + ( + :proofs, + :proofsys_id, + :status_id, + :proof + ); + "#, + )?; + let proof_type = + ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned())) + .map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?; + statement.execute(named_params! { + ":proofs": serde_json::to_string(&request.proofs)?, + ":proofsys_id": u8::from(proof_type), + ":status_id": i32::from(status), + ":proof": proof.map(hex::encode) + })?; + + Ok(()) + } + + fn get_aggregation_task_proof( + &self, + request: &AggregationOnlyRequest, + ) -> TaskManagerResult> { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + proof + FROM + aggregation_task_proofs tp + LEFT JOIN aggregation_tasks t ON tp.task_id = t.id + WHERE + t.proofs = :proofs + AND t.proofsys_id = :proofsys_id + LIMIT + 1; + "#, + )?; + let proof_type = + ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned())) + .map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?; + let query = statement.query_row( + named_params! { + ":proofs": serde_json::to_string(&request.proofs)?, + ":proofsys_id": u8::from(proof_type), + }, + |row| row.get::<_, Option>(0), + )?; + + let Some(proof) = query else { + return Ok(vec![]); + }; + + hex::decode(proof) + .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) + } + + fn prune_aggregation_db(&self) -> TaskManagerResult<()> { + let mut statement = self.conn.prepare_cached( + r#" + DELETE FROM + aggregation_tasks; + + DELETE FROM + aggregation_task_proofs; + + DELETE FROM + aggregation_task_status; + "#, + )?; + statement.execute([])?; + + Ok(()) + } + + fn list_all_aggregation_tasks(&self) -> TaskManagerResult> { + let mut statement = self.conn.prepare_cached( + r#" + SELECT + proofs, + proofsys_id, + status_id + FROM + aggregation_tasks + LEFT JOIN aggregation_task_status on task.id = aggregation_task_status.task_id + JOIN ( + SELECT + task_id, + MAX(timestamp) as latest_timestamp + FROM + aggregation_task_status + GROUP BY + task_id + ) latest_ts ON aggregation_task_status.task_id = latest_ts.task_id + AND aggregation_task_status.timestamp = latest_ts.latest_timestamp + "#, + )?; + let query = statement + .query_map([], |row| { + Ok(( + AggregationOnlyRequest { + proofs: serde_json::from_str(&row.get::<_, String>(0)?) + .map_err(|e| anyhow::anyhow!("Couldn't deserialize proofs: {e:?}")) + .unwrap(), + proof_type: Some( + ProofType::try_from(row.get::<_, u8>(1)?) + .map_err(|e| anyhow::anyhow!("Couldn't decode proof type: {e:?}")) + .unwrap() + .to_string(), + ), + prover_args: ProverSpecificOpts::default(), + }, + TaskStatus::from(row.get::<_, i32>(2)?), + )) + })? + .collect::, _>>()?; + + Ok(query) + } } #[async_trait::async_trait] @@ -947,32 +1285,48 @@ impl TaskManager for SqliteTaskManager { async fn enqueue_aggregation_task( &mut self, - _request: &AggregationOnlyRequest, + request: &AggregationOnlyRequest, ) -> TaskManagerResult<()> { - todo!() + let task_db = self.arc_task_db.lock().await; + task_db.enqueue_aggregation_task(request) } async fn get_aggregation_task_proving_status( &mut self, - _request: &AggregationOnlyRequest, + request: &AggregationOnlyRequest, ) -> TaskManagerResult { - todo!() + let task_db = self.arc_task_db.lock().await; + task_db.get_aggregation_task_proving_status(request) } async fn update_aggregation_task_progress( &mut self, - _request: &AggregationOnlyRequest, - _status: TaskStatus, - _proof: Option<&[u8]>, + request: &AggregationOnlyRequest, + status: TaskStatus, + proof: Option<&[u8]>, ) -> TaskManagerResult<()> { - todo!() + let task_db = self.arc_task_db.lock().await; + task_db.update_aggregation_task_progress(request, status, proof) } async fn get_aggregation_task_proof( &mut self, - _request: &AggregationOnlyRequest, + request: &AggregationOnlyRequest, ) -> TaskManagerResult> { - todo!() + let task_db = self.arc_task_db.lock().await; + task_db.get_aggregation_task_proof(request) + } + + async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> { + let task_db = self.arc_task_db.lock().await; + task_db.prune_aggregation_db() + } + + async fn list_all_aggregation_tasks( + &mut self, + ) -> TaskManagerResult> { + let task_db = self.arc_task_db.lock().await; + task_db.list_all_aggregation_tasks() } } diff --git a/tasks/src/lib.rs b/tasks/src/lib.rs index cc7523e35..195a68295 100644 --- a/tasks/src/lib.rs +++ b/tasks/src/lib.rs @@ -179,6 +179,8 @@ pub type TaskProvingStatusRecords = Vec; pub type TaskReport = (TaskDescriptor, TaskStatus); +pub type AggregationTaskReport = (AggregationOnlyRequest, TaskStatus); + #[derive(Debug, Clone, Default)] pub struct TaskManagerOpts { pub sqlite_file: PathBuf, @@ -250,6 +252,13 @@ pub trait TaskManager: IdStore + IdWrite { &mut self, request: &AggregationOnlyRequest, ) -> TaskManagerResult>; + + /// Prune old tasks. + async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()>; + + /// List all tasks in the db. + async fn list_all_aggregation_tasks(&mut self) + -> TaskManagerResult>; } pub fn ensure(expression: bool, message: &str) -> TaskManagerResult<()> { @@ -443,6 +452,26 @@ impl TaskManager for TaskManagerWrapper { } } } + + async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => manager.prune_aggregation_db().await, + TaskManagerInstance::Sqlite(ref mut manager) => manager.prune_aggregation_db().await, + } + } + + async fn list_all_aggregation_tasks( + &mut self, + ) -> TaskManagerResult> { + match &mut self.manager { + TaskManagerInstance::InMemory(ref mut manager) => { + manager.list_all_aggregation_tasks().await + } + TaskManagerInstance::Sqlite(ref mut manager) => { + manager.list_all_aggregation_tasks().await + } + } + } } pub fn get_task_manager(opts: &TaskManagerOpts) -> TaskManagerWrapper { diff --git a/tasks/src/mem_db.rs b/tasks/src/mem_db.rs index f3bee7883..223f5ba49 100644 --- a/tasks/src/mem_db.rs +++ b/tasks/src/mem_db.rs @@ -19,8 +19,8 @@ use tokio::sync::Mutex; use tracing::{debug, info}; use crate::{ - ensure, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, TaskManagerResult, - TaskProvingStatusRecords, TaskReport, TaskStatus, + ensure, AggregationTaskReport, TaskDescriptor, TaskManager, TaskManagerError, TaskManagerOpts, + TaskManagerResult, TaskProvingStatusRecords, TaskReport, TaskStatus, }; #[derive(Debug)] @@ -235,6 +235,23 @@ impl InMemoryTaskDb { hex::decode(proof) .map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned())) } + + fn prune_aggregation(&mut self) -> TaskManagerResult<()> { + self.aggregation_tasks_queue.clear(); + Ok(()) + } + + fn list_all_aggregation_tasks(&mut self) -> TaskManagerResult> { + Ok(self + .aggregation_tasks_queue + .iter() + .flat_map(|(request, statuses)| { + statuses + .last() + .map(|status| (request.clone(), status.0.clone())) + }) + .collect()) + } } #[async_trait::async_trait] @@ -372,6 +389,18 @@ impl TaskManager for InMemoryTaskManager { let mut db = self.db.lock().await; db.get_aggregation_task_proof(request) } + + async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> { + let mut db = self.db.lock().await; + db.prune_aggregation() + } + + async fn list_all_aggregation_tasks( + &mut self, + ) -> TaskManagerResult> { + let mut db = self.db.lock().await; + db.list_all_aggregation_tasks() + } } #[cfg(test)]