Skip to content

Commit

Permalink
fix: drop long seq even if not sample packing (#2211)
Browse files Browse the repository at this point in the history
* fix: drop long seq even if not sample packing

* fix: logging import

* fix: cfg passed being none

* fix: try to fix logging

* fix: refactor call to not use accelerate log

* fix: try to fix circular import issue

* fix: don't drop when skip prepare

* chore: remove duplicate line

* fix: update warning to mention that sequences will be trimmed

* fix: do not drop seq if input_ids don't exist

* fix: increase RM unittest sequence length to reduce trim warnings

* fix: solve conflicts

* fix: default min_seq_len in case of None
  • Loading branch information
NanoCode012 authored Feb 4, 2025
1 parent 158330a commit a620d48
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 63 deletions.
4 changes: 2 additions & 2 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _tokenize_single_prompt(self, prompt):

if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)

chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
Expand All @@ -70,7 +70,7 @@ def _tokenize_single_prompt(self, prompt):

if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)

rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
Expand Down
10 changes: 7 additions & 3 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
md5,
retry_on_request_exceptions,
)
Expand All @@ -56,7 +57,7 @@
process_datasets_for_packing,
)

LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)


@retry_on_request_exceptions(max_retries=3, delay=5)
Expand Down Expand Up @@ -339,8 +340,11 @@ def for_d_in_datasets(dataset_configs):
else:
LOG.debug("NOT shuffling merged datasets")

if cfg.sample_packing and not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if not cfg.skip_prepare_dataset:
dataset = drop_long_seq_in_dataset(dataset, cfg)

if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)

if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
Expand Down
60 changes: 58 additions & 2 deletions src/axolotl/utils/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""data handling helpers"""

import functools
import hashlib
import logging
import time
from enum import Enum

import huggingface_hub
import numpy as np
import requests
from datasets import Dataset
from datasets import Dataset, IterableDataset

from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq

LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)


class RetryStrategy(Enum):
Expand Down Expand Up @@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
)

return train_dataset, eval_dataset, dataset


def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset

drop_long = functools.partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)

try:
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass

try:
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None

filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess

drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"

dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")

return dataset
1 change: 0 additions & 1 deletion src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
else:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths
62 changes: 8 additions & 54 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions"""

import json
import math
import os
Expand Down Expand Up @@ -210,6 +211,8 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
Works for both single-example (list[int]) or batched (list[list[int]]).
"""
min_sequence_len = min_sequence_len or 2

input_ids = sample["input_ids"]

# Edge case: if input_ids is empty
Expand All @@ -232,20 +235,6 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):


def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)

try:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
except AttributeError:
pass

if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
Expand All @@ -259,46 +248,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")

filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess

try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")

if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_long,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")

def drop_no_trainable_tokens(sample):
"""
Drop samples if all labels are -100 (i.e., zero trainable tokens).
Expand All @@ -325,6 +274,11 @@ def drop_no_trainable_tokens(sample):
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess

drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_reward_model_smollm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_rm_lora(self, temp_dir):
"num_labels": 1,
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"sequence_len": 2048,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
Expand Down

0 comments on commit a620d48

Please sign in to comment.