-
Notifications
You must be signed in to change notification settings - Fork 27.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #287 from huggingface/gpt2
Gpt2
- Loading branch information
Showing
11 changed files
with
1,588 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import logging | ||
from tqdm import trange | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer | ||
|
||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||
datefmt = '%m/%d/%Y %H:%M:%S', | ||
level = logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
def top_k_logits(logits, k): | ||
if k == 0: | ||
return logits | ||
values, _ = torch.topk(logits, k) | ||
min_values = values[:, -1] | ||
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) | ||
|
||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True): | ||
if start_token is None: | ||
assert context is not None, 'Specify exactly one of start_token and context!' | ||
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) | ||
else: | ||
assert context is None, 'Specify exactly one of start_token and context!' | ||
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) | ||
prev = context | ||
output = context | ||
past = None | ||
with torch.no_grad(): | ||
for i in trange(length): | ||
logits, past = model(prev, past=past) | ||
logits = logits[:, -1, :] / temperature | ||
logits = top_k_logits(logits, k=top_k) | ||
log_probs = F.softmax(logits, dim=-1) | ||
if sample: | ||
prev = torch.multinomial(log_probs, num_samples=1) | ||
else: | ||
_, prev = torch.topk(log_probs, k=1, dim=-1) | ||
output = torch.cat((output, prev), dim=1) | ||
return output | ||
|
||
def run_model(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint') | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--nsamples", type=int, default=1) | ||
parser.add_argument("--batch_size", type=int, default=-1) | ||
parser.add_argument("--length", type=int, default=-1) | ||
parser.add_argument("--temperature", type=int, default=1) | ||
parser.add_argument("--top_k", type=int, default=0) | ||
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.') | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
if args.batch_size == -1: | ||
args.batch_size = 1 | ||
assert args.nsamples % args.batch_size == 0 | ||
|
||
np.random.seed(args.seed) | ||
torch.random.manual_seed(args.seed) | ||
torch.cuda.manual_seed(args.seed) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) | ||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) | ||
model.to(device) | ||
model.eval() | ||
|
||
if args.length == -1: | ||
args.length = model.config.n_ctx // 2 | ||
elif args.length > model.config.n_ctx: | ||
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) | ||
|
||
while not args.unconditional: | ||
if not args.unconditional: | ||
raw_text = input("Model prompt >>> ") | ||
while not raw_text: | ||
print('Prompt should not be empty!') | ||
raw_text = input("Model prompt >>> ") | ||
context_tokens = enc.encode(raw_text) | ||
generated = 0 | ||
for _ in range(args.nsamples // args.batch_size): | ||
out = sample_sequence( | ||
model=model, length=args.length, | ||
context=context_tokens if not args.unconditional else None, | ||
start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None, | ||
batch_size=args.batch_size, | ||
temperature=args.temperature, top_k=args.top_k, device=device | ||
) | ||
out = out[:, len(context_tokens):].tolist() | ||
for i in range(args.batch_size): | ||
generated += 1 | ||
text = enc.decode(out[i]) | ||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) | ||
print(text) | ||
print("=" * 80) | ||
|
||
if __name__ == '__main__': | ||
run_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import logging | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from tqdm import trange | ||
|
||
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer | ||
|
||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||
datefmt = '%m/%d/%Y %H:%M:%S', | ||
level = logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
def top_k_logits(logits, k): | ||
if k == 0: | ||
return logits | ||
values, _ = torch.topk(logits, k) | ||
min_values = values[:, -1] | ||
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) | ||
|
||
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): | ||
if start_token is None: | ||
assert context is not None, 'Specify exactly one of start_token and context!' | ||
context = torch.tensor(context, device=device, dtype=torch.long) | ||
else: | ||
assert context is None, 'Specify exactly one of start_token and context!' | ||
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long) | ||
prev = context | ||
output = context | ||
past = None | ||
with torch.no_grad(): | ||
for i in trange(length): | ||
logits, past = model(prev, past=past) | ||
logits = logits[:, -1, :] / temperature | ||
logits = top_k_logits(logits, k=top_k) | ||
log_probs = F.softmax(logits, dim=-1) | ||
prev = torch.multinomial(log_probs, num_samples=1) | ||
output = torch.cat((output, prev), dim=1) | ||
return output | ||
|
||
def sample_model(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint') | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--nsamples", type=int, default=0) | ||
parser.add_argument("--batch_size", type=int, default=1) | ||
parser.add_argument("--length", type=int, default=-1) | ||
parser.add_argument("--temperature", type=int, default=1) | ||
parser.add_argument("--top_k", type=int, default=0) | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
np.random.seed(args.seed) | ||
torch.random.manual_seed(args.seed) | ||
torch.cuda.manual_seed(args.seed) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) | ||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) | ||
model.to(device) | ||
model.eval() | ||
|
||
if args.length == -1: | ||
args.length = model.config.n_ctx | ||
elif args.length > model.config.n_ctx: | ||
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) | ||
|
||
generated = 0 | ||
while args.nsamples == 0 or generated < args.nsamples: | ||
out = sample_sequence( | ||
model=model, length=args.length, | ||
start_token=enc.encoder['<|endoftext|>'], | ||
batch_size=args.batch_size, | ||
temperature=args.temperature, top_k=args.top_k, device=device | ||
) | ||
out = out.tolist() | ||
for i in range(args.batch_size): | ||
generated += args.batch_size | ||
text = enc.decode(out[i]) | ||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) | ||
print(text) | ||
|
||
if __name__ == '__main__': | ||
sample_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
pytorch_pretrained_bert/convert_gpt2_checkpoint_to_pytorch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The HugginFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Convert OpenAI GPT checkpoint.""" | ||
|
||
from __future__ import absolute_import, division, print_function | ||
|
||
import argparse | ||
from io import open | ||
|
||
import torch | ||
|
||
from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, | ||
GPT2Config, | ||
GPT2Model, | ||
load_tf_weights_in_gpt2) | ||
|
||
|
||
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): | ||
# Construct model | ||
if gpt2_config_file == "": | ||
config = GPT2Config() | ||
else: | ||
config = GPT2Config(gpt2_config_file) | ||
model = GPT2Model(config) | ||
|
||
# Load weights from numpy | ||
load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) | ||
|
||
# Save pytorch-model | ||
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME | ||
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME | ||
print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) | ||
torch.save(model.state_dict(), pytorch_weights_dump_path) | ||
print("Save configuration file to {}".format(pytorch_config_dump_path)) | ||
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: | ||
f.write(config.to_json_string()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
## Required parameters | ||
parser.add_argument("--gpt2_checkpoint_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path the TensorFlow checkpoint path.") | ||
parser.add_argument("--pytorch_dump_folder_path", | ||
default = None, | ||
type = str, | ||
required = True, | ||
help = "Path to the output PyTorch model.") | ||
parser.add_argument("--gpt2_config_file", | ||
default = "", | ||
type = str, | ||
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" | ||
"This specifies the model architecture.") | ||
args = parser.parse_args() | ||
convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, | ||
args.gpt2_config_file, | ||
args.pytorch_dump_folder_path) |
Oops, something went wrong.