Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLaVA-OV] Support LLaVA-OneVision eval loss and data filtering #305

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions flagscale/train/models/llava_onevision/dataloader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ def datasets_provider(worker_config=None):

def train_valid_test_dataloaders_provider(train_val_test_num_samples):
"""Build multimodal train, validation and test dataloaders."""
args = get_args()

# In llava-ov, set skip_train False to eval each sample.
# Training while evaluating is not supported yet.
if args.skip_train:
args.eval_iters = args.train_iters

if get_tensor_model_parallel_rank() != 0:
return None, None, None

args = get_args()

worker_debug_path = None
worker_log_level = 0

Expand Down Expand Up @@ -110,11 +115,18 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples):
"loading dataloader checkpoint failed. Skipping. " + str(e)
)
if args.training_dataset_only:
return (
EnergonDataloader(train_dataloader),
EnergonDataloader(None),
EnergonDataloader(None),
)
if not args.skip_train:
return (
EnergonDataloader(train_dataloader),
None,
None,
)
else:
return (
None,
EnergonDataloader(train_dataloader),
None,
)
valid_dataloader = [
EnergonDataloader(get_loader(valid_ds, worker_config=worker_config))
for valid_ds in valid_ds1
Expand Down
19 changes: 16 additions & 3 deletions flagscale/train/models/llava_onevision/dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class AnyResTaskSample:
images: List[torch.Tensor]
image_sizes: List[torch.Tensor]
modalities: List[torch.Tensor]
ids: torch.Tensor
ids_shape: torch.Tensor

# Typing for the resulting batch data after encode_batch()
@dataclass
Expand All @@ -50,6 +52,8 @@ class AnyResTaskBatch(Batch):
image_sizes: torch.Tensor
split_image_sizes: torch.Tensor
modalities: torch.Tensor
ids: torch.Tensor
ids_shape: torch.Tensor


class AnyResTaskEncoder(DefaultTaskEncoder[InterleavedSample, InterleavedSample, AnyResTaskBatch, dict]):
Expand Down Expand Up @@ -84,6 +88,10 @@ def encode_interleaved(self, sample: InterleavedSample):
else:
assert ValueError("The sequence must have 4 or 5 elements, but got {len(sample.sequence)}.")

id = "".join(sample.__key__.split("/")[1:])
ids_tensor = torch.tensor([ord(c) for c in id], dtype=torch.uint8)
ids_shape = torch.tensor(ids_tensor.shape)

# process modalities to tensor
modalities_list = []
for modality in modalities:
Expand All @@ -107,7 +115,9 @@ def encode_interleaved(self, sample: InterleavedSample):
labels_shape=torch.tensor(labels.shape),
images=images,
image_sizes=image_sizes,
modalities=modalities
modalities=modalities,
ids=ids_tensor,
ids_shape=ids_shape
)

def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
Expand All @@ -121,7 +131,8 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
# Adapt video data by decord
image_sizes = torch.stack([image_sizes if len(image_sizes.shape) == 1 else torch.tensor((1, image_sizes.item())) for s in samples for image_sizes in s.image_sizes], dim=0)
modalities = torch.stack([modalities for s in samples for modalities in s.modalities], dim=0)

ids = torch.cat([s.ids.flatten() for s in samples], dim=0)
ids_shape = torch.stack([s.ids_shape for s in samples], dim=0)
batch = AnyResTaskBatch(
__keys__=[s.__key__ for s in samples],
__subflavors__=[s.__subflavors__ for s in samples],
Expand All @@ -132,7 +143,9 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
images=images,
image_sizes=image_sizes,
split_image_sizes=split_image_sizes,
modalities=modalities
modalities=modalities,
ids=ids,
ids_shape=ids_shape,
)

return batch
Expand Down
34 changes: 30 additions & 4 deletions flagscale/train/train_llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def get_batch(data_iterator):
labels_shape = tensor_parallel.broadcast_data(["labels_shape"], data, torch.int64)[
"labels_shape"
]
ids = tensor_parallel.broadcast_data(["ids"], data, torch.uint8)["ids"]
ids_shape = tensor_parallel.broadcast_data(["ids_shape"], data, torch.int64)[
"ids_shape"
]
images = tensor_parallel.broadcast_data(["images"], data, torch.float32)["images"]
split_image_sizes = tensor_parallel.broadcast_data(
["split_image_sizes"], data, torch.int64
Expand Down Expand Up @@ -229,6 +233,17 @@ def get_batch(data_iterator):
assert start_idx == labels.numel()
labels = labels_list

# ids to list
ids_list = []
start_idx = 0
for shape in ids_shape:
num_elements = torch.prod(shape).item()
sub_tensor = ids[start_idx : start_idx + num_elements].reshape(shape.tolist())
ids_list.append(sub_tensor)
start_idx += num_elements
assert start_idx == ids.numel()
ids = ids_list

# images to list
images_list = []
start_idx = 0
Expand Down Expand Up @@ -288,7 +303,7 @@ def get_batch(data_iterator):
attention_mask = input_ids.ne(tokenizer.pad_token_id)
torch.cuda.nvtx.range_pop()

return input_ids, labels, attention_mask, images, image_sizes, modalities
return input_ids, labels, attention_mask, images, image_sizes, modalities, ids


def pad_sequence(input_ids, batch_first, padding_value, tokenizer):
Expand Down Expand Up @@ -316,7 +331,13 @@ def get_image_token_count():
return num_image_tokens


def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tensor):
def loss_func(
labels: torch.Tensor,
loss_mask: torch.Tensor,
ids,
logits: torch.Tensor,
):
args = get_args()
labels = labels.transpose(0, 1).contiguous() # [b s] => [s b]
logits = logits.transpose(0, 1).contiguous() # [b s h] => [s b h]

Expand All @@ -334,6 +355,11 @@ def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tenso
loss = torch.mean(losses)

# Reduce loss for logging.
if args.skip_train:
assert isinstance(ids, list) and len(ids) == 1
id = "".join([chr(c) for c in ids[0].cpu().numpy()])
print(f"Evaluating id: {id}, loss: {loss.detach().clone().item()}", flush=True)

averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}

Expand All @@ -354,7 +380,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):

# Get the batch.
timers("batch-generator", log_level=2).start()
input_ids, labels, attention_mask, images, image_sizes, modalities = get_batch(
input_ids, labels, attention_mask, images, image_sizes, modalities, ids = get_batch(
data_iterator
)
if "text" in modalities and ("image" in modalities or "video" in modalities):
Expand All @@ -367,7 +393,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):
input_ids, labels, attention_mask, images, image_sizes, modalities
)

return output_tensor, partial(loss_func, labels, loss_mask)
return output_tensor, partial(loss_func, labels, loss_mask, ids)


def add_multimodal_extra_args(parser):
Expand Down
99 changes: 99 additions & 0 deletions tools/datasets/llava_onevision/filter_by_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

import os
import json
import argparse
import webdataset as wds
import tarfile
from multiprocessing import Pool


def find_tar_files(input_dir):
tar_files = []
for root, dirs, files in os.walk(input_dir):
for file in files:
if file.endswith('.tar'):
input_tarfile = os.path.join(root, file)
tar_files.append(input_tarfile)
assert tar_files
return tar_files


def process_tar_files_in_node(tar_files, input_dir, output_dir, ids_to_keep):
num_tar_files_in_node = len(tar_files)
assert num_tar_files_in_node > 0
if num_tar_files_in_node == 1:
num_process = 1
else:
num_process = num_tar_files_in_node // 2
num_process = 8 if num_process >= 8 else num_process

num_tar_files_per_process = num_tar_files_in_node // num_process
for i in range(num_process):
start = i * num_tar_files_per_process
end = (i + 1) * num_tar_files_per_process if i!= (num_process-1) else num_tar_files_in_node
process_files = tar_files[start:end]
with Pool(num_process) as p:
p.starmap(process_tarfile, [(tf, input_dir, output_dir, ids_to_keep) for tf in process_files])


def process_tarfile(input_tarfile, input_dir, output_dir, ids_to_keep):
print(f"Process {os.getpid()} is processing {input_tarfile}")
relative_path = os.path.relpath(input_tarfile, input_dir)
output_tarfile = os.path.join(output_dir, relative_path)
output_tarfile_dir = os.path.dirname(output_tarfile)
os.makedirs(output_tarfile_dir, exist_ok=True)

dataset = wds.WebDataset(input_tarfile, shardshuffle=False)
keep_samples = []

for sample in dataset:
if sample["__key__"] not in ids_to_keep:
print(f"id {sample['__key__']} is filtered out.")
else:
new_sample = {"__key__": sample["__key__"], "sequence.pyd": sample["sequence.pyd"]}
keep_samples.append(new_sample)
if not keep_samples:
print(f"All id in {input_tarfile} are filtered out.")
else:
output_dataset = wds.TarWriter(output_tarfile)
print(f"Writing {input_tarfile} to {output_tarfile} ...")
for sample in keep_samples:
output_dataset.write(sample)
output_dataset.close()
print(f"Writing {output_tarfile} done.")

def main():
parser = argparse.ArgumentParser(description='Filter a webdataset and save the filtered samples.')
parser.add_argument('--input_dir', type=str, help='Original Directory')
parser.add_argument('--output_dir', type=str, help='Directory to save the filtered data')
parser.add_argument('--json_file', type=str, help='Path to the JSON file containing ids')
args = parser.parse_args()

input_dir = args.input_dir
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
json_file = args.json_file

# Node size
size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
# Node rank
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))

with open(json_file, 'r') as f:
ids_to_keep = json.load(f)["ids"]

tar_files = find_tar_files(input_dir)
num_tar_files = len(tar_files)
assert num_tar_files > size
num_tar_files_per_node = num_tar_files // size
start_index = rank * num_tar_files_per_node
end_index = (rank + 1) * num_tar_files_per_node if rank!= size - 1 else num_tar_files
node_tar_files = tar_files[start_index:end_index]
print(f"node_tar_files: ", node_tar_files)
process_tar_files_in_node(node_tar_files, input_dir, output_dir, ids_to_keep)
print("Done")



if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions tools/datasets/llava_onevision/filter_to_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
import re
import json
import argparse
from typing import Dict


def main():
parser = argparse.ArgumentParser(description='Grep id and loss from log files.')
parser.add_argument('--input_dir', type=str, help='Directory to search log files.')
parser.add_argument('--output', type=str, help='Path to save the result.')
args = parser.parse_args()

result_dict: Dict[str, float] = {}
for root, dirs, files in os.walk(args.input_dir):
for file in files:
if file.endswith('.log'):
file_path = os.path.join(root, file)
with open(file_path, 'r') as f:
lines = f.readlines()
for line in lines:
match = re.search(r'Evaluating id: (\d+), loss: ([\d.]+)', line)
if match:
evaluating_id = match.group(1)
loss = float(match.group(2))
if evaluating_id in result_dict:
assert loss == result_dict[evaluating_id]
# Customize filtering rules such as
# if loss < 0.5:
# result_dict[evaluating_id] = loss

# NOTE: No filtering currently, Comment out if Customize
result_dict[evaluating_id] = loss

ids = list(result_dict.keys())
print("Keep id count: ", len(ids))
result = {"ids": ids}
assert args.output.endswith(".json")
with open(args.output, 'w') as f:
json.dump(result, f, indent=4)
print("Done")


if __name__ == "__main__":
main()
Loading
Loading