diff --git a/examples/command_line.py b/examples/command_line.py index 5ab1d28b..14568d24 100644 --- a/examples/command_line.py +++ b/examples/command_line.py @@ -45,8 +45,9 @@ def cli(ctx): "attention_output_with_dense": 0, "layer_norm_patch_steps": 50000, "gelu_patch_steps": 50000, - 'linear_min_parameters': 0, - 'null_score_diff_threshold': 0.0, + "eval_with_current_patch_params ": 0, + "linear_min_parameters": 0, + "null_score_diff_threshold": 0.0, } GLUE_TYPICAL_PARAMETERS = { @@ -84,6 +85,9 @@ def cli(ctx): "distil_alpha_ce": 0.1, "distil_alpha_teacher": 0.90, "attention_output_with_dense": 0, + "layer_norm_patch_steps": 50000, + "gelu_patch_steps": 50000, + "eval_with_current_patch_params ": 0, } QA_TASKS = {"squadv1", "squadv2"} diff --git a/examples/question_answering/qa_sparse_xp.py b/examples/question_answering/qa_sparse_xp.py index de33c0a7..31b93e48 100644 --- a/examples/question_answering/qa_sparse_xp.py +++ b/examples/question_answering/qa_sparse_xp.py @@ -129,6 +129,7 @@ class SparseQAShortNamer(TrialShortNamer): 'layer_norm_patch_start_delta': 0.99, 'gelu_patch':False, 'gelu_patch_steps': 50000, + 'eval_with_current_patch_params': False, 'linear_min_parameters': 0.005, 'rewind_model_name_or_path': None, 'lr_scheduler_type': 'SchedulerType.LINEAR', diff --git a/examples/text_classification/glue_sparse_xp.py b/examples/text_classification/glue_sparse_xp.py index b56c3b56..7f3a621a 100644 --- a/examples/text_classification/glue_sparse_xp.py +++ b/examples/text_classification/glue_sparse_xp.py @@ -124,6 +124,7 @@ class SparseGlueShortNamer(TrialShortNamer): 'layer_norm_patch_start_delta': 0.99, 'gelu_patch': False, 'gelu_patch_steps': 50000, + 'eval_with_current_patch_params': False, 'warmup_ratio': 0.0, 'fp16_full_eval': False, 'label_smoothing_factor': 0.0, diff --git a/nn_pruning/modules/nonorm.py b/nn_pruning/modules/nonorm.py index 05267c96..59795c85 100644 --- a/nn_pruning/modules/nonorm.py +++ b/nn_pruning/modules/nonorm.py @@ -29,10 +29,10 @@ def __init__(self, layerNorm, if self.schedule_callback is None: self.steps = steps - self.delta_step = (self.final_delta - self.delta) / self.steps - self.mix_step = 1 / self.steps self.delta = start_delta self.final_delta = 1.0 + self.delta_step = (self.final_delta - self.delta) / self.steps + self.mix_step = 1 / self.steps self.mix = 1.0 else: self.steps = None @@ -47,17 +47,17 @@ def __init__(self, layerNorm, def forward(self, batch): accumulator = self.accumulator.clone() - if self.training: - if self.schedule_callback is not None: - d = self.schedule_callback() - mix = d["mix"] - delta = d["delta"] - else: + if self.schedule_callback is not None: + d = self.schedule_callback() + mix = d["mix"] + delta = d["delta"] + else: + if self.training: mix = self.mix delta = self.delta - else: - mix = 0 - delta = 1.0 + else: + mix = 0 + delta = 1.0 if mix == 0 and delta == 1.0: batch_mean = accumulator[0] / accumulator[2] @@ -66,9 +66,10 @@ def forward(self, batch): batch_mean = batch.mean(-1, keepdim=True) batch_var = batch.var(-1, unbiased=False, keepdim=True) - one = torch.tensor(1.0, device=batch_var.device) - new_acc = torch.stack([batch_mean.mean(), batch_var.mean(), one]) - accumulator = torch.lerp(new_acc, accumulator, delta) + if self.training: + one = torch.tensor(1.0, device=batch_var.device) + new_acc = torch.stack([batch_mean.mean(), batch_var.mean(), one]) + accumulator = torch.lerp(new_acc, accumulator, delta) batch_mean = torch.lerp(accumulator[0] / accumulator[2], batch_mean, mix) batch_var = torch.lerp(accumulator[1] / accumulator[2], batch_var, mix) @@ -143,4 +144,3 @@ def is_patchable(self, module_name, module, raiseError): def new_child_module(self, child_module_name, child_module, patch_info): return NoNorm(child_module.weight.detach(), child_module.bias.detach()) - diff --git a/nn_pruning/patch_coordinator.py b/nn_pruning/patch_coordinator.py index 96c5c148..5c7bc91b 100644 --- a/nn_pruning/patch_coordinator.py +++ b/nn_pruning/patch_coordinator.py @@ -229,6 +229,12 @@ class SparseTrainingArguments: }, ) + eval_with_current_patch_params: bool = field( + default=False, + metadata={ + "help": "Whether to keep the transition parameters used during training for eval. Only for Layer2NoNorm." + }, + ) class ModelPatchingCoordinator: MODEL_STRUCTURE = BertStructure @@ -340,7 +346,7 @@ def interp(a,b, interpf): return a * interpf + (1.0 - interpf) * b if hasattr(sparse_args, "layer_norm_patch") and sparse_args.layer_norm_patch: - if training: + if training or sparse_args.eval_with_current_patch_params: interpf = 0.0 layer_norm_patch_steps = sparse_args.layer_norm_patch_steps if step < layer_norm_patch_steps: