Skip to content

Commit

Permalink
update dependency install (#79)
Browse files Browse the repository at this point in the history
* update dependency install

* download missed model_prepare.csv && dataset, support multiple-model
worker

* add handshake before establish a job

* add push_by_topic

* fix push_by_topic

* fix ecdh-psi not return status
  • Loading branch information
cyjseagull authored Nov 5, 2024
1 parent 937ea69 commit fa96b32
Show file tree
Hide file tree
Showing 26 changed files with 240 additions and 102 deletions.
12 changes: 6 additions & 6 deletions cpp/ppc-framework/protocol/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,10 @@ class TaskResult
virtual void setFileInfo(FileInfo::Ptr _fileInfo) { m_fileInfo = std::move(_fileInfo); }

// serialize the taskResult to json
virtual Json::Value serializeToJson() const
virtual Json::Value serializeToJson()
{
Json::Value response;
response["taskID"] = m_taskID;
if (!m_status.empty())
{
response["status"] = m_status;
}
if (m_timeCost)
{
response["timeCost"] = std::to_string(m_timeCost) + "ms";
Expand All @@ -81,7 +77,7 @@ class TaskResult
response["fileID"] = m_fileInfo->fileID;
response["fileMd5"] = m_fileInfo->fileMd5;
}
if (m_error)
if (m_error && m_error->errorCode() != 0)
{
response["code"] = m_error->errorCode();
response["message"] = m_error->errorMessage();
Expand All @@ -91,6 +87,10 @@ class TaskResult
response["code"] = 0;
response["message"] = "success";
}
if (!m_status.empty())
{
response["status"] = m_status;
}
return response;
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ void OtPIRImpl::runSenderGenerateCipher(PirTaskMessage taskMessage)

void OtPIRImpl::onReceiverTaskDone(bcos::Error::Ptr _error)
{
if (m_taskState->taskDone())
if (m_taskState->taskDone() && (!_error || _error->errorCode() == 0))
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class BsEcdhResult : public protocol::TaskResult
[[nodiscard]] Json::Value const& data() const { return m_data; }

// serialize the taskResult to json
[[nodiscard]] Json::Value serializeToJson() const override
[[nodiscard]] Json::Value serializeToJson() override
{
Json::Value response;
if (m_error && error()->errorCode())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,7 @@ void CM2020PSIReceiver::onReceiverException(const std::string& _module, const st

void CM2020PSIReceiver::onReceiverTaskDone(bcos::Error::Ptr _error)
{
if (m_taskState->taskDone())
if (m_taskState->taskDone() && (!_error || _error->errorCode() == 0))
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -832,7 +832,7 @@ void CM2020PSISender::onSenderException(const std::string& _message, const std::

void CM2020PSISender::onSenderTaskDone(bcos::Error::Ptr _error)
{
if (m_taskState->taskDone())
if (m_taskState->taskDone() && (!_error || _error->errorCode() == 0))
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class CM2020PSIResult : public protocol::TaskResult
}

// serialize the taskResult to json
[[nodiscard]] Json::Value serializeToJson() const override
[[nodiscard]] Json::Value serializeToJson() override
{
Json::Value response;
response["taskID"] = taskID();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LabeledPSIResult : public protocol::TaskResult
const std::vector<std::vector<std::string>>& getOutputs() { return m_outputs; }

// serialize the taskResult to json
Json::Value serializeToJson() const override
Json::Value serializeToJson() override
{
Json::Value response;
response["taskID"] = taskID();
Expand Down
12 changes: 12 additions & 0 deletions cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskState.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,11 @@ class TaskState : public std::enable_shared_from_this<TaskState>
-1, "task " + m_task->id() + " failed for " +
boost::lexical_cast<std::string>(m_failedCount) + " error!");
result->setError(std::move(error));
result->setStatus(ppc::protocol::toString(ppc::protocol::TaskStatus::FAILED));
}
else
{
result->setStatus(ppc::protocol::toString(ppc::protocol::TaskStatus::COMPLETED));
}

// clear file
Expand All @@ -239,6 +244,7 @@ class TaskState : public std::enable_shared_from_this<TaskState>
<< LOG_KV("msg", boost::diagnostic_information(e));
auto error = std::make_shared<bcos::Error>(-1, boost::diagnostic_information(e));
result->setError(std::move(error));
result->setStatus(ppc::protocol::toString(ppc::protocol::TaskStatus::FAILED));
}
if (m_callback)
{
Expand All @@ -262,6 +268,11 @@ class TaskState : public std::enable_shared_from_this<TaskState>
}
try
{
_result->setStatus(ppc::protocol::toString(ppc::protocol::TaskStatus::COMPLETED));
if (_result->error() && _result->error()->errorCode() != 0)
{
_result->setStatus(ppc::protocol::toString(ppc::protocol::TaskStatus::FAILED));
}
// Note: we consider that the task success even if the handler exception
if (_noticePeer && !m_onlySelfRun && _result->error() &&
_result->error()->errorCode() && m_notifyPeerFinishHandler)
Expand All @@ -288,6 +299,7 @@ class TaskState : public std::enable_shared_from_this<TaskState>
<< LOG_KV("msg", boost::diagnostic_information(e));
auto error = std::make_shared<bcos::Error>(-1, boost::diagnostic_information(e));
_result->setError(std::move(error));
_result->setStatus(ppc::protocol::toString(ppc::protocol::TaskStatus::FAILED));
}
if (m_callback)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ void LocalRouter::registerTopic(bcos::bytesConstRef _nodeID, std::string const&
}
for (auto const& msgInfo : msgQueue->messages)
{
LOCAL_ROUTER_LOG(INFO) << LOG_DESC("registerTopic, dispatcher the holding msg queue")
<< LOG_KV("topic", topic) << LOG_KV("nodeID", printNodeID(_nodeID));
dispatcherMessage(msgInfo.msg, msgInfo.callback, false);
}
}
Expand All @@ -80,6 +82,15 @@ bool LocalRouter::dispatcherMessage(
P2PMessage::Ptr const& msg, ReceiveMsgFunc callback, bool holding)
{
auto frontList = chooseReceiver(msg);
auto commonCallback = [](bcos::Error::Ptr error) {
if (!error || error->errorCode() == 0)
{
return;
}
LOCAL_ROUTER_LOG(WARNING) << LOG_DESC("dispatcherMessage to front failed")
<< LOG_KV("code", error->errorCode())
<< LOG_KV("msg", error->errorMessage());
};
// find the front
if (!frontList.empty())
{
Expand All @@ -93,15 +104,7 @@ bool LocalRouter::dispatcherMessage(
}
else
{
front->onReceiveMessage(msg->msg(), [](bcos::Error::Ptr error) {
if (!error || error->errorCode() == 0)
{
return;
}
LOCAL_ROUTER_LOG(WARNING) << LOG_DESC("dispatcherMessage to front failed")
<< LOG_KV("code", error->errorCode())
<< LOG_KV("msg", error->errorMessage());
});
front->onReceiveMessage(msg->msg(), commonCallback);
}
i++;
}
Expand All @@ -122,7 +125,12 @@ bool LocalRouter::dispatcherMessage(
// no connection found, cache the topic message and dispatcher later
if (msg->header()->routeType() == (uint16_t)RouteType::ROUTE_THROUGH_TOPIC && m_cache)
{
m_cache->insertCache(msg->header()->optionalField()->topic(), msg, callback);
// send response when hodling the message
if (callback)
{
callback(nullptr);
}
m_cache->insertCache(msg->header()->optionalField()->topic(), msg, commonCallback);
return true;
}
return false;
Expand Down
1 change: 0 additions & 1 deletion cpp/wedpr-transport/ppc-rpc/src/Rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ void Rpc::runTask(Json::Value const& _req, RespFunc _respFunc)
_respFunc(result->error(), result->serializeToJson());
return;
}

_respFunc(_result->error(), _result->serializeToJson());
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@ def stop(self):
pass

@abstractmethod
def push_by_nodeid(topic: str, dstNode: bytes, seq: int, payload: bytes, timeout: int):
def push_by_nodeid(self, topic: str, dstNode: bytes, seq: int, payload: bytes, timeout: int):
pass

@abstractmethod
def push_by_inst(topic: str, dstInst: str, seq: int, payload: bytes, timeout: int):
def push_by_inst(self, topic: str, dstInst: str, seq: int, payload: bytes, timeout: int):
pass

@abstractmethod
def push_by_component(topic: str, dstInst: str, component: str, seq: int, payload: bytes, timeout: int):
def push_by_component(self, topic: str, dstInst: str, component: str, seq: int, payload: bytes, timeout: int):
pass

@abstractmethod
def push_by_topic(self, topic: str, dstInst: str, seq: int, payload: bytes, timeout: int):
pass

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def push_by_nodeid(self, topic: str, dstNode: bytes, seq: int, payload: bytes, t
RouteType.ROUTE_THROUGH_NODEID.value, route_info, payload, seq, timeout)
Transport.check_result("push_by_nodeid", result)

def push_by_topic(self, topic: str, dstInst: str, seq: int, payload: bytes, timeout: int):
route_info = self.__route_info_builder.build(
topic=topic, dst_node=None, dst_inst=dstInst, component=None)
result = self._push_msg(
RouteType.ROUTE_THROUGH_TOPIC.value, route_info, payload, seq, timeout)
Transport.check_result("push_by_topic", result)

def push_by_inst(self, topic: str, dstInst: str, seq: int, payload: bytes, timeout: int):
route_info = self.__route_info_builder.build(
topic=topic, dst_node=None, dst_inst=dstInst, component=None)
Expand Down
6 changes: 6 additions & 0 deletions python/ppc_model/common/base_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def __init__(self, job_id: str, job_temp_dir: str):
# self.get_key_pair()
self.load_key('aes_key.bin')

@staticmethod
def load_file(storage_client, remote_path, local_path, logger):
if not os.path.exists(local_path):
logger.info(f"Download file from: {remote_path} to {local_path}")
storage_client.download_file(remote_path, local_path)

@staticmethod
def feature_engineering_input_path(job_id: str, job_temp_dir: str):
return os.path.join(job_temp_dir, job_id, BaseContext.MODEL_PREPARE_FILE)
Expand Down
7 changes: 4 additions & 3 deletions python/ppc_model/common/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ppc_common.ppc_utils import common_func
from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager
from wedpr_python_gateway_sdk.transport.impl.transport_loader import TransportLoader
from ppc_model.network.wedpr_model_transport import ModelRouter
from ppc_common.deps_services.mysql_storage import MySQLStorage
from ppc_common.ppc_config.sql_storage_config_loader import SQLStorageConfigLoader
from ppc_model.network.wedpr_model_transport import ModelTransport
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(self, log_config_path, config_path, plot_lock=None):
self.pop_msg_timeout_ms = 60000
# for UT
self.transport = None
self.model_router = None
# matplotlib 线程不安全,并行任务绘图增加全局锁
self.plot_lock = plot_lock
if plot_lock is None:
Expand Down Expand Up @@ -90,15 +92,14 @@ def init_transport(self, task_manager: TaskManager,
transport.start()
self.logger().info(
f"Start transport success, config: {transport.get_config().desc()}")
transport.register_component(component_type)
self.logger().info(
f"Register the component {component_type} success")
self.transport = ModelTransport(transport=transport,
self_agency_id=self_agency_id,
task_manager=task_manager,
component_type=component_type,
send_msg_timeout_ms=send_msg_timeout_ms,
pop_msg_timeout_ms=pop_msg_timeout_ms)
self.model_router = ModelRouter(logger=self.logger(),
transport=self.transport)

def logger(self, name=None):
if self.mock_logger is None:
Expand Down
6 changes: 0 additions & 6 deletions python/ppc_model/conf/application-sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ PUBLIC_KEY_LENGTH: 2048
MAX_MESSAGE_LENGTH_MB: 100
TASK_TIMEOUT_H: 1800


PEM_PATH: "/data/app/wedpr-model/wedpr-model-node/ppc_model_service/server.pem"
SHARE_PATH: "/data/app/wedpr-model/wedpr-model-node/ppc_model_service/dataset_share/"

DB_TYPE: "mysql"
SQLALCHEMY_DATABASE_URI: "mysql://[*user_ppcsmodeladm]:[*pass_ppcsmodeladm]@[@4346-TDSQL_VIP]:[@4346-TDSQL_PORT]/ppcsmodeladm?autocommit=true&charset=utf8mb4"

Expand All @@ -33,8 +29,6 @@ gm_public_key: ""
UPLOAD_FOLDER: "./upload_data_folder"
JOB_TEMP_DIR: ".cache/job"

FE_TIMEOUT_S: 5400

# the transport config
transport_threadpool_size: 4
transport_node_id: "MODEL_WeBank_NODE"
Expand Down
6 changes: 6 additions & 0 deletions python/ppc_model/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ppc_model.common.protocol import TaskRole
from ppc_model.common.model_result import ResultFileHandling, CommonMessage, SendMessage
from ppc_model.secure_lgbm.secure_lgbm_context import SecureLGBMContext
from ppc_model.common.base_context import BaseContext


class SecureDataset:
Expand All @@ -34,6 +35,11 @@ def __init__(self, ctx: SecureLGBMContext, model_data=None, delimiter: str = ' '
self.feature_name = None

if model_data is None:
# try to download the model_prepare_file
BaseContext.load_file(ctx.components.storage_client,
os.path.join(
ctx.job_id, BaseContext.MODEL_PREPARE_FILE),
ctx.model_prepare_file, ctx.components.logger())
self.model_data = pd.read_csv(
ctx.model_prepare_file, header=0, delimiter=delimiter)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ppc_model.feature_engineering.vertical.active_party import VerticalFeatureEngineeringActiveParty
from ppc_model.feature_engineering.vertical.passive_party import VerticalFeatureEngineeringPassiveParty
from ppc_model.interface.task_engine import TaskEngine
import os


class FeatureEngineeringEngine(TaskEngine):
Expand All @@ -15,7 +16,11 @@ class FeatureEngineeringEngine(TaskEngine):
def run(task_id, args):
input_path = BaseContext.feature_engineering_input_path(
args['job_id'], components.config_data['JOB_TEMP_DIR'])

# try to download the model_prepare_file
BaseContext.load_file(components.storage_client,
os.path.join(
args['job_id'], BaseContext.MODEL_PREPARE_FILE),
input_path, components.logger())
if args['is_label_holder']:
field_list, label, feature = SecureDataset.read_dataset(
input_path, True)
Expand All @@ -33,7 +38,7 @@ def run(task_id, args):
field_list, _, feature = SecureDataset.read_dataset(
input_path, False)
context = FeatureEngineeringContext(
task_id = task_id,
task_id=task_id,
args=args,
components=components,
role=TaskRole.PASSIVE_PARTY,
Expand Down
33 changes: 29 additions & 4 deletions python/ppc_model/interface/model_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
from abc import ABC

from pandas import DataFrame
from ppc_model.network.wedpr_model_transport import ModelRouter
from ppc_model.common.protocol import TaskRole


class ModelBase(ABC):
mode: str

def __init__(self, ctx):
self.ctx = ctx
self.ctx.model_router = ModelRouter(logger=self.ctx.components.logger(),
transport=self.ctx.components.transport,
participant_id_list=self.ctx.participant_id_list)
self.ctx.model_router = self.ctx.components.model_router
if self.ctx.role == TaskRole.ACTIVE_PARTY:
self.__active_handshake__()
else:
self.__passive_handshake__()

def __active_handshake__(self):
# handshake with all passive parties
for i in range(1, len(self.ctx.participant_id_list)):
participant = self.ctx.participant_id_list[i]
self.ctx.components.logger().info(
f"Active: send handshake to passive party: {participant}")
self.ctx.model_router.handshake(self.ctx.task_id, participant)
# wait for handshake response from the passive parties
self.ctx.components.logger().info(
f"Active: wait for handshake from passive party: {participant}")
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)

def __passive_handshake__(self):
self.ctx.components.logger().info(
f"Passive: send handshake to active party: {self.ctx.participant_id_list[0]}")
# send handshake to the active party
self.ctx.model_router.handshake(
self.ctx.task_id, self.ctx.participant_id_list[0])
# wait for handshake for the active party
self.ctx.components.logger().info(
f"Passive: wait for Handshake from active party: {self.ctx.participant_id_list[0]}")
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)

def fit(
self,
Expand Down
Loading

0 comments on commit fa96b32

Please sign in to comment.