Skip to content

Commit

Permalink
fixing some issues in attention module
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 14, 2024
1 parent d6c0fd3 commit a96c2ae
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
8 changes: 2 additions & 6 deletions src/python/easydel/modules/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(
shard_attention_computation: bool = True,
use_sharding_constraint: Optional[bool] = False,
axis_name: str = "sp",
backward_pass_impl: Literal["triton", "xla"] = "xla"
backward_pass_impl: Literal["triton", "xla"] = "triton"
):
platform = jax.lib.xla_bridge.get_backend().platform
if sm_scale is None:
Expand Down Expand Up @@ -935,10 +935,6 @@ def pallas_flash_attention(
query_sequence_length: int = None,
bias: Optional[Array] = None,
) -> AttentionOutput:
"""
TIP: for using this attention module set bias_partition_spec to (("dp","fsdp",),"sp")
"""

if query_sequence_length is None:
query_sequence_length = query_states.shape[1]
qps, kps, vps, bps, aps, is_gen = self.get_partition_specs(qs=query_sequence_length)
Expand Down Expand Up @@ -988,7 +984,7 @@ def cuddn_flash_attention(
except (ModuleNotFoundError, ImportError) as err:
raise RuntimeError(
"Please install transformer_engine first. you can install that by running "
f"`pip install git+https://github.com/NVIDIA/transformer_engine`"
f"`pip install git+https://github.com/NVIDIA/TransformerEngine`"
f"\nhere's extra information on error\n{err}"
)
batch, query_sequence_length, num_attention_heads, head_dim = query_states.shape
Expand Down
4 changes: 2 additions & 2 deletions src/python/easydel/modules/easydel_modelling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
scan_mlp_chunk_size: int = 1024,
attention_axis_name: str = "sp",
quantize_kv_cache: bool = False,
flash_attention_backward_pass_impl: Literal["triton", "xla"] = "xla",
flash_attention_backward_pass_impl: Literal["triton", "xla"] = "triton",
**kwargs
):
self.query_partition_spec = query_partition_spec
Expand Down Expand Up @@ -413,7 +413,7 @@ def add_basic_configurations(
set_attrs_smartly(self, "scan_mlp_chunk_size", 1024, scan_mlp_chunk_size)
set_attrs_smartly(self, "attention_axis_name", "sp", attention_axis_name)
set_attrs_smartly(self, "quantize_kv_cache", False, quantize_kv_cache)
set_attrs_smartly(self, "flash_attention_backward_pass_impl", "xla", flash_attention_backward_pass_impl)
set_attrs_smartly(self, "flash_attention_backward_pass_impl", "triton", flash_attention_backward_pass_impl)

def __repr__(self):

Expand Down

0 comments on commit a96c2ae

Please sign in to comment.