Skip to content

Commit

Permalink
fixing smap_req for pallas-flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 16, 2024
1 parent a96c2ae commit 6fa2a38
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions src/python/easydel/modules/attention_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,20 +943,30 @@ def pallas_flash_attention(
lambda s: s.astype(jnp.float32),
(query_states, key_states, value_states)
)
query_states = with_sharding_constraint(query_states, qps)
key_states = with_sharding_constraint(key_states, kps)
value_states = with_sharding_constraint(value_states, vps)
bias = with_sharding_constraint(bias, bps)
attention_outputs = flash_attention(
query_states,
key_states,
value_states,
bias=bias,
# query_states = with_sharding_constraint(query_states, qps)
# key_states = with_sharding_constraint(key_states, kps)
# value_states = with_sharding_constraint(value_states, vps)
# bias = with_sharding_constraint(bias, bps)
wrapped_fn = partial(
flash_attention,
sm_scale=self.sm_scale,
block_k=self.block_k,
block_q=self.block_q,
interpret=True if self.platform == "cpu" else None, # auto-decide
backward_pass_impl=self.backward_pass_impl
backward_pass_impl=self.backward_pass_impl,
debug=False
)
attention_outputs = shard_map(
f=wrapped_fn,
in_specs=(qps, kps, vps, bps),
out_specs=aps,
mesh=self.mesh,
check_rep=False,
)(
query_states,
key_states,
value_states,
bias,
)
attention_outputs = with_sharding_constraint(attention_outputs, aps)
return AttentionOutput(
Expand Down

0 comments on commit 6fa2a38

Please sign in to comment.