-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbertEmbeddings.py
55 lines (46 loc) · 2 KB
/
bertEmbeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import tensorflow as tf
import numpy as np
from tfRecordTools import *
from dataLoader import *
def createBertEmbeddingExample(wordVector, recordID, reverseWordIndex, encoder, preprocessor):
"""
Create tf.Example containing the sample's embedding and its ID.
Args:
wordVector - (np.ndarray) the text to decode
recordId - (int) ID of the sample
reverseWordIndex - (dict) The reverse word index to use
encoder - (string) encoder name
preprocessor - (string) preprocessor name
Returns:
example - (tf.Example) tf.Example containing the sample's embedding and its ID
"""
text = decodeReview(wordVector, reverseWordIndex)
# Shape = [batch_size,].
sentenceEmbedding = encoder(preprocessor(tf.reshape(text, shape=[-1, ])))['pooled_output']
# Flatten the sentence embedding back to 1-D.
sentenceEmbedding = tf.reshape(sentenceEmbedding, shape=[-1])
features = {
'id': bytesFeature(str(recordID)),
'embedding': floatFeature(sentenceEmbedding.numpy())
}
return tf.train.Example(features=tf.train.Features(feature=features))
def createBertEmbedding(wordVectors, outputPath, startingRecordId, reverseWordIndex, encoder, preprocessor):
"""
Create full set of BERT embeddings
Args:
wordVectors - (np.ndarray) all text to decode
outputPath - (string) path to output file
startingRecordId - (int) ID of the first sample
reverseWordIndex - (dict) The reverse word index to use
encoder - (string) encoder name
preprocessor - (string) preprocessor name
Returns:
recordID - (int) ID of the last sample
"""
recordID = int(startingRecordId)
with tf.io.TFRecordWriter(outputPath) as writer:
for word_vector in wordVectors:
example = createBertEmbeddingExample(word_vector, recordID, reverseWordIndex, encoder, preprocessor)
recordID = recordID + 1
writer.write(example.SerializeToString())
return recordID