From 5ee0d7a56d9f0db876a7ddf56711395da57c6bd4 Mon Sep 17 00:00:00 2001 From: Morty Date: Fri, 10 Jan 2025 17:20:56 +0800 Subject: [PATCH] fix: init multi task type issue --- prover/Cargo.lock | 2 +- prover/Cargo.toml | 2 +- prover/src/config.rs | 51 ------------ prover/src/main.rs | 22 +++-- prover/src/prover.rs | 90 +++++++-------------- prover/src/utils.rs | 24 ------ prover/src/version.rs | 18 ----- prover/src/zk_circuits_handler.rs | 43 ++++------ prover/src/zk_circuits_handler/darwin.rs | 7 +- prover/src/zk_circuits_handler/darwin_v2.rs | 7 +- 10 files changed, 70 insertions(+), 196 deletions(-) delete mode 100644 prover/src/config.rs delete mode 100644 prover/src/version.rs diff --git a/prover/Cargo.lock b/prover/Cargo.lock index 8ece6182b7..be3c20490d 100644 --- a/prover/Cargo.lock +++ b/prover/Cargo.lock @@ -4237,7 +4237,7 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "scroll-proving-sdk" version = "0.1.0" -source = "git+https://github.com/scroll-tech/scroll-proving-sdk.git?rev=61bbbe1#61bbbe1f5e28ae6da4ca2c161f830c8e48f9483b" +source = "git+https://github.com/scroll-tech/scroll-proving-sdk.git?rev=e29b98d#e29b98d441b4c8fdcf4c2497da83f809ca202c8e" dependencies = [ "anyhow", "async-trait", diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 65b1cef2c5..8f22ce292d 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -31,7 +31,7 @@ halo2_proofs = { git = "https://github.com/scroll-tech/halo2.git", branch = "v1. snark-verifier-sdk = { git = "https://github.com/scroll-tech/snark-verifier", branch = "develop", default-features = false, features = ["loader_halo2", "loader_evm", "halo2-pse"] } prover_darwin = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.12.2", package = "prover", default-features = false, features = ["parallel_syn", "scroll"] } prover_darwin_v2 = { git = "https://github.com/scroll-tech/zkevm-circuits.git", tag = "v0.13.1", package = "prover", default-features = false, features = ["parallel_syn", "scroll"] } -scroll-proving-sdk = { git = "https://github.com/scroll-tech/scroll-proving-sdk.git", rev = "61bbbe1"} +scroll-proving-sdk = { git = "https://github.com/scroll-tech/scroll-proving-sdk.git", rev = "e29b98d"} base64 = "0.13.1" reqwest = { version = "0.12.4", features = ["gzip"] } reqwest-middleware = "0.3" diff --git a/prover/src/config.rs b/prover/src/config.rs deleted file mode 100644 index 4effb958d0..0000000000 --- a/prover/src/config.rs +++ /dev/null @@ -1,51 +0,0 @@ -use anyhow::{bail, Result}; - -static SCROLL_PROVER_ASSETS_DIR_ENV_NAME: &str = "SCROLL_PROVER_ASSETS_DIR"; -static mut SCROLL_PROVER_ASSETS_DIRS: Vec = vec![]; - -#[derive(Debug)] -pub struct AssetsDirEnvConfig {} - -impl AssetsDirEnvConfig { - pub fn init() -> Result<()> { - let value = std::env::var(SCROLL_PROVER_ASSETS_DIR_ENV_NAME)?; - let dirs: Vec<&str> = value.split(',').collect(); - if dirs.len() != 2 { - bail!("env variable SCROLL_PROVER_ASSETS_DIR value must be 2 parts seperated by comma.") - } - unsafe { - SCROLL_PROVER_ASSETS_DIRS = dirs.into_iter().map(|s| s.to_string()).collect(); - log::info!( - "init SCROLL_PROVER_ASSETS_DIRS: {:?}", - SCROLL_PROVER_ASSETS_DIRS - ); - } - Ok(()) - } - - pub fn enable_first() { - unsafe { - log::info!( - "set env {SCROLL_PROVER_ASSETS_DIR_ENV_NAME} to {}", - &SCROLL_PROVER_ASSETS_DIRS[0] - ); - std::env::set_var( - SCROLL_PROVER_ASSETS_DIR_ENV_NAME, - &SCROLL_PROVER_ASSETS_DIRS[0], - ); - } - } - - pub fn enable_second() { - unsafe { - log::info!( - "set env {SCROLL_PROVER_ASSETS_DIR_ENV_NAME} to {}", - &SCROLL_PROVER_ASSETS_DIRS[1] - ); - std::env::set_var( - SCROLL_PROVER_ASSETS_DIR_ENV_NAME, - &SCROLL_PROVER_ASSETS_DIRS[1], - ); - } - } -} diff --git a/prover/src/main.rs b/prover/src/main.rs index d01d3144ff..c32bde5ee6 100644 --- a/prover/src/main.rs +++ b/prover/src/main.rs @@ -1,16 +1,19 @@ #![feature(lazy_cell)] #![feature(core_intrinsics)] -mod config; mod prover; mod types; mod utils; -mod version; mod zk_circuits_handler; use clap::{ArgAction, Parser}; use prover::LocalProver; -use scroll_proving_sdk::{config::Config, prover::ProverBuilder, utils::init_tracing}; +use scroll_proving_sdk::{ + config::Config, + prover::ProverBuilder, + utils::{get_version, init_tracing}, +}; +use utils::get_prover_type; #[derive(Parser, Debug)] #[clap(disable_version_flag = true)] @@ -35,18 +38,25 @@ async fn main() -> anyhow::Result<()> { let args = Args::parse(); if args.version { - println!("version is {}", version::get_version()); + println!("version is {}", get_version()); std::process::exit(0); } - utils::log_init(args.log_file); - let cfg: Config = Config::from_file(args.config_file)?; + let mut prover_types = vec![]; + cfg.prover.circuit_types.iter().for_each(|circuit_type| { + if let Some(pt) = get_prover_type(*circuit_type) { + if !prover_types.contains(&pt) { + prover_types.push(pt); + } + } + }); let local_prover = LocalProver::new( cfg.prover .local .clone() .ok_or_else(|| anyhow::anyhow!("Missing local prover configuration"))?, + prover_types, ); let prover = ProverBuilder::new(cfg) .with_proving_service(Box::new(local_prover)) diff --git a/prover/src/prover.rs b/prover/src/prover.rs index b23adfdca4..ed96c670a5 100644 --- a/prover/src/prover.rs +++ b/prover/src/prover.rs @@ -1,4 +1,5 @@ use crate::{ + types::ProverType, utils::get_prover_type, zk_circuits_handler::{CircuitsHandler, CircuitsHandlerProvider}, }; @@ -18,13 +19,14 @@ use std::{ sync::{Arc, Mutex}, time::{SystemTime, UNIX_EPOCH}, }; -use tokio::{runtime::Runtime, sync::RwLock, task::JoinHandle}; +use tokio::sync::RwLock; pub struct LocalProver { config: LocalProverConfig, + prover_types: Vec, circuits_handler_provider: RwLock, - current_task: Arc>>>>, next_task_id: Arc>, + result: Arc>>, } #[async_trait] @@ -52,26 +54,15 @@ impl ProvingService for LocalProver { GetVkResponse { vks, error: None } } async fn prove(&self, req: ProveRequest) -> ProveResponse { - let prover_type = match get_prover_type(req.circuit_type) { - Some(pt) => pt, - None => { - return build_prove_error_response( - String::new(), - TaskStatus::Failed, - None, - String::from("unsupported prover_type"), - ) - } - }; let handler = self .circuits_handler_provider .write() .await - .get_circuits_handler(&req.hard_fork_name, prover_type) + .get_circuits_handler(&req.hard_fork_name, self.prover_types.clone()) .context("failed to get circuit handler") .unwrap(); - match self.do_prove(req.clone(), handler) { + match self.do_prove(req.clone(), handler).await { Ok(resp) => resp, Err(e) => build_prove_error_response( String::new(), @@ -83,85 +74,58 @@ impl ProvingService for LocalProver { } async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse { - let mut current_task = self.current_task.lock().unwrap(); - - if let Some(handle) = current_task.take() { - if handle.is_finished() { - let result = Runtime::new().unwrap().block_on(handle).unwrap(); - match result { - Ok(proof) => { - return build_query_task_response( - req.task_id, - TaskStatus::Success, - Some(proof), - None, - ) - } - Err(e) => { - return build_query_task_response( - req.task_id, - TaskStatus::Failed, - None, - Some(e.to_string()), - ) - } - } - } else { - *current_task = Some(handle); - return build_query_task_response(req.task_id, TaskStatus::Proving, None, None); - } - } else { - let task_id = req.task_id.clone(); - return build_query_task_response( + let mut result_guard = self.result.lock().unwrap(); + let resp = match result_guard.as_ref() { + Ok(proof) => build_query_task_response( + req.task_id, + TaskStatus::Success, + Some(proof.clone()), + None, + ), + Err(e) => build_query_task_response( req.task_id, TaskStatus::Failed, None, - Some(String::from(&format!( - "failed to query task, task_id: {}", - task_id - ))), - ); - } + Some(e.to_string()), + ), + }; + *result_guard = Err(anyhow::Error::msg("prover not started")); + resp } } impl LocalProver { - pub fn new(config: LocalProverConfig) -> Self { + pub fn new(config: LocalProverConfig, prover_types: Vec) -> Self { let circuits_handler_provider = CircuitsHandlerProvider::new(config.clone()) .context("failed to create circuits handler provider") .unwrap(); Self { config, + prover_types, circuits_handler_provider: RwLock::new(circuits_handler_provider), - current_task: Arc::new(Mutex::new(None)), next_task_id: Arc::new(Mutex::new(0)), + result: Arc::new(Mutex::new(Err(anyhow::Error::msg("prover not started")))), } } - fn do_prove( + async fn do_prove( &self, req: ProveRequest, handler: Arc>, ) -> Result { - let mut current_task = self.current_task.lock().unwrap(); - if current_task.is_some() { - return Err(anyhow::Error::msg("prover working on previous task")); - } - let task_id = { let mut next_task_id = self.next_task_id.lock().unwrap(); *next_task_id += 1; *next_task_id }; - let req_clone = req.clone(); - let handle = tokio::spawn(async move { handler.get_proof_data(req_clone).await }); - *current_task = Some(handle); - let duration = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); let created_at = duration.as_secs() as f64 + duration.subsec_nanos() as f64 * 1e-9; + let result = handler.get_proof_data(req.clone()).await; + *self.result.lock().unwrap() = result; + Ok(ProveResponse { task_id: task_id.to_string(), circuit_type: req.circuit_type, diff --git a/prover/src/utils.rs b/prover/src/utils.rs index cca12bc5fc..b0602e9965 100644 --- a/prover/src/utils.rs +++ b/prover/src/utils.rs @@ -1,30 +1,6 @@ -use env_logger::Env; -use std::{fs::OpenOptions, sync::Once}; - use crate::types::ProverType; use scroll_proving_sdk::prover::types::CircuitType; -static LOG_INIT: Once = Once::new(); - -/// Initialize log -pub fn log_init(log_file: Option) { - LOG_INIT.call_once(|| { - let mut builder = env_logger::Builder::from_env(Env::default().default_filter_or("info")); - if let Some(file_path) = log_file { - let target = Box::new( - OpenOptions::new() - .write(true) - .create(true) - .truncate(false) - .open(file_path) - .expect("Can't create log file"), - ); - builder.target(env_logger::Target::Pipe(target)); - } - builder.init(); - }); -} - pub fn get_circuit_types(prover_type: ProverType) -> Vec { match prover_type { ProverType::Chunk => vec![CircuitType::Chunk], diff --git a/prover/src/version.rs b/prover/src/version.rs deleted file mode 100644 index 76249adeae..0000000000 --- a/prover/src/version.rs +++ /dev/null @@ -1,18 +0,0 @@ -use std::cell::OnceCell; - -static DEFAULT_COMMIT: &str = "unknown"; -static mut VERSION: OnceCell = OnceCell::new(); - -pub const TAG: &str = "v0.0.0"; -pub const DEFAULT_ZK_VERSION: &str = "000000-000000"; - -fn init_version() -> String { - let commit = option_env!("GIT_REV").unwrap_or(DEFAULT_COMMIT); - let tag = option_env!("GO_TAG").unwrap_or(TAG); - let zk_version = option_env!("ZK_VERSION").unwrap_or(DEFAULT_ZK_VERSION); - format!("{tag}-{commit}-{zk_version}") -} - -pub fn get_version() -> String { - unsafe { VERSION.get_or_init(init_version).clone() } -} diff --git a/prover/src/zk_circuits_handler.rs b/prover/src/zk_circuits_handler.rs index 6436473c34..5185259018 100644 --- a/prover/src/zk_circuits_handler.rs +++ b/prover/src/zk_circuits_handler.rs @@ -2,7 +2,7 @@ mod common; mod darwin; mod darwin_v2; -use crate::{config::AssetsDirEnvConfig, types::ProverType, utils::get_circuit_types}; +use crate::{types::ProverType, utils::get_circuit_types}; use anyhow::{bail, Result}; use async_trait::async_trait; use darwin::DarwinHandler; @@ -28,15 +28,15 @@ pub trait CircuitsHandler: Send + Sync { async fn get_proof_data(&self, prove_request: ProveRequest) -> Result; } -type CircuitsHandlerBuilder = - fn(prover_type: ProverType, config: &LocalProverConfig) -> Result>; +type CircuitsHandlerBuilder = fn( + prover_types: Vec, + config: &LocalProverConfig, +) -> Result>; pub struct CircuitsHandlerProvider { config: LocalProverConfig, circuits_handler_builder_map: HashMap, - current_fork_name: Option, - current_prover_type: Option, current_circuit: Option>>, } @@ -44,22 +44,16 @@ impl CircuitsHandlerProvider { pub fn new(config: LocalProverConfig) -> Result { let mut m: HashMap = HashMap::new(); - if let Err(e) = AssetsDirEnvConfig::init() { - log::error!("AssetsDirEnvConfig init failed: {:#}", e); - std::process::exit(-2); - } - fn handler_builder( - prover_type: ProverType, + prover_types: Vec, config: &LocalProverConfig, ) -> Result> { log::info!( "now init zk circuits handler, hard_fork_name: {}", &config.low_version_circuit.hard_fork_name ); - AssetsDirEnvConfig::enable_first(); DarwinHandler::new( - prover_type, + prover_types, &config.low_version_circuit.params_path, &config.low_version_circuit.assets_path, ) @@ -71,16 +65,15 @@ impl CircuitsHandlerProvider { ); fn next_handler_builder( - prover_type: ProverType, + prover_types: Vec, config: &LocalProverConfig, ) -> Result> { log::info!( "now init zk circuits handler, hard_fork_name: {}", &config.high_version_circuit.hard_fork_name ); - AssetsDirEnvConfig::enable_second(); DarwinV2Handler::new( - prover_type, + prover_types, &config.high_version_circuit.params_path, &config.high_version_circuit.assets_path, ) @@ -96,7 +89,6 @@ impl CircuitsHandlerProvider { config, circuits_handler_builder_map: m, current_fork_name: None, - current_prover_type: None, current_circuit: None, }; @@ -106,7 +98,7 @@ impl CircuitsHandlerProvider { pub fn get_circuits_handler( &mut self, hard_fork_name: &String, - prover_type: ProverType, + prover_types: Vec, ) -> Result>> { match &self.current_fork_name { Some(fork_name) if fork_name == hard_fork_name => { @@ -123,13 +115,12 @@ impl CircuitsHandlerProvider { ); if let Some(builder) = self.circuits_handler_builder_map.get(hard_fork_name) { log::info!("building circuits handler for {hard_fork_name}"); - let handler = builder(prover_type, &self.config) + let handler = builder(prover_types, &self.config) .expect("failed to build circuits handler"); self.current_fork_name = Some(hard_fork_name.clone()); - self.current_prover_type = Some(prover_type); - let rc_handler = Arc::new(handler); - self.current_circuit = Some(rc_handler.clone()); - Ok(rc_handler) + let arc_handler = Arc::new(handler); + self.current_circuit = Some(arc_handler.clone()); + Ok(arc_handler) } else { bail!("missing builder, there must be something wrong.") } @@ -144,10 +135,10 @@ impl CircuitsHandlerProvider { ) -> Vec { let mut vks: Vec = Vec::new(); for (hard_fork_name, build) in self.circuits_handler_builder_map.iter() { - for prover_type in prover_types.iter() { - let handler = - build(*prover_type, config).expect("failed to build circuits handler"); + let handler = + build(prover_types.clone(), config).expect("failed to build circuits handler"); + for prover_type in prover_types.iter() { for task_type in get_circuit_types(*prover_type).into_iter() { let vk = handler .get_vk(task_type) diff --git a/prover/src/zk_circuits_handler/darwin.rs b/prover/src/zk_circuits_handler/darwin.rs index e2ae353905..1644dabb4b 100644 --- a/prover/src/zk_circuits_handler/darwin.rs +++ b/prover/src/zk_circuits_handler/darwin.rs @@ -87,8 +87,8 @@ impl DarwinHandler { Ok(handler) } - pub fn new(prover_type: ProverType, params_dir: &str, assets_dir: &str) -> Result { - Self::new_multi(vec![prover_type], params_dir, assets_dir) + pub fn new(prover_types: Vec, params_dir: &str, assets_dir: &str) -> Result { + Self::new_multi(prover_types, params_dir, assets_dir) } async fn gen_chunk_proof_raw(&self, chunk_trace: Vec) -> Result { @@ -214,11 +214,12 @@ mod tests { use super::*; use crate::zk_circuits_handler::utils::encode_vk; use prover_darwin::utils::chunk_trace_to_witness_block; + use scroll_proving_sdk::utils::init_tracing; use std::{path::PathBuf, sync::LazyLock}; #[ctor::ctor] fn init() { - crate::utils::log_init(None); + init_tracing(); log::info!("logger initialized"); } diff --git a/prover/src/zk_circuits_handler/darwin_v2.rs b/prover/src/zk_circuits_handler/darwin_v2.rs index a81776faf8..d6e5813ff9 100644 --- a/prover/src/zk_circuits_handler/darwin_v2.rs +++ b/prover/src/zk_circuits_handler/darwin_v2.rs @@ -87,8 +87,8 @@ impl DarwinV2Handler { Ok(handler) } - pub fn new(prover_type: ProverType, params_dir: &str, assets_dir: &str) -> Result { - Self::new_multi(vec![prover_type], params_dir, assets_dir) + pub fn new(prover_types: Vec, params_dir: &str, assets_dir: &str) -> Result { + Self::new_multi(prover_types, params_dir, assets_dir) } async fn gen_chunk_proof_raw(&self, chunk_trace: Vec) -> Result { @@ -218,11 +218,12 @@ mod tests { aggregator::eip4844, utils::chunk_trace_to_witness_block, BatchData, BatchHeader, MAX_AGG_SNARKS, }; + use scroll_proving_sdk::utils::init_tracing; use std::{path::PathBuf, sync::LazyLock}; #[ctor::ctor] fn init() { - crate::utils::log_init(None); + init_tracing(); log::info!("logger initialized"); }