-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathembedding.py
52 lines (38 loc) · 2.03 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Copyright (c) 2023 Graphcore Ltd.
# This is a re-implementation of Llama 2 by Graphcore Ltd
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
from poptransformer import ops
from poptransformer.layers import BaseLayer, Embedding
class BaseTransformerEmbedding(BaseLayer):
def __init__(self, context, name, vocab_size, embd_size):
super().__init__(context, name)
self.vocab_size = vocab_size
self.embd_size = embd_size
self.collect_bind_layer_weights()
def collect_bind_layer_weights(self):
self.wte = Embedding(self.context, None, self.vocab_size, self.embd_size)
def __call__(self, graph, input_ids, sequence_length):
with graph.nameScope(self.context):
embeds = self.wte(graph, input_ids, sequence_length)
return embeds
class TPTransformerEmbedding(BaseTransformerEmbedding):
def __call__(self, graph, input_ids, sequence_length):
with graph.nameScope(self.context):
embeds = self.wte(graph, input_ids, sequence_length)
embeds = graph.aiGraphcore.replicatedallreduce([embeds])
return embeds
class TransformerEmbedding(TPTransformerEmbedding, BaseTransformerEmbedding):
layer_class_map = {
'tp': TPTransformerEmbedding,
'shard': BaseTransformerEmbedding}
def __init__(self, context, name, vocab_size, embd_size):
model_type = self.model_type
self.layer_class = self.layer_class_map.get(model_type, None)
if not self.layer_class:
raise ValueError(f"Invalid model_type {model_type}, options: {self.layer_class_map.keys()}")
self.logger.debug(f'initializing model type: {self.layer_class.__name__}')
super().__init__(context, name, vocab_size, embd_size)
def __call__(self, graph, input_ids, sequence_length):
return self.layer_class.__call__(self, graph, input_ids, sequence_length)
def collect_bind_layer_weights(self):
return self.layer_class.collect_bind_layer_weights(self)