From cb036c02201dc3404f1e78f28ff52ae19ddd616a Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 11 Jul 2023 16:35:21 +0800 Subject: [PATCH] add redpajama data preprocessing code (#12) * add testing scripts * remove temp-dir for worker * remove test files * add redpajama dp code * ignore all notebook files * update streaming code * add write-on-host for streaming * better line alignment * move files * rename folder * rename folder and add group_files * debug * add recovery test scripts * add additional python packages * add test flag * add README and some minor fixes * change the image name * change the directory back * add training stop for the second * fix typo * add data source support * clean up a bit * restructure folders * restructure files * add script headers * reorder and add READMEs * revert back due to file movements * fix typo * fix lib import * enable mounting localdisk * change name of cc * fix dtype * performance optimization for streaming * use the latest ray * change node * add new files * bug fix * add nltk * fix hdfs after re-order folders * set default to false * use variables instead of credentials * change the training config path * update README --- .gitignore | 1 + Finetune/README.md | 42 +- Finetune/llm_pretrain_template.conf | 31 +- Finetune/plugin/trainer/pretrainer.py | 36 +- Finetune/workflow.yaml | 23 - tools/pretrain_recovery_test/README.md | 14 + tools/pretrain_recovery_test/compare_logs.py | 90 +++ tools/pretrain_recovery_test/test_end2end.sh | 52 ++ tools/redpajama_data_processing/README.md | 91 +++ .../redpajama_data_processing/count_tokens.py | 25 + .../redpajama_data_processing/group_files.py | 84 +++ .../indexed_dataset.py | 604 ++++++++++++++++++ .../merge_datasets.py | 68 ++ .../preprocess_data.py | 384 +++++++++++ .../preprocess_full.py | 248 +++++++ tools/redpajama_data_processing/run-dp.sh | 10 + .../workload_in_containers}/Dockerfile | 5 +- tools/workload_in_containers/README.md | 22 + .../workload_in_containers}/build-image.sh | 2 +- .../configs/core-site.xml | 2 +- .../configs/hadoop-env.sh | 0 .../configs/hdfs-site.xml | 4 +- .../configs/mapred-site.xml | 0 .../workload_in_containers}/configs/workers | 0 .../configs/yarn-site.xml | 0 .../workload_in_containers/launch_workload.py | 8 +- .../workload_in_containers}/run-hdfs.sh | 12 +- .../run-ray-cluster.sh | 14 +- tools/workload_in_containers/workload.yaml | 36 ++ 29 files changed, 1825 insertions(+), 83 deletions(-) mode change 100644 => 100755 Finetune/README.md mode change 100644 => 100755 Finetune/llm_pretrain_template.conf mode change 100644 => 100755 Finetune/plugin/trainer/pretrainer.py delete mode 100644 Finetune/workflow.yaml create mode 100644 tools/pretrain_recovery_test/README.md create mode 100644 tools/pretrain_recovery_test/compare_logs.py create mode 100755 tools/pretrain_recovery_test/test_end2end.sh create mode 100644 tools/redpajama_data_processing/README.md create mode 100644 tools/redpajama_data_processing/count_tokens.py create mode 100755 tools/redpajama_data_processing/group_files.py create mode 100644 tools/redpajama_data_processing/indexed_dataset.py create mode 100644 tools/redpajama_data_processing/merge_datasets.py create mode 100644 tools/redpajama_data_processing/preprocess_data.py create mode 100644 tools/redpajama_data_processing/preprocess_full.py create mode 100755 tools/redpajama_data_processing/run-dp.sh rename {Finetune => tools/workload_in_containers}/Dockerfile (86%) mode change 100644 => 100755 create mode 100644 tools/workload_in_containers/README.md rename {Finetune => tools/workload_in_containers}/build-image.sh (90%) rename {Finetune => tools/workload_in_containers}/configs/core-site.xml (95%) rename {Finetune => tools/workload_in_containers}/configs/hadoop-env.sh (100%) rename {Finetune => tools/workload_in_containers}/configs/hdfs-site.xml (94%) rename {Finetune => tools/workload_in_containers}/configs/mapred-site.xml (100%) rename {Finetune => tools/workload_in_containers}/configs/workers (100%) rename {Finetune => tools/workload_in_containers}/configs/yarn-site.xml (100%) rename Finetune/launch_workflow.py => tools/workload_in_containers/launch_workload.py (92%) rename {Finetune => tools/workload_in_containers}/run-hdfs.sh (78%) rename {Finetune => tools/workload_in_containers}/run-ray-cluster.sh (91%) create mode 100755 tools/workload_in_containers/workload.yaml diff --git a/.gitignore b/.gitignore index bee8a64b7..ae51dc12d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ __pycache__ +**.ipynb \ No newline at end of file diff --git a/Finetune/README.md b/Finetune/README.md old mode 100644 new mode 100755 index 633dea9b4..886dafea1 --- a/Finetune/README.md +++ b/Finetune/README.md @@ -1,6 +1,5 @@ -## Accelerate + Ray -### 1. Prepare environment -### Bare-metal +## 1. Prepare environment +### 1.1 Bare-metal Follow [LLM Finetune](https://wiki.ith.intel.com/pages/viewpage.action?spaceKey=AppliedML&title=LLM+Finetune). Please change ``huggingface accelerate`` repo to: [huggingface accelerate](https://github.com/KepingYan/accelerate) branch: FSDP_CPU @@ -11,22 +10,8 @@ pip install -U "ray[default] @ LINK_TO_WHEEL.whl" pip install --pre raydp pip install "ray[tune]" tabulate tensorboard ``` - -### Using Docker -```bash -# on head node -git clone https://github.com/intel-sandbox/llm-ray.git -cd llm-ray/Finetune -./build-image.sh -# save docker image -docker save -o ray-image.tar ray-llm:latest -# copy over to worker nodes, this is an optional step if all your cluster nodes are NFS-shared -scp ray-image.tar : -# on worker nodes -docker load -i ray-image.tar -``` - -### 2. Enable torch_ccl [optional] +## 2. Accelerate + Ray +### 2.1 Enable torch_ccl [optional] ```python from raydp.torch.config import TorchConfig @@ -41,7 +26,7 @@ def train_fashion_mnist(...): ... ``` -### 3. Set parameters [optional] +### 2.2 Set parameters [optional] - FSDP parameters ```python trainer = AccelerateTrainer( @@ -73,19 +58,15 @@ def train_fashion_mnist(...): } ``` -### 5. Test Ray TorchTrainer example -#### Bare-metal +### 2.3 Test Ray TorchTrainer example +#### 2.3.1 Bare-metal ```bash oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)") && source $oneccl_bindings_for_pytorch_path/env/setvars.sh python -u run_clm_no_trainer_ray.py --model_name_or_path EleutherAI/gpt-j-6B --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --per_device_train_batch_size 2 --per_device_eval_batch_size 4 --num_train_epochs 1 --address 10.165.9.53 --num_workers 2 ``` -#### Using Docker -```bash -python launch_workflow.py -w workflow.yaml -``` -## FSDP_CPU + Ray -### 1. Enable fsdp_cpu in Ray +## 3. FSDP_CPU + Ray +### 3.1 Enable fsdp_cpu in Ray Edit codes in train_loop_utils.py ```python class _TorchAccelerator(Accelerator): @@ -112,7 +93,7 @@ def train_func(config: Dict): ... ``` -### 2. enable torch_ccl in Ray +### 3.2 enable torch_ccl in Ray ```bash pip install --pre raydp ``` @@ -157,11 +138,12 @@ File "env/lib/python3.7/site-packages/ray/train/_internal/worker_group.py", line RuntimeError: no support for _allgather_base in Gloo process group ``` -### 3. Test Fashion MNIST example +### 3.3 Test Fashion MNIST example ```python python run_minist_fsdp.py ``` + ## Memory Status Reference to Applied Machine Learning team ([intel-sandbox/HuggingFace](https://github.com/intel-sandbox/HuggingFace/tree/main/test/memory)) - First run finetune code and get the pid of a Ray worker process. diff --git a/Finetune/llm_pretrain_template.conf b/Finetune/llm_pretrain_template.conf old mode 100644 new mode 100755 index c80736854..6cd3b075e --- a/Finetune/llm_pretrain_template.conf +++ b/Finetune/llm_pretrain_template.conf @@ -14,10 +14,10 @@ # The type of dataset, now only HuggingfaceDataset is supported. "type": "GroupDataset", # The name/path of dataset in huggingface. - "path": "/mnt/DP_disk2/yuliang/workspace/data/group_token5", + "path": "/home/user/tmp/pretrain_data", # Whether to use the datasets.load_from_disk() interface to load data. "load_from_disk": False, - # Config of dataset, all items will be transfered to datasets.load_dataset() or datasets.load_from_disk(). + # Config of dataset, all items will be transfscered to datasets.load_dataset() or datasets.load_from_disk(). "load_config" : { "streaming": True } @@ -58,12 +58,14 @@ "num_train_epochs": 1, # The max training step of each epoch, if set to None means unlimited. # In most cases this item is for debugging. - "max_train_step_per_episode": None, + "max_train_step_per_episode": 2, # The max evaluating step of each epoch, if set to None means unlimited. # In most cases this item is for debugging. "max_eval_step_per_episode": 0, # Output directory. Only absolute path is supported. "output": "/tmp/output", + # directory to save stepwise training states. this param is mainly used for recovery validation + "save_state_path": "/home/user/tmp/state", "dataprocesser": { # The type of dataprocesser. "type": "PlainIDProcesser", @@ -78,22 +80,29 @@ }, "checkpoint": { # The root path of checkpoint. Only absolute path is supported - "root_path": "/tmp/checkpoint", + "root_path": "/home/user/tmp/checkpoint", #"step": 39 - } + }, + "lr_scheduler": { + "enable": True, + "lr_scheduler_type": "linear", + "max_train_steps": 50 + }, }, # Ray related configuration, Only used when mode is set to ray "ray_config": { - # The config of ray.init. All items will be tranfered to ray.init(). + # The config of ray.init. All items will be transferred to ray.init(). # More information can refer to https://docs.ray.io/en/latest/ray-core/api/doc/ray.init.html "init": { # Environment variables for ray workers "runtime_env": { "env_vars": { - "OMP_NUM_THREADS": "56", + "OMP_NUM_THREADS": "60", "ACCELERATE_USE_CPU": "True", "ACCELERATE_MIXED_PRECISION": "no", - "CCL_WORKER_COUNT": "2", # CCL setting + "FI_PROVIDER": "tcp", # Network setting + "FI_TCP_IFACE": "ens39f0", + "CCL_WORKER_COUNT": "4", # CCL setting #"CCL_LOG_LEVEL": "info", "WORLD_SIZE": "2", # Enable multi-process } @@ -101,14 +110,14 @@ # The address of the Ray cluster to connect to. "address": "auto", # The IP address of the node that we are on. - "_node_ip_address": "127.0.0.1", + "_node_ip_address": "10.165.9.53", }, "scaling_config": { # Number of worker. - "num_workers": 2, + "num_workers": 4, # The amount of resources per worker. "resources_per_worker": { - "CPU": 56 + "CPU": 60 }, # The placement strategy to use for the placement group of the Ray actors. "placement_strategy": "SPREAD" diff --git a/Finetune/plugin/trainer/pretrainer.py b/Finetune/plugin/trainer/pretrainer.py old mode 100644 new mode 100755 index 4a898e066..ff07be952 --- a/Finetune/plugin/trainer/pretrainer.py +++ b/Finetune/plugin/trainer/pretrainer.py @@ -1,11 +1,14 @@ import os import math import time +import json +import shutil import torch import transformers - + from ray.air.checkpoint import Checkpoint +from pathlib import Path from .. import dataprocesser from .trainer import Trainer @@ -157,18 +160,43 @@ def prepare(self, model, tokenizer, dataset, optimizer, accelerator): train_dataloader, eval_dataloader, ) + def _check_and_mkdir(self, path): + path = Path(path) + if not path.exists(): + path.mkdir(parents=True) + + def _write_json(self, target_dict, save_path): + json_object = json.dumps(target_dict, indent=4) + with open(save_path, "w") as outfile: + outfile.write(json_object) + def train(self): num_train_epochs = self.config.get("num_train_epochs", 1) checkpoint = self.config.get("checkpoint") log_step = self.config.get("log_step", 1) max_train_step_per_episode = self.config.get("max_train_step_per_episode") max_eval_step_per_episode = self.config.get("max_eval_step_per_episode") + save_state_path = self.config.get("save_state_path") + + if save_state_path is not None and int(self.rank) == 0: + self._check_and_mkdir(save_state_path) + training_state = {} + else: + training_state = None + for idx in range(self.starting_episode, len(self.train_dataloader), 1): logger.info(f"start train episode {idx}") + if training_state is not None and int(self.rank) == 0: + training_state[f'episode_{idx}'] = {} self.model.train() current_train_dataloader = self.train_dataloader[idx] start = time.time() for step, batch in enumerate(current_train_dataloader): + if training_state is not None and int(self.rank) == 0: + training_state[f'episode_{idx}'][f'step_{step}'] = {} + training_state[f'episode_{idx}'][f'step_{step}']['data'] = batch['input_ids'][0].tolist()[:50] + training_state[f'episode_{idx}'][f'step_{step}']['learning_rate'] = self.lr_scheduler.state_dict()['_last_lr'] + with self.accelerator.accumulate(self.model): outputs = self.model(**batch) loss = outputs.loss @@ -182,6 +210,12 @@ def train(self): if step % log_step == 0: logger.info(f"train episode:[{idx}/{len(self.train_dataloader)}]\tstep:[{step}]\tloss:{loss}\tppl:{math.exp(loss)}\ttime:{time.time()-start}") start = time.time() + if training_state is not None and int(self.rank) == 0: + training_state[f'episode_{idx}'][f'step_{step}']['loss'] = loss.item() + training_state[f'episode_{idx}'][f'step_{step}']['ppl'] = math.exp(loss) + file_name = "stepwise_training_state_recovery" if self.starting_episode > 0 else "stepwise_training_state" + self._write_json(training_state, f"{save_state_path}/{file_name}.json") + if max_train_step_per_episode is not None: if step >= max_train_step_per_episode: break diff --git a/Finetune/workflow.yaml b/Finetune/workflow.yaml deleted file mode 100644 index 64316f19c..000000000 --- a/Finetune/workflow.yaml +++ /dev/null @@ -1,23 +0,0 @@ -general: - run_ray_cluster: True - run_hdfs: True - run_training_job: False - model_dir: /home/fanlilin/.cache/huggingface # all folders are NFS-shared across nodes - tmp_dir: /home/fanlilin/tmp - workspace_dir: /home/fanlilin/workspace/llm-ray-fanli - image_name: ray-llm:latest - -nodes: - - node: 10.165.9.53 - type: head - cores: 0-95 - - - node: 10.165.9.51 - type: worker - cores: 0-95 - user: fanlilin - password: Intel123 - -training_spec: - task_name: clm - config_path: ./llm_finetune_template.conf diff --git a/tools/pretrain_recovery_test/README.md b/tools/pretrain_recovery_test/README.md new file mode 100644 index 000000000..2bb88c325 --- /dev/null +++ b/tools/pretrain_recovery_test/README.md @@ -0,0 +1,14 @@ +# How to run End2End Validation of the Recovery Test? + +## Step 1: Set up Env +Please follow [this guide](../workload_in_containers/README.md) on how to set-up the container environment of this workload. When the containers are running, you can enter the container on head node using following command: +```bash +docker exec -it ray-leader bash +``` + +## Step 2: Start the script +You can use the `test_end2end.sh` to run the end-to-end validation for ray recovery mechanism. +```bash +cd tools/pretrain_recovery_test +./test_end2end.sh +``` \ No newline at end of file diff --git a/tools/pretrain_recovery_test/compare_logs.py b/tools/pretrain_recovery_test/compare_logs.py new file mode 100644 index 000000000..0c72e9d5a --- /dev/null +++ b/tools/pretrain_recovery_test/compare_logs.py @@ -0,0 +1,90 @@ +import json +import time +import argparse +import os + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +def read_json(json_file): + + with open(json_file) as file: + parsed_json = json.load(file) + + return parsed_json + +def get_all_episodes(parsed_json): + + parsed_json = dict(sorted(parsed_json.items())) + + return parsed_json.keys() + +def identify_common_episode(first_json, second_json): + + first_episodes = get_all_episodes(first_json) + second_episodes = get_all_episodes(second_json) + + common_episodes = list(set(first_episodes).intersection(second_episodes)) + + if len(common_episodes) == 0: + print("the 2 trainings have no episode overlapped. Check your json file!") + return -1 + elif len(common_episodes) > 1: + print("the 2 trainings have more than 1 overlapped episodes. Check your json files!") + return -1 + else: + return common_episodes[0] + +def compare_training_states(json1, json2, step): + + step = f'step_{step}' + + data_result = json1[step]['data'] == json2[step]['data'] + lr_result = json1[step]['learning_rate'] == json2[step]['learning_rate'] + loss_result = json1[step]['loss'] == json2[step]['loss'] + + return data_result, lr_result, loss_result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--file_path", + type=str, + default='/home/user/tmp/state', + help="absolute path of the json files" + ) + args = parser.parse_args() + + # read the json files + state1 = read_json(os.path.join(args.file_path, 'stepwise_training_state.json')) + state2 = read_json(os.path.join(args.file_path, 'stepwise_training_state_recovery.json')) + + # identify the overlapped episode + common_episode = identify_common_episode(state1, state2) + print(f"the common episode of 2 trainings: {common_episode}\n") + + # compare the different training states + data_result, lr_result, loss_result = compare_training_states(state1[common_episode], state2[common_episode], 0) + + # print out the detailed comparison results + print(f"Are the Data the same?\n{data_result}") + print(f"Are the Learning Rate the same?\n{lr_result}") + print(f"Are the Training Loss the same?\n{loss_result}") + + if data_result and lr_result and loss_result: + print(f"{bcolors.OKGREEN}\nrecovery tests all passed!{bcolors.ENDC}") + else: + print(f"{bcolors.FAIL}recovery test failed! check the detailed log above.{bcolors.ENDC}") + + +if __name__ == "__main__": + main() diff --git a/tools/pretrain_recovery_test/test_end2end.sh b/tools/pretrain_recovery_test/test_end2end.sh new file mode 100755 index 000000000..d7c1bde44 --- /dev/null +++ b/tools/pretrain_recovery_test/test_end2end.sh @@ -0,0 +1,52 @@ + +echo -e "\npreprocessing RedPajama-Data-1T-Sample..." +python ../redpajama_data_preprocessing/preprocess_data.py \ + --input togethercomputer/RedPajama-Data-1T-Sample \ + --load-batch-size 100000 \ + --max-length 2048 \ + --output-prefix processed_json \ + --output-format json \ + --num-samples 1024 \ + --parallelism 180 + +echo -e "\ncreating sample training data for recovery validation..." +python ../redpajama_data_preprocessing/group_files.py \ + --src-data-path /home/user/tmp/processed_json \ + --des-data-path /home/user/tmp/pretrain_data \ + --test + +echo -e "\nstart pre-training in the background..." +python ../../Finetune/main.py --config_path ../../Finetune/llm_pretrain_template.conf &> training_log.txt & + +echo -e "\nlet the training run for 8 mins..." +sleep 400 + +echo -e "\nmanually stop the training..." +pkill -f main.py + +echo -e "\nrestart the training..." +python ../../Finetune/main.py --config_path ../../Finetune/llm_pretrain_template.conf &> training_log2.txt & + +echo -e "\nlet the training run for 8 mins..." +sleep 400 + +echo -e "\nmanually stop the training..." +pkill -f main.py + +echo -e "\ncompare the results..." +python compare_logs.py --file_path /home/user/tmp/state + + + + + + + + + + + + + + + diff --git a/tools/redpajama_data_processing/README.md b/tools/redpajama_data_processing/README.md new file mode 100644 index 000000000..70ceb05c3 --- /dev/null +++ b/tools/redpajama_data_processing/README.md @@ -0,0 +1,91 @@ +# How to preprocess RedPajama Dataset on top of Ray? + +## Step 1: Set up Env +Please follow [this guide](../workload_in_containers/README.md) on how to set-up the container environment of this workload. When the containers are running, you can enter the container on head node using following command: +```bash +docker exec -it ray-leader bash +``` + +## Step 2: Run Preprocessing job +```python +cd tools/redpajama_data_processing +./run-dp.sh +``` +The `run-dp.sh` will run the `preprocess_data.py` under the hood. By default, it will preprocess the RedPajama Sample Data named `togethercomputer/RedPajama-Data-1T-Sample` from HuggingFace and save them under the temporary path of the container, namely the `/home/user/tmp`, with the folder name `processed_megatron`. But you can modify this script based on your need. The `preprocess_data.py` accept various flags. Run the following command to find out the possible params: +```bash +python preprocess_data.py -h +``` +If the data preprocessing gets finished, you will see the total execution time of this script in the command-line output. + +## Step 3: Merge Multiple Megatron Data Files [Optional] +For Megatron-format data, you may need to do an extra step to merge multiple data files in to one. You can use the `merge_datasets.py` as follows: + +```python +python merged_datasets.py --input --output-prefix +``` + +# Notes of running on full RedPajama dataset +Due to the large size and long processing time for the full RedPajama dataset, we provided an additional script optimized for full RedPajama dataset. Checkout the args to see the possible preprocessing options. +```python +python preprocess_full.py -h +``` +We recommend users to first download the RedPajama dataset to local disk and preprocess the data per data source. The following command will download the full dataset to your local disk: +```bash +wget 'https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt' +while read line; do + dload_loc=${line#https://data.together.xyz/redpajama-data-1T/v1.0.0/} + mkdir -p $(dirname $dload_loc) + wget "$line" -O "$dload_loc" +done < urls.txt +``` +Here is an example for starting full dataset preprocessing on the full dataset of source book: +```python +python preprocess_full.py \ + --input togethercomputer/RedPajama-Data-1T \ + --data-dir /home/user/local \ + --cache-dir /home/user/local \ + --source book \ + --output-prefix full_megatron \ + --cpu-per-worker 90 +``` + +## NLTK +If you are using the `split-sentences` flag, please make sure to run the following lines of code on each worker first before getting started with the data preprocessing for better performance. +```python +import nltk +nltk.download('punkt') +nltk.data.load("tokenizers/punkt/english.pickle") +``` +## Common_Crawl +Please note that the current `togethercomputer/RedPajama-Data-1T` data loading script is not compatible with HuggingFace streaming mode implementation for common_crawl datset. For common_crawl, you have to first download the dataset to local disk and preprocess the data in local mode. + +## Disk Space +For the full dataset mode, we need to have both + + + +# Troubleshooting +## Connection Error +When running data preprocessing on the full redpajama dataset using HuggingFace streaming API, you may encounter following errors: +```bash +aiohttp.client_exceptions.ClientPayloadError: Response payload is not completed +``` +or +```bash +MemoryError +(MapBatches(preprocess_megatron) pid=299104, ip=10.165.9.23) [2023-07-01 11:40:34,924 E 299104 300227] gcs_rpc_client.h:542: Failed to connect to GCS within 60 seconds. GCS may have been killed. It's either GCS is terminated by `ray stop` or is killed unexpectedly. If it is killed unexpectedly, see the log file gcs_server.out. https://docs.ray.io/en/master/ray-observability/ray-logging.html#logging-directory-structure. The program will terminate. +``` +or +```bash +aiohttp failed response.json() with status 500 +``` +or +``` +aiohttp.client_exceptions.ClientResponseError: 520, message='' +`` +All these errors are related to the network issue. The solution is to set a lower batch-size, e.g. 10000 or even lower. + + + + + diff --git a/tools/redpajama_data_processing/count_tokens.py b/tools/redpajama_data_processing/count_tokens.py new file mode 100644 index 000000000..d8e4859c8 --- /dev/null +++ b/tools/redpajama_data_processing/count_tokens.py @@ -0,0 +1,25 @@ +from indexed_dataset import MMapIndexedDataset +from transformers import AutoTokenizer + +import argparse + +# get the first argument as a file name, and an output file +parser = argparse.ArgumentParser() +parser.add_argument("file_name", help="the file name to read") +parser.add_argument("output_file", help="the file name to write") +args = parser.parse_args() + +ds = MMapIndexedDataset(args.file_name) + +tok = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + +num_tokens = [ + len(ds[i]) for i in range(len(ds)) +] + +# write it out to an output_file +with open(args.output_file, "w") as f: + for i in num_tokens: + f.write(f"{i}\n") + +print(f'Total tokens: {sum(num_tokens)}') \ No newline at end of file diff --git a/tools/redpajama_data_processing/group_files.py b/tools/redpajama_data_processing/group_files.py new file mode 100755 index 000000000..947f8c2b9 --- /dev/null +++ b/tools/redpajama_data_processing/group_files.py @@ -0,0 +1,84 @@ +""" +this script is for grouping data files into multiple data folders +with the episode names for pre-training on Ray. +""" + +import os +import argparse + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--src-data-path', + required=True, + default='/home/user/tmp/processed_json', + help="the path where the data to be grouped is located" + ) + parser.add_argument( + '--des-data-path', + required=True, + default='/home/user/tmp/train_data', + help="the path where the data to be grouped should be stored" + ) + parser.add_argument( + '--test', + default=False, + action='store_true', + help="if specified, test pretrain data folder will be created for recovery validation" + ) + parser.add_argument( + "--num-steps", + type=int, + default=10, + help="the number of folders to be created" + ) + args = parser.parse_args() + src_data_path = args.src_data_path + des_data_path = args.des_data_path + num_steps = args.num_steps + + json_files = [os.path.join(src_data_path, name) for name in os.listdir(src_data_path) if name.endswith('json')] + json_files = sorted(json_files) + num_files = len(json_files) + print(num_files) + + if args.test: + idx = 0 + for step in range(4): + + new_data_path = os.path.join(des_data_path, f"episode_{step}") + + if not os.path.exists(new_data_path): + os.makedirs(new_data_path, exist_ok=True) + + file_chunk = json_files[idx] + file_name = os.path.basename(file_chunk) + os.rename(file_chunk, os.path.join(new_data_path, file_name)) + + idx += 1 + print(f"folder {step} finalized!") + + else: + num_files_per_folder = num_files // num_steps + rest = num_files % num_steps + print(num_files_per_folder) + + true_steps = num_steps+1 if rest > 0 else num_steps + + for step in range(true_steps): + new_data_path = os.path.join(des_data_path, f"episode_{step}") + + if not os.path.exists(new_data_path): + os.makedirs(new_data_path, exist_ok=True) + + file_chunk = json_files[num_files_per_folder*step:num_files_per_folder*(step+1)] + + for file in file_chunk: + file_name = os.path.basename(file) + os.rename(file, os.path.join(new_data_path, file_name)) + + print(f"folder {step} finalized!") + + +if __name__ == "__main__": + main() diff --git a/tools/redpajama_data_processing/indexed_dataset.py b/tools/redpajama_data_processing/indexed_dataset.py new file mode 100644 index 000000000..5c2dc1547 --- /dev/null +++ b/tools/redpajama_data_processing/indexed_dataset.py @@ -0,0 +1,604 @@ +# Copyright (c) 2021, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +# copied from fairseq/fairseq/data/indexed_dataset.py +# Removed IndexedRawTextDataset since it relied on Fairseq dictionary +# other slight modifications to remove fairseq dependencies +# Added document index to index file and made it accessible. +# An empty sentence no longer separates documents. + +import os +import shutil +import struct +from functools import lru_cache +from itertools import accumulate + +import numpy as np +import torch + + +def print_rank_0(*message): + """If distributed is initialized print only on rank 0.""" + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == 0: + print(*message, flush=True) + else: + print(*message, flush=True) + + +def __best_fitting_dtype(vocab_size=None): + if vocab_size is not None and vocab_size < 65500: + return np.uint16 + else: + return np.int32 + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + else: + return None + else: + print(f"Dataset does not exist: {path}") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) + return None + + +def make_builder(out_file, impl, vocab_size=None): + if impl == "mmap": + return MMapIndexedDatasetBuilder( + out_file, dtype=__best_fitting_dtype(vocab_size) + ) + else: + return IndexedDatasetBuilder(out_file) + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print( + "Path should be a basename that both .idx and .bin can be appended to get full filenames." + ) + return None + if impl == "infer": + impl = infer_dataset_impl(path) + if impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +def dataset_exists(path, impl): + if impl == "mmap": + return MMapIndexedDataset.exists(path) + else: + return IndexedDataset.exists(path) + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def write_longs(f, a): + f.write(np.array(a, dtype=np.int64)) + + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float32, + 7: np.float64, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes.keys(): + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +def create_doc_idx(sizes): + doc_idx = [0] + for i, s in enumerate(sizes): + if s == 0: + doc_idx.append(i + 1) + return doc_idx + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " + "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists( + data_file_path(path) + ) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class IndexedDatasetBuilder(object): + element_sizes = { + np.uint8: 1, + np.int8: 1, + np.int16: 2, + np.int32: 4, + np.int64: 8, + np.float32: 4, + np.float64: 8, + } + + def __init__(self, out_file, dtype=np.int32): + self.out_file = open(out_file, "wb") + self.dtype = dtype + self.data_offsets = [0] + self.dim_offsets = [0] + self.sizes = [] + self.element_size = self.element_sizes[self.dtype] + self.doc_idx = [0] + + def add_item(self, np_array): + assert isinstance(np_array, np.ndarray) and np_array.dtype == self.dtype + bytes = self.out_file.write(np_array) + self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size) + for s in np_array.shape: + self.sizes.append(s) + self.dim_offsets.append(self.dim_offsets[-1] + len(np_array.shape)) + + def end_document(self): + self.doc_idx.append(len(self.sizes)) + + def merge_file_(self, another_file): + index = IndexedDataset(another_file) + assert index.dtype == self.dtype + + begin = self.data_offsets[-1] + for offset in index.data_offsets[1:]: + self.data_offsets.append(begin + offset) + self.sizes.extend(index.sizes) + begin = self.dim_offsets[-1] + for dim_offset in index.dim_offsets[1:]: + self.dim_offsets.append(begin + dim_offset) + + with open(data_file_path(another_file), "rb") as f: + while True: + data = f.read(1024) + if data: + self.out_file.write(data) + else: + break + + def finalize(self, index_file): + self.out_file.close() + index = open(index_file, "wb") + index.write(b"TNTIDX\x00\x00") + index.write(struct.pack(" pd.DataFrame: + return batch[batch['source'] == src] + + tmp_data = tokenized_data.map_batches(filter_source, batch_format='pandas', batch_size=None) + tmp_rows = tmp_data.count() + + out_file = f'{output_dir}/{src}.bin' + idx_file = f'{output_dir}/{src}.idx' + + for docs in tmp_data.iterator().iter_batches(batch_size=tmp_rows): + save_megatron(out_file, idx_file, docs) + else: + out_file = f'{output_dir}/all_redpajama.bin' + idx_file = f'{output_dir}/all_redpajama.idx' + + num_rows = tokenized_data.count() + for docs in tokenized_data.iterator().iter_batches(batch_size=num_rows): + save_megatron(out_file, idx_file, docs) + + +def build_megatron_distributed(tokenized_data, save_on_source, output_dir): + + if save_on_source: + def write_megatron(batch: pd.DataFrame) -> pd.DataFrame: + task_id = ray.get_runtime_context().get_task_id() + all_sources = batch['source'].unique() + for src in all_sources: + batch = batch[batch['source'] == src] + + if not os.path.exists(f"{output_dir}/{src}"): + os.makedirs(f"{output_dir}/{src}") + + out_file = f'{output_dir}/{src}/{task_id[:20]}.bin' + idx_file = f'{output_dir}/{src}/{task_id[:20]}.idx' + + save_megatron(out_file, idx_file, batch) + + return pd.DataFrame({'task_id': [task_id]}) + else: + def write_megatron(batch: pd.DataFrame) -> pd.DataFrame: + + task_id = ray.get_runtime_context().get_task_id() + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + out_file = f'{output_dir}/{task_id[:20]}.bin' + idx_file = f'{output_dir}/{task_id[:20]}.idx' + + save_megatron(out_file, idx_file, batch) + + return pd.DataFrame({'task_id': [task_id]}) + + task_ids = tokenized_data.map_batches(write_megatron, batch_format="pandas", batch_size=None) + task_ids.materialize() + + +def make_megatron_dataset(tokenized_data, save_on_host, save_on_source, output_dir): + + if save_on_host: + tokenized_data = tokenized_data.repartition(1) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + build_megatron_central(tokenized_data, save_on_source, output_dir) + + if save_on_source: + for src in ['arxiv', 'book', 'common_crawl', 'c4', 'wikipedia', 'stackexchange', 'github']: + build_megatron_central(tokenized_data, src, output_dir) + print(f"samples from {src} data source were written to disk!") + else: + build_megatron_central(tokenized_data, 'all', output_dir) + print(f"all samples from data source were written to disk") + + else: + build_megatron_distributed(tokenized_data, save_on_source, output_dir) + + +def main(): + args = get_args() + + output_dir = f'{args.output_path}/{args.output_prefix}' + max_length = args.max_length + drop_tokens = args.drop_tokens + + use_sample = 'sample' in args.input.lower() + text_field = 'text' + meta_field = 'meta' if use_sample else 'red_pajama_subset' + + tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', use_fast=True) + eos_tokens = tokenizer(args.eos_text)['input_ids'] + #pad_tokens = tokenizer('<|padding|>')['input_ids'] + + ray.init(address='auto') + pprint(ray.cluster_resources()) + num_nodes = len(ray.nodes()) + parallelism = num_nodes * args.cpu_per_worker + + def preprocess_json(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + + # load tokenizer took 0.15s + tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', use_fast=True) + samples = batch[text_field].tolist() + buffer = [] + token_chunked = [] + + for sample in samples: + encoded = tokenizer(sample, + truncation=False, + padding=False) + ids = encoded['input_ids'] + buffer = buffer + ids + eos_tokens + + while len(buffer) >= max_length: + concat_sample = buffer[:max_length] + token_chunked.append(concat_sample) + buffer = buffer[max_length:] + + if not drop_tokens: + #add padding to sequence shorter than max_length + buffer = buffer + [1]*(max_length - len(buffer)) + token_chunked[-1] = buffer + + return {"tokens": np.asarray(token_chunked)} + + if use_sample: + def preprocess_megatron(batch: Dict[str, np.ndarray]) -> pd.DataFrame: + + tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', use_fast=True) + + metas = batch[meta_field].tolist() + samples = batch[text_field].tolist() + + ids = [] + lens = [] + + for sample in samples: + encoded = tokenizer(sample, + truncation=False, + padding=False) + sample_id = encoded['input_ids'] + eos_tokens + ids.append(sample_id) + lens.append(len(sample)) + + sources = [] + for meta in metas: + meta_dict = eval(meta) + meta_keys = meta_dict.keys() + if 'arxiv_id' in meta_keys: + sources.append('arxiv') + elif 'pred_label_prob' in meta_keys: + sources.append('common_crawl') + elif 'short_book_title' in meta_keys: + sources.append('book') + elif 'title' in meta_keys: + if 'url' in meta_keys: + if 'wikipedia' in meta_dict['url']: + sources.append('wikipedia') + else: + sources.append('book') + else: + sources.append('book') + else: + sources.append(meta_dict['source']) + + return pd.DataFrame({"tokens": ids, "length": lens, "source": sources}) + else: + def preprocess_megatron(batch: Dict[str, np.ndarray]) -> pd.DataFrame: + + tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', use_fast=True) + + sources = batch[meta_field].tolist() + samples = batch[text_field].tolist() + + ids = [] + lens = [] + + for sample in samples: + encoded = tokenizer(sample, + truncation=False, + padding=False) + sample_id = encoded['input_ids'] + eos_tokens + ids.append(sample_id) + lens.append(len(sample)) + + return pd.DataFrame({"tokens": ids, "length": lens, "source": sources}) + + if args.stream: + if use_sample: + dataset = load_dataset(args.input, streaming=True) + else: + dataset = load_dataset(args.input, 'default', streaming=True) + + idx = 1 + for rows in dataset['train'].iter(batch_size=args.load_batch_size): + df = pd.DataFrame(rows) + ray_dataset = ray.data.from_pandas(df) + ray_dataset = ray_dataset.repartition(parallelism) + if args.output_format == 'json': + tokenized_data = ray_dataset.map_batches(preprocess_json, batch_format="numpy", batch_size=None) + + total_rows = tokenized_data.count() + num_partition = total_rows//args.num_samples if not args.save_on_host else 1 + tokenized_data = tokenized_data.repartition(num_partition) + + if args.save_on_host: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + index = 0 + for batch in tokenized_data.iterator().iter_batches(batch_size=args.num_samples): + batch.to_json(f'{output_dir}/partition_{index}.json', orient="records", lines=True) + index += 1 + else: + tokenized_data.write_json(output_dir) + elif args.output_format == 'megatron': + tokenized_data = ray_dataset.map_batches(preprocess_megatron, batch_format="numpy", batch_size=None) + make_megatron_dataset(tokenized_data, args.save_on_host, args.save_on_source, output_dir) + + idx += 1 + if idx % 100 == 0: + print(f"{idx} * {args.load_batch_size} samples are written to disk.") + + else: + raw_dataset = load_dataset(args.input)['train'] + ray_dataset = ray.data.from_huggingface(raw_dataset) + # create multiple data blocks + ray_dataset = ray_dataset.repartition(parallelism, shuffle=True) + + if args.output_format == 'json': + fn_name = "preprocess_json" + elif args.output_format == 'megatron': + fn_name = "preprocess_megatron" + + tokenized_data = ray_dataset.map_batches(eval(fn_name), batch_format="numpy", batch_size=None) + total_rows = tokenized_data.count() + print(f"Total number of rows after processing: {total_rows}") + + if args.output_format == 'json': + num_partition = total_rows//args.num_samples if not args.save_on_host else 1 + tokenized_data = tokenized_data.repartition(num_partition) + if args.save_on_host: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + index = 0 + for batch in tokenized_data.iterator().iter_batches(batch_size=args.num_samples): + batch.to_json(f'{output_dir}/partition_{index}.json', orient="records", lines=True) + index += 1 + else: + tokenized_data.write_json(output_dir) + elif args.output_format == 'megatron': + make_megatron_dataset(tokenized_data, args.save_on_host, args.save_on_source, output_dir) + + +if __name__ == "__main__": + start = time.time() + main() + end = time.time() + print(f"\nthis script took {end-start}s.") + diff --git a/tools/redpajama_data_processing/preprocess_full.py b/tools/redpajama_data_processing/preprocess_full.py new file mode 100644 index 000000000..b72b614fe --- /dev/null +++ b/tools/redpajama_data_processing/preprocess_full.py @@ -0,0 +1,248 @@ +""" +this script is for processing redpajama full data on streaming mode and saving to megatron-format. +Different to `preprocess_data.py`, this script is mainly for performance. +""" + +import os +import time +import argparse +from pprint import pprint +from typing import Dict, List + +import nltk +import ray +import ray.data +import pandas as pd +import numpy as np +from datasets import load_dataset +from transformers import ( + CONFIG_MAPPING, + MODEL_MAPPING, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + SchedulerType, + default_data_collator, + get_scheduler, +) + +from indexed_dataset import MMapIndexedDatasetBuilder + + +class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars): + + _period_context_fmt = r""" + \S* # some word material + %(SentEndChars)s # a potential sentence ending + \s* # <-- THIS is what I changed + (?=(?P + %(NonWord)s # either other punctuation + | + (?P\S+) # <-- Normally you would have \s+ here + ))""" + + +class IdentitySplitter(object): + def tokenize(self, *text): + return text + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="input data") + group.add_argument( + "--input", + type=str, + required=True, + help="Name of the dataset repository,e.g. togethercomputer/RedPajama-Data-1T" + ) + group.add_argument( + "--data-dir", + type=str, + required=False, + help="for local mode, you need to provide local RedPajama dataset repository, e.g. /home/user/local" + ) + group.add_argument( + "--cache-dir", + type=str, + default='/root/.cache', + help="Hugging Face Cache dir, where the hugging face dataset it stored" + ) + group.add_argument( + "--source", + type=str, + default='default', + help="data source of the redpajama data, please choose from \ + ['arxiv', 'book', 'c4', 'common_crawl', 'github', 'stackexchange', 'wikipedia'] \ + by default the value is set to default" + ) + group.add_argument('--split-sentences', action='store_true', + help='Split documents into sentences.') + group.add_argument('--keep-newlines', action='store_true', + help='Keep newlines between sentences when splitting.') + group.add_argument( + '--stream', + default=False, + action='store_true', + help="whether to load data from hugging face using streaming mode" + ) + group.add_argument( + "--load-batch-size", type=int, default=1000, help="only needed if you use streaming mode to read data from hugging face" + ) + group = parser.add_argument_group(title="output data") + group.add_argument( + "--output-prefix", + type=str, + required=True, + help="Path to binary output file without suffix", + ) + group = parser.add_argument_group(title="runtime") + group.add_argument( + "--cpu-per-worker", type=int, default=1, help="Number of CPUs to use per worker" + ) + + args = parser.parse_args() + args.output_path = '/home/user/local' + return args + +def save_megatron(out_file, idx_file, docs): + data_builder = MMapIndexedDatasetBuilder(out_file, dtype=np.uint16) + + for doc in docs['tokens']: + data_builder.add_item(np.array(doc, dtype=data_builder.dtype)) + + data_builder.end_document() + data_builder.finalize(idx_file) + + +def main(): + args = get_args() + + output_dir = f'{args.output_path}/{args.output_prefix}' + data_source = args.source + eos_tokens = [0] + keep_newlines = args.keep_newlines + split_sentences = args.split_sentences + cache_dir = args.cache_dir + ray.init(address='auto') + pprint(ray.cluster_resources()) + num_nodes = len(ray.nodes()) + parallelism = num_nodes * args.cpu_per_worker + + def preprocess_megatron_sentence(batch: Dict[str, np.ndarray]) -> pd.DataFrame: + + task_id = ray.get_runtime_context().get_task_id() + + tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', use_fast=True) + + splitter = nltk.data.load("tokenizers/punkt/english.pickle") + + if keep_newlines: + splitter = nltk.tokenize.punkt.PunktSentenceTokenizer( + train_text = splitter._params, + lang_vars = CustomLanguageVars()) + else: + splitter = splitter + + samples = batch['text'].tolist() + + ids = [] + lens = [] + + for sample in samples: + sample_ids = [] + for sentence in splitter.tokenize(sample): + + encoded = tokenizer(sentence, + truncation=False, + padding=False)['input_ids'] + if len(encoded) > 0: + sample_ids = sample_ids + encoded + + sample_ids = sample_ids + eos_tokens + ids.append(sample_ids) + lens.append(len(sample)) + + batch = pd.DataFrame({"tokens": ids, "length": lens}) + + if not os.path.exists(f"{output_dir}/{data_source}"): + os.makedirs(f"{output_dir}/{data_source}") + + out_file = f'{output_dir}/{data_source}/{task_id[:20]}.bin' + idx_file = f'{output_dir}/{data_source}/{task_id[:20]}.idx' + + save_megatron(out_file, idx_file, batch) + + return pd.DataFrame({'task_id': [task_id]}) + + + def preprocess_megatron(batch: Dict[str, np.ndarray]) -> pd.DataFrame: + + task_id = ray.get_runtime_context().get_task_id() + + tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', use_fast=True) + + samples = batch['text'].tolist() + + ids = [] + lens = [] + + for sample in samples: + encoded = tokenizer(sample, + truncation=False, + padding=False) + sample_id = encoded['input_ids'] + eos_tokens + ids.append(sample_id) + lens.append(len(sample)) + + batch = pd.DataFrame({"tokens": ids, "length": lens}) + + if not os.path.exists(f"{output_dir}/{data_source}"): + os.makedirs(f"{output_dir}/{data_source}") + + out_file = f'{output_dir}/{data_source}/{task_id[:20]}.bin' + idx_file = f'{output_dir}/{data_source}/{task_id[:20]}.idx' + + save_megatron(out_file, idx_file, batch) + + return pd.DataFrame({'task_id': [task_id]}) + + if split_sentences: + fn_name = 'preprocess_megatron_sentence' + else: + fn_name = 'preprocess_megatron' + + if args.stream: + dataset = load_dataset(args.input, data_source, streaming=True)['train'] + + idx = 1 + for rows in dataset.iter(batch_size=args.load_batch_size): + print("-----------------------------") + df = pd.DataFrame(rows) + ray_dataset = ray.data.from_pandas(df) + ray_dataset = ray_dataset.repartition(parallelism) + + tokenized_data = ray_dataset.map_batches(eval(fn_name), batch_format="numpy", batch_size=None) + tokenized_data.materialize() + + if idx % 10 == 0: + print(f"{idx} * {args.load_batch_size} samples were written to disk.") + idx += 1 + print("============================") + else: + os.environ["RED_PAJAMA_DATA_DIR"] = args.data_dir + ds = load_dataset(args.input, args.source, cache_dir=cache_dir)['train'] + ray_dataset = ray.data.from_huggingface(ds) + # create multiple data blocks + ray_dataset = ray_dataset.repartition(parallelism) + + tokenized_data = ray_dataset.map_batches(eval(fn_name), batch_format="numpy", batch_size=None) + tokenized_data.materialize() + + +if __name__ == "__main__": + start = time.time() + main() + end = time.time() + print(f"\nthis script took {end-start}s.") + diff --git a/tools/redpajama_data_processing/run-dp.sh b/tools/redpajama_data_processing/run-dp.sh new file mode 100755 index 000000000..0ab6a2ca1 --- /dev/null +++ b/tools/redpajama_data_processing/run-dp.sh @@ -0,0 +1,10 @@ +echo -e "\n distributed save with data source" +python preprocess_data.py \ + --input togethercomputer/RedPajama-Data-1T-Sample \ + --load-batch-size 100000 \ + --max-length 2048 \ + --output-prefix processed_megatron \ + --output-format megatron \ + --num-samples 1024 \ + --parallelism 180 \ + --save-on-source diff --git a/Finetune/Dockerfile b/tools/workload_in_containers/Dockerfile old mode 100644 new mode 100755 similarity index 86% rename from Finetune/Dockerfile rename to tools/workload_in_containers/Dockerfile index 66791072b..0d2e847a2 --- a/Finetune/Dockerfile +++ b/tools/workload_in_containers/Dockerfile @@ -15,7 +15,7 @@ RUN conda update --all RUN conda install mkl mkl-include RUN pip install --upgrade pip -RUN pip install astunparse numpy ninja pyyaml setuptools cmake typing_extensions six requests dataclasses datasets evaluate +RUN pip install astunparse nltk numpy ninja pyyaml setuptools cmake typing_extensions six requests dataclasses datasets evaluate scikit-image dm-tree gymnasium ENV CMAKE_PREFIX_PATH=/opt/conda @@ -46,7 +46,8 @@ RUN git clone https://github.com/huggingface/transformers && \ pip install . # install ray-related libs -RUN pip install -U "ray[default]" && pip install --pre raydp && pip install "ray[tune]" tabulate tensorboard +RUN pip install -U "ray[default] @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp310-cp310-manylinux2014_x86_64.whl" +RUN pip install --pre raydp && pip install "ray[tune]" tabulate tensorboard # install java RUN wget --no-check-certificate -q https://repo.huaweicloud.com/java/jdk/8u201-b09/jdk-8u201-linux-x64.tar.gz && \ diff --git a/tools/workload_in_containers/README.md b/tools/workload_in_containers/README.md new file mode 100644 index 000000000..4abf9d919 --- /dev/null +++ b/tools/workload_in_containers/README.md @@ -0,0 +1,22 @@ +# How to run this workload inside Docker containers? + +## Step 1: Build Docker Images +```bash +# on head node +./build-image.sh +# save docker image +docker save -o ray-image.tar llm-ray:latest +# copy over to worker nodes, this is an optional step if all your cluster nodes are NFS-shared +scp ray-image.tar : +# on worker nodes +docker load -i ray-image.tar +``` + +## Step 2: Specify Workload Config File +Before launch the workload, you need to configure the `workflow.yaml`. The parameters in the config file are self-explained and you need to set your own values, e.g. the cluster node name, user name and the password. Please note that setting `run_training_job` to `True` will start the training job automatically when the containers are up. Therefore, if you only need the dev environment of the workload, remember to set the `run_training_job` to `False`. + + +## Step 3: Launch Workload +```bash +python launch_workload.py -w workload.yaml +``` \ No newline at end of file diff --git a/Finetune/build-image.sh b/tools/workload_in_containers/build-image.sh similarity index 90% rename from Finetune/build-image.sh rename to tools/workload_in_containers/build-image.sh index ff067deeb..46b5f745c 100755 --- a/Finetune/build-image.sh +++ b/tools/workload_in_containers/build-image.sh @@ -3,7 +3,7 @@ dockerfile=Dockerfile docker build \ -f ${dockerfile} . \ - -t ray-llm:latest \ + -t llm-ray:latest \ --network=host \ --build-arg http_proxy=${http_proxy} \ --build-arg https_proxy=${https_proxy} \ diff --git a/Finetune/configs/core-site.xml b/tools/workload_in_containers/configs/core-site.xml similarity index 95% rename from Finetune/configs/core-site.xml rename to tools/workload_in_containers/configs/core-site.xml index 7bfe0cbcc..2630429aa 100644 --- a/Finetune/configs/core-site.xml +++ b/tools/workload_in_containers/configs/core-site.xml @@ -18,7 +18,7 @@ hadoop.tmp.dir - /home/user/tmp/hdfs/tmp + /home/user/local/tmp/hdfs/tmp diff --git a/Finetune/configs/hadoop-env.sh b/tools/workload_in_containers/configs/hadoop-env.sh similarity index 100% rename from Finetune/configs/hadoop-env.sh rename to tools/workload_in_containers/configs/hadoop-env.sh diff --git a/Finetune/configs/hdfs-site.xml b/tools/workload_in_containers/configs/hdfs-site.xml similarity index 94% rename from Finetune/configs/hdfs-site.xml rename to tools/workload_in_containers/configs/hdfs-site.xml index 778f41b37..a1901fd20 100644 --- a/Finetune/configs/hdfs-site.xml +++ b/tools/workload_in_containers/configs/hdfs-site.xml @@ -24,12 +24,12 @@ dfs.namenode.name.dir - /home/user/tmp/hdfs/nn + /home/user/local/tmp/hdfs/nn dfs.datanode.data.dir - /home/user/tmp/hdfs/dn + /home/user/local/tmp/hdfs/dn diff --git a/Finetune/configs/mapred-site.xml b/tools/workload_in_containers/configs/mapred-site.xml similarity index 100% rename from Finetune/configs/mapred-site.xml rename to tools/workload_in_containers/configs/mapred-site.xml diff --git a/Finetune/configs/workers b/tools/workload_in_containers/configs/workers similarity index 100% rename from Finetune/configs/workers rename to tools/workload_in_containers/configs/workers diff --git a/Finetune/configs/yarn-site.xml b/tools/workload_in_containers/configs/yarn-site.xml similarity index 100% rename from Finetune/configs/yarn-site.xml rename to tools/workload_in_containers/configs/yarn-site.xml diff --git a/Finetune/launch_workflow.py b/tools/workload_in_containers/launch_workload.py similarity index 92% rename from Finetune/launch_workflow.py rename to tools/workload_in_containers/launch_workload.py index 8d372f92e..7fdbfe7d4 100644 --- a/Finetune/launch_workflow.py +++ b/tools/workload_in_containers/launch_workload.py @@ -10,7 +10,8 @@ def __init__(self, cfg): self.run_hdfs = workflow_config['general']["run_hdfs"] self.run_training_job = workflow_config['general']['run_training_job'] self.model_dir = workflow_config['general']['model_dir'] - self.tmp_dir = workflow_config['general']['tmp_dir'] + self.tmp_dir = workflow_config['general']['nfs_dir'] + self.local_dir = workflow_config['general']['local_dir'] self.workspace_dir = workflow_config['general']['workspace_dir'] self.image_name = workflow_config['general']['image_name'] self.cluster_config = workflow_config['nodes'] @@ -51,6 +52,7 @@ def startup_ray_cluster(self) : -w {self.workspace_dir} \ -m {self.model_dir} \ -t {self.tmp_dir} \ + -l {self.local_dir} \ -i {self.image_name}') if ret == 0 : print("Successfully startup the ray head!") @@ -68,6 +70,7 @@ def startup_ray_cluster(self) : -w {self.workspace_dir} \ -m {self.model_dir} \ -t {self.tmp_dir} \ + -l {self.local_dir} \ -u {node["user"]} \ -p {node["password"]} \ -i {self.image_name} \ @@ -81,7 +84,7 @@ def startup_ray_cluster(self) : def startup_hdfs(self): - ret = os.system(f'docker exec ray-leader bash run-hdfs.sh -m {self.head_ip} -w {self.worker_ips}') + ret = os.system(f'docker exec ray-leader bash tools/workload_in_containers/run-hdfs.sh -m {self.head_ip} -w {self.worker_ips}') if ret == 0: print("Successfully startup HDFS!") @@ -116,6 +119,7 @@ def process(self): if self.run_training_job: self.run_training() + def parse_cmd(): args = argparse.ArgumentParser(description='parse arguments', epilog=' ', formatter_class=argparse.RawTextHelpFormatter) diff --git a/Finetune/run-hdfs.sh b/tools/workload_in_containers/run-hdfs.sh similarity index 78% rename from Finetune/run-hdfs.sh rename to tools/workload_in_containers/run-hdfs.sh index 50ff002d2..db074c02c 100755 --- a/Finetune/run-hdfs.sh +++ b/tools/workload_in_containers/run-hdfs.sh @@ -26,12 +26,12 @@ then fi echo -e "\ncopy hadoop configuration files..." -cp configs/hadoop-env.sh $HADOOP_HOME/etc/hadoop/hadoop-env.sh -cp configs/hdfs-site.xml $HADOOP_HOME/etc/hadoop/hdfs-site.xml && \ -cp configs/core-site.xml $HADOOP_HOME/etc/hadoop/core-site.xml && \ -cp configs/mapred-site.xml $HADOOP_HOME/etc/hadoop/mapred-site.xml && \ -cp configs/yarn-site.xml $HADOOP_HOME/etc/hadoop/yarn-site.xml && \ -cp configs/workers $HADOOP_HOME/etc/hadoop/workers +cp /home/user/workspace/tools/workload_in_containers/configs/hadoop-env.sh $HADOOP_HOME/etc/hadoop/hadoop-env.sh +cp /home/user/workspace/tools/workload_in_containers/configs/hdfs-site.xml $HADOOP_HOME/etc/hadoop/hdfs-site.xml && \ +cp /home/user/workspace/tools/workload_in_containers/configs/core-site.xml $HADOOP_HOME/etc/hadoop/core-site.xml && \ +cp /home/user/workspace/tools/workload_in_containers/configs/mapred-site.xml $HADOOP_HOME/etc/hadoop/mapred-site.xml && \ +cp /home/user/workspace/tools/workload_in_containers/configs/yarn-site.xml $HADOOP_HOME/etc/hadoop/yarn-site.xml && \ +cp /home/user/workspace/tools/workload_in_containers/configs/workers $HADOOP_HOME/etc/hadoop/workers sed -i 's@hadoop-leader@'"$master_ip"'@' $HADOOP_HOME/etc/hadoop/core-site.xml && \ diff --git a/Finetune/run-ray-cluster.sh b/tools/workload_in_containers/run-ray-cluster.sh similarity index 91% rename from Finetune/run-ray-cluster.sh rename to tools/workload_in_containers/run-ray-cluster.sh index 29420aa3c..a4e39d687 100755 --- a/Finetune/run-ray-cluster.sh +++ b/tools/workload_in_containers/run-ray-cluster.sh @@ -4,7 +4,7 @@ OPTIND=1 WORKSPACE_DIR=/home/user/workspace MODEL_DIR=/root/.cache/huggingface/ TMP_DIR=/home/user/tmp - +LOCAL_DIR=/home/user/local RAM=$(awk '/^Mem/ {print $2}' <(free -mh)) RAM=${RAM/Gi/} @@ -32,6 +32,8 @@ usage() { echo " hugging face model directory" echo " -t tmp_dir" echo " temporary directory" + echo " -l local_dir" + echo " non-nfs directoy which is only accessible by each worker" echo " -u user" echo " user name for access worker server" echo " -p password" @@ -47,7 +49,7 @@ usage() { echo "" } -while getopts "h?r:a:c:f:w:m:t:u:p:i:s:" opt; do +while getopts "h?r:a:c:f:w:m:t:l:u:p:i:s:" opt; do case "$opt" in h|\?) usage @@ -67,6 +69,8 @@ while getopts "h?r:a:c:f:w:m:t:u:p:i:s:" opt; do ;; t) tmp_dir=$OPTARG ;; + l) local_dir=$OPTARG + ;; u) user=$OPTARG ;; p) password=$OPTARG @@ -99,12 +103,13 @@ if [[ $run_type = "startup_head" ]]; then -v ${workspace}:${WORKSPACE_DIR} \ -v ${model_dir}:${MODEL_DIR} \ -v ${tmp_dir}:${TMP_DIR} \ + -v ${local_dir}:${LOCAL_DIR} \ -w /home/user/workspace \ --shm-size ${shm_size} \ --cpuset-cpus=${cores_range} \ --name ray-leader ${image} - docker exec ray-leader /bin/bash -c "ray start --head --node-ip-address=${head_address} --dashboard-port=9999 --ray-debugger-external --temp-dir=/home/user/tmp/ray" + docker exec ray-leader /bin/bash -c "ray start --head --node-ip-address=${head_address} --dashboard-port=9999 --ray-debugger-external --temp-dir=/home/user/local/ray" elif [[ $run_type = "startup_worker" ]]; then @@ -121,6 +126,7 @@ elif [[ $run_type = "startup_worker" ]]; then -v ${workspace}:${WORKSPACE_DIR} \ -v ${model_dir}:${MODEL_DIR} \ -v ${tmp_dir}:${TMP_DIR} \ + -v ${local_dir}:${LOCAL_DIR} \ --cpuset-cpus=${cores_range} \ --network host \ -w /home/user/workspace \ @@ -131,7 +137,7 @@ elif [[ $run_type = "startup_worker" ]]; then EOF sshpass -p $password ssh -o StrictHostKeychecking=no $user@$worker_ip bash << EOF - docker exec $worker_name /bin/bash -c "ray start --address=${head_address} --ray-debugger-external --temp-dir=/home/user/tmp/ray" + docker exec $worker_name /bin/bash -c "ray start --address=${head_address} --ray-debugger-external" EOF elif [[ $run_type = "stop_ray" ]]; then diff --git a/tools/workload_in_containers/workload.yaml b/tools/workload_in_containers/workload.yaml new file mode 100755 index 000000000..40b35c8bc --- /dev/null +++ b/tools/workload_in_containers/workload.yaml @@ -0,0 +1,36 @@ +general: + run_ray_cluster: True + run_hdfs: False + run_training_job: False + model_dir: /home/user/.cache/huggingface # all folders are NFS-shared across nodes + nfs_dir: /home/user/tmp + local_dir: /localdisk/user + workspace_dir: /home/user/workspace/llm-ray + image_name: llm-ray:latest + +nodes: + - node: 10.165.9.53 + type: head + cores: 0-95 + + - node: 10.165.9.52 + type: worker + cores: 0-95 + user: $user + password: $paassword + + - node: 10.165.9.164 + type: worker + cores: 0-95 + user: $user + password: $paassword + + - node: 10.165.9.23 + type: worker + cores: 0-95 + user: $user + password: $paassword + +training_spec: + task_name: clm + config_path: ../../Finetune/llm_pretrain_template.conf \ No newline at end of file