From adbd1e77481bbca794309e354470d39fb74dc872 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 06:21:49 -0500 Subject: [PATCH 01/11] allow flexibility in transformers version for FSDP --- requirements.txt | 2 +- src/axolotl/core/trainer_builder.py | 55 ++++++++++++++++++++--------- 2 files changed, 40 insertions(+), 17 deletions(-) diff --git a/requirements.txt b/requirements.txt index 1d21cb354c..ae1b1838d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ packaging==23.2 peft==0.14.0 -transformers==4.47.0 +transformers>=4.46.3 tokenizers>=0.20.1 bitsandbytes==0.45.0 accelerate==1.2.0 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index baac94da80..76b093f546 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -22,6 +22,7 @@ import torch import transformers from datasets import Dataset +from packaging import version from peft.optimizers import create_loraplus_optimizer from torch import nn from torch.optim.lr_scheduler import OneCycleLR @@ -973,7 +974,10 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super().log(logs, start_time) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super().log(logs, start_time) + return super().log(logs) # transformers<=4.46 def store_metrics( self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" @@ -1165,9 +1169,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(DPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(DPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): @@ -1185,9 +1193,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): @@ -1232,9 +1244,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(KTOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(KTOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): @@ -1252,9 +1268,13 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non for key, metrics in self._stored_metrics[train_eval].items(): logs[key] = torch.tensor(metrics).mean().item() del self._stored_metrics[train_eval] - return super(CPOTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(CPOTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): @@ -1266,9 +1286,12 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: # TODO remove once trl supports the updated to the Trainer.log method - return super(RewardTrainer, self).log( # pylint: disable=bad-super-call - logs, start_time - ) + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + return super(RewardTrainer, self).log( # pylint: disable=bad-super-call + logs, start_time + ) + # transformers<=4.46 + return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call class TrainerBuilderBase(abc.ABC): From 96c032492334c1f2579ea3cd398b19cb0b5d692a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 06:39:37 -0500 Subject: [PATCH 02/11] more flexibility with dev versions of 4.47.0.dev0 --- src/axolotl/core/trainer_builder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 76b093f546..691437bc65 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -976,7 +976,10 @@ def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> Non del self._stored_metrics[train_eval] if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): - return super().log(logs, start_time) + try: + return super().log(logs, start_time) + except TypeError: + return super().log(logs) # transformers<=4.46 return super().log(logs) # transformers<=4.46 def store_metrics( From e63cd17d0d12bc9bb191e0d2cd654923be20555f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 08:25:50 -0500 Subject: [PATCH 03/11] add patch for fsdp --- docker/Dockerfile-base | 2 +- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 80 +++++++++++++++++++ src/axolotl/utils/models.py | 7 ++ 3 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 src/axolotl/monkeypatch/trainer_fsdp_optim.py diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 7eab3b3e43..4b24bfc3ae 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -16,7 +16,7 @@ ENV PYTHON_VERSION=$PYTHON_VERSION ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST RUN apt-get update \ - && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ + && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \ && wget \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ && mkdir /root/.conda \ diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py new file mode 100644 index 0000000000..b09038793f --- /dev/null +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -0,0 +1,80 @@ +""" +fix for FSDP optimizer save in trainer w 4.47.0 +""" +import inspect +import logging + +from transformers.trainer import Trainer + +from axolotl.monkeypatch.unsloth_ import detab_code + +LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") + +ORIGINAL_TRAINER_CODE = """ + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled + +""" + +PATCHED_TRAINER_CODE = """ + + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled of self.is_fsdp_enabledå + +""" + + +def get_training_loop_code() -> str: + training_loop = inspect.getsource( + Trainer._inner_training_loop # pylint: disable=protected-access + ) + return training_loop + + +def check_training_loop_is_patchable() -> bool: + training_loop = get_training_loop_code() + training_loop, _ = detab_code(training_loop) + return ORIGINAL_TRAINER_CODE in training_loop + + +def patch_training_loop_for_fsdp(): + """ + monkeypatch for fixing the training loop for fsdp with optimizer save + """ + + try: + training_loop = get_training_loop_code() + except OSError: + return + Trainer._original_inner_training_loop = ( # pylint: disable=protected-access + training_loop + ) + training_loop, _ = detab_code(training_loop) + if ORIGINAL_TRAINER_CODE not in training_loop: + return + + training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE) + training_loop = training_loop.replace( + "def _inner_training_loop(", + "def _fixed_inner_training_loop(", + 1, + ) + + # load imports necessary + import transformers.trainer + + items_to_import = [] + for item in dir(transformers.trainer): + if item in training_loop: + items_to_import.append(item) + + exec( # pylint: disable=exec-used # nosec B102 + "from transformers.trainer import (" + + ", ".join(x for x in items_to_import) + + ")", + globals(), + ) + exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 + LOG.info("patching _inner_training_loop for fsdp optimsizer save") + Trainer.training_loop = ( # pylint: disable=protected-access + _fixed_training_loop # pylint: disable=undefined-variable # noqa: F821 + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 88a8aa581f..99095c1bfc 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -380,6 +380,13 @@ def apply_patches(self) -> None: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) + if self.cfg.fsdp: + from axolotl.monkeypatch.trainer_fsdp_optim import ( + patch_training_loop_for_fsdp, + ) + + patch_training_loop_for_fsdp() + if self.cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper From 4dc0645d77a5381f110836d6879e62d736531edf Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 08:26:54 -0500 Subject: [PATCH 04/11] fix typo --- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index b09038793f..451407c782 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -18,7 +18,7 @@ PATCHED_TRAINER_CODE = """ - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled of self.is_fsdp_enabledå + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabledå """ From 9a2e73cd6745e205436051d924e5666c1aeec760 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 08:28:09 -0500 Subject: [PATCH 05/11] correct fn name --- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index 451407c782..dd52e92458 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -76,5 +76,5 @@ def patch_training_loop_for_fsdp(): exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 LOG.info("patching _inner_training_loop for fsdp optimsizer save") Trainer.training_loop = ( # pylint: disable=protected-access - _fixed_training_loop # pylint: disable=undefined-variable # noqa: F821 + _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 ) From e87db0bded42992d399340f66401924cc3053b8c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 08:32:00 -0500 Subject: [PATCH 06/11] stray character --- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index dd52e92458..68f1b10a50 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -18,7 +18,7 @@ PATCHED_TRAINER_CODE = """ - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabledå + delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled """ From 5aee743c6d2b0508f7c88a860a0e89c9e7161935 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 08:42:07 -0500 Subject: [PATCH 07/11] fix patch --- src/axolotl/monkeypatch/trainer_fsdp_optim.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/trainer_fsdp_optim.py b/src/axolotl/monkeypatch/trainer_fsdp_optim.py index 68f1b10a50..835dea69b5 100644 --- a/src/axolotl/monkeypatch/trainer_fsdp_optim.py +++ b/src/axolotl/monkeypatch/trainer_fsdp_optim.py @@ -8,7 +8,7 @@ from axolotl.monkeypatch.unsloth_ import detab_code -LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") +LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") ORIGINAL_TRAINER_CODE = """ @@ -74,7 +74,7 @@ def patch_training_loop_for_fsdp(): globals(), ) exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102 - LOG.info("patching _inner_training_loop for fsdp optimsizer save") - Trainer.training_loop = ( # pylint: disable=protected-access + LOG.info("patching _inner_training_loop for fsdp optimizer save") + Trainer._inner_training_loop = ( # pylint: disable=protected-access _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 ) From 4c9a83f7db8f555f31b9948fb4645040ea76d6dd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 12:24:57 -0500 Subject: [PATCH 08/11] reset Trainer too --- tests/conftest.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1295d34b64..621827effe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,18 +119,25 @@ def temp_dir(): @pytest.fixture(scope="function", autouse=True) def cleanup_monkeypatches(): + from transformers import Trainer from transformers.models.llama.modeling_llama import LlamaFlashAttention2 original_fa2_forward = LlamaFlashAttention2.forward + original_trainer_inner_training_loop = ( + Trainer._inner_training_loop # pylint: disable=protected-access + ) # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward LlamaFlashAttention2.forward = original_fa2_forward + Trainer._inner_training_loop = ( # pylint: disable=protected-access + original_trainer_inner_training_loop + ) # Reset other known monkeypatches modules_to_reset: list[tuple[str, list[str]]] = [ ("transformers.models.llama.modeling_llama", ["LlamaFlashAttention2"]), - ("transformers.trainer",), + ("transformers.trainer", ["Trainer"]), ("transformers.loss.loss_utils",), ] for module_name_tuple in modules_to_reset: From 0009550f5b65c74abd5bf3ce00feb1bde13299b2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 12:59:00 -0500 Subject: [PATCH 09/11] also reset Trainer.training_step --- tests/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 621827effe..a9dde9dd88 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,6 +126,7 @@ def cleanup_monkeypatches(): original_trainer_inner_training_loop = ( Trainer._inner_training_loop # pylint: disable=protected-access ) + original_trainer_training_step = Trainer.training_step # monkey patches can happen inside the tests yield # Reset LlamaFlashAttention2 forward @@ -133,6 +134,7 @@ def cleanup_monkeypatches(): Trainer._inner_training_loop = ( # pylint: disable=protected-access original_trainer_inner_training_loop ) + Trainer.training_step = original_trainer_training_step # Reset other known monkeypatches modules_to_reset: list[tuple[str, list[str]]] = [ From c2df5b5cd9d51e82fb3a68b8f2b2fa2462be7e50 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 13:00:26 -0500 Subject: [PATCH 10/11] allow tests/patched to run more than one process on e2e runner --- cicd/cicd.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 79b3cc95e0..a6ff1c6383 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -2,6 +2,6 @@ set -e pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ -pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/ +pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ From ea71d538f8479c5690ae2b0d1b641e25f423978b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 8 Dec 2024 13:38:41 -0500 Subject: [PATCH 11/11] skip tests/patched in e2e for now since it's run in regular pytest --- cicd/cicd.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cicd/cicd.sh b/cicd/cicd.sh index a6ff1c6383..c3e46920d8 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -2,6 +2,6 @@ set -e pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ -pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ +# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/