Skip to content

Commit

Permalink
using int32 instead of int64 for alibi
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed May 16, 2024
1 parent 9679be5 commit 42bd785
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/python/easydel/modules/mosaic_mpt/modelling_mpt_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ def __call__(
def build_mpt_alibi_tensor(num_heads, sequence_length, alibi_bias_max=8):
alibi = jnp.arange(1 - sequence_length, 1, dtype="i4").reshape(1, 1, 1, sequence_length)
num_heads_power_of_2 = 2 ** math.ceil(math.log2(num_heads))
jax.config.update("jax_enable_x64", True)
base = jnp.arange(1, num_heads_power_of_2 + 1, dtype=jnp.int64).astype("float32")
base = jnp.arange(1, num_heads_power_of_2 + 1, dtype=jnp.int32).astype("float32")
base = base * (alibi_bias_max / num_heads_power_of_2)

slopes = 1.0 / jnp.pow(2, base)
Expand Down

0 comments on commit 42bd785

Please sign in to comment.