Skip to content

Commit

Permalink
Merge pull request #80 from neverix/main
Browse files Browse the repository at this point in the history
Add parallel forward (#74)
  • Loading branch information
kuprel authored Aug 3, 2022
2 parents 6c9aeef + 64e0fd9 commit 48793e0
Show file tree
Hide file tree
Showing 4 changed files with 330 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
**/*.ckpt
.vscode
**/.ipynb_checkpoints
generated.png
**/generated
**/pretrained
**/*.msgpack
*.egg-info/
.idea/
*.egg
dist
build
Expand Down
41 changes: 28 additions & 13 deletions min_dalle/models/dalle_bart_decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, List
from typing import Tuple, List, Optional, Union
import torch
from torch import nn, LongTensor, FloatTensor, BoolTensor
from .dalle_bart_encoder import GLU, AttentionBase
Expand Down Expand Up @@ -34,8 +34,11 @@ def forward(
keys = self.k_proj.forward(decoder_state)
values = self.v_proj.forward(decoder_state)
queries = self.q_proj.forward(decoder_state)
attn_state_new = torch.cat([keys, values]).to(attention_state.dtype)
attention_state[:, token_index] = attn_state_new
attn_state_new = torch.cat([keys, values])
if attention_state is None:
attention_state = attn_state_new
else:
attention_state[:, token_index[0]] = attn_state_new.to(attention_state.dtype)
batch_count = decoder_state.shape[0]
keys = attention_state[:batch_count]
values = attention_state[batch_count:]
Expand Down Expand Up @@ -66,13 +69,15 @@ def forward(
self,
decoder_state: FloatTensor,
encoder_state: FloatTensor,
attention_state: FloatTensor,
attention_state: Optional[FloatTensor],
attention_mask: BoolTensor,
token_index: LongTensor
) -> Tuple[FloatTensor, FloatTensor]:
# Self Attention
self_attn_mask = self.token_indices < token_index + 1
self_attn_mask = self_attn_mask[None][[0] * decoder_state.shape[0]]
if token_index.shape[1] == 1:
self_attn_mask = self.token_indices <= token_index
else:
self_attn_mask = self.token_indices[:token_index.shape[1]][None, None, :] <= token_index[:, :, None]
residual = decoder_state
decoder_state = self.pre_self_attn_layer_norm.forward(decoder_state)
decoder_state, attention_state = self.self_attn.forward(
Expand Down Expand Up @@ -139,18 +144,26 @@ def forward(
settings: FloatTensor,
attention_mask: BoolTensor,
encoder_state: FloatTensor,
attention_state: FloatTensor,
attention_state: Optional[FloatTensor],
prev_tokens: LongTensor,
token_index: LongTensor
) -> Tuple[LongTensor, FloatTensor]:
token_index: LongTensor,
return_logits: bool = False
) -> Union[Tuple[LongTensor, FloatTensor], FloatTensor]:
image_count = encoder_state.shape[0] // 2
token_index_batched = token_index[[0] * image_count * 2]
prev_tokens = prev_tokens[list(range(image_count)) * 2]
if token_index.ndim == 1:
token_index = token_index.unsqueeze(0).repeat(image_count * 2, 1)
token_index_batched = token_index[list(range(image_count)) * 2]
if prev_tokens.ndim == 1:
prev_tokens = prev_tokens.unsqueeze(0)
prev_tokens = prev_tokens.T[list(range(image_count)) * 2]
prev_tokens.clamp_(0, self.image_vocab_count)
decoder_state = self.embed_tokens.forward(prev_tokens)
decoder_state += self.embed_positions.forward(token_index_batched)
decoder_state = self.layernorm_embedding.forward(decoder_state)
decoder_state = decoder_state[:, None]
if decoder_state.ndim < 3:
decoder_state = decoder_state[:, None]
if attention_state is None:
attention_state = [None] * self.layer_count
for i in range(self.layer_count):
decoder_state, attention_state[i] = self.layers[i].forward(
decoder_state,
Expand All @@ -169,11 +182,13 @@ def forward(
logits[:image_count] * (1 - supercondition_factor) +
logits[image_count:] * supercondition_factor
)
if return_logits:
return logits
logits_sorted, _ = logits.sort(descending=True)
is_kept = logits >= logits_sorted[:, top_k - 1]
logits -= logits_sorted[:, [0]]
logits /= temperature
logits.exp_()
logits *= is_kept.to(torch.float32)
image_tokens = torch.multinomial(logits, 1)[:, 0]
return image_tokens, attention_state
return image_tokens, attention_state
6 changes: 5 additions & 1 deletion min_dalle/models/dalle_bart_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def forward(
queries,
keys
)
attention_weights += attention_bias[:, None, None, :]
if attention_bias.ndim == 3:
attention_bias = attention_bias[:, None, :attention_weights.shape[-2], :attention_weights.shape[-1]]
elif attention_bias.ndim == 2:
attention_bias = attention_bias[:, None, None, :attention_weights.shape[-1]]
attention_weights += attention_bias
attention_weights = torch.softmax(attention_weights, -1)
attention_output: FloatTensor = torch.einsum(
"bhqk,bkhc->bqhc",
Expand Down
Loading

0 comments on commit 48793e0

Please sign in to comment.