Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to predownload data from s3 at the start of each checkpoint. #280

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions open_lm/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import numpy as np
import torch

from braceexpand import braceexpand
from pathlib import Path
from typing import List, Optional
from tqdm import tqdm

Expand Down Expand Up @@ -481,3 +483,58 @@ def _single_epoch_string(
shard_strings_per_source.append(shard_string_source)

return shard_strings_per_source, num_samples_per_source, next_shard_per_source


def download_data_to_local(shard_strings_per_source, temp_dir, only_rename=False, max_retries=3):
if not os.path.exists(temp_dir):
os.makedirs(temp_dir, exist_ok=True)

local_shard_strings_per_source = []

for shard_string in shard_strings_per_source:
if not shard_string.startswith("pipe:aws s3 cp "):
local_shard_strings_per_source.append(shard_string)
continue

shard_string = shard_string[len("pipe:aws s3 cp ") : -len(" -")]

shards = list(braceexpand(shard_string))
shard_directory = Path(shards[0][len("s3://") :]).parent
shard_ids = [Path(s).with_suffix("").name for s in shards]

if not only_rename:
retries = 0

while True:
aws_command = [
"aws",
"s3",
"cp",
"--recursive",
f"s3://{shard_directory}",
f"{temp_dir}",
"--exclude",
"*",
]
for sf in shard_ids:
aws_command.extend(["--include", f"{sf}.tar"])

result = subprocess.run(
aws_command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if result.returncode != 0:
if retries < max_retries:
retries += 1
else:
raise RuntimeError(
f"Error: Failed to download data to local storage: {result.stderr.decode('utf-8')}"
)
else:
break

local_shard_string = temp_dir + "{" + ",".join(shard_ids) + "}.tar"
local_shard_strings_per_source.append(local_shard_string)

return local_shard_strings_per_source
17 changes: 16 additions & 1 deletion open_lm/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import atexit
import logging
import os
import shutil
import re
import sys
import random
Expand Down Expand Up @@ -48,7 +49,7 @@

from open_lm.utils.transformers.hf_wrapper import create_wrapped_hf_model
from open_lm.data import get_data, get_wds_dataset
from open_lm.distributed import is_master, init_distributed_device, broadcast_object
from open_lm.distributed import is_master, is_local_master, init_distributed_device, broadcast_object
from open_lm.logger import setup_logging
from open_lm.params import parse_args
from open_lm.scheduler import cosine_lr, const_lr
Expand All @@ -57,6 +58,7 @@
from open_lm.file_utils import (
pt_load,
check_exists,
download_data_to_local,
start_sync_process,
remote_sync_with_expon_backoff,
get_metadata_file,
Expand Down Expand Up @@ -795,6 +797,12 @@ def main(args):
shard_shuffle_seed=args.shard_shuffle_seed,
)

if args.temp_local_data_dir is not None:
download_rank = is_master(args) if args.local_dir_shared_across_nodes else is_local_master(args)
train_data_string_per_source = download_data_to_local(
train_data_string_per_source, args.temp_local_data_dir, only_rename=not download_rank
)

# In the distributed case, make sure that all nodes receive the same string
if args.distributed:
all_source_strings = ["" for _ in range(args.world_size)]
Expand Down Expand Up @@ -919,6 +927,13 @@ def main(args):
f"{num_ckpt_too_few_tokens} checkpoints happened where the number of tokens seen was {1 - args.data_tolerate_error_p} of expected. This is likely due to transient errors e.g. reading from S3."
)

if args.temp_local_data_dir is not None:
cleanup_rank = is_master(args) if args.local_dir_shared_across_nodes else is_local_master(args)
if cleanup_rank:
shutil.rmtree(args.temp_local_data_dir)
if args.distributed:
dist.barrier()

if done_training:
if is_master(args):
logging.info("Model has seen the desired number of tokens. Ending training.")
Expand Down
11 changes: 11 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,17 @@ def parse_args(args):
default=0,
help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed",
)
parser.add_argument(
"--temp-local-data-dir",
type=str,
default=None,
help="If set, move the data to temporary local storage at the start of each epoch, to minimize s3 errors.",
)
parser.add_argument(
"--local-dir-shared-across-nodes",
action="store_true",
help="Whether the --temp-local-data-dir argument refers to a path seen by all nodes or by each node separately.",
)

add_model_args(parser)

Expand Down
Loading