Skip to content

Commit

Permalink
updating attention module and now pallas_flash is fully supported a…
Browse files Browse the repository at this point in the history
…cross all the platforms
  • Loading branch information
erfanzar committed May 14, 2024
1 parent bbcb1bc commit 3c2154a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 25 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ With its comprehensive set of features and tools, EasyDeL aims to streamline and
of machine learning models, particularly in the domain of large language models and video-related applications.

> **News**
>
>
> `pallas_flash` is now available for CPU/GPU/TPU with custom pallas kernel.
>
> DeepseekV2 Model is Added (beta mood).
>
> OpenELM Model is Added.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"jax>=0.4.20",
"jaxlib>=0.4.20",
"flax>=0.7.5",
"fjformer>=0.0.56",
"fjformer>=0.0.57",
"transformers>=4.33.0",
"einops~=0.6.1",
"optax~=0.1.7",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ typing~=3.7.4.3
jax>=0.4.20
jaxlib>=0.4.20
flax>=0.7.5
fjformer>=0.0.56
fjformer>=0.0.57
transformers>=4.34.0
einops~=0.6.1
optax~=0.1.7
Expand Down
37 changes: 15 additions & 22 deletions src/python/easydel/modules/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from fjformer import with_sharding_constraint

try:
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention as tpu_flash_attention
from jax.experimental.pallas.ops.tpu.flash_attention import BlockSizes as BlockSizesFlashAttn

except (ModuleNotFoundError, ImportError) as e:
from fjformer.pallas_operations.flash_attention.tpu import flash_attention
from fjformer.pallas_operations.flash_attention.tpu import BlockSizes as BlockSizesFlashAttn
from fjformer.pallas_operations.tpu_flash_attention.tpu import flash_attention as tpu_flash_attention
from fjformer.pallas_operations.tpu_flash_attention.tpu import BlockSizes as BlockSizesFlashAttn
from fjformer.pallas_operations.ring_attention import ring_flash_attention_tpu

try:
Expand Down Expand Up @@ -105,11 +105,11 @@ def get_flash_attention() -> Tuple[Callable, bool, bool]:
if platform == "gpu":
warnings.warn("for GPU backend use `cudnn` or `pallas_flash`")
float32_logits = False
ring_attention_fn = mha
ring_attention_fn = flash_attention
do_shard_map = True
elif platform == "tpu":
float32_logits = True
ring_attention_fn = flash_attention
ring_attention_fn = tpu_flash_attention
do_shard_map = False
else:
raise ValueError(f"Unsupported platform {platform}")
Expand Down Expand Up @@ -387,14 +387,12 @@ def __call__(
key_value_sequence_length=key_value_sequence_length
)
elif self.attn_mechanism == "pallas_flash":
return self.pallas_mha(
return self.pallas_flash_attention(
query_states=query_states,
key_states=key_states,
value_states=value_states,
query_sequence_length=query_sequence_length,
attention_mask=attention_mask,
bias=bias,
causal=causal
)
elif self.attn_mechanism == "splash":
if segment_ids is not None:
Expand Down Expand Up @@ -926,23 +924,18 @@ def splash_attention_call(q, k, v, am):
attention_weights=None
)

def pallas_mha(
def pallas_flash_attention(
self,
*,
query_states: Array,
key_states: Array,
value_states: Array,
query_sequence_length: int = None,
attention_mask: Optional[Array] = None,
bias: Optional[Array] = None,
causal: bool = True,
) -> AttentionOutput:
"""
TIP: for using this attention module set bias_partition_spec to (("dp","fsdp",),"sp")
"""
assert attention_mask is not None, "`attention_mask` is required for `pallas_mha`"
if attention_mask.ndim == 4:
attention_mask = attention_mask[:, 0, -1]

if query_sequence_length is None:
query_sequence_length = query_states.shape[1]
Expand All @@ -955,16 +948,16 @@ def pallas_mha(
query_states = with_sharding_constraint(query_states, qps)
key_states = with_sharding_constraint(key_states, kps)
value_states = with_sharding_constraint(value_states, vps)
attention_mask = with_sharding_constraint(attention_mask, PartitionSpec(qps[0], qps[1]))
bias = with_sharding_constraint(bias, bps)
attention_outputs = flash_attention(
query_states,
key_states,
value_states,
attention_mask=attention_mask.astype("int"),
bias=bias,
sm_scale=self.sm_scale,
causal=True,
block_k=self.block_k,
block_q=self.block_q
block_q=self.block_q,
interpret=True if self.platform == "cpu" else None # auto-decide
)
attention_outputs = with_sharding_constraint(attention_outputs, aps)
return AttentionOutput(
Expand Down Expand Up @@ -1127,11 +1120,11 @@ def make_inputs():
"local_ring",
"blockwise",
"vanilla",
# "wise_ring",
"wise_ring",
"sharded_vanilla",
# "flash",
# "splash",
# "cudnn",
"flash",
"splash",
"cudnn",
"pallas_flash"
]
fns = {
Expand Down

0 comments on commit 3c2154a

Please sign in to comment.