diff --git a/tests/trainer/test_unified_checkpoint.py b/tests/trainer/test_unified_checkpoint.py index 4b8bcbfb8dd0..8f45e96d68c3 100644 --- a/tests/trainer/test_unified_checkpoint.py +++ b/tests/trainer/test_unified_checkpoint.py @@ -1109,3 +1109,39 @@ def runfrist(self, train_args): def rerun(self, train_args): self.run_n1c8(self.run_pretrain_file, **train_args) + + +class TestUnifiedCheckpointOnN1C8SaveLoadSpeedNoOptimizer(TestUnifiedCheckpointBase): + def setUp(self): + super().setUp() + for config_key in self.configs: + self.configs[config_key]["unified_checkpoint"] = 1 + self.configs[config_key]["unified_checkpoint_config"] = "master_weight_compatible" + self.configs[config_key]["ignore_load_lr_and_optim"] = 1 + self.configs[config_key]["ignore_save_lr_and_optim"] = 1 + self.need_allclose = False + self.rtol = 1e-7 + + def runfrist(self, train_args): + self.run_n1c8(self.run_pretrain_file, **train_args) + + def rerun(self, train_args): + self.run_n1c8(self.run_pretrain_file, **train_args) + + +class TestPaddleCheckpointOnN1C8SaveLoadSpeedNoOptimizer(TestUnifiedCheckpointBase): + def setUp(self): + super().setUp() + for config_key in self.configs: + self.configs[config_key]["unified_checkpoint"] = 0 + self.configs[config_key]["ignore_load_lr_and_optim"] = 1 + self.configs[config_key]["ignore_save_lr_and_optim"] = 1 + + self.need_allclose = False + self.rtol = 1e-7 + + def runfrist(self, train_args): + self.run_n1c8(self.run_pretrain_file, **train_args) + + def rerun(self, train_args): + self.run_n1c8(self.run_pretrain_file, **train_args)