From 6a0c53eb67ce54dbb89970f6ce3533e031d80c40 Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 8 Jan 2025 16:58:01 +0800 Subject: [PATCH] Fix --- test/collective/collective_global_gather.py | 13 ++++--------- test/collective/collective_global_scatter.py | 13 ++++--------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/test/collective/collective_global_gather.py b/test/collective/collective_global_gather.py index 77d5df10c5fdd5..70c1abd6b3e338 100644 --- a/test/collective/collective_global_gather.py +++ b/test/collective/collective_global_gather.py @@ -62,10 +62,8 @@ def run_trainer(self, args): endpoints = args["endpoints"].split(",") rank = args["trainerid"] current_endpoint = args["currentendpoint"] - if args["dynamic_static_unified_comm"]: - paddle.distributed.collective._init_parallel_env(args["backend"]) - else: - paddle.distributed.init_parallel_env() + + paddle.distributed.collective._init_parallel_env(args["backend"]) nranks = 2 if args['backend'] == 'nccl': device_id = int(os.getenv("FLAGS_selected_gpus", "0")) @@ -112,11 +110,8 @@ def run_trainer(self, args): ) if args['static_mode']: - result = ( - self.get_model(train_prog, startup_prog, rank) - if args["dynamic_static_unified_comm"] - else self.get_model(train_prog, startup_prog, rank) - ) + result = self.get_model(train_prog, startup_prog, rank) + fetch_list = [] for elem in result: fetch_list.append(elem.name) diff --git a/test/collective/collective_global_scatter.py b/test/collective/collective_global_scatter.py index 2987c30e34f28d..b63a0e564f09d3 100644 --- a/test/collective/collective_global_scatter.py +++ b/test/collective/collective_global_scatter.py @@ -63,10 +63,8 @@ def run_trainer(self, args): rank = args["trainerid"] current_endpoint = args["currentendpoint"] nranks = 2 - if args["dynamic_static_unified_comm"]: - paddle.distributed.collective._init_parallel_env(args["backend"]) - else: - paddle.distributed.init_parallel_env() + + paddle.distributed.collective._init_parallel_env(args["backend"]) if args['backend'] == 'nccl': device_id = int(os.getenv("FLAGS_selected_gpus", "0")) place = base.CUDAPlace( @@ -90,11 +88,8 @@ def run_trainer(self, args): "float32" ) if args['static_mode']: - result = ( - self.get_model(train_prog, startup_prog, rank) - if args["dynamic_static_unified_comm"] - else self.get_model(train_prog, startup_prog, rank) - ) + result = self.get_model(train_prog, startup_prog, rank) + exe = base.Executor(place) exe.run(startup_prog) fetch_list = []