Skip to content

Commit

Permalink
fix handshake bug (#80)
Browse files Browse the repository at this point in the history
* remove handle_local_psi

* fix handshake bug
  • Loading branch information
cyjseagull authored Nov 5, 2024
1 parent fa96b32 commit 2c1aebb
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 67 deletions.
10 changes: 8 additions & 2 deletions cpp/wedpr-computing/ppc-psi/src/cm2020-psi/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,14 @@ inline uint32_t dedupDataBatch(ppc::io::DataBatch::Ptr dataBatch)
return 0;
}
auto& data = dataBatch->mutableData();
tbb::parallel_sort(data->begin(), data->end());
auto unique_end = std::unique(data->begin(), data->end());
// Note: the header field should not been sorted
auto it = data->begin() + 1;
if (it >= data->end())
{
return data->size();
}
tbb::parallel_sort(it, data->end());
auto unique_end = std::unique(it, data->end());
data->erase(unique_end, data->end());
return data->size();
}
Expand Down
1 change: 1 addition & 0 deletions python/ppc_model/common/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class Context(BaseContext):

def __init__(self, job_id: str, task_id: str, components: Initializer, role: TaskRole = None):
super().__init__(job_id, components.config_data['JOB_TEMP_DIR'])
self.my_agency_id = components.config_data['AGENCY_ID']
self.task_id = task_id
self.components = components
self.role = role
30 changes: 10 additions & 20 deletions python/ppc_model/interface/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,23 @@ class ModelBase(ABC):
def __init__(self, ctx):
self.ctx = ctx
self.ctx.model_router = self.ctx.components.model_router
if self.ctx.role == TaskRole.ACTIVE_PARTY:
self.__active_handshake__()
else:
self.__passive_handshake__()
self.__handshake__()

def __active_handshake__(self):
def __handshake__(self):
# handshake with all passive parties
for i in range(1, len(self.ctx.participant_id_list)):
for i in range(0, len(self.ctx.participant_id_list)):
participant = self.ctx.participant_id_list[i]
if participant == self.ctx.my_agency_id:
continue
self.ctx.components.logger().info(
f"Active: send handshake to passive party: {participant}")
f"Send handshake to party: {participant}")
self.ctx.model_router.handshake(self.ctx.task_id, participant)
# wait for handshake response from the passive parties
self.ctx.components.logger().info(
f"Active: wait for handshake from passive party: {participant}")
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)

def __passive_handshake__(self):
self.ctx.components.logger().info(
f"Passive: send handshake to active party: {self.ctx.participant_id_list[0]}")
# send handshake to the active party
self.ctx.model_router.handshake(
self.ctx.task_id, self.ctx.participant_id_list[0])
# wait for handshake for the active party
# wait for handshake response from the passive parties
self.ctx.components.logger().info(
f"Passive: wait for Handshake from active party: {self.ctx.participant_id_list[0]}")
self.ctx.model_router.wait_for_handshake(self.ctx.task_id)
f"Wait for handshake from all parities")
self.ctx.model_router.wait_for_handshake(
self.ctx.task_id, self.ctx.participant_id_list, self.ctx.my_agency_id)

def fit(
self,
Expand Down
45 changes: 30 additions & 15 deletions python/ppc_model/network/wedpr_model_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,37 @@ def handshake(self, task_id, participant):
seq=0, payload=bytes(),
timeout=self.transport.send_msg_timeout)

def wait_for_handshake(self, task_id):
topic = ModelTransport.get_topic_without_agency(
task_id, BaseMessage.Handshake.value)
self.transport.transport.register_topic(topic)
result = self.transport.pop_by_topic(topic=topic, task_id=task_id)

if result is None:
raise Exception(f"wait_for_handshake failed!")
self.logger.info(
f"wait_for_handshake success, task: {task_id}, detail: {result}")
with self._rw_lock.gen_wlock():
from_inst = result.get_header().get_src_inst()
def __all_connected__(self, task_id, participant_id_list, self_agency_id):
with self._rw_lock.gen_rlock():
if task_id not in self.router_info.keys():
self.router_info.update({task_id: dict()})
self.router_info.get(task_id).update(
{from_inst: result.get_header().get_src_node().decode("utf-8")})
return False
for participant in participant_id_list:
if participant == self_agency_id:
continue
if participant not in self.router_info.get(task_id).keys():
return False
self.logger.info(
f"__all_connected__, task: {task_id}, participant_id_list: {participant_id_list}")
return True

def wait_for_handshake(self, task_id, participant_id_list: list, self_agency_id):
while not self.__all_connected__(task_id, participant_id_list, self_agency_id):
time.sleep(0.04)
topic = ModelTransport.get_topic_without_agency(
task_id, BaseMessage.Handshake.value)
self.transport.transport.register_topic(topic)
result = self.transport.pop_by_topic(topic=topic, task_id=task_id)

if result is None:
raise Exception(f"wait_for_handshake failed!")
self.logger.info(
f"wait_for_handshake success, task: {task_id}, detail: {result}")
with self._rw_lock.gen_wlock():
from_inst = result.get_header().get_src_inst()
if task_id not in self.router_info.keys():
self.router_info.update({task_id: dict()})
self.router_info.get(task_id).update(
{from_inst: result.get_header().get_src_node().decode("utf-8")})

def on_task_finish(self, task_id):
topic = ModelTransport.get_topic_without_agency(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ def processing(self):
if need_psi and (not utils.file_exists(psi_result_path)):
storage_client.download_file(
self.ctx.remote_psi_result_path, psi_result_path)
self.handle_local_psi_result(
self.ctx.remote_psi_result_path, psi_result_path)
log.info(
f"prepare_xgb_after_psi, make_dataset_to_xgb_data_plus_psi_data, dataset_file_path={dataset_file_path}, "
f"psi_result_path={dataset_file_path}, model_prepare_file={model_prepare_file}")
f"psi_result_path={psi_result_path}, model_prepare_file={model_prepare_file}, "
f"remote_psi_result_path: {self.ctx.remote_psi_result_path}")
self.make_dataset_to_xgb_data()
storage_client.upload_file(
model_prepare_file, job_id + os.sep + BaseContext.MODEL_PREPARE_FILE)
Expand All @@ -48,25 +47,6 @@ def processing(self):
log.info(
f"call prepare_xgb_after_psi success, job_id={job_id}, timecost: {time.time() - start}")

def handle_local_psi_result(self, remote_psi_result_path, local_psi_result_path):
try:
log = self.ctx.components.logger()
log.info(
f"handle_local_psi_result: start handle_local_psi_result, psi_result_path={local_psi_result_path}")
with open(local_psi_result_path, 'r+', encoding='utf-8') as psi_result_file:
content = psi_result_file.read()
psi_result_file.seek(0, 0)
psi_result_file.write('id\n' + content)
log.info(
f"handle_local_psi_result: call handle_local_psi_result success, psi_result_path={local_psi_result_path}")
# upload to remote
self.ctx.components.storage_client.upload_file(
local_psi_result_path, remote_psi_result_path)
except BaseException as e:
log.exception(
f"handle_local_psi_result: handle_local_psi_result, psi_result_path={local_psi_result_path}, error:{e}")
raise e

def make_dataset_to_xgb_data(self):
log = self.ctx.components.logger()
dataset_file_path = self.ctx.dataset_file_path
Expand Down
22 changes: 14 additions & 8 deletions python/ppc_model/secure_lr/vertical/active_party.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def fit(
self.log.info(
f'task {self.ctx.task_id}: Starting the lr on the active party.')
self._init_active_data()
max_iter = self._init_iter(self.dataset.train_X.shape[0],

max_iter = self._init_iter(self.dataset.train_X.shape[0],
self.params.epochs, self.params.batch_size)
self.log.info(f"task: {self.ctx.task_id}, max_iter: {max_iter}")
for _ in range(max_iter):
self._iter_id += 1
start_time = time.time()
Expand All @@ -59,7 +60,8 @@ def fit(
self._build_iter(feature_select, idx)

# 预测
self._train_praba = self._predict_tree(self.dataset.train_X, LRMessage.PREDICT_LEAF_MASK.value)
self._train_praba = self._predict_tree(
self.dataset.train_X, LRMessage.PREDICT_LEAF_MASK.value)
# print('train_praba', set(self._train_praba))

# 评估
Expand All @@ -69,10 +71,11 @@ def fit(
self.log.info(
f'task {self.ctx.task_id}: iter-{self._iter_id}, auc: {auc}.')
self.log.info(f'task {self.ctx.task_id}: Ending iter-{self._iter_id}, '
f'time_costs: {time.time() - start_time}s.')
f'time_costs: {time.time() - start_time}s.')

# 预测验证集
self._test_praba = self._predict_tree(self.dataset.test_X, LRMessage.TEST_LEAF_MASK.value)
self._test_praba = self._predict_tree(
self.dataset.test_X, LRMessage.TEST_LEAF_MASK.value)
if not self.params.silent and self.dataset.test_y is not None:
auc = Evaluation.fevaluation(
self.dataset.test_y, self._test_praba)['auc']
Expand All @@ -89,7 +92,8 @@ def predict(self, dataset: SecureDataset = None) -> np.ndarray:
if dataset is None:
dataset = self.dataset

test_praba = self._predict_tree(dataset.test_X, LRMessage.VALID_LEAF_MASK.value)
test_praba = self._predict_tree(
dataset.test_X, LRMessage.VALID_LEAF_MASK.value)
self._test_praba = test_praba

if dataset.test_y is not None:
Expand Down Expand Up @@ -139,8 +143,10 @@ def _build_iter(self, feature_select, idx):
public_key_list, d_other_list, partner_index_list = self._receive_d_instance_list()
deriv = self._calculate_deriv(x_, d, partner_index_list, d_other_list)

self._train_weights -= self.params.learning_rate * deriv.astype('float')
self._train_weights[~np.isin(np.arange(len(self._train_weights)), feature_select)] = 0
self._train_weights -= self.params.learning_rate * \
deriv.astype('float')
self._train_weights[~np.isin(
np.arange(len(self._train_weights)), feature_select)] = 0

def _predict_tree(self, X, key_type):
train_g = self._loss_func.dot_product(X, self._train_weights)
Expand Down
1 change: 1 addition & 0 deletions python/tools/install.sh
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
apt-get install pkg-config python3-dev default-libmysqlclient-dev build-essential
apt-get install graphviz

0 comments on commit 2c1aebb

Please sign in to comment.