Skip to content

Commit

Permalink
enable llava1.6 with deepspeed
Browse files Browse the repository at this point in the history
  • Loading branch information
schoi-habana committed Oct 23, 2024
1 parent 03fa6dd commit 1f4a2df
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
67 changes: 62 additions & 5 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@
logger = logging.getLogger(__name__)


def override_print(enable):
import builtins as __builtin__

builtin_print = __builtin__.print

def print(*args, **kwargs):
force = kwargs.pop("force", False)
if force or enable:
builtin_print(*args, **kwargs)

__builtin__.print = print

def override_logger(logger, enable):
logger_info = logger.info

def info(*args, **kwargs):
force = kwargs.pop("force", False)
if force or enable:
logger_info(*args, **kwargs)

logger.info = info

def setup_quantization(model, args):
from neural_compressor.torch.quantization import FP8Config, convert, prepare

Expand Down Expand Up @@ -129,6 +151,11 @@ def main():

# set args.quant_config with env variable if it is set
args.quant_config = os.getenv("QUANT_CONFIG", "")

args.local_rank = int(os.getenv("LOCAL_RANK", "0"))
args.world_size = int(os.getenv("WORLD_SIZE", "0"))
args.global_rank = int(os.getenv("RANK", "0"))

os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE")
adapt_transformers_to_gaudi()

Expand Down Expand Up @@ -181,12 +208,47 @@ def main():

htcore.hpu_set_env()

########sc

#os.environ.setdefault("PT_HPU_LAZY_ACC_PAR_MODE", "0")
#os.environ.setdefault("PT_HPU_ENABLE_LAZY_COLLECTIVES", "TRUE")#"true")


generator = pipeline(
"image-to-text",
model=args.model_name_or_path,
torch_dtype=model_dtype,
device="hpu",
)

#############sc
if args.world_size > 1:
override_print(args.global_rank == 0)
override_logger(logger, args.global_rank == 0)

import deepspeed

logger.info("DeepSpeed is enabled.")
deepspeed.init_distributed(dist_backend="hccl")
generator.model.eval()

ds_inference_kwargs = {"dtype": model_dtype}
ds_inference_kwargs["tensor_parallel"] = {"tp_size": args.world_size}
ds_inference_kwargs["enable_cuda_graph"] = args.use_hpu_graphs
#ds_inference_kwargs["injection_policy"] = {}#get_ds_injection_policy(config)

generator.model = deepspeed.init_inference(generator.model, **ds_inference_kwargs).module

else:
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

generator.model = wrap_in_hpu_graph(generator.model)


###########sc


generate_kwargs = {
"lazy_mode": True,
"hpu_graphs": args.use_hpu_graphs,
Expand All @@ -198,11 +260,6 @@ def main():
if args.use_kv_cache:
generate_kwargs["use_cache"] = args.use_kv_cache

if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

generator.model = wrap_in_hpu_graph(generator.model)

if args.quant_config:
generator.model = setup_quantization(generator.model, args)
htcore.hpu_initialize(generator.model)
Expand Down
19 changes: 12 additions & 7 deletions optimum/habana/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,18 @@ def forward(
attn_weights_reshaped = None
# get query proj
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) #[5,2,577,64]
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
#print(f"*****************self.num_heads:{self.num_heads}, self.head_dim:{self.head_dim}, query_states shape:{query_states.shape}")
#8x: num_heads 2, head_dim 64, [5, 577, 128] #heads are split across 8x
#1x: num_heads 16, head_dim 64, [5,577, 1024] (bs, tgt_len, hidden_size)
proj_shape = (bsz * self.num_heads, -1, self.head_dim) #(5*2, 577, 64) #actual shape [5,577,128]
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) #(5*2, 577, 64)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)

src_len = key_states.size(1)
src_len = key_states.size(1) #577
if FusedSDPA and use_flash_attention:
import habana_frameworks.torch.hpu as ht

Expand Down Expand Up @@ -154,9 +157,11 @@ def forward(
f" {attn_output.size()}"
)

attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
#8x: num_heads 2, head_dim 64, [5, 577, 128] #heads are split across 8x
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) #5,2,577,64
attn_output = attn_output.transpose(1, 2) #5,577,2,64
#attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) #5, 577, 128
attn_output = attn_output.reshape(bsz, tgt_len, -1)

attn_output = self.out_proj(attn_output)

Expand Down

0 comments on commit 1f4a2df

Please sign in to comment.