Skip to content

Commit

Permalink
Starcoder2 : KVCache and flash attention (FusedSDPA) enablement (hugg…
Browse files Browse the repository at this point in the history
…ingface#1149)

Co-authored-by: Colabrese <[email protected]>
Co-authored-by: Abhilash Majumder <[email protected]>
Co-authored-by: Sayantan Sarkar <[email protected]>
Co-authored-by: regisss <[email protected]>
  • Loading branch information
5 people authored Aug 6, 2024
1 parent ec90e05 commit 13b6452
Show file tree
Hide file tree
Showing 8 changed files with 665 additions and 308 deletions.
13 changes: 11 additions & 2 deletions examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,22 @@ def __init__(self, tokenizer, model, args, options):
self.options = options
self._device = args.device
self.model_inputs = {"use_cache": self.options.use_cache}
if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral", "qwen2", "gptj"]:
if self.model.config.model_type in [
"llama",
"mistral",
"falcon",
"phi",
"mixtral",
"qwen2",
"gptj",
"starcoder2",
]:
self.model_inputs.update(
{
"reuse_cache": self.options.reuse_cache,
}
)
if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon"]:
if self.model.config.model_type in ["llama", "mistral", "qwen2", "falcon", "starcoder2"]:
if self.model.config.model_type != "falcon":
self.model_inputs.update(
{
Expand Down
2 changes: 1 addition & 1 deletion examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger):

model = deepspeed.init_inference(model, **ds_inference_kwargs)
model = model.module
if model.config.model_type in ["llama", "falcon", "qwen2"]:
if model.config.model_type in ["llama", "falcon", "qwen2", "starcoder2"]:
patch_scoped_linear_all_reduce(model)

if args.quant_config:
Expand Down
8 changes: 5 additions & 3 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"starcoder2",
"persimmon",
"qwen2",
"starcoder2",
"llava",
"llava_next",
"stablelm",
Expand Down Expand Up @@ -435,7 +436,7 @@ def create_pad_arg(pad_amount, i, j):
else:
assert False
elif model_kwargs["past_key_values"][0][0].dim() == 4:
return (0, 0, 0, pad_amount) # llama, falcon, qwen2
return (0, 0, 0, pad_amount) # llama, falcon, qwen2, starcoder2
else:
assert False, "Unknown case, please handle, or dont use bucketing"

Expand Down Expand Up @@ -860,7 +861,8 @@ def generate(
"phi",
"qwen2",
"gptj",
], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2 and gptj at the moment"
"starcoder2",
], "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2 and starcoder2 at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
Expand Down Expand Up @@ -1016,7 +1018,7 @@ def generate(
model_kwargs["kv_cache_len"] = calculated_max_length
model_kwargs["kv_cache_pad_len"] = generation_config.max_new_tokens

if self.config.model_type in ["llama", "falcon", "mistral", "qwen2", "gptj"]:
if self.config.model_type in ["llama", "falcon", "mistral", "qwen2", "gptj", "starcoder2"]:
if self.config.max_position_embeddings < calculated_max_length:
unwrap_deepspeed_model(self).update_sincos_cache(seq_len=calculated_max_length)

Expand Down
8 changes: 4 additions & 4 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@
GaudiQwen2Model,
GaudiStableLmDecoderLayer,
GaudiStableLmForCausalLM,
GaudiStarcoder2Attention,
GaudiStarcoder2DecoderLayer,
GaudiStarcoder2ForCausalLM,
GaudiStarcoder2Model,
LlamaConfig,
MistralConfig,
MixtralConfig,
Expand Down Expand Up @@ -175,8 +177,6 @@
gaudi_SpeechT5DecoderLayer_forward,
gaudi_stablelm_attention_forward,
gaudi_stablelm_model_forward,
gaudi_starcoder2_attention_forward,
gaudi_starcoder2_model_forward,
gaudi_swin_get_attn_mask,
gaudi_t5_layernorm_forward,
gaudi_T5Attention_forward,
Expand Down Expand Up @@ -517,8 +517,8 @@ def adapt_transformers_to_gaudi():

# Optimization for starcoder2 on Gaudi
transformers.models.starcoder2.modeling_starcoder2.Starcoder2ForCausalLM = GaudiStarcoder2ForCausalLM
transformers.models.starcoder2.modeling_starcoder2.Starcoder2Model.forward = gaudi_starcoder2_model_forward
transformers.models.starcoder2.modeling_starcoder2.Starcoder2Attention.forward = gaudi_starcoder2_attention_forward
transformers.models.starcoder2.modeling_starcoder2.Starcoder2Model = GaudiStarcoder2Model
transformers.models.starcoder2.modeling_starcoder2.Starcoder2Attention = GaudiStarcoder2Attention
transformers.models.starcoder2.modeling_starcoder2.Starcoder2DecoderLayer = GaudiStarcoder2DecoderLayer

# Optimization for qwen2 on Gaudi
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@
gaudi_stablelm_model_forward,
)
from .starcoder2 import (
GaudiStarcoder2Attention,
GaudiStarcoder2DecoderLayer,
GaudiStarcoder2ForCausalLM,
gaudi_starcoder2_attention_forward,
gaudi_starcoder2_model_forward,
GaudiStarcoder2Model,
)
from .swin import gaudi_swin_get_attn_mask
from .t5 import (
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/starcoder2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .modeling_starcoder2 import (
GaudiStarcoder2Attention,
GaudiStarcoder2DecoderLayer,
GaudiStarcoder2ForCausalLM,
gaudi_starcoder2_attention_forward,
gaudi_starcoder2_model_forward,
GaudiStarcoder2Model,
)
Loading

0 comments on commit 13b6452

Please sign in to comment.