diff --git a/src/adapters/models/deberta/modeling_deberta.py b/src/adapters/models/deberta/modeling_deberta.py index c8fda7d89..77c6117b1 100644 --- a/src/adapters/models/deberta/modeling_deberta.py +++ b/src/adapters/models/deberta/modeling_deberta.py @@ -145,10 +145,15 @@ def forward( if rel_att is not None: # >>> START AH Changes <<< - rel_att_padded = torch.zeros_like(attention_scores) - rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, :, -rel_att.size(-1) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att # >>> END AH Changes <<< - attention_scores = attention_scores + rel_att_padded # bxhxlxd if self.head_logits_proj is not None: diff --git a/src/adapters/models/deberta_v2/modeling_deberta_v2.py b/src/adapters/models/deberta_v2/modeling_deberta_v2.py index 6325fab95..2b673c491 100644 --- a/src/adapters/models/deberta_v2/modeling_deberta_v2.py +++ b/src/adapters/models/deberta_v2/modeling_deberta_v2.py @@ -137,7 +137,17 @@ def forward( # >>> END AH Changes <<< if rel_att is not None: - attention_scores = attention_scores + rel_att + # >>> START AH Changes <<< + # rel_att is set to 0 by default, i.e. rel_att is always not None (don't know why HuggingFace does this). + # Hence, we must check whether rel_att is a tensor and if so, pad it with zeros to be able to add it to attention_scores. + if isinstance(rel_att, torch.Tensor): + rel_att_padded = torch.zeros_like(attention_scores) + rel_att_padded[:, :, -rel_att.size(2) :] = rel_att + attention_scores = attention_scores + rel_att_padded + else: + attention_scores = attention_scores + rel_att + # >>> END AH Changes <<< + attention_scores = attention_scores attention_scores = attention_scores.view( -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)