Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Jul 20, 2024
1 parent 091c018 commit 5a75a13
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 17 deletions.
13 changes: 6 additions & 7 deletions src/transformers/models/zamba2/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class CausalSelfAttention(nn.Module):

def __init__(self, config, layer_number, attn_mask_type=AttnMaskType.padding, **kwargs):
def __init__(self, config, layer_number):
super().__init__()
assert config.hidden_size % config.num_attention_heads == 0
self.config = config
Expand Down Expand Up @@ -61,16 +61,15 @@ def __init__(self, config, layer_number, attn_mask_type=AttnMaskType.padding, **
self.linear_v_lora_A_list.append(linear_v_lora_A)
self.linear_v_lora_B_list.append(linear_v_lora_B)

def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
"""Allocate memory to store kv cache during inference."""

def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype, device):
"""Allocate memory to store kv cache during inference."""
return torch.empty(
inference_max_sequence_length,
batch_size,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head * 2,
dtype=dtype,
device=torch.cuda.current_device(),
device=device,
)

def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb, layer_number):
Expand All @@ -93,10 +92,10 @@ def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_p
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.dtype
inf_max_seq_length, inf_max_batch_size, key.dtype, inference_params.device
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.dtype
inf_max_seq_length, inf_max_batch_size, value.dtype, inference_params.device
)
inference_params.key_value_memory_dict[layer_number] = (
inference_key_memory,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/zamba2/mamba2_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def forward(self, u, from_shared_proj=None, seqlen=None, seq_idx=None, inference
else:
zxbcdt = self.in_proj[0](u) # (B, L, d_in_proj) or (B * L, d_in_proj)
if seqlen_og is not None:
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
A = -torch.exp(self.A_log) # (nheads) or (d_inner, d_state)
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)
if self.use_mem_eff_path and inference_params is None:
Expand Down
13 changes: 4 additions & 9 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,9 @@ def count_mem_blocks_in_config(config):

class HybridMambaAttentionDynamicCache:
### This is actually a static cache
def __init__(self, max_batch_size, max_sequence_length): ###, dtype=torch.bfloat16):
def __init__(self, max_batch_size, max_sequence_length, device=None): ###, dtype=torch.bfloat16):
### self.dtype = dtype
self.device = device
self.max_sequence_length = max_sequence_length
self.max_batch_size = max_batch_size
self.sequence_len_offset = 0
Expand Down Expand Up @@ -1077,7 +1078,6 @@ def forward(
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence.
"""

residual = hidden_states

# `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712).
Expand All @@ -1086,7 +1086,6 @@ def forward(
hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states
)
hidden_states = self.input_layernorm(hidden_states)

hidden_states = self.mamba(
u=hidden_states,
inference_params=past_key_value,
Expand Down Expand Up @@ -1304,7 +1303,6 @@ def forward(

original_hidden_states = torch.clone(inputs_embeds)#.transpose(0, 1) ###
# original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer

if use_cache and past_key_values is None:
logger.warning_once(
"Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
Expand Down Expand Up @@ -1357,7 +1355,7 @@ def forward(
cache_position=cache_position,
)
block_count += 1
transformer_hidden_states = layer_outputs[0]
transformer_hidden_states = layer_outputs[0]
if output_attentions:
if layer_outputs[1] is not None:
all_self_attns += (layer_outputs[1],)
Expand All @@ -1372,7 +1370,6 @@ def forward(
)
else:
transformer_hidden_states = None

if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
next(mamba_layers).__call__,
Expand All @@ -1397,7 +1394,6 @@ def forward(
cache_position=cache_position,
)
hidden_states = layer_outputs[0]

hidden_states = self.final_layernorm(hidden_states)

# add hidden states from the last decoder layer
Expand Down Expand Up @@ -1560,7 +1556,6 @@ def forward(
else:
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
Expand Down Expand Up @@ -1629,7 +1624,7 @@ def prepare_inputs_for_generation(
# )
max_sequence_length = max_new_tokens + input_ids.shape[1]
past_key_values = HybridMambaAttentionDynamicCache(
input_ids.shape[0], max_sequence_length
input_ids.shape[0], max_sequence_length, device=self.device
)

position_ids = kwargs.get("position_ids", None)
Expand Down

0 comments on commit 5a75a13

Please sign in to comment.