-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathGPT2.swift
152 lines (126 loc) · 5.57 KB
/
GPT2.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Checkpoints
import Foundation
import ModelSupport
import TensorFlow
public class GPT2 {
public static let remoteCheckpoint: URL =
URL(string: "https://openaipublic.blob.core.windows.net/gpt-2/models/117M/model.ckpt")!
public enum GPT2Error: Error {
case invalidEncoding(id: Int32)
}
public var model: TransformerLM
public let bpe: BytePairEncoder
public let contextSize: Int
public var seed: Tensor<Int32>
public var temperature: Float = 1.0
private var states: [AttentionContext]
private let endOfText = "<|endoftext|>"
private var endOfTextId = 0
internal var storage: URL
public init(checkpoint: URL = GPT2.remoteCheckpoint) throws {
var parameters = TransformerLMConfig(
vocabSize: 1, contextSize: 1024,
embeddingSize: 768, headCount: 12, layerCount: 12)
// Try loading from the given checkpoint.
do {
let auxiliary: [String] = [
"checkpoint",
"encoder.json",
"hparams.json",
"model.ckpt.meta",
"vocab.bpe",
]
let reader: CheckpointReader = try CheckpointReader(
checkpointLocation: checkpoint,
modelName: "GPT2-\(checkpoint.pathComponents.dropLast().last ?? "")",
additionalFiles: auxiliary)
// TODO(michellecasbon): expose this.
reader.isCRCVerificationEnabled = false
storage = reader.localCheckpointLocation.deletingLastPathComponent()
// Load model configuration.
let hparamsFile: URL = storage.appendingPathComponent("hparams.json")
let configuration: (file: URL, data: Data) = try (
hparamsFile, Data(contentsOf: hparamsFile)
)
parameters = try JSONDecoder().decode(
TransformerLMConfig.self,
from: configuration.data)
// Initialize a model with the given config.
model = TransformerLM(reader: reader, config: parameters, scope: "model")
// Load existing token mappings.
let vocabularyFileURL: URL = storage.appendingPathComponent("encoder.json")
let vocabulary: (file: URL, data: Data) = try (
vocabularyFileURL, Data(contentsOf: vocabularyFileURL)
)
// Load existing merge pairs.
let mergesFileURL: URL = storage.appendingPathComponent("vocab.bpe")
let merges: (file: URL, data: Data) = try (
mergesFileURL, Data(contentsOf: mergesFileURL)
)
// Create a bytepair encoder with loaded token mappings.
bpe = try BytePairEncoder(
vocabularyFile: vocabulary.file, mergesFile: merges.file)
endOfTextId = bpe.vocabulary.id(forToken: endOfText)!
print("GPT-2 loaded from checkpoint successfully.")
} catch {
// If checkpoint is invalid, throw the error and exit.
print("Fail to load GPT-2 from checkpoint. \(error)")
throw error
}
contextSize = parameters.contextSize
// TODO: Add argument that controls this.
seed = Tensor(shape: [1, 1], scalars: [Int32(endOfTextId)])
// Reset attention context for each layer.
let empty =
Tensor<Float>(zeros: [
parameters.headCount, 0,
parameters.embeddingSize / parameters.headCount,
])
states = (0..<parameters.layerCount).map { _ in
AttentionContext(key: empty, value: empty)
}
print("GPT-2 init complete.")
}
public func embedding(for string: String) -> Tensor<Int32> {
let tokens = bpe.encode(token: string, variant: .gpt2)
// TODO(michellecasbon): Decide how to prevent OOV or choose a better ID (probably not 0).
let ids = tokens.map { Int32(bpe.vocabulary.id(forToken: $0) ?? 0) }
return Tensor(shape: [1, ids.count], scalars: ids)
}
public func generate() throws -> String {
let result = model(seed, states: &states)
let (batchSize, timesteps, vocabularySize) =
(result.shape[0], result.shape[1], result.shape[2])
let logits =
result.slice(
lowerBounds: [0, timesteps - 1, 0],
upperBounds: [batchSize, timesteps, vocabularySize]) / temperature
seed = Tensor(
randomCategorialLogits: logits.squeezingShape(at: 1),
sampleCount: 1)
let id = Int32(seed[0][0])!
if id == Int32(endOfTextId) {
// Replace with newline.
return "\r\n"
}
if let token: String = bpe.vocabulary.token(forId: Int(id)) {
let decodedToken = BytePairEncoder.decode(token: token)
// Make any line breaks universal.
return decodedToken.replacingOccurrences(of: "\n", with: "\r\n")
}
throw GPT2Error.invalidEncoding(id: id)
}
}