diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 612d6b9642..13c2a36014 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -459,6 +459,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + reuse_kv_x_layer_idx: Optional[int] = None, attn_logit_softcapping: Optional[float] = None, kv_dim: Optional[int] = None, ): @@ -474,7 +475,17 @@ def __init__( self.n_heads = n_heads self.kv_n_heads = kv_n_heads self.sliding_window_size = sliding_window_size + if reuse_kv_x_layer_idx is not None: + if reuse_kv_layer_idx is not None: + raise ValueError( + 'Only one of reuse_kv_layer_idx and reuse_kv_x_layer_idx can be set.', + ) + if self.fused_qkv: + raise ValueError( + 'reuse_kv_x_layer_idx is not supported with fused_qkv.', + ) self.reuse_kv_layer_idx = reuse_kv_layer_idx + self.reuse_kv_x_layer_idx = reuse_kv_x_layer_idx self.attn_logit_softcapping = attn_logit_softcapping self.kv_dim = kv_dim if kv_dim is not None else self.d_model @@ -600,11 +611,14 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + x_prev: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if x_prev is not None: + extra_kwargs['x_prev'] = x_prev query, key, value = self.get_qkv( x=x, key_value_states=key_value_states, @@ -648,6 +662,7 @@ def forward( def get_qkv( self, x: torch.Tensor, + x_prev: Optional[torch.Tensor] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, @@ -656,6 +671,7 @@ def get_qkv( Args: x (torch.Tensor): The input query tensor. + x_prev (Optional[torch.Tensor]): The input tensor for the previous layer. prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer. key_value_states (Optional[torch.Tensor]): The input tensor for keys and values. @@ -709,11 +725,22 @@ def get_qkv( else: query = self.Wq(x) if key_value_states is not None: + if self.reuse_kv_x_layer_idx is not None: + raise NotImplementedError( + 'reuse_kv_x_layer_idx is not supported with key_value_states.', + ) key = self.Wk(key_value_states) value = self.Wv(key_value_states) else: - key = self.Wk(x) - value = self.Wv(x) + kv_input = x + if self.reuse_kv_x_layer_idx is not None: + if x_prev is None: + raise ValueError( + 'x_prev is None, cannot reuse_prev_layer_x.', + ) + kv_input = x_prev + key = self.Wk(kv_input) + value = self.Wv(kv_input) if self.clip_qkv: query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index c88cf33d1b..c70e678fff 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -165,11 +165,14 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + x_prev: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if x_prev is not None: + extra_kwargs['x_prev'] = x_prev if key_value_states is not None: extra_kwargs['key_value_states'] = key_value_states @@ -332,12 +335,15 @@ def forward( prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, key_value_states: Optional[torch.Tensor] = None, + x_prev: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: a = self.norm_1(x) extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if x_prev is not None: + extra_kwargs['x_prev'] = x_prev if key_value_states is not None: extra_kwargs['key_value_states'] = key_value_states diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 1adb64dc21..b9f42427c4 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -406,5 +406,6 @@ def allowed_block_overrides(self): 'attn_config': { 'sliding_window_size': None, 'reuse_kv_layer_idx': None, + 'reuse_kv_x_layer_idx': None, }, } diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 94e5fa29d5..89a54e4ecf 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -495,7 +495,10 @@ def construct_blocks(self, config: MPTConfig) -> nn.ModuleList: nn.ModuleList: The list of Transformer blocks. """ block_args = self.extract_block_args(config.to_dict()) - self.kv_cache_layers = set() + self.state_cache_layers = { + 'reuse_kv_layer_idx': set(), + 'reuse_kv_x_layer_idx': set(), + } self.blocks_fuse_norm_attn_norm = block_args.get( 'fuse_norm_attn_norm', False, @@ -537,7 +540,10 @@ def _get_override_block_args_list( new_block_args_list = [] layer_description_list = [] - reuse_kv_layer_idx_dict = {} + reuse_state_layer_idx_dicts = { + 'reuse_kv_layer_idx': {}, + 'reuse_kv_x_layer_idx': {}, + } for b_idx in range(config.n_layers): module_name = model_modules_order_expanded[b_idx] override_config = {} @@ -545,22 +551,35 @@ def _get_override_block_args_list( override_config = copy.deepcopy( config.block_overrides['overrides'][module_name], ) - if 'reuse_kv_layer_idx' in override_config.get( - 'attn_config', - {}, - ): - reuse_kv_layer_idx = MPTModel._resolve_reuse_kv_layer_idx( + attn_config = override_config.get('attn_config', {}) + if 'reuse_kv_layer_idx' in attn_config and 'reuse_kv_x_layer_idx' in attn_config: + raise ValueError( + 'Only one of reuse_kv_layer_idx and reuse_kv_x_layer_idx can be specified.', + ) + + reuse_type = None + if 'reuse_kv_layer_idx' in attn_config: + reuse_type = 'reuse_kv_layer_idx' + elif 'reuse_kv_x_layer_idx' in attn_config: + reuse_type = 'reuse_kv_x_layer_idx' + + if reuse_type is not None: + reuse_state_layer_idx = MPTModel._resolve_reuse_state_layer_idx( overrides_definition=config. block_overrides['overrides'], model_modules_order_expanded= model_modules_order_expanded, b_idx=b_idx, override_config=override_config, - reuse_kv_layer_idx_dict=reuse_kv_layer_idx_dict, + reuse_state_layer_idx_dict=reuse_state_layer_idx_dicts[ + reuse_type], + reuse_type=reuse_type, + ) + override_config['attn_config'][reuse_type + ] = reuse_state_layer_idx + self.state_cache_layers[reuse_type].add( + reuse_state_layer_idx, ) - override_config['attn_config']['reuse_kv_layer_idx' - ] = reuse_kv_layer_idx - self.kv_cache_layers.add(reuse_kv_layer_idx) layer_description_list.append([ b_idx, module_name, @@ -582,35 +601,37 @@ def _get_override_block_args_list( return new_block_args_list @staticmethod - def _resolve_reuse_kv_layer_idx( + def _resolve_reuse_state_layer_idx( overrides_definition: dict[str, Any], model_modules_order_expanded: list[str], b_idx: int, override_config: dict[str, Any], - reuse_kv_layer_idx_dict: dict[int, int], + reuse_state_layer_idx_dict: dict[int, int], + reuse_type: str, ) -> int: override_attn_config = override_config['attn_config'] - if override_attn_config['reuse_kv_layer_idx'] >= 0: + if override_attn_config[reuse_type] >= 0: raise ValueError( - f'The relative index of kv layer to reuse, {override_attn_config["reuse_kv_layer_idx"]=}, should be negative.', + f'The relative index of kv layer to reuse, {override_attn_config[reuse_type]=}, should be negative.', ) - reuse_kv_layer_idx = b_idx + override_attn_config['reuse_kv_layer_idx'] - if reuse_kv_layer_idx < 0: + reuse_state_layer_idx = b_idx + override_attn_config[reuse_type] + if reuse_state_layer_idx < 0: raise ValueError( - f'The absolute index of kv layer to reuse, {reuse_kv_layer_idx} should be non-negative.', + f'The absolute index of kv layer to reuse, {reuse_state_layer_idx} should be non-negative.', ) - if reuse_kv_layer_idx in reuse_kv_layer_idx_dict: - reuse_kv_layer_idx = reuse_kv_layer_idx_dict[reuse_kv_layer_idx] - reuse_kv_layer_idx_dict[b_idx] = reuse_kv_layer_idx + if reuse_state_layer_idx in reuse_state_layer_idx_dict: + reuse_state_layer_idx = reuse_state_layer_idx_dict[ + reuse_state_layer_idx] + reuse_state_layer_idx_dict[b_idx] = reuse_state_layer_idx - parent_layer_name = model_modules_order_expanded[reuse_kv_layer_idx] + parent_layer_name = model_modules_order_expanded[reuse_state_layer_idx] parent_config = {} if parent_layer_name == 'default' else copy.deepcopy( overrides_definition[parent_layer_name], ) if 'attn_config' not in parent_config: parent_config['attn_config'] = {} - parent_config['attn_config']['reuse_kv_layer_idx'] = override_config[ - 'attn_config']['reuse_kv_layer_idx'] + parent_config['attn_config']['reuse_state_layer_idx'] = override_config[ + 'attn_config']['reuse_state_layer_idx'] if override_config != parent_config and not ( 'allow_mismatch' in override_config and @@ -620,7 +641,7 @@ def _resolve_reuse_kv_layer_idx( 'For reusing the kv cache of a previous layer, the previous layer should match the block config as the current layer.', ) - return reuse_kv_layer_idx + return reuse_state_layer_idx @staticmethod def _get_modules_order_expanded(order: list[dict[str, Any]]) -> list[str]: @@ -935,7 +956,7 @@ def forward( # initialize the past key values cache if it should be used presents = () if use_cache else None if ( - use_cache or len(self.kv_cache_layers) > 0 + use_cache or len(self.state_cache_layers['reuse_kv_layer_idx']) > 0 ) and past_key_values is None: past_key_values = [() for _ in range(self.config.n_layers) ] # type: ignore @@ -954,17 +975,28 @@ def forward( ) layer_kv_cache_dict = {} + layer_kv_x_cache_dict = {} for b_idx, block in enumerate(self.blocks): attn_block = block.norm_attn_norm.attn if self.blocks_fuse_norm_attn_norm else block.attn if attn_block.reuse_kv_layer_idx is not None: if attn_block.reuse_kv_layer_idx not in layer_kv_cache_dict: raise KeyError( - f'kv cache for layer {block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.', + f'kv cache for layer {attn_block.reuse_kv_layer_idx} not found in {layer_kv_cache_dict=}.', ) prev_layer_key_value = layer_kv_cache_dict[ attn_block.reuse_kv_layer_idx] else: prev_layer_key_value = None + if b_idx in self.state_cache_layers['reuse_kv_x_layer_idx']: + layer_kv_x_cache_dict[b_idx] = x + if attn_block.reuse_kv_x_layer_idx is not None: + if attn_block.reuse_kv_x_layer_idx not in layer_kv_x_cache_dict: + raise KeyError( + f'kv cache for layer {attn_block.reuse_kv_x_layer_idx} not found in {layer_kv_x_cache_dict=}.', + ) + x_prev = layer_kv_x_cache_dict[attn_block.reuse_kv_layer_idx] + else: + x_prev = None if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) @@ -974,6 +1006,8 @@ def forward( extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value + if x_prev is not None: + extra_kwargs['x_prev'] = x_prev x, attn_weights, present = block( x, past_key_value=past_key_value, @@ -988,7 +1022,7 @@ def forward( ) if presents is not None: presents += (present,) - if b_idx in self.kv_cache_layers: + if b_idx in self.state_cache_layers['reuse_kv_layer_idx']: layer_kv_cache_dict[b_idx] = [ present[0][:, past_position:], present[1][:, past_position:],