Skip to content

Commit

Permalink
adding option to modify the backward kernel for pallas flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 14, 2024
1 parent 3c2154a commit d6c0fd3
Show file tree
Hide file tree
Showing 15 changed files with 38 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/python/easydel/modules/arctic/modelling_arctic_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def setup(self) -> None:
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions src/python/easydel/modules/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
shard_attention_computation: bool = True,
use_sharding_constraint: Optional[bool] = False,
axis_name: str = "sp",
backward_pass_impl: Literal["triton", "xla"] = "xla"
):
platform = jax.lib.xla_bridge.get_backend().platform
if sm_scale is None:
Expand Down Expand Up @@ -202,6 +203,7 @@ def __init__(
self.generation_bias_partition_spec = generation_bias_partition_spec
self.generation_attention_partition_spec = generation_attention_partition_spec
self.axis_name = axis_name
self.backward_pass_impl = backward_pass_impl
if attn_mechanism == "splash" and self.platform != "tpu":
raise OSError("splash attention is only supported on TPU.")
if attn_mechanism == "flash" and self.platform != "tpu":
Expand Down Expand Up @@ -957,7 +959,8 @@ def pallas_flash_attention(
sm_scale=self.sm_scale,
block_k=self.block_k,
block_q=self.block_q,
interpret=True if self.platform == "cpu" else None # auto-decide
interpret=True if self.platform == "cpu" else None, # auto-decide
backward_pass_impl=self.backward_pass_impl
)
attention_outputs = with_sharding_constraint(attention_outputs, aps)
return AttentionOutput(
Expand Down Expand Up @@ -1138,7 +1141,7 @@ def make_inputs():
end = time.time() - start
outs_and_grads[nm] = out + (end,)
except Exception as e:
print(f"{nm} failled :\n\n{e}")
print(f"{nm} is Failed :\n\n{e}")
outs_and_grads[nm] = (None, None, None)
frame_out = {}
for key, (out, grad, time_took) in outs_and_grads.items():
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/cohere/modelling_cohere_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

def _merge_heads(self, hidden_states):
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/dbrx/modelling_dbrx_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)
self.resid_dropout = flax.linen.Dropout(rate=config.resid_pdrop)

Expand Down
8 changes: 7 additions & 1 deletion src/python/easydel/modules/easydel_modelling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class EasyDeLPretrainedConfig(PretrainedConfig):
:param use_sharding_constraint: bool: whether to use sharding constraint for the arrays
:param use_scan_mlp: bool: Determine whether to use scan_mlp or not
:param backend: Optional[None]: Specify the backend to use
:param flash_attention_backward_pass_impl: Literal["triton", "xla"]: Specify the backward pass kernel for flash attention
"""

def __init__(
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
scan_mlp_chunk_size: int = 1024,
attention_axis_name: str = "sp",
quantize_kv_cache: bool = False,
flash_attention_backward_pass_impl: Literal["triton", "xla"] = "xla",
**kwargs
):
self.query_partition_spec = query_partition_spec
Expand Down Expand Up @@ -142,6 +144,7 @@ def __init__(
self.use_sharding_constraint = use_sharding_constraint
self.attention_axis_name = attention_axis_name
self.quantize_kv_cache = quantize_kv_cache
self.flash_attention_backward_pass_impl = flash_attention_backward_pass_impl
super().__init__(**kwargs)

@staticmethod
Expand Down Expand Up @@ -284,7 +287,8 @@ def add_basic_configurations(
use_scan_mlp: bool = ...,
scan_mlp_chunk_size: int = ...,
attention_axis_name: str = ...,
quantize_kv_cache: bool = ...
quantize_kv_cache: bool = ...,
flash_attention_backward_pass_impl: Literal["triton", "xla"] = ...
):
"""
It initializes all the attributes of an object, and it's called when you create a new instance of that class.
Expand Down Expand Up @@ -326,6 +330,7 @@ def add_basic_configurations(
:param scan_mlp_chunk_size: int: Size of chunks in scan MLP.
:param attention_axis_name: str: Name of the attention axis name
:param quantize_kv_cache: bool: Whether to quantize Key/Value in attention for generation process.
:param flash_attention_backward_pass_impl: Literal["triton", "xla"]: Specify the backward pass kernel for flash attention
"""
set_attrs_smartly(self, "axis_dims", (1, -1, 1, 1), axis_dims)
set_attrs_smartly(self, "axis_names", ("dp", "fsdp", "tp", "sp"), axis_names)
Expand Down Expand Up @@ -408,6 +413,7 @@ def add_basic_configurations(
set_attrs_smartly(self, "scan_mlp_chunk_size", 1024, scan_mlp_chunk_size)
set_attrs_smartly(self, "attention_axis_name", "sp", attention_axis_name)
set_attrs_smartly(self, "quantize_kv_cache", False, quantize_kv_cache)
set_attrs_smartly(self, "flash_attention_backward_pass_impl", "xla", flash_attention_backward_pass_impl)

def __repr__(self):

Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def setup(self) -> None:
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.config.n_heads),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)
if self.config.qk_ln:
self.q_ln = nn.LayerNorm(use_bias=self.config.use_norm_bias)
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/openelm/modelling_openelm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

self.head_dim = config.head_dim
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/phi/modelling_phi_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

def _merge_heads(self, hidden_states):
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/phi3/modelling_phi3_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

def _merge_heads(self, hidden_states):
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/qwen1/modelling_qwen1_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

def _merge_heads(self, hidden_states):
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/qwen2/modelling_qwen_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)
self.resid_dropout = flax.linen.Dropout(rate=config.resid_pdrop)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)
self.resid_dropout = flax.linen.Dropout(rate=config.attention_dropout)

Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/roberta/modelling_roberta_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)
self.query = Linear(
self.config.hidden_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def setup(self):
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

def _merge_heads(self, hidden_states):
Expand Down
3 changes: 2 additions & 1 deletion src/python/easydel/modules/whisper/modelling_whisper_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def setup(self) -> None:
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
axis_name=self.config.attention_axis_name
axis_name=self.config.attention_axis_name,
backward_pass_impl=self.config.flash_attention_backward_pass_impl
)

def __call__(
Expand Down

0 comments on commit d6c0fd3

Please sign in to comment.