-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
507976b
commit 06347bc
Showing
81 changed files
with
1,403 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
.env | ||
.env | ||
|
||
aura-palette-model/samples | ||
|
||
/aura-palette-model/data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,53 @@ | ||
# The Aura Palette | ||
|
||
Welcome to Aura Palette: a color palette generator using machine learning! With this tool, you can easily create a beautiful and cohesive color scheme for your design project, simply by providing a short piece of text as input. The generator will analyze the text and return a palette of colors that captures its essence and mood. Whether you are a graphic designer, artist, or just looking to add some color to your life, this tool is for you. So let's get started! Simply type in your text below, and discover the magic of machine learning and color. | ||
|
||
|
||
## Team | ||
The Aura Palette is an ML-based color palette generator. It is being developed by Suleyman, Ata, Can, Zeynepnur and Ayda as their senior project | ||
in Bilkent University for CS491/492 course. | ||
|
||
### Instructions: | ||
|
||
##### Install yarn with (for MacOS): | ||
|
||
|
||
### Boring stuff: | ||
|
||
|
||
|
||
### Instructions for frontend | ||
|
||
``` | ||
brew install yarn | ||
yarn install | ||
yarn start | ||
``` | ||
|
||
### Instructions for backend | ||
|
||
|
||
### Instructions for ML | ||
|
||
##### (i) Install requirements | ||
|
||
```bash | ||
$ pip install -r requirements.txt | ||
``` | ||
##### Start project with: | ||
|
||
##### Training Text-to-Palette Generation Networks (TPN) with PAT data | ||
|
||
```bash | ||
$ python main.py --mode train_TPN | ||
``` | ||
yarn start | ||
|
||
##### Testing TPN | ||
```bash | ||
$ python main.py --mode test_TPN --resume_epoch 500 | ||
``` | ||
|
||
##### (ii) For custom test | ||
|
||
```bash | ||
$ python server.py | ||
# in another terminal window: | ||
curl -i -H "Content-Type: application/json" -X POST -d '{"queryStrings": ["your first query", "your second query"], "numPalettesPerQuery":1}' http://localhost:8000/palette | ||
``` |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2018 awesome-davian | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import torch | ||
import torch.utils.data as data | ||
import pickle | ||
import os | ||
import numpy as np | ||
from skimage.color import rgb2lab | ||
import warnings | ||
|
||
class PAT_Dataset(data.Dataset): | ||
def __init__(self, src_path, trg_path, input_dict): | ||
with open(src_path, 'rb') as fin: | ||
self.src_seqs = pickle.load(fin) | ||
with open(trg_path, 'rb') as fin: | ||
self.trg_seqs = pickle.load(fin) | ||
|
||
words_index = [] | ||
for index, palette_name in enumerate(self.src_seqs): | ||
temp = [0] * input_dict.max_len | ||
|
||
for i, word in enumerate(palette_name): | ||
temp[i] = input_dict.word2index[word] | ||
words_index.append(temp) | ||
self.src_seqs = torch.LongTensor(words_index) | ||
|
||
palette_list = [] | ||
for index, palettes in enumerate(self.trg_seqs): | ||
temp = [] | ||
for palette in palettes: | ||
rgb = np.array([palette[0], palette[1], palette[2]]) / 255.0 | ||
warnings.filterwarnings("ignore") | ||
lab = rgb2lab(rgb[np.newaxis, np.newaxis, :], illuminant='D50').flatten() | ||
temp.append(lab[0]) | ||
temp.append(lab[1]) | ||
temp.append(lab[2]) | ||
palette_list.append(temp) | ||
|
||
self.trg_seqs = torch.FloatTensor(palette_list) | ||
self.num_total_seqs = len(self.src_seqs) | ||
|
||
def __getitem__(self, index): | ||
src_seq = self.src_seqs[index] | ||
trg_seq = self.trg_seqs[index] | ||
return src_seq, trg_seq | ||
|
||
def __len__(self): | ||
return self.num_total_seqs | ||
|
||
|
||
def t2p_loader(batch_size, input_dict): | ||
train_src_path = os.path.join('./data/hexcolor_vf/train_names.pkl') | ||
train_trg_path = os.path.join('./data/hexcolor_vf/train_palettes_rgb.pkl') | ||
val_src_path = os.path.join('./data/hexcolor_vf/test_names.pkl') | ||
val_trg_path = os.path.join('./data/hexcolor_vf/test_palettes_rgb.pkl') | ||
|
||
train_dataset = PAT_Dataset(train_src_path, train_trg_path, input_dict) | ||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
batch_size=batch_size, | ||
num_workers=2, | ||
drop_last=True, | ||
shuffle=True) | ||
|
||
test_dataset = PAT_Dataset(val_src_path, val_trg_path, input_dict) | ||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | ||
batch_size=batch_size, | ||
num_workers=2, | ||
drop_last=True, | ||
shuffle=False) | ||
|
||
return train_loader, test_loader | ||
|
||
|
||
class Image_Dataset(data.Dataset): | ||
def __init__(self, image_dir, pal_dir): | ||
with open(image_dir, 'rb') as f: | ||
self.image_data = np.asarray(pickle.load(f)) / 255 | ||
|
||
with open(pal_dir, 'rb') as f: | ||
self.pal_data = rgb2lab(np.asarray(pickle.load(f)) | ||
.reshape(-1, 5, 3) / 256 | ||
, illuminant='D50') | ||
|
||
self.data_size = self.image_data.shape[0] | ||
|
||
def __len__(self): | ||
return self.data_size | ||
|
||
def __getitem__(self, idx): | ||
return self.image_data[idx], self.pal_data[idx] | ||
|
||
|
||
def p2c_loader(dataset, batch_size, idx=0): | ||
if dataset == 'imagenet': | ||
|
||
train_img_path = './data/imagenet/train_palette_set_origin/train_images_%d.txt' % (idx) | ||
train_pal_path = './data/imagenet/train_palette_set_origin/train_palette_%d.txt' % (idx) | ||
|
||
train_dataset = Image_Dataset(train_img_path, train_pal_path) | ||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
num_workers=2) | ||
|
||
imsize = 256 | ||
|
||
elif dataset == 'bird256': | ||
|
||
train_img_path = './data/bird256/train_palette/train_images_origin.txt' | ||
train_pal_path = './data/bird256/train_palette/train_palette_origin.txt' | ||
|
||
train_dataset = Image_Dataset(train_img_path, train_pal_path) | ||
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
num_workers=2) | ||
|
||
imsize = 256 | ||
|
||
return train_loader, imsize | ||
|
||
|
||
class Test_Dataset(data.Dataset): | ||
def __init__(self, input_dict, txt_path, pal_path, img_path, transform=None): | ||
self.transform = transform | ||
with open(img_path, 'rb') as f: | ||
self.images = np.asarray(pickle.load(f)) / 255 | ||
with open(txt_path, 'rb') as fin: | ||
self.src_seqs = pickle.load(fin) | ||
with open(pal_path, 'rb') as fin: | ||
self.trg_seqs = pickle.load(fin) | ||
|
||
# ==================== Preprocessing src_seqs ====================# | ||
# Return a list of indexes, one for each word in the sentence. | ||
words_index = [] | ||
for index, palette_name in enumerate(self.src_seqs): | ||
# Set list size to the longest palette name. | ||
temp = [0] * input_dict.max_len | ||
for i, word in enumerate(palette_name): | ||
temp[i] = input_dict.word2index[word] | ||
words_index.append(temp) | ||
|
||
self.src_seqs = torch.LongTensor(words_index) | ||
|
||
# ==================== Preprocessing trg_seqs ====================# | ||
palette_list = [] | ||
for palettes in self.trg_seqs: | ||
temp = [] | ||
for palette in palettes: | ||
rgb = np.array([palette[0], palette[1], palette[2]]) / 255.0 | ||
warnings.filterwarnings("ignore") | ||
lab = rgb2lab(rgb[np.newaxis, np.newaxis, :], illuminant='D50').flatten() | ||
temp.append(lab[0]) | ||
temp.append(lab[1]) | ||
temp.append(lab[2]) | ||
palette_list.append(temp) | ||
|
||
self.trg_seqs = torch.FloatTensor(palette_list) | ||
|
||
self.num_total_data = len(self.src_seqs) | ||
|
||
def __len__(self): | ||
return self.num_total_data | ||
|
||
def __getitem__(self, idx): | ||
"""Returns one data pair.""" | ||
text = self.src_seqs[idx] | ||
palette = self.trg_seqs[idx] | ||
image = self.images[idx] | ||
if self.transform: | ||
image = self.transform(image) | ||
|
||
return text, palette, image | ||
|
||
|
||
def test_loader(dataset, batch_size, input_dict): | ||
|
||
if dataset == 'bird256': | ||
|
||
txt_path = './data/hexcolor_vf/test_names.pkl' | ||
pal_path = './data/hexcolor_vf/test_palettes_rgb.pkl' | ||
img_path = './data/bird256/test_palette/test_images_origin.txt' | ||
|
||
test_dataset = Test_Dataset(input_dict, txt_path, pal_path, img_path) | ||
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=2) | ||
imsize = 256 | ||
|
||
return test_loader, imsize |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#! /usr/bin/env python | ||
from __future__ import division | ||
import os | ||
import argparse | ||
from solver import Solver | ||
|
||
|
||
def main(args): | ||
|
||
# Create directory if it doesn't exist. | ||
if not os.path.exists(args.text2pal_dir): | ||
os.makedirs(args.text2pal_dir) | ||
if not os.path.exists(args.pal2color_dir): | ||
os.makedirs(args.pal2color_dir) | ||
if not os.path.exists(args.train_sample_dir): | ||
os.makedirs(args.train_sample_dir) | ||
if not os.path.exists(os.path.join(args.test_sample_dir, args.mode)): | ||
os.makedirs(os.path.join(args.test_sample_dir, args.mode)) | ||
|
||
# Solver for training and testing Text2Colors. | ||
solver = Solver(args) | ||
|
||
# Train or test. | ||
if args.mode == 'train_TPN': | ||
solver.train_TPN() | ||
|
||
elif args.mode == 'train_PCN': | ||
solver.train_PCN() | ||
|
||
elif args.mode == 'test_TPN': | ||
solver.test_TPN() | ||
|
||
elif args.mode == 'test_text2colors': | ||
solver.test_text2colors() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
|
||
# Model configuration. | ||
# text2pal | ||
parser.add_argument('--hidden_size', type=int, default=150) | ||
parser.add_argument('--n_layers', type=int, default=1) | ||
# pal2color | ||
parser.add_argument('--always_give_global_hint', type=int, default=1) | ||
parser.add_argument('--add_L', type=int, default=1) | ||
|
||
# Training and testing configuration. | ||
parser.add_argument('--mode', type=str, default='train_TPN', | ||
choices=['train_TPN', 'train_PCN', 'test_TPN', 'test_text2colors']) | ||
parser.add_argument('--dataset', type=str, default='bird256', choices=['imagenet', 'bird256']) | ||
parser.add_argument('--lr', type=float, default=5e-4, help='initial learning rate') | ||
parser.add_argument('--num_epochs', type=int, default=1000, help='number of epochs for training') | ||
parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch') | ||
parser.add_argument('--batch_size', type=int, default=32, help='batch size for training') | ||
parser.add_argument('--dropout_p', type=float, default=0.2) | ||
parser.add_argument('--weight_decay', type=float, default=5e-5) | ||
parser.add_argument('--beta1', type=float, default=0.5) | ||
parser.add_argument('--beta2', type=float, default=0.99) | ||
parser.add_argument('--lambda_sL1', type=float, default=100.0, help='weight for L1 loss') | ||
parser.add_argument('--lambda_KL', type=float, default=0.5, help='weight for KL loss') | ||
parser.add_argument('--lambda_GAN', type=float, default=0.1) | ||
|
||
# Directories. | ||
parser.add_argument('--text2pal_dir', type=str, default='./models/TPN') | ||
parser.add_argument('--pal2color_dir', type=str, default='./models/PCN') | ||
parser.add_argument('--train_sample_dir', type=str, default='./samples/train') | ||
parser.add_argument('--test_sample_dir', type=str, default='./samples/test') | ||
|
||
# Step size. | ||
parser.add_argument('--log_interval', type=int, default=1, | ||
help='how many steps to wait before logging training status') | ||
parser.add_argument('--sample_interval', type=int, default=20, | ||
help='how many steps to wait before saving the training output') | ||
parser.add_argument('--save_interval', type=int, default=50, | ||
help='how many steps to wait before saving the trained models') | ||
args = parser.parse_args() | ||
print(args) | ||
main(args) |
Oops, something went wrong.