Skip to content

Commit

Permalink
update to SEED LLaMA V2
Browse files Browse the repository at this point in the history
  • Loading branch information
sijeh committed Oct 19, 2023
1 parent 7999b56 commit f5e7552
Show file tree
Hide file tree
Showing 234 changed files with 4,433 additions and 19,437 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pretrained/*
!pretrained/.gitkeep
**/__pycache__/**
log/
File renamed without changes.
5 changes: 5 additions & 0 deletions configs/llm/seed_llama_14b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: models.model_tools.get_pretrained_llama_causal_model
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_14b_sft

torch_dtype: fp16
low_cpu_mem_usage: True
5 changes: 5 additions & 0 deletions configs/llm/seed_llama_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: models.model_tools.get_pretrained_llama_causal_model
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_8b_sft

torch_dtype: fp16
low_cpu_mem_usage: True
4 changes: 4 additions & 0 deletions configs/tokenizer/seed_llama_tokenizer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: models.seed_llama_tokenizer.SeedLlamaTokenizer.from_pretrained
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_tokenizer
fp16: True
load_diffusion: True
4 changes: 4 additions & 0 deletions configs/transform/clip_transform.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: models.transforms.get_transform
type: clip
image_size: 224
keep_ratio: False
39 changes: 0 additions & 39 deletions demo_recon.py

This file was deleted.

Binary file removed demos/cat.jpg
Binary file not shown.
190 changes: 190 additions & 0 deletions gradio_demo/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple

import io
import base64
import os
from PIL import Image
import copy

IMG_FLAG = '<image>'


class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()


def decode_image(encoded_image: str) -> Image:
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
buffer = io.BytesIO(decoded_bytes)
image = Image.open(buffer)
return image


def encode_image(image: Image.Image, format: str = 'PNG') -> str:
with io.BytesIO() as buffer:
image.save(buffer, format=format)
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
return encoded_image


@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[dict] # multi-turn -> user & assistant -> {'images': [PIL.Image,], 'text': str}
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"

skip_next: bool = False

def get_prompt(self):
messages = copy.deepcopy(self.messages)
if self.sep_style == SeparatorStyle.SINGLE:
if self.system is None or self.system == '':
text = ''
else:
text = self.system + self.sep
images = []
for message in messages:
text += message['role'] + ": " + message['message']['text'] + self.sep
for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
if image_ids is not None:
images.append(image_ids)
else:
image = Image.open(image_path).resize((256, 256))
image_base64 = encode_image(image)
images.append(image_base64)

text += self.roles[1] + ":"
elif self.sep_style == SeparatorStyle.LLAMA_2:
b_token = "[INST] "
e_token = " [/INST]"
if self.system is None or self.system == '':
text = ''
else:
text = f"<<SYS>>\n{self.system}\n<</SYS>>\n\n"
images = []
for idx, message in enumerate(messages):
# text += message['role'] + ": " + message['message']['text'] + self.sep
if idx % 2 == 0:
text += b_token + message['message']['text'] + e_token + self.sep
else:
text += message['message']['text'] + self.sep

for image_path, image_ids in zip(message['message']['images'], message['message']['images_ids']):
if image_ids is not None:
images.append(image_ids)
else:
image = Image.open(image_path).resize((256, 256))
image_base64 = encode_image(image)
images.append(image_base64)
else:
raise NotImplementedError

return {'text': text, 'images': images}

def update_image_ids(self, images_ids):
image_count = 0
for message in self.messages:
for idx in range(len(message['message']['images_ids'])):
if message['message']["images_ids"][idx] is None:
message['message']["images_ids"][idx] = images_ids[image_count]
image_count += 1

assert len(images_ids) == image_count, print(len(images_ids), image_count)

def append_message(self, role, message):
self.messages.append([role, message])

def to_gradio_chatbot(self):
dialog = []
for i, single_turn in enumerate(self.messages[self.offset:]):
single_turn = single_turn['message']
text_list = single_turn['text'].split(IMG_FLAG)
assert len(text_list) == len(single_turn['images']) + 1, print(text_list, len(single_turn['images']))
message = ''
for image_idx in range(len(single_turn['images'])):
# image = single_turn['images'][image_idx]
# image_base64 = encode_image(image)
# image_str = f'<img src="data:image/png;base64,{image_base64}" alt="user upload image" />'
image_path = single_turn['images'][image_idx]
if image_path == '':
message += text_list[image_idx] + '<corrupt_image>'
else:
message += text_list[image_idx] + f'![](file={image_path})'
message += text_list[-1]

if i % 2 == 0:
dialog.append([message, None])
else:
dialog[-1][-1] = message

return dialog

def copy(self):
return Conversation(system=self.system,
roles=self.roles,
messages=copy.deepcopy(self.messages),
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)

def dict(self):
messages = copy.deepcopy(self.messages)
for message in messages:
if 'images_ids' in message:
message.pop('images_ids')
for i in range(len(message['message']['images'])):
message['message']['images'][i] = os.path.basename(message['message']['images'][i])
return {
"system": self.system,
"roles": self.roles,
"messages": messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}


conv_seed_llama = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep='\n',
)

conv_seed_llama_2 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. ",
roles=("USER", "ASSISTANT"),
version="v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep='\n',
)

conv_seed_llama_3 = Conversation(
system="",
roles=("[INST]", "[/INST]"),
version="v2",
messages=[],
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep='\n',
)
Loading

0 comments on commit f5e7552

Please sign in to comment.