Skip to content

Commit

Permalink
conceptlab
Browse files Browse the repository at this point in the history
  • Loading branch information
eps696 committed Jul 19, 2024
1 parent 84480e8 commit a957ac1
Show file tree
Hide file tree
Showing 4 changed files with 432 additions and 10 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Current functions:

Fine-tuning with your images:
* Add subject (new token) with [textual inversion]
* Same with **inventing novel imagery** with [ConceptLab]
* Add subject (new token + Unet delta) with [custom diffusion]
* Add subject (Unet low rank delta) with [LoRA]

Expand Down Expand Up @@ -185,6 +186,12 @@ Custom diffusion trains faster and can achieve impressive reproduction quality (
LoRA finetuning seems less precise while may affect wider spectrum of topics, and is a de-facto industry standard now.
Textual inversion is more generic but stable. Also, its embeddings can be easily combined together on load.

One can also train new token embedding for a novel unusual subject within a class, employing the trick from [ConceptLab] (see their webpage for details):
```
python src/trainew.py --token mypet --term pet
```


* Generate an image with trained weights from [LoRA]:
```
python src/gen.py -t "cosmic beast cat" --load_lora mycat1-lora.pt
Expand Down Expand Up @@ -274,5 +281,6 @@ Huge respect to the people behind [Stable Diffusion], [Hugging Face], and the wh
[SDXL-Lightning]: <https://huggingface.co/ByteDance/SDXL-Lightning>
[TCD Scheduler]: <https://mhh0318.github.io/tcd/>
[Self-Attention Guidance]: <https://github.com/KU-CVLAB/Self-Attention-Guidance>
[ConceptLab]: <https://kfirgoldberg.github.io/ConceptLab>
[Instruct pix2pix]: <https://github.com/timothybrooks/instruct-pix2pix>
[instruct-pix2pix]: <https://huggingface.co/timbrooks/instruct-pix2pix>
88 changes: 78 additions & 10 deletions src/core/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,99 @@
"nice picture in the style of {}",
]

# = = = by ConceptLab = = =

PREFIXES = ['a', 'the', 'my']

object_templates_standard = [
"Professional high-quality photo of {a} {token}. photorealistic, 4k, HQ",
"A photo of {a} {token}",
"A photo of {a} {token}. photorealistic, 4k, HQ",
]
object_templates_edits = [
"Professional high-quality art of {a} {token}. photorealistic",
"A painting of {a} {token}",
"A watercolor painting of {a} {token}",
"A painting of {a} {token} in the style of monet",
"Colorful graffiti of {a} {token}. photorealistic, 4k, HQ",
"A line drawing of {a} {token}",
"Oil painting of {a} {token}",
"Professional high-quality art of {a} {token} in the style of a cartoon",
"A close-up photo of {a} {token}",
]
import math
object_templates = object_templates_standard * math.ceil(len(object_templates_edits) / len(object_templates_standard)) + object_templates_edits

style_templates = [
"a painting in the style of {token}",
"a painting of a dog in the style of {token}",
"a painting of a cat in the style of {token}",
"a painting portrait in the style of {token}",
"a painting of a vase with flowers in the style of {token}",
"a painting of a valley in the style of {token}",
"a painting of a fruit bowl in the style of {token}",
"A painting of a bicycle in the style of {token}",
"A painting of a pair of shoes in the style of {token}",
"A painting portrait of a musician playing a musical instrument in the style of {token}",
"A painting of a cup of coffee with steam in the style of {token}",
"A painting close-up painting of a seashell with delicate textures in the style of {token}",
"A painting of a vintage camera in the style of {token}",
"A painting of a bouquet of wildflowers in the style of {token}",
"A painting table set with fine china and silverware in the style of {token}",
"A painting of a bookshelf filled with books in the style of {token}",
"A painting close-up painting of a glass jar filled with marbles in the style of {token}",
"A painting portrait of a dancer captured in mid-motion in the style of {token}",
"A painting of a collection of antique keys with intricate designs in the style of {token}",
"A painting of a pair of sunglasses reflecting a scenic landscape in the style of {token}",
]

class Capturer():
def __init__(self):
def __init__(self, model='base'):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16 if self.device == 'cuda' else torch.float32
self.max_len = 32

from transformers import AutoProcessor, BlipForConditionalGeneration # , Blip2ForConditionalGeneration
CAPTION_MODELS = {
'blip-base': 'Salesforce/blip-image-captioning-base', # 990MB
'blip-large': 'Salesforce/blip-image-captioning-large', # 1.9GB
'base': 'Salesforce/blip-image-captioning-base', # 990MB
'large': 'Salesforce/blip-image-captioning-large', # 1.9GB
# 'blip2-2.7b': 'Salesforce/blip2-opt-2.7b', # 15.5GB
# 'blip2-flan-t5-xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
't5xl': 'Salesforce/blip2-flan-t5-xl', # 15.77GB
}
model_path = CAPTION_MODELS['blip-base'] # 'blip-large' is too imaginative
model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype) # Blip2ForConditionalGeneration
self.processor = AutoProcessor.from_pretrained(model_path, do_rescale=False)
model_path = CAPTION_MODELS[model]
if model.lower() in ['t5xl']:
from transformers import Blip2Processor, Blip2ForConditionalGeneration
model = Blip2ForConditionalGeneration.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16)
self.processor = Blip2Processor.from_pretrained(model_path)
else:
from transformers import AutoProcessor, BlipForConditionalGeneration
model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=self.dtype)
self.processor = AutoProcessor.from_pretrained(model_path, do_rescale=False)
self.model = model.eval().to(self.device)

def __call__(self, image):
if torch.is_tensor(image): image = (image + 1.) / 2.
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
inputs = inputs.to(self.dtype)
inputs = self.processor(images=image, return_tensors="pt").to(self.device, dtype=self.dtype)
tokens = self.model.generate(**inputs, max_new_tokens=self.max_len)
return self.processor.batch_decode(tokens, skip_special_tokens=True)[0].strip()

class ConceptDataset(Dataset):
def __init__(self, token, type):
self.token = token
self.templates = style_templates if type == 'style' else object_templates

def __len__(self):
return 5 # Doesn't really matter as we use steps

def __getitem__(self, i: int):
example = {}
template = random.choice(self.templates)
if '{a}' in template:
template = template.format(a=random.choice(PREFIXES), token='{token}')
text = template.format(token = self.token)
example["template"] = template
example["text"] = text
return example

class FinetuneDataset(Dataset):
def __init__(self, inputs, tokenizer, size=512, style=False, aug_img=True, aug_txt=True, add_caption=False, flip=True):
self.size = size
Expand Down
2 changes: 2 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import logging
logging.getLogger('xformers').setLevel(logging.ERROR)
logging.getLogger('diffusers.models.modeling_utils').setLevel(logging.CRITICAL)
import warnings
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

import torch
import torch.nn.functional as F
Expand Down
Loading

0 comments on commit a957ac1

Please sign in to comment.