Skip to content

Commit

Permalink
feat(tasks): implement sqlite version of aggregation tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
petarvujovic98 committed Oct 17, 2024
1 parent b68fac3 commit d1ec67f
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 13 deletions.
11 changes: 11 additions & 0 deletions core/src/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ impl TryFrom<u8> for ProofType {
}
}

impl From<ProofType> for u8 {
fn from(val: ProofType) -> Self {
match val {
ProofType::Native => 0,
ProofType::Sp1 => 1,
ProofType::Sgx => 2,
ProofType::Risc0 => 3,
}
}
}

impl From<ProofType> for VerifierType {
fn from(val: ProofType) -> Self {
match val {
Expand Down
237 changes: 224 additions & 13 deletions tasks/src/adv_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@
use std::{
fs::File,
path::Path,
str::FromStr,
sync::{Arc, Once},
};

use chrono::{DateTime, Utc};
use raiko_core::interfaces::AggregationOnlyRequest;
use raiko_core::interfaces::{AggregationOnlyRequest, ProofType, ProverSpecificOpts};
use raiko_lib::{
primitives::B256,
prover::{IdStore, IdWrite, ProofKey, ProverError, ProverResult},
Expand Down Expand Up @@ -977,6 +978,210 @@ impl TaskDb {

Ok(query)
}

fn enqueue_aggregation_task(&self, request: &AggregationOnlyRequest) -> TaskManagerResult<()> {
let mut statement = self.conn.prepare_cached(
r#"
INSERT INTO
enqueue_aggregation_task(
proofs,
proofsys_id
)
VALUES
(
:proofs
:proofsys_id
);
"#,
)?;
let proof_type =
ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned()))
.map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?;
statement.execute(named_params! {
":proofs": serde_json::to_string(&request.proofs)?,
":proofsys_id": u8::from(proof_type),
})?;

Ok(())
}

fn get_aggregation_task_proving_status(
&self,
request: &AggregationOnlyRequest,
) -> TaskManagerResult<TaskProvingStatusRecords> {
let mut statement = self.conn.prepare_cached(
r#"
SELECT
ts.status_id,
tp.proof,
timestamp
FROM
aggregation_task_status ts
LEFT JOIN aggregation_tasks t ON ts.task_id = t.id
LEFT JOIN aggregation_task_proofs tp ON tp.task_id = t.id
WHERE
t.proofs = :proofs
AND t.proofsys_id = :proofsys_id
ORDER BY
ts.timestamp;
"#,
)?;
let proof_type =
ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned()))
.map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?;
let query = statement.query_map(
named_params! {
":proofs": serde_json::to_string(&request.proofs)?,
":proofsys_id": u8::from(proof_type),
},
|row| {
Ok((
TaskStatus::from(row.get::<_, i32>(0)?),
row.get::<_, Option<String>>(1)?,
row.get::<_, DateTime<Utc>>(2)?,
))
},
)?;

Ok(query.collect::<Result<Vec<_>, _>>()?)
}

fn update_aggregation_task_progress(
&self,
request: &AggregationOnlyRequest,
status: TaskStatus,
proof: Option<&[u8]>,
) -> TaskManagerResult<()> {
let mut statement = self.conn.prepare_cached(
r#"
INSERT INTO
update_aggregation_task_progress(
proofs,
proofsys_id,
status_id,
proof
)
VALUES
(
:proofs,
:proofsys_id,
:status_id,
:proof
);
"#,
)?;
let proof_type =
ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned()))
.map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?;
statement.execute(named_params! {
":proofs": serde_json::to_string(&request.proofs)?,
":proofsys_id": u8::from(proof_type),
":status_id": i32::from(status),
":proof": proof.map(hex::encode)
})?;

Ok(())
}

fn get_aggregation_task_proof(
&self,
request: &AggregationOnlyRequest,
) -> TaskManagerResult<Vec<u8>> {
let mut statement = self.conn.prepare_cached(
r#"
SELECT
proof
FROM
aggregation_task_proofs tp
LEFT JOIN aggregation_tasks t ON tp.task_id = t.id
WHERE
t.proofs = :proofs
AND t.proofsys_id = :proofsys_id
LIMIT
1;
"#,
)?;
let proof_type =
ProofType::from_str(request.proof_type.as_ref().unwrap_or(&"native".to_owned()))
.map_err(|e| anyhow::anyhow!("Conversion error {e:?}"))?;
let query = statement.query_row(
named_params! {
":proofs": serde_json::to_string(&request.proofs)?,
":proofsys_id": u8::from(proof_type),
},
|row| row.get::<_, Option<String>>(0),
)?;

let Some(proof) = query else {
return Ok(vec![]);
};

hex::decode(proof)
.map_err(|_| TaskManagerError::SqlError("couldn't decode from hex".to_owned()))
}

fn prune_aggregation_db(&self) -> TaskManagerResult<()> {
let mut statement = self.conn.prepare_cached(
r#"
DELETE FROM
aggregation_tasks;
DELETE FROM
aggregation_task_proofs;
DELETE FROM
aggregation_task_status;
"#,
)?;
statement.execute([])?;

Ok(())
}

fn list_all_aggregation_tasks(&self) -> TaskManagerResult<Vec<AggregationTaskReport>> {
let mut statement = self.conn.prepare_cached(
r#"
SELECT
proofs,
proofsys_id,
status_id
FROM
aggregation_tasks
LEFT JOIN aggregation_task_status on task.id = aggregation_task_status.task_id
JOIN (
SELECT
task_id,
MAX(timestamp) as latest_timestamp
FROM
aggregation_task_status
GROUP BY
task_id
) latest_ts ON aggregation_task_status.task_id = latest_ts.task_id
AND aggregation_task_status.timestamp = latest_ts.latest_timestamp
"#,
)?;
let query = statement
.query_map([], |row| {
Ok((
AggregationOnlyRequest {
proofs: serde_json::from_str(&row.get::<_, String>(0)?)
.map_err(|e| anyhow::anyhow!("Couldn't deserialize proofs: {e:?}"))
.unwrap(),
proof_type: Some(
ProofType::try_from(row.get::<_, u8>(1)?)
.map_err(|e| anyhow::anyhow!("Couldn't decode proof type: {e:?}"))
.unwrap()
.to_string(),
),
prover_args: ProverSpecificOpts::default(),
},
TaskStatus::from(row.get::<_, i32>(2)?),
))
})?
.collect::<Result<Vec<AggregationTaskReport>, _>>()?;

Ok(query)
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -1080,42 +1285,48 @@ impl TaskManager for SqliteTaskManager {

async fn enqueue_aggregation_task(
&mut self,
_request: &AggregationOnlyRequest,
request: &AggregationOnlyRequest,
) -> TaskManagerResult<()> {
todo!()
let task_db = self.arc_task_db.lock().await;
task_db.enqueue_aggregation_task(request)
}

async fn get_aggregation_task_proving_status(
&mut self,
_request: &AggregationOnlyRequest,
request: &AggregationOnlyRequest,
) -> TaskManagerResult<TaskProvingStatusRecords> {
todo!()
let task_db = self.arc_task_db.lock().await;
task_db.get_aggregation_task_proving_status(request)
}

async fn update_aggregation_task_progress(
&mut self,
_request: &AggregationOnlyRequest,
_status: TaskStatus,
_proof: Option<&[u8]>,
request: &AggregationOnlyRequest,
status: TaskStatus,
proof: Option<&[u8]>,
) -> TaskManagerResult<()> {
todo!()
let task_db = self.arc_task_db.lock().await;
task_db.update_aggregation_task_progress(request, status, proof)
}

async fn get_aggregation_task_proof(
&mut self,
_request: &AggregationOnlyRequest,
request: &AggregationOnlyRequest,
) -> TaskManagerResult<Vec<u8>> {
todo!()
let task_db = self.arc_task_db.lock().await;
task_db.get_aggregation_task_proof(request)
}

async fn prune_aggregation_db(&mut self) -> TaskManagerResult<()> {
todo!()
let task_db = self.arc_task_db.lock().await;
task_db.prune_aggregation_db()
}

async fn list_all_aggregation_tasks(
&mut self,
) -> TaskManagerResult<Vec<AggregationTaskReport>> {
todo!()
let task_db = self.arc_task_db.lock().await;
task_db.list_all_aggregation_tasks()
}
}

Expand Down

0 comments on commit d1ec67f

Please sign in to comment.