forked from openai/DALL-E
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2d47195
Showing
11 changed files
with
499 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# OS specific | ||
*.DS_Store | ||
|
||
# Python | ||
/build | ||
/dist | ||
__pycache__ | ||
*.ipynb_checkpoints | ||
*.egg-info | ||
|
||
# Vim | ||
*.vim | ||
*.swk | ||
*.swl | ||
*.swm | ||
*.swn | ||
*.swo | ||
*.swp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
Modified MIT License | ||
|
||
Software Copyright (c) 2021 OpenAI | ||
|
||
We don’t claim ownership of the content you create with the DALL-E discrete VAE, so it is yours to | ||
do with as you please. We only ask that you use the model responsibly and clearly indicate that it | ||
was used. | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and | ||
associated documentation files (the "Software"), to deal in the Software without restriction, | ||
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, | ||
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, | ||
subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included | ||
in all copies or substantial portions of the Software. | ||
The above copyright notice and this permission notice need not be included | ||
with content created by the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | ||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS | ||
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE | ||
OR OTHER DEALINGS IN THE SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Overview | ||
|
||
[[Blog]](https://openai.com/blog/dall-e/) [[Paper]](https://arxiv.org/abs/2102.12092) [[Model Card]](model_card.md) [[Usage]](notebooks/usage.ipynb) | ||
|
||
This is the official PyTorch package for the discrete VAE used for DALL·E. | ||
|
||
# Installation | ||
|
||
Before running [the example notebook](notebooks/usage.ipynb), you will need to install the package using | ||
|
||
pip install git+https://github.com/openai/DALL-E.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import io, requests | ||
import torch | ||
import torch.nn as nn | ||
|
||
from dall_e.encoder import Encoder | ||
from dall_e.decoder import Decoder | ||
from dall_e.utils import map_pixels, unmap_pixels | ||
|
||
def load_model(path: str, device: torch.device = None) -> nn.Module: | ||
if path.startswith('http://') or path.startswith('https://'): | ||
resp = requests.get(path) | ||
resp.raise_for_status() | ||
|
||
with io.BytesIO(resp.content) as buf: | ||
return torch.load(buf, map_location=device) | ||
else: | ||
with open(path, 'rb') as f: | ||
return torch.load(f, map_location=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import attr | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from collections import OrderedDict | ||
from functools import partial | ||
from dall_e.utils import Conv2d | ||
|
||
@attr.s(eq=False, repr=False) | ||
class DecoderBlock(nn.Module): | ||
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) | ||
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) | ||
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) | ||
|
||
device: torch.device = attr.ib(default=None) | ||
requires_grad: bool = attr.ib(default=False) | ||
|
||
def __attrs_post_init__(self) -> None: | ||
super().__init__() | ||
self.n_hid = self.n_out // 4 | ||
self.post_gain = 1 / (self.n_layers ** 2) | ||
|
||
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) | ||
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() | ||
self.res_path = nn.Sequential(OrderedDict([ | ||
('relu_1', nn.ReLU()), | ||
('conv_1', make_conv(self.n_in, self.n_hid, 1)), | ||
('relu_2', nn.ReLU()), | ||
('conv_2', make_conv(self.n_hid, self.n_hid, 3)), | ||
('relu_3', nn.ReLU()), | ||
('conv_3', make_conv(self.n_hid, self.n_hid, 3)), | ||
('relu_4', nn.ReLU()), | ||
('conv_4', make_conv(self.n_hid, self.n_out, 3)),])) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.id_path(x) + self.post_gain * self.res_path(x) | ||
|
||
@attr.s(eq=False, repr=False) | ||
class Decoder(nn.Module): | ||
group_count: int = 4 | ||
n_init: int = attr.ib(default=128, validator=lambda i, a, x: x >= 8) | ||
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) | ||
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) | ||
output_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) | ||
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) | ||
|
||
device: torch.device = attr.ib(default=torch.device('cpu')) | ||
requires_grad: bool = attr.ib(default=False) | ||
use_mixed_precision: bool = attr.ib(default=True) | ||
|
||
def __attrs_post_init__(self) -> None: | ||
super().__init__() | ||
|
||
blk_range = range(self.n_blk_per_group) | ||
n_layers = self.group_count * self.n_blk_per_group | ||
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) | ||
make_blk = partial(DecoderBlock, n_layers=n_layers, device=self.device, | ||
requires_grad=self.requires_grad) | ||
|
||
self.blocks = nn.Sequential(OrderedDict([ | ||
('input', make_conv(self.vocab_size, self.n_init, 1, use_float16=False)), | ||
('group_1', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(self.n_init if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], | ||
('upsample', nn.Upsample(scale_factor=2, mode='nearest')), | ||
]))), | ||
('group_2', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(8 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], | ||
('upsample', nn.Upsample(scale_factor=2, mode='nearest')), | ||
]))), | ||
('group_3', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], | ||
('upsample', nn.Upsample(scale_factor=2, mode='nearest')), | ||
]))), | ||
('group_4', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], | ||
]))), | ||
('output', nn.Sequential(OrderedDict([ | ||
('relu', nn.ReLU()), | ||
('conv', make_conv(1 * self.n_hid, 2 * self.output_channels, 1)), | ||
]))), | ||
])) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
if len(x.shape) != 4: | ||
raise ValueError(f'input shape {x.shape} is not 4d') | ||
if x.shape[1] != self.vocab_size: | ||
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.vocab_size}') | ||
if x.dtype != torch.float32: | ||
raise ValueError('input must have dtype torch.float32') | ||
|
||
return self.blocks(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import attr | ||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from collections import OrderedDict | ||
from functools import partial | ||
from dall_e.utils import Conv2d | ||
|
||
@attr.s(eq=False, repr=False) | ||
class EncoderBlock(nn.Module): | ||
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) | ||
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 4 ==0) | ||
n_layers: int = attr.ib(validator=lambda i, a, x: x >= 1) | ||
|
||
device: torch.device = attr.ib(default=None) | ||
requires_grad: bool = attr.ib(default=False) | ||
|
||
def __attrs_post_init__(self) -> None: | ||
super().__init__() | ||
self.n_hid = self.n_out // 4 | ||
self.post_gain = 1 / (self.n_layers ** 2) | ||
|
||
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) | ||
self.id_path = make_conv(self.n_in, self.n_out, 1) if self.n_in != self.n_out else nn.Identity() | ||
self.res_path = nn.Sequential(OrderedDict([ | ||
('relu_1', nn.ReLU()), | ||
('conv_1', make_conv(self.n_in, self.n_hid, 3)), | ||
('relu_2', nn.ReLU()), | ||
('conv_2', make_conv(self.n_hid, self.n_hid, 3)), | ||
('relu_3', nn.ReLU()), | ||
('conv_3', make_conv(self.n_hid, self.n_hid, 3)), | ||
('relu_4', nn.ReLU()), | ||
('conv_4', make_conv(self.n_hid, self.n_out, 1)),])) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.id_path(x) + self.post_gain * self.res_path(x) | ||
|
||
@attr.s(eq=False, repr=False) | ||
class Encoder(nn.Module): | ||
group_count: int = 4 | ||
n_hid: int = attr.ib(default=256, validator=lambda i, a, x: x >= 64) | ||
n_blk_per_group: int = attr.ib(default=2, validator=lambda i, a, x: x >= 1) | ||
input_channels: int = attr.ib(default=3, validator=lambda i, a, x: x >= 1) | ||
vocab_size: int = attr.ib(default=8192, validator=lambda i, a, x: x >= 512) | ||
|
||
device: torch.device = attr.ib(default=torch.device('cpu')) | ||
requires_grad: bool = attr.ib(default=False) | ||
use_mixed_precision: bool = attr.ib(default=True) | ||
|
||
def __attrs_post_init__(self) -> None: | ||
super().__init__() | ||
|
||
blk_range = range(self.n_blk_per_group) | ||
n_layers = self.group_count * self.n_blk_per_group | ||
make_conv = partial(Conv2d, device=self.device, requires_grad=self.requires_grad) | ||
make_blk = partial(EncoderBlock, n_layers=n_layers, device=self.device, | ||
requires_grad=self.requires_grad) | ||
|
||
self.blocks = nn.Sequential(OrderedDict([ | ||
('input', make_conv(self.input_channels, 1 * self.n_hid, 7)), | ||
('group_1', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(1 * self.n_hid, 1 * self.n_hid)) for i in blk_range], | ||
('pool', nn.MaxPool2d(kernel_size=2)), | ||
]))), | ||
('group_2', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(1 * self.n_hid if i == 0 else 2 * self.n_hid, 2 * self.n_hid)) for i in blk_range], | ||
('pool', nn.MaxPool2d(kernel_size=2)), | ||
]))), | ||
('group_3', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(2 * self.n_hid if i == 0 else 4 * self.n_hid, 4 * self.n_hid)) for i in blk_range], | ||
('pool', nn.MaxPool2d(kernel_size=2)), | ||
]))), | ||
('group_4', nn.Sequential(OrderedDict([ | ||
*[(f'block_{i + 1}', make_blk(4 * self.n_hid if i == 0 else 8 * self.n_hid, 8 * self.n_hid)) for i in blk_range], | ||
]))), | ||
('output', nn.Sequential(OrderedDict([ | ||
('relu', nn.ReLU()), | ||
('conv', make_conv(8 * self.n_hid, self.vocab_size, 1, use_float16=False)), | ||
]))), | ||
])) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
if len(x.shape) != 4: | ||
raise ValueError(f'input shape {x.shape} is not 4d') | ||
if x.shape[1] != self.input_channels: | ||
raise ValueError(f'input has {x.shape[1]} channels but model built for {self.input_channels}') | ||
if x.dtype != torch.float32: | ||
raise ValueError('input must have dtype torch.float32') | ||
|
||
return self.blocks(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import attr | ||
import math | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
logit_laplace_eps: float = 0.1 | ||
|
||
@attr.s(eq=False) | ||
class Conv2d(nn.Module): | ||
n_in: int = attr.ib(validator=lambda i, a, x: x >= 1) | ||
n_out: int = attr.ib(validator=lambda i, a, x: x >= 1) | ||
kw: int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1) | ||
|
||
use_float16: bool = attr.ib(default=True) | ||
device: torch.device = attr.ib(default=torch.device('cpu')) | ||
requires_grad: bool = attr.ib(default=False) | ||
|
||
def __attrs_post_init__(self) -> None: | ||
super().__init__() | ||
|
||
w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32, | ||
device=self.device, requires_grad=self.requires_grad) | ||
w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2)) | ||
|
||
b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device, | ||
requires_grad=self.requires_grad) | ||
self.w, self.b = nn.Parameter(w), nn.Parameter(b) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
if self.use_float16 and 'cuda' in self.w.device.type: | ||
if x.dtype != torch.float16: | ||
x = x.half() | ||
|
||
w, b = self.w.half(), self.b.half() | ||
else: | ||
if x.dtype != torch.float32: | ||
x = x.float() | ||
|
||
w, b = self.w, self.b | ||
|
||
return F.conv2d(x, w, b, padding=(self.kw - 1) // 2) | ||
|
||
def map_pixels(x: torch.Tensor) -> torch.Tensor: | ||
if len(x.shape) != 4: | ||
raise ValueError('expected input to be 4d') | ||
if x.dtype != torch.float: | ||
raise ValueError('expected input to have type float') | ||
|
||
return (1 - 2 * logit_laplace_eps) * x + logit_laplace_eps | ||
|
||
def unmap_pixels(x: torch.Tensor) -> torch.Tensor: | ||
if len(x.shape) != 4: | ||
raise ValueError('expected input to be 4d') | ||
if x.dtype != torch.float: | ||
raise ValueError('expected input to have type float') | ||
|
||
return torch.clamp((x - logit_laplace_eps) / (1 - 2 * logit_laplace_eps), 0, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Model Card: DALL·E dVAE | ||
|
||
Following [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from | ||
Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we're providing some information about about the discrete | ||
VAE (dVAE) that was used to train DALL·E. | ||
|
||
## Model Details | ||
|
||
The dVAE was developed by researchers at OpenAI to reduce the memory footprint of the transformer trained on the | ||
text-to-image generation task. The details involved in training the dVAE are described in [the paper][dalle_paper]. This | ||
model card describes the first version of the model, released in February 2021. The model consists of a convolutional | ||
encoder and decoder whose architectures are described [here](dall_e/encoder.py) and [here](dall_e/decoder.py), respectively. | ||
For questions or comments about the models or the code release, please file a Github issue. | ||
|
||
## Model Use | ||
|
||
### Intended Use | ||
|
||
The model is intended for others to use for training their own generative models. | ||
|
||
### Out-of-Scope Use Cases | ||
|
||
This model is inappropriate for high-fidelity image processing applications. We also do not recommend its use as a | ||
general-purpose image compressor. | ||
|
||
## Training Data | ||
|
||
The model was trained on publicly available text-image pairs collected from the internet. This data consists partly of | ||
[Conceptual Captions][cc] and a filtered subset of [YFCC100M][yfcc100m]. We used a subset of the filters described in | ||
[Sharma et al.][cc_paper] to construct this dataset; further details are described in [our paper][dalle_paper]. We will | ||
not be releasing the dataset. | ||
|
||
## Performance and Limitations | ||
|
||
The heavy compression from the encoding process results in a noticeable loss of detail in the reconstructed images. This | ||
renders it inappropriate for applications that require fine-grained details of the image to be preserved. | ||
|
||
[dalle_paper]: https://arxiv.org/abs/2102.12092 | ||
[cc]: https://ai.google.com/research/ConceptualCaptions | ||
[cc_paper]: https://www.aclweb.org/anthology/P18-1238/ | ||
[yfcc100m]: http://projects.dfki.uni-kl.de/yfcc100m/ |
Oops, something went wrong.