Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Jan 13, 2025
1 parent 8781b2c commit 2370091
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 30 deletions.
31 changes: 29 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
reuse_kv_x_layer_idx: Optional[int] = None,
attn_logit_softcapping: Optional[float] = None,
kv_dim: Optional[int] = None,
):
Expand All @@ -474,7 +475,17 @@ def __init__(
self.n_heads = n_heads
self.kv_n_heads = kv_n_heads
self.sliding_window_size = sliding_window_size
if reuse_kv_x_layer_idx is not None:
if reuse_kv_layer_idx is not None:
raise ValueError(
'Only one of reuse_kv_layer_idx and reuse_kv_x_layer_idx can be set.',
)
if self.fused_qkv:
raise ValueError(
'reuse_kv_x_layer_idx is not supported with fused_qkv.',
)
self.reuse_kv_layer_idx = reuse_kv_layer_idx
self.reuse_kv_x_layer_idx = reuse_kv_x_layer_idx
self.attn_logit_softcapping = attn_logit_softcapping

self.kv_dim = kv_dim if kv_dim is not None else self.d_model
Expand Down Expand Up @@ -600,11 +611,14 @@ def forward(
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
x_prev: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
if x_prev is not None:
extra_kwargs['x_prev'] = x_prev
query, key, value = self.get_qkv(
x=x,
key_value_states=key_value_states,
Expand Down Expand Up @@ -648,6 +662,7 @@ def forward(
def get_qkv(
self,
x: torch.Tensor,
x_prev: Optional[torch.Tensor] = None,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
Expand All @@ -656,6 +671,7 @@ def get_qkv(
Args:
x (torch.Tensor): The input query tensor.
x_prev (Optional[torch.Tensor]): The input tensor for the previous layer.
prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer.
key_value_states (Optional[torch.Tensor]): The input tensor for keys and values.
Expand Down Expand Up @@ -709,11 +725,22 @@ def get_qkv(
else:
query = self.Wq(x)
if key_value_states is not None:
if self.reuse_kv_x_layer_idx is not None:
raise NotImplementedError(
'reuse_kv_x_layer_idx is not supported with key_value_states.',
)
key = self.Wk(key_value_states)
value = self.Wv(key_value_states)
else:
key = self.Wk(x)
value = self.Wv(x)
kv_input = x
if self.reuse_kv_x_layer_idx is not None:
if x_prev is None:
raise ValueError(
'x_prev is None, cannot reuse_prev_layer_x.',
)
kv_input = x_prev
key = self.Wk(kv_input)
value = self.Wv(kv_input)

if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
Expand Down
6 changes: 6 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,14 @@ def forward(
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
x_prev: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
if x_prev is not None:
extra_kwargs['x_prev'] = x_prev
if key_value_states is not None:
extra_kwargs['key_value_states'] = key_value_states

Expand Down Expand Up @@ -332,12 +335,15 @@ def forward(
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
x_prev: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[tuple[torch.Tensor, torch.Tensor]]]:
a = self.norm_1(x)
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
if x_prev is not None:
extra_kwargs['x_prev'] = x_prev
if key_value_states is not None:
extra_kwargs['key_value_states'] = key_value_states

Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,5 +406,6 @@ def allowed_block_overrides(self):
'attn_config': {
'sliding_window_size': None,
'reuse_kv_layer_idx': None,
'reuse_kv_x_layer_idx': None,
},
}
90 changes: 62 additions & 28 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,10 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList:
nn.ModuleList: The list of Transformer blocks.
"""
block_args = self.extract_block_args(config.to_dict())
self.kv_cache_layers = set()
self.state_cache_layers = {
'reuse_kv_layer_idx': set(),
'reuse_kv_x_layer_idx': set(),
}
self.blocks_fuse_norm_attn_norm = block_args.get(
'fuse_norm_attn_norm',
False,
Expand Down Expand Up @@ -537,30 +540,46 @@ def _get_override_block_args_list(
new_block_args_list = []
layer_description_list = []

reuse_kv_layer_idx_dict = {}
reuse_state_layer_idx_dicts = {
'reuse_kv_layer_idx': {},
'reuse_kv_x_layer_idx': {},
}
for b_idx in range(config.n_layers):
module_name = model_modules_order_expanded[b_idx]
override_config = {}
if module_name != 'default':
override_config = copy.deepcopy(
config.block_overrides['overrides'][module_name],
)
if 'reuse_kv_layer_idx' in override_config.get(
'attn_config',
{},
):
reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_idx(
attn_config = override_config.get('attn_config', {})
if 'reuse_kv_layer_idx' in attn_config and 'reuse_kv_x_layer_idx' in attn_config:
raise ValueError(
'Only one of reuse_kv_layer_idx and reuse_kv_x_layer_idx can be specified.',
)

reuse_type = None
if 'reuse_kv_layer_idx' in attn_config:
reuse_type = 'reuse_kv_layer_idx'
elif 'reuse_kv_x_layer_idx' in attn_config:
reuse_type = 'reuse_kv_x_layer_idx'

if reuse_type is not None:
reuse_state_layer_idx = MPTModel._resolve_reuse_state_layer_idx(
overrides_definition=config.
block_overrides['overrides'],
model_modules_order_expanded=
model_modules_order_expanded,
b_idx=b_idx,
override_config=override_config,
reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict,
reuse_state_layer_idx_dict=reuse_state_layer_idx_dicts[
reuse_type],
reuse_type=reuse_type,
)
override_config['attn_config'][reuse_type
] = reuse_state_layer_idx
self.state_cache_layers[reuse_type].add(
reuse_state_layer_idx,
)
override_config['attn_config']['reuse_kv_layer_idx'
] = reuse_kv_layer_idx
self.kv_cache_layers.add(reuse_kv_layer_idx)
layer_description_list.append([
b_idx,
module_name,
Expand All @@ -582,35 +601,37 @@ def _get_override_block_args_list(
return new_block_args_list

@staticmethod
def _resolve_reuse_kv_layer_idx(
def _resolve_reuse_state_layer_idx(
overrides_definition: dict[str, Any],
model_modules_order_expanded: list[str],
b_idx: int,
override_config: dict[str, Any],
reuse_kv_layer_idx_dict: dict[int, int],
reuse_state_layer_idx_dict: dict[int, int],
reuse_type: str,
) -> int:
override_attn_config = override_config['attn_config']
if override_attn_config['reuse_kv_layer_idx'] >= 0:
if override_attn_config[reuse_type] >= 0:
raise ValueError(
f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.',
f'The relative index of kv layer to reuse, {override_attn_config[reuse_type]=}, should be negative.',
)
reuse_kv_layer_idx = b_idx + override_attn_config['reuse_kv_layer_idx']
if reuse_kv_layer_idx < 0:
reuse_state_layer_idx = b_idx + override_attn_config[reuse_type]
if reuse_state_layer_idx < 0:
raise ValueError(
f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.',
f'The absolute index of kv layer to reuse, {reuse_state_layer_idx} should be non-negative.',
)
if reuse_kv_layer_idx in reuse_kv_layer_idx_dict:
reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx]
reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx
if reuse_state_layer_idx in reuse_state_layer_idx_dict:
reuse_state_layer_idx = reuse_state_layer_idx_dict[
reuse_state_layer_idx]
reuse_state_layer_idx_dict[b_idx] = reuse_state_layer_idx

parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx]
parent_layer_name = model_modules_order_expanded[reuse_state_layer_idx]
parent_config = {} if parent_layer_name == 'default' else copy.deepcopy(
overrides_definition[parent_layer_name],
)
if 'attn_config' not in parent_config:
parent_config['attn_config'] = {}
parent_config['attn_config']['reuse_kv_layer_idx'] = override_config[
'attn_config']['reuse_kv_layer_idx']
parent_config['attn_config']['reuse_state_layer_idx'] = override_config[
'attn_config']['reuse_state_layer_idx']

if override_config != parent_config and not (
'allow_mismatch' in override_config and
Expand All @@ -620,7 +641,7 @@ def _resolve_reuse_kv_layer_idx(
'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.',
)

return reuse_kv_layer_idx
return reuse_state_layer_idx

@staticmethod
def _get_modules_order_expanded(order: list[dict[str, Any]]) -> list[str]:
Expand Down Expand Up @@ -935,7 +956,7 @@ def forward(
# initialize the past key values cache if it should be used
presents = () if use_cache else None
if (
use_cache or len(self.kv_cache_layers) > 0
use_cache or len(self.state_cache_layers['reuse_kv_layer_idx']) > 0
) and past_key_values is None:
past_key_values = [() for _ in range(self.config.n_layers)
] # type: ignore
Expand All @@ -954,17 +975,28 @@ def forward(
)

layer_kv_cache_dict = {}
layer_kv_x_cache_dict = {}
for b_idx, block in enumerate(self.blocks):
attn_block = block.norm_attn_norm.attn if self.blocks_fuse_norm_attn_norm else block.attn
if attn_block.reuse_kv_layer_idx is not None:
if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict:
raise KeyError(
f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.',
f'kv cache for layer {attn_block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.',
)
prev_layer_key_value = layer_kv_cache_dict[
attn_block.reuse_kv_layer_idx]
else:
prev_layer_key_value = None
if b_idx in self.state_cache_layers['reuse_kv_x_layer_idx']:
layer_kv_x_cache_dict[b_idx] = x
if attn_block.reuse_kv_x_layer_idx is not None:
if attn_block.reuse_kv_x_layer_idx not in layer_kv_x_cache_dict:
raise KeyError(
f'kv cache for layer {attn_block.reuse_kv_x_layer_idx} not found in {layer_kv_x_cache_dict=}.',
)
x_prev = layer_kv_x_cache_dict[attn_block.reuse_kv_layer_idx]
else:
x_prev = None
if output_hidden_states:
assert all_hidden_states is not None # pyright
all_hidden_states = all_hidden_states + (x,)
Expand All @@ -974,6 +1006,8 @@ def forward(
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
if x_prev is not None:
extra_kwargs['x_prev'] = x_prev
x, attn_weights, present = block(
x,
past_key_value=past_key_value,
Expand All @@ -988,7 +1022,7 @@ def forward(
)
if presents is not None:
presents += (present,)
if b_idx in self.kv_cache_layers:
if b_idx in self.state_cache_layers['reuse_kv_layer_idx']:
layer_kv_cache_dict[b_idx] = [
present[0][:, past_position:],
present[1][:, past_position:],
Expand Down

0 comments on commit 2370091

Please sign in to comment.