Skip to content

Commit

Permalink
Layer2NoNorm uses same mix and delta value during eval mode as during…
Browse files Browse the repository at this point in the history
… training (useful for debugging or analysis of Layer2NoNorm transition) (huggingface#13)

Added option to keep Layer2NoNorm transition parameters during evaluation

Authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
echarlaix authored Apr 8, 2021
1 parent 5214f49 commit 8673d52
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 18 deletions.
8 changes: 6 additions & 2 deletions examples/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def cli(ctx):
"attention_output_with_dense": 0,
"layer_norm_patch_steps": 50000,
"gelu_patch_steps": 50000,
'linear_min_parameters': 0,
'null_score_diff_threshold': 0.0,
"eval_with_current_patch_params ": 0,
"linear_min_parameters": 0,
"null_score_diff_threshold": 0.0,
}

GLUE_TYPICAL_PARAMETERS = {
Expand Down Expand Up @@ -84,6 +85,9 @@ def cli(ctx):
"distil_alpha_ce": 0.1,
"distil_alpha_teacher": 0.90,
"attention_output_with_dense": 0,
"layer_norm_patch_steps": 50000,
"gelu_patch_steps": 50000,
"eval_with_current_patch_params ": 0,
}

QA_TASKS = {"squadv1", "squadv2"}
Expand Down
1 change: 1 addition & 0 deletions examples/question_answering/qa_sparse_xp.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class SparseQAShortNamer(TrialShortNamer):
'layer_norm_patch_start_delta': 0.99,
'gelu_patch':False,
'gelu_patch_steps': 50000,
'eval_with_current_patch_params': False,
'linear_min_parameters': 0.005,
'rewind_model_name_or_path': None,
'lr_scheduler_type': 'SchedulerType.LINEAR',
Expand Down
1 change: 1 addition & 0 deletions examples/text_classification/glue_sparse_xp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ class SparseGlueShortNamer(TrialShortNamer):
'layer_norm_patch_start_delta': 0.99,
'gelu_patch': False,
'gelu_patch_steps': 50000,
'eval_with_current_patch_params': False,
'warmup_ratio': 0.0,
'fp16_full_eval': False,
'label_smoothing_factor': 0.0,
Expand Down
30 changes: 15 additions & 15 deletions nn_pruning/modules/nonorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def __init__(self, layerNorm,

if self.schedule_callback is None:
self.steps = steps
self.delta_step = (self.final_delta - self.delta) / self.steps
self.mix_step = 1 / self.steps
self.delta = start_delta
self.final_delta = 1.0
self.delta_step = (self.final_delta - self.delta) / self.steps
self.mix_step = 1 / self.steps
self.mix = 1.0
else:
self.steps = None
Expand All @@ -47,17 +47,17 @@ def __init__(self, layerNorm,
def forward(self, batch):
accumulator = self.accumulator.clone()

if self.training:
if self.schedule_callback is not None:
d = self.schedule_callback()
mix = d["mix"]
delta = d["delta"]
else:
if self.schedule_callback is not None:
d = self.schedule_callback()
mix = d["mix"]
delta = d["delta"]
else:
if self.training:
mix = self.mix
delta = self.delta
else:
mix = 0
delta = 1.0
else:
mix = 0
delta = 1.0

if mix == 0 and delta == 1.0:
batch_mean = accumulator[0] / accumulator[2]
Expand All @@ -66,9 +66,10 @@ def forward(self, batch):
batch_mean = batch.mean(-1, keepdim=True)
batch_var = batch.var(-1, unbiased=False, keepdim=True)

one = torch.tensor(1.0, device=batch_var.device)
new_acc = torch.stack([batch_mean.mean(), batch_var.mean(), one])
accumulator = torch.lerp(new_acc, accumulator, delta)
if self.training:
one = torch.tensor(1.0, device=batch_var.device)
new_acc = torch.stack([batch_mean.mean(), batch_var.mean(), one])
accumulator = torch.lerp(new_acc, accumulator, delta)

batch_mean = torch.lerp(accumulator[0] / accumulator[2], batch_mean, mix)
batch_var = torch.lerp(accumulator[1] / accumulator[2], batch_var, mix)
Expand Down Expand Up @@ -143,4 +144,3 @@ def is_patchable(self, module_name, module, raiseError):

def new_child_module(self, child_module_name, child_module, patch_info):
return NoNorm(child_module.weight.detach(), child_module.bias.detach())

8 changes: 7 additions & 1 deletion nn_pruning/patch_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ class SparseTrainingArguments:
},
)

eval_with_current_patch_params: bool = field(
default=False,
metadata={
"help": "Whether to keep the transition parameters used during training for eval. Only for Layer2NoNorm."
},
)
class ModelPatchingCoordinator:
MODEL_STRUCTURE = BertStructure

Expand Down Expand Up @@ -340,7 +346,7 @@ def interp(a,b, interpf):
return a * interpf + (1.0 - interpf) * b

if hasattr(sparse_args, "layer_norm_patch") and sparse_args.layer_norm_patch:
if training:
if training or sparse_args.eval_with_current_patch_params:
interpf = 0.0
layer_norm_patch_steps = sparse_args.layer_norm_patch_steps
if step < layer_norm_patch_steps:
Expand Down

0 comments on commit 8673d52

Please sign in to comment.