Skip to content

Commit

Permalink
fix ci
Browse files Browse the repository at this point in the history
  • Loading branch information
DrownFish19 committed Feb 27, 2025
1 parent 87ba5d7 commit 1a62f74
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions llm/alignment/ppo/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
)
from paddlenlp.transformers.model_utils import _add_variant
from paddlenlp.utils.env import PADDLE_WEIGHTS_NAME

from paddlenlp.trainer.utils import distributed_concat

class StepTrainer(Trainer):
"""
Expand Down Expand Up @@ -235,7 +235,8 @@ def get_train_step_vars(self, vars: Optional[Dict] = None) -> Dict:
# should be called after model is wrapped since the model field should
# use model_wrapped.

assert self.model is not self.model_wrapped
if paddle.distributed.get_world_size() > 1:
assert self.model is not self.model_wrapped
self.train_step_vars = {
# meaningless vars can pass from outter, dummy value is enough
"epoch": 0, # meaningless for step training
Expand Down Expand Up @@ -789,9 +790,8 @@ def update(self, metrics: Dict[str, paddle.Tensor]) -> Union[None, Dict[str, flo

self.counter += 1
if self.counter == self.freq:
from paddlenlp.trainer.utils import distributed_concat
metrics = distributed_concat(self.metrics) if paddle.distributed.get_world_size() > 1 else self.metrics

metrics = distributed_concat(self.metrics)
out_metrics = {}
if self.use_stack:
mean_metric = metrics.mean(0)
Expand Down Expand Up @@ -2712,18 +2712,21 @@ def normalize_batch_data(
batch_rewards = paddle.concat(batch_rewards_list, axis=0)
batch_rewards = batch_rewards.cast(paddle.float32)

hcg = fleet.get_hybrid_communicate_group()
sd_group = hcg.get_sharding_parallel_group()
dp_group = hcg.get_data_parallel_group()

if sd_group.nranks > 1:
all_gather_batch_rewards = []
dist.all_gather(all_gather_batch_rewards, batch_rewards, group=sd_group)
batch_rewards = paddle.flatten(paddle.stack(all_gather_batch_rewards))
if dp_group.nranks > 1:
all_gather_batch_rewards = []
dist.all_gather(all_gather_batch_rewards, batch_rewards, group=dp_group)
batch_rewards = paddle.flatten(paddle.stack(all_gather_batch_rewards))
try:
hcg = fleet.get_hybrid_communicate_group()
sd_group = hcg.get_sharding_parallel_group()
dp_group = hcg.get_data_parallel_group()

if sd_group.nranks > 1:
all_gather_batch_rewards = []
dist.all_gather(all_gather_batch_rewards, batch_rewards, group=sd_group)
batch_rewards = paddle.flatten(paddle.stack(all_gather_batch_rewards))
if dp_group.nranks > 1:
all_gather_batch_rewards = []
dist.all_gather(all_gather_batch_rewards, batch_rewards, group=dp_group)
batch_rewards = paddle.flatten(paddle.stack(all_gather_batch_rewards))
except AttributeError:
pass

batch_rewards_mean = batch_rewards.mean()
# batch_rewards_std = batch_rewards.std()
Expand Down Expand Up @@ -2819,9 +2822,6 @@ def normalize_batch_data(
rl_batch.pop("prompt")

if use_advantage_normalization:
hcg = fleet.get_hybrid_communicate_group()
sd_group = hcg.get_sharding_parallel_group()
dp_group = hcg.get_data_parallel_group()
all_advantages_list = []
for rl_batch in rl_batches:
sequence_mask = rl_batch["sequence_mask"].cast(paddle.int64) # length: src + tgt
Expand All @@ -2830,17 +2830,23 @@ def normalize_batch_data(
all_advantages = paddle.concat(all_advantages_list, axis=0)
all_advantages = all_advantages.cast(paddle.float32)

if sd_group.nranks > 1:
object_list = []
dist.all_gather_object(object_list, all_advantages.tolist(), group=sd_group)
flattened_data = [item for sublist in object_list for item in sublist]
all_advantages = paddle.to_tensor(flattened_data, dtype="float32")
if dp_group.nranks > 1:
object_list = []
dist.all_gather_object(object_list, all_advantages.tolist(), group=dp_group)
flattened_data = [item for sublist in object_list for item in sublist]
all_advantages = paddle.to_tensor(flattened_data, dtype="float32")

try:
hcg = fleet.get_hybrid_communicate_group()
sd_group = hcg.get_sharding_parallel_group()
dp_group = hcg.get_data_parallel_group()

if sd_group.nranks > 1:
object_list = []
dist.all_gather_object(object_list, all_advantages.tolist(), group=sd_group)
flattened_data = [item for sublist in object_list for item in sublist]
all_advantages = paddle.to_tensor(flattened_data, dtype="float32")
if dp_group.nranks > 1:
object_list = []
dist.all_gather_object(object_list, all_advantages.tolist(), group=dp_group)
flattened_data = [item for sublist in object_list for item in sublist]
all_advantages = paddle.to_tensor(flattened_data, dtype="float32")
except AttributeError:
pass
all_advantages_mean = all_advantages.mean()
all_advantages_std = all_advantages.std()
for rl_batch in rl_batches:
Expand Down

0 comments on commit 1a62f74

Please sign in to comment.