Skip to content

Commit

Permalink
add 8bit quantization model
Browse files Browse the repository at this point in the history
  • Loading branch information
sijeh committed Oct 23, 2023
1 parent 212f935 commit 32e2906
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 29 deletions.
5 changes: 5 additions & 0 deletions configs/llm/seed_llama_14b_8bit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: transformers.LlamaForCausalLM.from_pretrained
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_14b_sft
load_in_8bit: True
# device_map: auto
low_cpu_mem_usage: True
5 changes: 5 additions & 0 deletions configs/llm/seed_llama_8b_8bit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_target_: transformers.LlamaForCausalLM.from_pretrained
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_llama_8b_sft
load_in_8bit: True
# device_map: auto
low_cpu_mem_usage: True
3 changes: 2 additions & 1 deletion configs/tokenizer/seed_llama_tokenizer.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
_target_: models.seed_llama_tokenizer.SeedLlamaTokenizer.from_pretrained
pretrained_model_name_or_path: ${oc.env:PROJECT_ROOT}/pretrained/seed_tokenizer
fp16: True
load_diffusion: True
load_diffusion: True
encoder_url: https://huggingface.co/AILab-CVC/seed-tokenizer-2/blob/main/seed_quantizer.pt
6 changes: 3 additions & 3 deletions gradio_demo/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def dict(self):
}


conv_seed_llama = Conversation(
conv_seed_vicuna = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="v2",
Expand All @@ -169,7 +169,7 @@ def dict(self):
sep='\n',
)

conv_seed_llama_2 = Conversation(
conv_seed_vicuna_system = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. ",
roles=("USER", "ASSISTANT"),
version="v2",
Expand All @@ -179,7 +179,7 @@ def dict(self):
sep='\n',
)

conv_seed_llama_3 = Conversation(
conv_seed_llama2 = Conversation(
system="",
roles=("[INST]", "[/INST]"),
version="v2",
Expand Down
40 changes: 28 additions & 12 deletions gradio_demo/seed_llama_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io
import base64
from PIL import Image
import gc

pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

Expand Down Expand Up @@ -49,7 +50,8 @@ class Arguments:
port: Optional[str] = field(default=80, metadata={"help": "network port"})
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"})
tokenizer_device: Optional[str] = field(default='cuda:0', metadata={"help": "tokenizer device"})
offload_tokenizer: Optional[bool] = field(default=False, metadata={"help": "offload image tokenizer"})
offload_encoder: Optional[bool] = field(default=False, metadata={"help": "offload image tokenizer"})
offload_decoder: Optional[bool] = field(default=True, metadata={"help": "offload image tokenizer"})


parser = transformers.HfArgumentParser(Arguments)
Expand All @@ -64,13 +66,22 @@ def __init__(self, args) -> None:
self.image_id_shift = 32000

self.image_transform = hydra.utils.instantiate(image_transform_cfg)
tokenizer_device = 'cpu' if args.offload_tokenizer else args.tokenizer_device
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=tokenizer_device, load_diffusion=True)
model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16)
self.model = model.eval().to(args.llm_device)
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=args.tokenizer_device, load_diffusion=True)

if args.offload_encoder:
self.tokenizer.image_tokenizer.model.visual_encoder.to('cpu')
if args.offload_decoder:
self.tokenizer.image_tokenizer.diffusion_model.to('cpu')

# model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16)
# self.model = model.eval().to(args.llm_device)
model = hydra.utils.instantiate(model_cfg, device_map=args.llm_device).eval()
self.model = model
print(model.get_memory_footprint())
self.llm_device = args.llm_device
self.tokenizer_device = args.tokenizer_device
self.offload_tokenizer = args.offload_tokenizer
self.offload_encoder = args.offload_encoder
self.offload_decoder = args.offload_decoder
self.boi_token_id = self.tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0]
self.eoi_token_id = self.tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0]
print('Init Done...')
Expand Down Expand Up @@ -111,11 +122,14 @@ def generate():

if len(images_tensor_list) > 0:
images_tensor = torch.stack(images_tensor_list, dim=0).to(service.tokenizer_device)
if args.offload_tokenizer:
service.tokenizer.image_tokenizer.model.to(service.tokenizer_device)
if service.offload_encoder:
service.tokenizer.image_tokenizer.model.visual_encoder.to(service.tokenizer_device)

images_ids_1 = service.tokenizer.encode_image(image_torch=images_tensor).cpu()
if args.offload_tokenizer:
service.tokenizer.image_tokenizer.model.to('cpu')
if args.offload_encoder:
service.tokenizer.image_tokenizer.model.visual_encoder.to('cpu')
torch.cuda.empty_cache()
gc.collect()
num_image_ids = images_ids_1.shape[-1]
else:
num_image_ids = len(images_ids_list[-1])
Expand Down Expand Up @@ -188,11 +202,13 @@ def generate():
error_msg.append(f'Some image_id out of range: [0, {NUM_IMG_CODES})')
image_base64 = ''
else:
if service.offload_tokenizer:
if service.offload_decoder:
service.tokenizer.image_tokenizer.diffusion_model.to(service.tokenizer_device)
image = service.tokenizer.decode_image(image_ids)[0]
if service.offload_tokenizer:
if service.offload_decoder:
service.tokenizer.image_tokenizer.diffusion_model.to('cpu')
torch.cuda.empty_cache()
gc.collect()
image_base64 = encode_image(image)

generated_image_base64_list.append(image_base64)
Expand Down
14 changes: 11 additions & 3 deletions gradio_demo/seed_llama_gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import requests

from utils import build_logger
from conversation import conv_seed_llama_3 as conv_seed_llama
from conversation import conv_seed_vicuna, conv_seed_llama2
# from conversation import conv_seed_llama

IMG_FLAG = '<image>'
Expand All @@ -41,11 +41,19 @@
class Arguments:
server_port: Optional[int] = field(default=7860, metadata={"help": "network port"})
server_name: Optional[str] = field(default='0.0.0.0', metadata={"help": "network address"})
request_address: Optional[str] = field(default='http://0.0.0.0:7890/generate', metadata={"help": "request address"})
request_address: Optional[str] = field(default='http://127.0.0.1:7890/generate', metadata={"help": "request address"})
model_type: Optional[str] = field(default='seed-llama-14b', metadata={"help": "choice: [seed-llama-8b, seed-llama-14b]"})

parser = transformers.HfArgumentParser(Arguments)
args, = parser.parse_args_into_dataclasses()

if args.model_type == 'seed-llama-8b':
conv_seed_llama = conv_seed_vicuna
elif args.model_type == 'seed-llama-14b':
conv_seed_llama = conv_seed_llama2
else:
raise ValueError


def decode_image(encoded_image: str) -> Image:
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8'))
Expand Down Expand Up @@ -283,7 +291,7 @@ def http_bot(dialog_state, input_state, temperature, top_p, max_new_tokens, num_
'max_new_tokens': int(max_new_tokens),
'num_beams': int(num_beams)
})

print('request_address', args.request_address)
response = requests.request(method="POST", url=args.request_address, headers=headers, json=payload)
results = response.json()
print('response: ', {'text': results['text'], 'images_ids': results['images_ids'], 'error_msg': results['error_msg']})
Expand Down
18 changes: 13 additions & 5 deletions models/seed_llama_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def encode(self, image_torch):
if len(image_torch.shape) == 3:
image_torch = image_torch.unsqueeze(0)

img = image_torch.to(self.device)
# img = image_torch.to(self.device)
img = image_torch
if self.fp16:
img = img.half()
with torch.no_grad():
Expand Down Expand Up @@ -126,20 +127,25 @@ def __init__(self,
device='cuda',
fp16=True,
load_diffusion=False,
encoder_url=None,
**kwargs):
super().__init__(vocab_file, unk_token, bos_token, eos_token, pad_token, sp_model_kwargs, add_bos_token, add_eos_token,
clean_up_tokenization_spaces, **kwargs)
self.device = device
self.fp16 = fp16
self.pad_token = self.unk_token
self.load_diffusion = load_diffusion
self.encoder_url = encoder_url

self.load_image_tokenizer()

def load_image_tokenizer(self):
assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
if not hasattr(self, '_image_tokenizer'):
model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
if self.encoder_url is not None:
model_path = self.encoder_url
else:
model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
# diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
self._image_tokenizer = ImageTokenizer(model_path=model_path,
Expand All @@ -152,8 +158,10 @@ def load_image_tokenizer(self):
def image_tokenizer(self):
assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
if not hasattr(self, '_image_tokenizer'):

model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
if self.encoder_url is not None:
model_path = self.encoder_url
else:
model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
# diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
self._image_tokenizer = ImageTokenizer(model_path=model_path,
Expand Down Expand Up @@ -188,7 +196,7 @@ def encode_image(
if image_pil is not None:
image_torch = self.image_tokenizer.processor(image_pil)

image_torch = image_torch.to(self.device)
image_torch = image_torch.to(self.device)
return self.image_tokenizer.encode(image_torch)

def decode_image(self, indices, negative_indices=None, guidance_scale=10):
Expand Down
8 changes: 6 additions & 2 deletions models/seed_qformer/qformer_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from .blip2 import Blip2Base, disabled_train
from .vit import Block

from .utils import download_cached_file, is_url

class VectorQuantizer2(nn.Module):
"""
Expand Down Expand Up @@ -363,7 +363,11 @@ def from_pretrained(cls, pretrained_model_path, **kwargs):
max_txt_len=max_txt_len,
)

ckpt = torch.load(pretrained_model_path, map_location="cpu")
if pretrained_model_path.startswith('http'):
cached_file = download_cached_file(pretrained_model_path, check_hash=False, progress=True)
ckpt = torch.load(cached_file, map_location="cpu")
else:
ckpt = torch.load(pretrained_model_path, map_location="cpu")
missing, unexcepted = model.load_state_dict(ckpt, strict=False)
print('missing keys: ', len(missing), 'unexpected keys:', len(unexcepted))
return model
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ flask
gradio
timm
diffusers>=0.20.2
accelerate
einops
6 changes: 4 additions & 2 deletions scripts/start_backend.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
python3 gradio_demo/seed_llama_flask.py \
--image_transform configs/transform/clip_transform.yaml \
--tokenizer configs/tokenizer/seed_llama_tokenizer.yaml \
--model configs/llm/seed_llama_14b.yaml \
--model configs/llm/seed_llama_14b_8bit.yaml \
--port 7890 \
--llm_device cuda:0 \
--tokenizer_device cuda:1
--tokenizer_device cuda:0 \
--offload_encoder \
--offload_decoder
2 changes: 1 addition & 1 deletion scripts/start_frontend.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python3 gradio_demo/seed_llama_gradio.py
python3 gradio_demo/seed_llama_gradio.py --server_port 80 --request_address http://127.0.0.1:7890/generate --model_type seed-llama-14b

0 comments on commit 32e2906

Please sign in to comment.