Skip to content

Commit

Permalink
fix push_by_topic
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Nov 5, 2024
1 parent 241fd1a commit 4b41b2c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ void LocalRouter::registerTopic(bcos::bytesConstRef _nodeID, std::string const&
}
for (auto const& msgInfo : msgQueue->messages)
{
LOCAL_ROUTER_LOG(INFO) << LOG_DESC("registerTopic, dispatcher the holding msg queue")
<< LOG_KV("topic", topic) << LOG_KV("nodeID", printNodeID(_nodeID));
dispatcherMessage(msgInfo.msg, msgInfo.callback, false);
}
}
Expand All @@ -80,6 +82,15 @@ bool LocalRouter::dispatcherMessage(
P2PMessage::Ptr const& msg, ReceiveMsgFunc callback, bool holding)
{
auto frontList = chooseReceiver(msg);
auto commonCallback = [](bcos::Error::Ptr error) {
if (!error || error->errorCode() == 0)
{
return;
}
LOCAL_ROUTER_LOG(WARNING) << LOG_DESC("dispatcherMessage to front failed")
<< LOG_KV("code", error->errorCode())
<< LOG_KV("msg", error->errorMessage());
};
// find the front
if (!frontList.empty())
{
Expand All @@ -93,15 +104,7 @@ bool LocalRouter::dispatcherMessage(
}
else
{
front->onReceiveMessage(msg->msg(), [](bcos::Error::Ptr error) {
if (!error || error->errorCode() == 0)
{
return;
}
LOCAL_ROUTER_LOG(WARNING) << LOG_DESC("dispatcherMessage to front failed")
<< LOG_KV("code", error->errorCode())
<< LOG_KV("msg", error->errorMessage());
});
front->onReceiveMessage(msg->msg(), commonCallback);
}
i++;
}
Expand All @@ -122,7 +125,12 @@ bool LocalRouter::dispatcherMessage(
// no connection found, cache the topic message and dispatcher later
if (msg->header()->routeType() == (uint16_t)RouteType::ROUTE_THROUGH_TOPIC && m_cache)
{
m_cache->insertCache(msg->header()->optionalField()->topic(), msg, callback);
// send response when hodling the message
if (callback)
{
callback(nullptr);
}
m_cache->insertCache(msg->header()->optionalField()->topic(), msg, commonCallback);
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ppc_model.feature_engineering.vertical.active_party import VerticalFeatureEngineeringActiveParty
from ppc_model.feature_engineering.vertical.passive_party import VerticalFeatureEngineeringPassiveParty
from ppc_model.interface.task_engine import TaskEngine
import os


class FeatureEngineeringEngine(TaskEngine):
Expand All @@ -15,7 +16,11 @@ class FeatureEngineeringEngine(TaskEngine):
def run(task_id, args):
input_path = BaseContext.feature_engineering_input_path(
args['job_id'], components.config_data['JOB_TEMP_DIR'])

# try to download the model_prepare_file
BaseContext.load_file(components.storage_client,
os.path.join(
args['job_id'], BaseContext.MODEL_PREPARE_FILE),
input_path, components.logger())
if args['is_label_holder']:
field_list, label, feature = SecureDataset.read_dataset(
input_path, True)
Expand All @@ -33,7 +38,7 @@ def run(task_id, args):
field_list, _, feature = SecureDataset.read_dataset(
input_path, False)
context = FeatureEngineeringContext(
task_id = task_id,
task_id=task_id,
args=args,
components=components,
role=TaskRole.PASSIVE_PARTY,
Expand Down
34 changes: 24 additions & 10 deletions python/ppc_model/interface/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,32 @@ def __init__(self, ctx):
self.ctx = ctx
self.ctx.model_router = self.ctx.components.model_router
if self.ctx.role == TaskRole.ACTIVE_PARTY:
# handshake with all passive parties
for i in range(1, len(self.ctx.participant_id_list)):
participant = self.ctx.participant_id_list[i]
self.ctx.components.logger().info(
f"Handshake with passive party: {participant}")
self.ctx.model_router.handshake(self.ctx.task_id, participant)
self.__active_handshake__()
else:
# wait for handshake for the active party
self.__passive_handshake__()

def __active_handshake__(self):
# handshake with all passive parties
for i in range(1, len(self.ctx.participant_id_list)):
participant = self.ctx.participant_id_list[i]
self.ctx.components.logger().info(
f"Active: send handshake to passive party: {participant}")
self.ctx.model_router.handshake(self.ctx.task_id, participant)
# wait for handshake response from the passive parties
self.ctx.components.logger().info(
f"Wait for Handshake from active party: {self.ctx.participant_id_list[0]}")
self.ctx.model_router.wait_for_handshake(
self.ctx.task_id, self.ctx.participant_id_list[0])
f"Active: wait for handshake from passive party: {participant}")
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)

def __passive_handshake__(self):
self.ctx.components.logger().info(
f"Passive: send handshake to active party: {self.ctx.participant_id_list[0]}")
# send handshake to the active party
self.ctx.model_router.handshake(
self.ctx.task_id, self.ctx.participant_id_list[0])
# wait for handshake for the active party
self.ctx.components.logger().info(
f"Passive: wait for Handshake from active party: {self.ctx.participant_id_list[0]}")
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)

def fit(
self,
Expand Down
25 changes: 16 additions & 9 deletions python/ppc_model/network/wedpr_model_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ModelRouter(ModelRouterApi):
def __init__(self, logger, transport: ModelTransport):
self.logger = logger
self.transport = transport
# task_id=>{agency=>selectedNode}
self.router_info = {}
self._rw_lock = rwlock.RWLockWrite()

Expand All @@ -87,22 +88,27 @@ def handshake(self, task_id, participant):
topic = ModelTransport.get_topic_without_agency(
task_id, BaseMessage.Handshake.value)
self.transport.transport.register_topic(topic)
self.transport.transport.push_by_topic(topic=task_id,
self.transport.transport.push_by_topic(topic=topic,
dstInst=participant,
seq=0, payload=bytes(),
timeout=self.transport.send_msg_timeout)

def wait_for_handshake(self, task_id, from_inst):
def wait_for_handshake(self, task_id):
topic = ModelTransport.get_topic_without_agency(
task_id, BaseMessage.Handshake.value)
self.transport.transport.register_topic(topic)
result = self.transport.pop_by_topic(topic=topic, task_id=task_id)

if result is None:
raise Exception(f"wait_for_handshake failed!")
self.logger.info(
f"wait_for_handshake success, task: {task_id}, detail: {result}")
with self._rw_lock.gen_wlock():
self.router_info.update(
{task_id: result.get_src_node().decode("utf-8")})
from_inst = result.get_header().get_src_inst()
if task_id not in self.router_info.keys():
self.router_info.update({task_id: dict()})
self.router_info.get(task_id).update(
{from_inst: result.get_header().get_src_node().decode("utf-8")})

def on_task_finish(self, task_id):
topic = ModelTransport.get_topic_without_agency(
Expand All @@ -112,14 +118,15 @@ def on_task_finish(self, task_id):
if task_id in self.router_info.keys():
self.router_info.pop(task_id)

def __get_dstnode_by_task_id(self, task_id):
def __get_dstnode__(self, task_id, dst_agency):
with self._rw_lock.gen_rlock():
if task_id in self.router_info.keys():
return self.router_info.get(task_id)
raise Exception(f"No Router found for task {task_id}")
if task_id in self.router_info.keys() and dst_agency in self.router_info.get(task_id).keys():
return self.router_info.get(task_id).get(dst_agency)
raise Exception(
f"No Router found for task {task_id}, dst_agency: {dst_agency}")

def push(self, task_id: str, task_type: str, dst_agency: str, payload: bytes, seq: int = 0):
dst_node = self.__get_dstnode_by_task_id(task_id)
dst_node = self.__get_dstnode__(task_id, dst_agency)
self.transport.push_by_nodeid(
task_id=task_id, task_type=task_type, dst_node=dst_node, payload=payload, seq=seq)

Expand Down

0 comments on commit 4b41b2c

Please sign in to comment.