Skip to content

Commit

Permalink
merge deltanet changes (#25)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
Co-authored-by: Yikang Shen <[email protected]>
  • Loading branch information
mayank31398 and yikangshen authored Sep 26, 2024
1 parent 95e54d3 commit 9add88a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 126 deletions.
87 changes: 0 additions & 87 deletions configs/research/rnn/rnn_debug.yml

This file was deleted.

51 changes: 12 additions & 39 deletions dolomite_engine/hf_models/models/rnn_dolomite/attention/deltanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,6 @@
from fla.ops.delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule


def simple_norm(x: torch.Tensor) -> torch.Tensor:
return F.normalize(x, dim=-1) * x.shape[-1] ** 0.5


def elu_p1(x: torch.Tensor) -> torch.Tensor:
return F.elu(x) + 1


def sum_norm(x: torch.Tensor) -> torch.Tensor:
return x / x.sum(-1, keepdim=True)


def elu_norm(x: torch.Tensor) -> torch.Tensor:
x = elu_p1(x)
return x / x.sum(-1, keepdim=True)


if is_fla_available():

class ParameterizedShortConvolution(ShortConvolution):
Expand Down Expand Up @@ -102,8 +85,6 @@ def __init__(
self.head_v_dim = self.value_dim // self.num_heads
self.layer_idx = layer_idx

self.silu = nn.SiLU()

assert mode in ["chunk", "fused_chunk", "fused_recurrent"], f"Not suppoerted mode `{mode}`."
assert self.key_dim % self.num_heads == 0, f"key dim must be divisible by num_heads of {self.num_heads}"
assert self.value_dim % self.num_heads == 0, f"value dim must be divisible by num_heads of {self.num_heads}"
Expand All @@ -113,9 +94,7 @@ def __init__(
init_method = InitMethod(config.init_method)
if init_method == InitMethod.mup:
std_in /= math.sqrt(config.m_width)
self.q_proj = ParameterizedLinear(self.hidden_size, self.key_dim, bias=False, std=std_in)
self.k_proj = ParameterizedLinear(self.hidden_size, self.key_dim, bias=False, std=std_in)
self.v_proj = ParameterizedLinear(self.hidden_size, self.value_dim, bias=False, std=std_in)
self.c_attn = ParameterizedLinear(self.hidden_size, 2 * self.key_dim + self.value_dim, bias=False, std=std_in)

if use_short_conv:
std_conv = initializer_range
Expand Down Expand Up @@ -173,31 +152,24 @@ def forward(
if attention_mask.shape[-1] != hidden_states.shape[-2]:
attention_mask = attention_mask[:, -1:]

c_attn = self.c_attn(hidden_states)
q, k, v = c_attn.split((self.key_dim, self.key_dim, self.value_dim), dim=-1)

if self.use_short_conv:
conv_state = last_state[0] if use_cache else None
if self.share_conv_kernel:
# conv state is updated inplace
hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state)

q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
else:
conv_state_q = last_state[0] if use_cache else None
conv_state_k = last_state[1] if use_cache else None
conv_state_v = last_state[2] if use_cache else None

q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)

q = self.q_conv1d(q, attention_mask, conv_state_q)
k = self.k_conv1d(k, attention_mask, conv_state_k)
v = self.v_conv1d(v, attention_mask, conv_state_v)
else:
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.silu(self.v_proj(hidden_states))
v = F.silu(v)

# dealing with left-padding
if attention_mask is not None:
Expand All @@ -207,22 +179,23 @@ def forward(

if self.qk_activation != "silu":
if self.qk_activation == "relu":
q, k = q.relu(), k.relu()
q = F.relu(q)
k = F.relu(k)
elif self.qk_activation == "elu":
q = elu_p1(q)
k = elu_p1(k)
q = F.elu(q) + 1
k = F.elu(k) + 1
elif self.qk_activation == "identity":
pass
else:
raise NotImplementedError

if self.qk_norm is not None:
if self.qk_norm == "l2":
k = F.normalize(k, dim=-1, p=2)
q = F.normalize(q, dim=-1, p=2)
k = F.normalize(k, dim=-1, p=2)
elif self.qk_norm == "sum":
q = sum_norm(q)
k = sum_norm(k)
q = q / q.sum(-1, keepdim=True)
k = k / k.sum(-1, keepdim=True)

if self.use_beta:
beta = rearrange(self.b_proj(hidden_states), "b l h -> b h l").sigmoid()
Expand Down

0 comments on commit 9add88a

Please sign in to comment.