Skip to content

Commit

Permalink
fixed merge related errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Dec 2, 2023
1 parent 8c5f1c5 commit d4d3b60
Showing 1 changed file with 0 additions and 22 deletions.
22 changes: 0 additions & 22 deletions equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from .._module import field, Module
from ._dropout import Dropout
from ._embedding import RotaryPositionalEmbedding
from ._linear import Linear


Expand Down Expand Up @@ -122,7 +121,6 @@ class MultiheadAttention(Module):
output_proj: Linear
dropout: Dropout

rope_embeddings: Optional[RotaryPositionalEmbedding] = field(static=True)
num_heads: int = field(static=True)
query_size: int = field(static=True)
key_size: int = field(static=True)
Expand All @@ -134,7 +132,6 @@ class MultiheadAttention(Module):
use_key_bias: bool = field(static=True)
use_value_bias: bool = field(static=True)
use_output_bias: bool = field(static=True)
use_rope_embeddings: bool = field(static=True)
state_length: Optional[int] = field(static=True)

def __init__(
Expand All @@ -150,8 +147,6 @@ def __init__(
use_key_bias: bool = False,
use_value_bias: bool = False,
use_output_bias: bool = False,
use_rope_embeddings: bool = False,
state_length: Optional[int] = None,
dropout_p: float = 0.0,
inference: bool = False,
*,
Expand Down Expand Up @@ -212,17 +207,6 @@ def __init__(
)
self.dropout = Dropout(dropout_p, inference=inference)

if use_rope_embeddings:
if state_length is None:
raise ValueError(
"state_length must be specified when use_rope_embeddings is True"
)
else:
self.state_length = state_length
self.rope_embeddings = RotaryPositionalEmbedding(qk_size, self.state_length)
else:
self.rope_embeddings = None

self.num_heads = num_heads
self.query_size = query_size
self.key_size = key_size
Expand All @@ -234,7 +218,6 @@ def __init__(
self.use_key_bias = use_key_bias
self.use_value_bias = use_value_bias
self.use_output_bias = use_output_bias
self.use_rope_embeddings = use_rope_embeddings

@jax.named_scope("eqx.nn.MultiheadAttention")
def __call__(
Expand Down Expand Up @@ -306,11 +289,6 @@ def __call__(

query_heads = self._project(self.query_proj, query)
key_heads = self._project(self.key_proj, key_)
if self.use_rope_embeddings and self.rope_embeddings is not None:
query_heads = jax.vmap(self.rope_embeddings, in_axes=1, out_axes=1)(
query_heads
)
key_heads = jax.vmap(self.rope_embeddings, in_axes=1, out_axes=1)(key_heads)
value_heads = self._project(self.value_proj, value)

q_shape, k_shape, v_shape = (
Expand Down

0 comments on commit d4d3b60

Please sign in to comment.