diff --git a/scripts/seed_llama_inference_14B.py b/scripts/seed_llama_inference_14B.py new file mode 100644 index 0000000..054b614 --- /dev/null +++ b/scripts/seed_llama_inference_14B.py @@ -0,0 +1,120 @@ +import hydra + +import pyrootutils +import os +import torch + +from omegaconf import OmegaConf +import json +from typing import Optional +import transformers +from PIL import Image +from torchvision.transforms.functional import InterpolationMode + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +BOI_TOKEN = '' +EOI_TOKEN = '' +IMG_TOKEN = '' + +IMG_FLAG = '' +NUM_IMG_TOKNES = 32 +NUM_IMG_CODES = 8192 +image_id_shift = 32000 + + + + +def generate(tokenizer, input_tokens, generation_config, model): + + input_ids = tokenizer(input_tokens, add_special_tokens=False, return_tensors='pt').input_ids + input_ids = input_ids.to("cuda") + + generate_ids = model.generate( + input_ids=input_ids, + **generation_config + ) + generate_ids = generate_ids[0][input_ids.shape[1]:] + + return generate_ids + +def decode_image_text(generate_ids, tokenizer, save_path=None): + + boi_list = torch.where(generate_ids == tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0])[0] + eoi_list = torch.where(generate_ids == tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0])[0] + + if len(boi_list) == 0 and len(eoi_list) == 0: + text_ids = generate_ids + texts = tokenizer.decode(text_ids, skip_special_tokens=True) + print(texts) + + else: + boi_index = boi_list[0] + eoi_index = eoi_list[0] + + text_ids = generate_ids[:boi_index] + if len(text_ids) != 0: + texts = tokenizer.decode(text_ids, skip_special_tokens=True) + print(texts) + + image_ids = (generate_ids[boi_index+1:eoi_index] - image_id_shift).reshape(1,-1) + + images = tokenizer.decode_image(image_ids) + + images[0].save(save_path) + + +device = "cuda" + +tokenizer_cfg_path = 'configs/tokenizer/seed_llama_tokenizer.yaml' +tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path) +tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=device, load_diffusion=True) + +transform_cfg_path = 'configs/transform/clip_transform.yaml' +transform_cfg = OmegaConf.load(transform_cfg_path) +transform = hydra.utils.instantiate(transform_cfg) + +model_cfg = OmegaConf.load('configs/llm/seed_llama_14b.yaml') +model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16) +model = model.eval().to(device) + +generation_config = { + 'temperature': 1.0, + 'num_beams': 1, + 'max_new_tokens': 512, + 'top_p': 0.5, + 'do_sample': True + } + +s_token = "[INST] " +e_token = " [/INST]" +sep = "\n" + + +### visual question answering +image_path = "images/cat.jpg" +image = Image.open(image_path).convert('RGB') +image_tensor = transform(image).to(device) +img_ids = tokenizer.encode_image(image_torch=image_tensor) +img_ids = img_ids.view(-1).cpu().numpy() +img_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(item) for item in img_ids]) + EOI_TOKEN + +question = "What is this animal?" + +input_tokens = tokenizer.bos_token + s_token + img_tokens + question + e_token + sep +generate_ids = generate(tokenizer, input_tokens, generation_config, model) +decode_image_text(generate_ids, tokenizer) + +### text-to-image generation +prompt = "Can you generate an image of a dog on the green grass?" +input_tokens = tokenizer.bos_token + s_token + prompt + e_token + sep +generate_ids = generate(tokenizer, input_tokens, generation_config, model) +save_path = 'dog.jpg' +decode_image_text(generate_ids, tokenizer, save_path) + +### multimodal prompt image generation +instruction = "Can you make the cat wear sunglasses?" +input_tokens = tokenizer.bos_token + s_token + img_tokens + instruction + e_token + sep +generate_ids = generate(tokenizer, input_tokens, generation_config, model) +save_path = 'cat_sunglasses.jpg' +decode_image_text(generate_ids, tokenizer, save_path) \ No newline at end of file diff --git a/scripts/seed_llama_inference_8B.py b/scripts/seed_llama_inference_8B.py new file mode 100644 index 0000000..73b8960 --- /dev/null +++ b/scripts/seed_llama_inference_8B.py @@ -0,0 +1,120 @@ +import hydra + +import pyrootutils +import os +import torch + +from omegaconf import OmegaConf +import json +from typing import Optional +import transformers +from PIL import Image +from torchvision.transforms.functional import InterpolationMode + +pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) + +BOI_TOKEN = '' +EOI_TOKEN = '' +IMG_TOKEN = '' + +IMG_FLAG = '' +NUM_IMG_TOKNES = 32 +NUM_IMG_CODES = 8192 +image_id_shift = 32000 + + + + +def generate(tokenizer, input_tokens, generation_config, model): + + input_ids = tokenizer(input_tokens, add_special_tokens=False, return_tensors='pt').input_ids + input_ids = input_ids.to("cuda") + + generate_ids = model.generate( + input_ids=input_ids, + **generation_config + ) + generate_ids = generate_ids[0][input_ids.shape[1]:] + + return generate_ids + +def decode_image_text(generate_ids, tokenizer, save_path=None): + + boi_list = torch.where(generate_ids == tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0])[0] + eoi_list = torch.where(generate_ids == tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0])[0] + + if len(boi_list) == 0 and len(eoi_list) == 0: + text_ids = generate_ids + texts = tokenizer.decode(text_ids, skip_special_tokens=True) + print(texts) + + else: + boi_index = boi_list[0] + eoi_index = eoi_list[0] + + text_ids = generate_ids[:boi_index] + if len(text_ids) != 0: + texts = tokenizer.decode(text_ids, skip_special_tokens=True) + print(texts) + + image_ids = (generate_ids[boi_index+1:eoi_index] - image_id_shift).reshape(1,-1) + + images = tokenizer.decode_image(image_ids) + + images[0].save(save_path) + + +device = "cuda" + +tokenizer_cfg_path = 'configs/tokenizer/seed_llama_tokenizer.yaml' +tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path) +tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=device, load_diffusion=True) + +transform_cfg_path = 'configs/transform/clip_transform.yaml' +transform_cfg = OmegaConf.load(transform_cfg_path) +transform = hydra.utils.instantiate(transform_cfg) + +model_cfg = OmegaConf.load('configs/llm/seed_llama_8b.yaml') +model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16) +model = model.eval().to(device) + +generation_config = { + 'temperature': 1.0, + 'num_beams': 1, + 'max_new_tokens': 512, + 'top_p': 0.5, + 'do_sample': True + } + +s_token = "USER:" +e_token = "ASSISTANT:" +sep = "\n" + + +### visual question answering +image_path = "images/cat.jpg" +image = Image.open(image_path).convert('RGB') +image_tensor = transform(image).to(device) +img_ids = tokenizer.encode_image(image_torch=image_tensor) +img_ids = img_ids.view(-1).cpu().numpy() +img_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(item) for item in img_ids]) + EOI_TOKEN + +question = "What is this animal?" + +input_tokens = tokenizer.bos_token + s_token + " " + img_tokens + question + sep + e_token +generate_ids = generate(tokenizer, input_tokens, generation_config, model) +decode_image_text(generate_ids, tokenizer) + +### text-to-image generation +prompt = "Can you generate an image of a dog on the green grass?" +input_tokens = tokenizer.bos_token + s_token + " " + prompt + sep + e_token +generate_ids = generate(tokenizer, input_tokens, generation_config, model) +save_path = 'dog.jpg' +decode_image_text(generate_ids, tokenizer, save_path) + +### multimodal prompt image generation +instruction = "Can you make the cat wear sunglasses?" +input_tokens = tokenizer.bos_token + s_token + " " + img_tokens + instruction + sep + e_token +generate_ids = generate(tokenizer, input_tokens, generation_config, model) +save_path = 'cat_sunglasses.jpg' +decode_image_text(generate_ids, tokenizer, save_path) \ No newline at end of file