Skip to content

Commit

Permalink
add count trained tokens (#9800)
Browse files Browse the repository at this point in the history
* add count

* fix

* fix
  • Loading branch information
lugimzzz authored Feb 20, 2025
1 parent 347d77c commit b8ebe3e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
27 changes: 27 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ def fn(layer):

# very last
self._memory_tracker.stop_and_update_metrics()
if self.args.count_trained_tokens:
self.trained_effective_tokens = 0
self.trained_tokens = 0

def _wrap_amp_model(self, args, model):
logger.info("Using half precision")
Expand Down Expand Up @@ -1122,6 +1125,9 @@ def _inner_training_loop(
is_no_sync = True

sync_context = model.no_sync() if is_no_sync else contextlib.nullcontext()
if self.args.count_trained_tokens:
self.trained_effective_tokens += (inputs["input_ids"] != self.args.pad_token_id).sum()
self.trained_tokens += inputs["input_ids"].numel()
with sync_context:
if "step_control" in inspect.signature(self.training_step).parameters:
tr_loss_step = self.training_step(model, inputs, step_control=step_control)
Expand Down Expand Up @@ -1570,6 +1576,27 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
self._save_checkpoint(model, metrics=metrics)
logger.info(f"{self.runtime_timer.log()}")
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
self.log_trained_tokens()

def log_trained_tokens(self):
if self.args.count_trained_tokens:
token_list = []
for token_num in [self.trained_effective_tokens, self.trained_tokens]:
tensors = token_num.reshape([1])
if self.hcg._sharding_degree > 1:
output_tensors = []
paddle.distributed.all_gather(output_tensors, tensors, group=self.hcg._sharding_comm_group)
tensors = paddle.concat(output_tensors).sum().reshape([1])
if self.hcg._dp_degree > 1:
output_tensors = []
paddle.distributed.all_gather(output_tensors, tensors, group=self.hcg._dp_comm_group)
tensors = paddle.concat(output_tensors).sum().reshape([1])
token_list.append(tensors.item())
if self.is_local_process_zero():

logger.info(
f"Update to now, trained_effective_tokens: {token_list[0]}, trained_tokens: {token_list[1]}."
)

def _get_learning_rate(self):
return self.optimizer.get_lr()
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,14 @@ class TrainingArguments:
default=300,
metadata={"help": "Timeout seconds for downloading checkpoint from remote cluster."},
)
count_trained_tokens: bool = field(
default=False,
metadata={"help": "Whether to count trained tokens."},
)
pad_token_id: int = field(
default=0,
metadata={"help": "The id of the padding token."},
)

def __post_init__(self):
if in_auto_parallel_align_mode():
Expand Down

0 comments on commit b8ebe3e

Please sign in to comment.