Skip to content

Commit

Permalink
machine learning added
Browse files Browse the repository at this point in the history
  • Loading branch information
solomonNSI committed Dec 19, 2022
1 parent 507976b commit 06347bc
Show file tree
Hide file tree
Showing 81 changed files with 1,403 additions and 5 deletions.
Binary file modified .DS_Store
Binary file not shown.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
.env
.env

aura-palette-model/samples

/aura-palette-model/data
45 changes: 41 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,53 @@
# The Aura Palette

Welcome to Aura Palette: a color palette generator using machine learning! With this tool, you can easily create a beautiful and cohesive color scheme for your design project, simply by providing a short piece of text as input. The generator will analyze the text and return a palette of colors that captures its essence and mood. Whether you are a graphic designer, artist, or just looking to add some color to your life, this tool is for you. So let's get started! Simply type in your text below, and discover the magic of machine learning and color.


## Team
The Aura Palette is an ML-based color palette generator. It is being developed by Suleyman, Ata, Can, Zeynepnur and Ayda as their senior project
in Bilkent University for CS491/492 course.

### Instructions:

##### Install yarn with (for MacOS):


### Boring stuff:



### Instructions for frontend

```
brew install yarn
yarn install
yarn start
```

### Instructions for backend


### Instructions for ML

##### (i) Install requirements

```bash
$ pip install -r requirements.txt
```
##### Start project with:

##### Training Text-to-Palette Generation Networks (TPN) with PAT data

```bash
$ python main.py --mode train_TPN
```
yarn start

##### Testing TPN
```bash
$ python main.py --mode test_TPN --resume_epoch 500
```

##### (ii) For custom test

```bash
$ python server.py
# in another terminal window:
curl -i -H "Content-Type: application/json" -X POST -d '{"queryStrings": ["your first query", "your second query"], "numPalettesPerQuery":1}' http://localhost:8000/palette
```
Binary file added aura-palette-model/.DS_Store
Binary file not shown.
21 changes: 21 additions & 0 deletions aura-palette-model/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2018 awesome-davian

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
189 changes: 189 additions & 0 deletions aura-palette-model/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import torch
import torch.utils.data as data
import pickle
import os
import numpy as np
from skimage.color import rgb2lab
import warnings

class PAT_Dataset(data.Dataset):
def __init__(self, src_path, trg_path, input_dict):
with open(src_path, 'rb') as fin:
self.src_seqs = pickle.load(fin)
with open(trg_path, 'rb') as fin:
self.trg_seqs = pickle.load(fin)

words_index = []
for index, palette_name in enumerate(self.src_seqs):
temp = [0] * input_dict.max_len

for i, word in enumerate(palette_name):
temp[i] = input_dict.word2index[word]
words_index.append(temp)
self.src_seqs = torch.LongTensor(words_index)

palette_list = []
for index, palettes in enumerate(self.trg_seqs):
temp = []
for palette in palettes:
rgb = np.array([palette[0], palette[1], palette[2]]) / 255.0
warnings.filterwarnings("ignore")
lab = rgb2lab(rgb[np.newaxis, np.newaxis, :], illuminant='D50').flatten()
temp.append(lab[0])
temp.append(lab[1])
temp.append(lab[2])
palette_list.append(temp)

self.trg_seqs = torch.FloatTensor(palette_list)
self.num_total_seqs = len(self.src_seqs)

def __getitem__(self, index):
src_seq = self.src_seqs[index]
trg_seq = self.trg_seqs[index]
return src_seq, trg_seq

def __len__(self):
return self.num_total_seqs


def t2p_loader(batch_size, input_dict):
train_src_path = os.path.join('./data/hexcolor_vf/train_names.pkl')
train_trg_path = os.path.join('./data/hexcolor_vf/train_palettes_rgb.pkl')
val_src_path = os.path.join('./data/hexcolor_vf/test_names.pkl')
val_trg_path = os.path.join('./data/hexcolor_vf/test_palettes_rgb.pkl')

train_dataset = PAT_Dataset(train_src_path, train_trg_path, input_dict)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
num_workers=2,
drop_last=True,
shuffle=True)

test_dataset = PAT_Dataset(val_src_path, val_trg_path, input_dict)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
num_workers=2,
drop_last=True,
shuffle=False)

return train_loader, test_loader


class Image_Dataset(data.Dataset):
def __init__(self, image_dir, pal_dir):
with open(image_dir, 'rb') as f:
self.image_data = np.asarray(pickle.load(f)) / 255

with open(pal_dir, 'rb') as f:
self.pal_data = rgb2lab(np.asarray(pickle.load(f))
.reshape(-1, 5, 3) / 256
, illuminant='D50')

self.data_size = self.image_data.shape[0]

def __len__(self):
return self.data_size

def __getitem__(self, idx):
return self.image_data[idx], self.pal_data[idx]


def p2c_loader(dataset, batch_size, idx=0):
if dataset == 'imagenet':

train_img_path = './data/imagenet/train_palette_set_origin/train_images_%d.txt' % (idx)
train_pal_path = './data/imagenet/train_palette_set_origin/train_palette_%d.txt' % (idx)

train_dataset = Image_Dataset(train_img_path, train_pal_path)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2)

imsize = 256

elif dataset == 'bird256':

train_img_path = './data/bird256/train_palette/train_images_origin.txt'
train_pal_path = './data/bird256/train_palette/train_palette_origin.txt'

train_dataset = Image_Dataset(train_img_path, train_pal_path)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2)

imsize = 256

return train_loader, imsize


class Test_Dataset(data.Dataset):
def __init__(self, input_dict, txt_path, pal_path, img_path, transform=None):
self.transform = transform
with open(img_path, 'rb') as f:
self.images = np.asarray(pickle.load(f)) / 255
with open(txt_path, 'rb') as fin:
self.src_seqs = pickle.load(fin)
with open(pal_path, 'rb') as fin:
self.trg_seqs = pickle.load(fin)

# ==================== Preprocessing src_seqs ====================#
# Return a list of indexes, one for each word in the sentence.
words_index = []
for index, palette_name in enumerate(self.src_seqs):
# Set list size to the longest palette name.
temp = [0] * input_dict.max_len
for i, word in enumerate(palette_name):
temp[i] = input_dict.word2index[word]
words_index.append(temp)

self.src_seqs = torch.LongTensor(words_index)

# ==================== Preprocessing trg_seqs ====================#
palette_list = []
for palettes in self.trg_seqs:
temp = []
for palette in palettes:
rgb = np.array([palette[0], palette[1], palette[2]]) / 255.0
warnings.filterwarnings("ignore")
lab = rgb2lab(rgb[np.newaxis, np.newaxis, :], illuminant='D50').flatten()
temp.append(lab[0])
temp.append(lab[1])
temp.append(lab[2])
palette_list.append(temp)

self.trg_seqs = torch.FloatTensor(palette_list)

self.num_total_data = len(self.src_seqs)

def __len__(self):
return self.num_total_data

def __getitem__(self, idx):
"""Returns one data pair."""
text = self.src_seqs[idx]
palette = self.trg_seqs[idx]
image = self.images[idx]
if self.transform:
image = self.transform(image)

return text, palette, image


def test_loader(dataset, batch_size, input_dict):

if dataset == 'bird256':

txt_path = './data/hexcolor_vf/test_names.pkl'
pal_path = './data/hexcolor_vf/test_palettes_rgb.pkl'
img_path = './data/bird256/test_palette/test_images_origin.txt'

test_dataset = Test_Dataset(input_dict, txt_path, pal_path, img_path)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2)
imsize = 256

return test_loader, imsize
79 changes: 79 additions & 0 deletions aura-palette-model/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#! /usr/bin/env python
from __future__ import division
import os
import argparse
from solver import Solver


def main(args):

# Create directory if it doesn't exist.
if not os.path.exists(args.text2pal_dir):
os.makedirs(args.text2pal_dir)
if not os.path.exists(args.pal2color_dir):
os.makedirs(args.pal2color_dir)
if not os.path.exists(args.train_sample_dir):
os.makedirs(args.train_sample_dir)
if not os.path.exists(os.path.join(args.test_sample_dir, args.mode)):
os.makedirs(os.path.join(args.test_sample_dir, args.mode))

# Solver for training and testing Text2Colors.
solver = Solver(args)

# Train or test.
if args.mode == 'train_TPN':
solver.train_TPN()

elif args.mode == 'train_PCN':
solver.train_PCN()

elif args.mode == 'test_TPN':
solver.test_TPN()

elif args.mode == 'test_text2colors':
solver.test_text2colors()


if __name__ == '__main__':
parser = argparse.ArgumentParser()

# Model configuration.
# text2pal
parser.add_argument('--hidden_size', type=int, default=150)
parser.add_argument('--n_layers', type=int, default=1)
# pal2color
parser.add_argument('--always_give_global_hint', type=int, default=1)
parser.add_argument('--add_L', type=int, default=1)

# Training and testing configuration.
parser.add_argument('--mode', type=str, default='train_TPN',
choices=['train_TPN', 'train_PCN', 'test_TPN', 'test_text2colors'])
parser.add_argument('--dataset', type=str, default='bird256', choices=['imagenet', 'bird256'])
parser.add_argument('--lr', type=float, default=5e-4, help='initial learning rate')
parser.add_argument('--num_epochs', type=int, default=1000, help='number of epochs for training')
parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch')
parser.add_argument('--batch_size', type=int, default=32, help='batch size for training')
parser.add_argument('--dropout_p', type=float, default=0.2)
parser.add_argument('--weight_decay', type=float, default=5e-5)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.99)
parser.add_argument('--lambda_sL1', type=float, default=100.0, help='weight for L1 loss')
parser.add_argument('--lambda_KL', type=float, default=0.5, help='weight for KL loss')
parser.add_argument('--lambda_GAN', type=float, default=0.1)

# Directories.
parser.add_argument('--text2pal_dir', type=str, default='./models/TPN')
parser.add_argument('--pal2color_dir', type=str, default='./models/PCN')
parser.add_argument('--train_sample_dir', type=str, default='./samples/train')
parser.add_argument('--test_sample_dir', type=str, default='./samples/test')

# Step size.
parser.add_argument('--log_interval', type=int, default=1,
help='how many steps to wait before logging training status')
parser.add_argument('--sample_interval', type=int, default=20,
help='how many steps to wait before saving the training output')
parser.add_argument('--save_interval', type=int, default=50,
help='how many steps to wait before saving the trained models')
args = parser.parse_args()
print(args)
main(args)
Loading

0 comments on commit 06347bc

Please sign in to comment.