From 8651a7789eb39c295547a19c0d32d37494e8bcf1 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Sun, 20 Oct 2024 18:13:00 +0200 Subject: [PATCH 1/7] add distributed inference example for llava_next --- .../inference/distributed/llava_next_video.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 examples/inference/distributed/llava_next_video.py diff --git a/examples/inference/distributed/llava_next_video.py b/examples/inference/distributed/llava_next_video.py new file mode 100644 index 00000000000..4c18e65843b --- /dev/null +++ b/examples/inference/distributed/llava_next_video.py @@ -0,0 +1,163 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fire +from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration +import numpy as np +import torch +import time +from accelerate import PartialState +import os +import av +from huggingface_hub import hf_hub_download +import json +from accelerate.utils import gather_object + +START_TIME = time.strftime("%Y%m%d_%H%M%S") +DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + + +""" +Example: + +accelerate launch --num_processes=2 llava_next_video.py +""" + + +def main( + model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf", + save_dir: str = "./evaluation/examples", + dtype: str = "fp16", + low_mem: bool = True, +): + # Start up the distributed environment without needing the Accelerator. + distributed_state = PartialState() + + processor = LlavaNextVideoProcessor.from_pretrained(model_name) + model = LlavaNextVideoForConditionalGeneration.from_pretrained( + model_name, torch_dtype=DTYPE_MAP[dtype], low_cpu_mem_usage=low_mem, device_map=distributed_state.device + ) + + if distributed_state.is_main_process: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + print(f"Directory '{save_dir}' created successfully.") + else: + print(f"Directory '{save_dir}' already exists.") + + # Load the video as an np.array, sampling uniformly 8 frames (can sample more for longer videos) + video_path = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset" + ) + container = av.open(video_path) + total_frames = container.streams.video[0].frames + indices = np.arange(0, total_frames, total_frames / 8).astype(int) + video = read_video_pyav(container, indices) + + conversations = [ + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Why is this video funny?"}, + {"type": "video"}, + ], + } + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Why is this video sad?"}, + {"type": "video"}, + ], + } + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do you think about this video?"}, + {"type": "video"}, + ], + } + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Do you like this video?"}, + {"type": "video"}, + ], + } + ], + ] + + formatted_prompts = [ + processor.apply_chat_template(conversations[i], add_generation_prompt=True) + for i in range(0, len(conversations)) + ] + + count = 0 + distributed_state.num_processes = len(formatted_prompts) + with distributed_state.split_between_processes(formatted_prompts) as prompts: + input = processor(text=prompts, videos=video, return_tensors="pt").to(model.device) + output = model.generate(**input, max_new_tokens=60) + generated_text = processor.decode(output[0][2:], skip_special_tokens=True) + + distributed_state.wait_for_everyone() + + answers = gather_object(generated_text) + input_prompts = gather_object(prompts) + + if distributed_state.is_main_process: + for ans, prompt in zip(answers, input_prompts): + count += 1 + example_file = f"example_{count}" + temp_dir = os.path.join(save_dir, example_file) + + metadata = { + "prompt": prompt, + "generated_answer": ans, + } + with open(temp_dir, "w") as f: + json.dump(metadata, f, indent=4) + + if distributed_state.is_main_process: + print(f">>> Video answer generation Finished. Saved in {save_dir}") + + +def read_video_pyav(container, indices): + """ + Decode the video with PyAV decoder. + Args: + container (`av.container.input.InputContainer`): PyAV container. + indices (`List[int]`): List of frame indices to decode. + Returns: + result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + """ + frames = [] + container.seek(0) + start_index = indices[0] + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= start_index and i in indices: + frames.append(frame) + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + +if __name__ == "__main__": + fire.Fire(main) From 0b727360ee883fba019432ad23b1dbc5db573119 Mon Sep 17 00:00:00 2001 From: VladOS95-cyber Date: Sun, 20 Oct 2024 20:00:23 +0200 Subject: [PATCH 2/7] some fixes and refactoring --- .../inference/distributed/llava_next_video.py | 71 +++++++++++-------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/examples/inference/distributed/llava_next_video.py b/examples/inference/distributed/llava_next_video.py index 4c18e65843b..50dc81face6 100644 --- a/examples/inference/distributed/llava_next_video.py +++ b/examples/inference/distributed/llava_next_video.py @@ -22,7 +22,9 @@ import av from huggingface_hub import hf_hub_download import json -from accelerate.utils import gather_object +import queue +from concurrent.futures import ThreadPoolExecutor +import pathlib START_TIME = time.strftime("%Y%m%d_%H%M%S") DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} @@ -31,7 +33,7 @@ """ Example: -accelerate launch --num_processes=2 llava_next_video.py +accelerate launch llava_next_video.py """ @@ -39,6 +41,7 @@ def main( model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf", save_dir: str = "./evaluation/examples", dtype: str = "fp16", + num_workers: int = 1, low_mem: bool = True, ): # Start up the distributed environment without needing the Accelerator. @@ -46,7 +49,7 @@ def main( processor = LlavaNextVideoProcessor.from_pretrained(model_name) model = LlavaNextVideoForConditionalGeneration.from_pretrained( - model_name, torch_dtype=DTYPE_MAP[dtype], low_cpu_mem_usage=low_mem, device_map=distributed_state.device + model_name, torch_dtype=dtype[DTYPE_MAP], low_cpu_mem_usage=low_mem, device_map=distributed_state.device ) if distributed_state.is_main_process: @@ -109,33 +112,43 @@ def main( for i in range(0, len(conversations)) ] - count = 0 - distributed_state.num_processes = len(formatted_prompts) - with distributed_state.split_between_processes(formatted_prompts) as prompts: - input = processor(text=prompts, videos=video, return_tensors="pt").to(model.device) - output = model.generate(**input, max_new_tokens=60) - generated_text = processor.decode(output[0][2:], skip_special_tokens=True) - - distributed_state.wait_for_everyone() - - answers = gather_object(generated_text) - input_prompts = gather_object(prompts) - - if distributed_state.is_main_process: - for ans, prompt in zip(answers, input_prompts): - count += 1 - example_file = f"example_{count}" - temp_dir = os.path.join(save_dir, example_file) - - metadata = { - "prompt": prompt, - "generated_answer": ans, - } - with open(temp_dir, "w") as f: - json.dump(metadata, f, indent=4) + def save_results(output_queue: queue.Queue, output_dir: pathlib.Path): + count = 0 + while True: + try: + item = output_queue.get(timeout=5) + if item is None: + break + example_file = f"example_{count}" + temp_dir = os.path.join(output_dir, example_file) + + metadata = { + "prompt": item[0], + "generated_answer": item[1], + } + with open(temp_dir, "w") as f: + json.dump(metadata, f, indent=4) + count += 1 + + except queue.Empty: + continue - if distributed_state.is_main_process: - print(f">>> Video answer generation Finished. Saved in {save_dir}") + distributed_state.num_processes = len(formatted_prompts) + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=num_workers) + save_future = save_thread.submit(save_results, output_queue, save_dir) + + try: + with distributed_state.split_between_processes(formatted_prompts) as prompt: + input = processor(text=prompt, videos=video, padding=True, return_tensors="pt").to(model.device) + output = model.generate(**input, max_new_tokens=60) + generated_text = processor.decode(output[0][2:], skip_special_tokens=True) + output_queue.put((prompt, generated_text)) + finally: + output_queue.put(None) + save_thread.shutdown(wait=True) + + save_future.result() def read_video_pyav(container, indices): From 48b1aa04299cfb4932737311aac284b3ae13eaaf Mon Sep 17 00:00:00 2001 From: VladOS95-cyber Date: Mon, 4 Nov 2024 18:50:28 +0100 Subject: [PATCH 3/7] add captions extraction --- .../inference/distributed/llava_next_video.py | 187 +++++++----------- 1 file changed, 69 insertions(+), 118 deletions(-) diff --git a/examples/inference/distributed/llava_next_video.py b/examples/inference/distributed/llava_next_video.py index 50dc81face6..dc8806ce90e 100644 --- a/examples/inference/distributed/llava_next_video.py +++ b/examples/inference/distributed/llava_next_video.py @@ -12,19 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import fire -from transformers import LlavaNextVideoProcessor, LlavaNextVideoForConditionalGeneration -import numpy as np -import torch -import time -from accelerate import PartialState -import os -import av -from huggingface_hub import hf_hub_download import json +import os +import pathlib import queue +import time from concurrent.futures import ThreadPoolExecutor -import pathlib + +from itertools import chain +import fire +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor + +from accelerate import PartialState + START_TIME = time.strftime("%Y%m%d_%H%M%S") DTYPE_MAP = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} @@ -33,13 +36,51 @@ """ Example: -accelerate launch llava_next_video.py +accelerate launch llava_next_video.py """ +def save_results(output_queue: queue.Queue, output_dir: pathlib.Path): + count = 0 + while True: + try: + item = output_queue.get(timeout=5) + if item is None: + break + example_file = f"example_{count}" + temp_dir = os.path.join(output_dir, example_file) + + metadata = { + "caption": item[0], + "generated_answer": item[1], + } + with open(temp_dir, "w") as f: + json.dump(metadata, f, indent=4) + count += 1 + + except queue.Empty: + continue + + +def get_batches(captions, batch_size): + num_batches = (len(captions) + batch_size - 1) // batch_size + batches = [] + + for i in range(num_batches): + start_index = i * batch_size + end_index = min((i + 1) * batch_size, len(captions)) + batch = captions[start_index:end_index] + batches.append(batch) + + return batches + + def main( model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf", save_dir: str = "./evaluation/examples", + max_captions: int = 10, + max_new_tokens: int = 100, + batch_size: int = 4, dtype: str = "fp16", num_workers: int = 1, low_mem: bool = True, @@ -49,7 +90,7 @@ def main( processor = LlavaNextVideoProcessor.from_pretrained(model_name) model = LlavaNextVideoForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype[DTYPE_MAP], low_cpu_mem_usage=low_mem, device_map=distributed_state.device + model_name, torch_dtype=DTYPE_MAP[dtype], low_cpu_mem_usage=low_mem, device_map=distributed_state.device ) if distributed_state.is_main_process: @@ -59,118 +100,28 @@ def main( else: print(f"Directory '{save_dir}' already exists.") - # Load the video as an np.array, sampling uniformly 8 frames (can sample more for longer videos) - video_path = hf_hub_download( - repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset" - ) - container = av.open(video_path) - total_frames = container.streams.video[0].frames - indices = np.arange(0, total_frames, total_frames / 8).astype(int) - video = read_video_pyav(container, indices) - - conversations = [ - [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Why is this video funny?"}, - {"type": "video"}, - ], - } - ], - [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Why is this video sad?"}, - {"type": "video"}, - ], - } - ], - [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What do you think about this video?"}, - {"type": "video"}, - ], - } - ], - [ - { - "role": "user", - "content": [ - {"type": "text", "text": "Do you like this video?"}, - {"type": "video"}, - ], - } - ], - ] - - formatted_prompts = [ - processor.apply_chat_template(conversations[i], add_generation_prompt=True) - for i in range(0, len(conversations)) - ] - - def save_results(output_queue: queue.Queue, output_dir: pathlib.Path): - count = 0 - while True: - try: - item = output_queue.get(timeout=5) - if item is None: - break - example_file = f"example_{count}" - temp_dir = os.path.join(output_dir, example_file) - - metadata = { - "prompt": item[0], - "generated_answer": item[1], - } - with open(temp_dir, "w") as f: - json.dump(metadata, f, indent=4) - count += 1 - - except queue.Empty: - continue - - distributed_state.num_processes = len(formatted_prompts) + captions = load_dataset("nkp37/OpenVid-1M", split="train")["caption"][:max_captions] + + # split long-text captions into small sentences + splitted_captions = list(chain.from_iterable([captions[i].split(".") for i in range(len(captions))])) + batches = get_batches(splitted_captions, batch_size) + output_queue = queue.Queue() save_thread = ThreadPoolExecutor(max_workers=num_workers) save_future = save_thread.submit(save_results, output_queue, save_dir) - - try: - with distributed_state.split_between_processes(formatted_prompts) as prompt: - input = processor(text=prompt, videos=video, padding=True, return_tensors="pt").to(model.device) - output = model.generate(**input, max_new_tokens=60) - generated_text = processor.decode(output[0][2:], skip_special_tokens=True) - output_queue.put((prompt, generated_text)) - finally: - output_queue.put(None) - save_thread.shutdown(wait=True) + for _, caption_batch in tqdm(enumerate(batches), total=len(batches)): + try: + with distributed_state.split_between_processes(caption_batch) as caption: + input = processor(caption, padding=True, return_tensors="pt").to(model.device) + output = model.generate(**input, max_new_tokens=max_new_tokens) + generated_text = processor.decode(output[0][2:], skip_special_tokens=True) + output_queue.put((caption, generated_text)) + finally: + output_queue.put(None) + save_thread.shutdown(wait=True) save_future.result() -def read_video_pyav(container, indices): - """ - Decode the video with PyAV decoder. - Args: - container (`av.container.input.InputContainer`): PyAV container. - indices (`List[int]`): List of frame indices to decode. - Returns: - result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). - """ - frames = [] - container.seek(0) - start_index = indices[0] - end_index = indices[-1] - for i, frame in enumerate(container.decode(video=0)): - if i > end_index: - break - if i >= start_index and i in indices: - frames.append(frame) - return np.stack([x.to_ndarray(format="rgb24") for x in frames]) - - if __name__ == "__main__": fire.Fire(main) From 9ecb971293b31237bed02614e913af6066dd58d5 Mon Sep 17 00:00:00 2001 From: VladOS95-cyber Date: Tue, 5 Nov 2024 09:13:30 +0100 Subject: [PATCH 4/7] small refactoring --- examples/inference/distributed/llava_next_video.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/examples/inference/distributed/llava_next_video.py b/examples/inference/distributed/llava_next_video.py index dc8806ce90e..4737bbd5d47 100644 --- a/examples/inference/distributed/llava_next_video.py +++ b/examples/inference/distributed/llava_next_video.py @@ -19,7 +19,6 @@ import time from concurrent.futures import ThreadPoolExecutor -from itertools import chain import fire import torch from datasets import load_dataset @@ -100,11 +99,9 @@ def main( else: print(f"Directory '{save_dir}' already exists.") - captions = load_dataset("nkp37/OpenVid-1M", split="train")["caption"][:max_captions] - - # split long-text captions into small sentences - splitted_captions = list(chain.from_iterable([captions[i].split(".") for i in range(len(captions))])) - batches = get_batches(splitted_captions, batch_size) + captions = load_dataset("nkp37/OpenVid-1M", split="train")["caption"] + reduced_captions = captions[: min(len(captions), max_captions)] + batches = get_batches(reduced_captions, batch_size) output_queue = queue.Queue() save_thread = ThreadPoolExecutor(max_workers=num_workers) @@ -114,7 +111,7 @@ def main( with distributed_state.split_between_processes(caption_batch) as caption: input = processor(caption, padding=True, return_tensors="pt").to(model.device) output = model.generate(**input, max_new_tokens=max_new_tokens) - generated_text = processor.decode(output[0][2:], skip_special_tokens=True) + generated_text = processor.batch_decode(output, skip_special_tokens=True) output_queue.put((caption, generated_text)) finally: output_queue.put(None) From 336def55c8d33d5c467545f27acc378fd6e3e346 Mon Sep 17 00:00:00 2001 From: VladOS95-cyber Date: Sun, 24 Nov 2024 14:50:44 +0100 Subject: [PATCH 5/7] add captions generation --- .../inference/distributed/llava_next_video.py | 105 ++++++++++++++---- 1 file changed, 86 insertions(+), 19 deletions(-) diff --git a/examples/inference/distributed/llava_next_video.py b/examples/inference/distributed/llava_next_video.py index 4737bbd5d47..0eb12cb105a 100644 --- a/examples/inference/distributed/llava_next_video.py +++ b/examples/inference/distributed/llava_next_video.py @@ -17,11 +17,13 @@ import pathlib import queue import time +import av from concurrent.futures import ThreadPoolExecutor import fire import torch -from datasets import load_dataset +from huggingface_hub import snapshot_download +import numpy as np from tqdm import tqdm from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor @@ -46,13 +48,11 @@ def save_results(output_queue: queue.Queue, output_dir: pathlib.Path): item = output_queue.get(timeout=5) if item is None: break + prompt, video, generated_text = item example_file = f"example_{count}" temp_dir = os.path.join(output_dir, example_file) - metadata = { - "caption": item[0], - "generated_answer": item[1], - } + metadata = {"prompt": prompt, "video": video, "generated_text": generated_text} with open(temp_dir, "w") as f: json.dump(metadata, f, indent=4) count += 1 @@ -61,23 +61,85 @@ def save_results(output_queue: queue.Queue, output_dir: pathlib.Path): continue -def get_batches(captions, batch_size): - num_batches = (len(captions) + batch_size - 1) // batch_size +def get_batches(processed_videos, batch_size): + num_batches = (len(processed_videos) + batch_size - 1) // batch_size batches = [] for i in range(num_batches): start_index = i * batch_size - end_index = min((i + 1) * batch_size, len(captions)) - batch = captions[start_index:end_index] + end_index = min((i + 1) * batch_size, len(processed_videos)) + batch = processed_videos[start_index:end_index] batches.append(batch) return batches +def read_video_pyav(container, indices): + """ + Decode the video with PyAV decoder. + Args: + container (`av.container.input.InputContainer`): PyAV container. + indices (`List[int]`): List of frame indices to decode. + Returns: + result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + """ + frames = [] + container.seek(0) + start_index = indices[0] + end_index = indices[-1] + for i, frame in enumerate(container.decode(video=0)): + if i > end_index: + break + if i >= start_index and i in indices: + frames.append(frame) + return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + +def get_video_paths(video_dir): + """Get paths to all video files in the directory and its subdirectories.""" + video_extensions = (".mp4", ".avi", ".mov", ".mkv") # Add more extensions if needed + video_paths = [] + + for root, _, files in os.walk(video_dir): + for file in files: + if file.lower().endswith(video_extensions): + video_paths.append(os.path.join(root, file)) + + return video_paths + + +def process_videos(video_paths, processor, prompt): + """Process a batch of videos and prepare them for the model.""" + batch_inputs = [] + + for video_path in video_paths: + try: + container = av.open(video_path) + total_frames = container.streams.video[0].frames + indices = np.arange(0, total_frames, total_frames / 8).astype(int) + clip = read_video_pyav(container, indices) + container.close() + + processed = processor(text=prompt, videos=clip, return_tensors="pt") + batch_inputs.append( + { + "input_ids": processed["input_ids"], + "pixel_values_videos": processed["pixel_values_videos"], + "video": video_path, + } + ) + + except Exception as e: + print(f"Error processing video {video_path}: {str(e)}") + continue + + return batch_inputs + + def main( model_name: str = "llava-hf/LLaVA-NeXT-Video-7B-hf", save_dir: str = "./evaluation/examples", - max_captions: int = 10, + prompt: str = "USER: