Skip to content

Commit

Permalink
handle attention_mask=None
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed May 25, 2024
1 parent d64f1bc commit 3ed9519
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/billm/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1]
)
bi_attention_mask = torch.zeros_like(causal_mask)
bi_attention_mask = torch.zeros_like(causal_mask) if causal_mask is not None else None

# embed positions
hidden_states = inputs_embeds
Expand Down
2 changes: 1 addition & 1 deletion src/billm/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def forward(
past_key_values_length,
sliding_window=self.config.sliding_window,
)
bi_attention_mask = torch.zeros_like(attention_mask)
bi_attention_mask = torch.zeros_like(attention_mask) if attention_mask is not None else None

hidden_states = inputs_embeds

Expand Down
2 changes: 1 addition & 1 deletion src/billm/modeling_openelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def forward(
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
bi_attention_mask = torch.zeros_like(causal_mask)
bi_attention_mask = torch.zeros_like(causal_mask) if causal_mask is not None else None

# embed positions
hidden_states = inputs_embeds
Expand Down
2 changes: 1 addition & 1 deletion src/billm/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def forward(
past_key_values_length,
sliding_window=self.config.sliding_window,
)
bi_attention_mask = torch.zeros_like(attention_mask)
bi_attention_mask = torch.zeros_like(attention_mask) if attention_mask is not None else None

hidden_states = inputs_embeds

Expand Down

0 comments on commit 3ed9519

Please sign in to comment.