Skip to content

Commit

Permalink
test(trainer): add test save&load without optimizer weights
Browse files Browse the repository at this point in the history
  • Loading branch information
DrownFish19 committed Jan 11, 2024
1 parent 28bc5bf commit 783399b
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tests/trainer/test_unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,39 @@ def runfrist(self, train_args):

def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)


class TestUnifiedCheckpointOnN1C8SaveLoadSpeedNoOptimizer(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
for config_key in self.configs:
self.configs[config_key]["unified_checkpoint"] = 1
self.configs[config_key]["unified_checkpoint_config"] = "master_weight_compatible"
self.configs[config_key]["ignore_load_lr_and_optim"] = 1
self.configs[config_key]["ignore_save_lr_and_optim"] = 1
self.need_allclose = False
self.rtol = 1e-7

def runfrist(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)


class TestPaddleCheckpointOnN1C8SaveLoadSpeedNoOptimizer(TestUnifiedCheckpointBase):
def setUp(self):
super().setUp()
for config_key in self.configs:
self.configs[config_key]["unified_checkpoint"] = 0
self.configs[config_key]["ignore_load_lr_and_optim"] = 1
self.configs[config_key]["ignore_save_lr_and_optim"] = 1

self.need_allclose = False
self.rtol = 1e-7

def runfrist(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

def rerun(self, train_args):
self.run_n1c8(self.run_pretrain_file, **train_args)

0 comments on commit 783399b

Please sign in to comment.