-
Notifications
You must be signed in to change notification settings - Fork 6
/
embed.py
38 lines (31 loc) · 1.28 KB
/
embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy
import jax.numpy as np
import flax.linen as nn
class EmbeddingLearned(nn.Module):
vocab_size: int
embed_size: int
max_seq_len: int
@nn.compact
def __call__(self, inputs):
length = inputs.shape[-1]
positions = np.arange(start=0, stop=length, step=1)
word_embedding = nn.Embed(self.vocab_size, self.embed_size)(inputs)
position_embedding = nn.Embed(self.max_seq_len, self.embed_size)(positions)
return word_embedding + position_embedding
class EmbeddingFixed(nn.Module):
vocab_size: int
embed_size: int
max_seq_len: int
def setup(self):
pe = numpy.zeros((self.max_seq_len, self.embed_size), dtype=numpy.float32)
position = numpy.arange(0, self.max_seq_len)[:, numpy.newaxis]
div_term = numpy.exp(numpy.arange(0, self.embed_size, 2) * -(numpy.log(10000.0) / self.embed_size))
pe[:, 0::2] = numpy.sin(position * div_term)
pe[:, 1::2] = numpy.cos(position * div_term)
pe = pe[numpy.newaxis, :, :] # [1, T, H]
self.pe = np.array(pe)
@nn.compact
def __call__(self, x):
word_embedding = nn.Embed(self.vocab_size, self.embed_size)(x)
positional_embedding = self.pe[:, 0:x.shape[1], :]
return word_embedding + positional_embedding