diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index c08b60151..a73cf2a68 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -24,6 +24,7 @@ from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt2_dataset import GPT2Dataset from megatron.data.pairwise_dataset import PairwiseDataset +from megatron.data.online_dataset import OnlineDataset from megatron.data.samplers import DistributedBatchSampler @@ -532,7 +533,56 @@ def build_train_valid_test_data_loaders(neox_args): pipe_load = True # Data loader only on rank 0 of each model parallel group. - if mpu.get_model_parallel_rank() == 0 and pipe_load: + if ( + pipe_load + and (neox_args.dataset_impl == "online") + and (mpu.get_model_parallel_rank() == 0) + ): + # Can skip most of the work... + train_iters = neox_args.train_iters + eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters + test_iters = neox_args.eval_iters + # Build datasets... + print( + f"train_iters: {train_iters}, eval_iters: {eval_iters}, test_iters: {test_iters}" + ) + train_datasets = OnlineDataset( + leave_one_out=neox_args.reinforce_leave_one_out, + data_split="train", + num_samples=train_iters * neox_args.train_batch_size, + seq_length=neox_args.seq_length, + dataserver_ips=neox_args.online_dataserver_ips, + dataserver_ports=neox_args.online_dataserver_ports, + ) + valid_datasets = OnlineDataset( + leave_one_out=neox_args.reinforce_leave_one_out, + data_split="valid", + num_samples=eval_iters * neox_args.train_batch_size, + seq_length=neox_args.seq_length, + dataserver_ips=neox_args.online_dataserver_ips, + dataserver_ports=neox_args.online_dataserver_ports, + ) + test_datasets = OnlineDataset( + leave_one_out=neox_args.reinforce_leave_one_out, + data_split="test", + num_samples=test_iters * neox_args.train_batch_size, + seq_length=neox_args.seq_length, + dataserver_ips=neox_args.online_dataserver_ips, + dataserver_ports=neox_args.online_dataserver_ports, + ) + # print length of datasets + # Build dataloders. + train_dataloader = make_data_loader(train_datasets, neox_args=neox_args) + valid_dataloader = make_data_loader(valid_datasets, neox_args=neox_args) + test_dataloader = make_data_loader(test_datasets, neox_args=neox_args) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and neox_args.train_iters > 0 + do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 + do_test = test_dataloader is not None and neox_args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) + elif mpu.get_model_parallel_rank() == 0 and pipe_load: # Number of train/valid/test samples. if neox_args.train_iters is not None: train_iters = neox_args.train_iters diff --git a/megatron/data/online_dataset.py b/megatron/data/online_dataset.py new file mode 100644 index 000000000..9a12c1875 --- /dev/null +++ b/megatron/data/online_dataset.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Online dataset.""" +from typing import Union, List + +import numpy as np +import torch +import torch.utils.data +import socket +import pickle +from megatron.mpu.initialize import get_data_parallel_rank + + +class OnlineDataset(torch.utils.data.Dataset): + def __init__( + self, + num_samples, + seq_length, + leave_one_out=False, + data_split="train", + dataserver_ips: Union[str, List[str]] = "localhost", + dataserver_ports: Union[int, List[int]] = 10000, + ): + self.num_samples = num_samples + self.global_rank = get_data_parallel_rank() + self.leave_one_out = leave_one_out + self.reward_buffer = [] + self.online_batching_data = [] + self.data_split = data_split + self.seq_length = seq_length + self.dataserver_ips = dataserver_ips + self.dataserver_ports = dataserver_ports + + def __len__(self): + # dummy value since it's decided by the Online Trainer + return self.num_samples + + def update_online_batches(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if isinstance(self.dataserver_ips, str): + ipaddr = self.dataserver_ips + else: + ipaddr = self.dataserver_ips[self.global_rank] + if isinstance(self.dataserver_ports, int): + # simply add over the global rank + port = self.dataserver_ports + else: + # in case we want to use different ports for different ranks, e.g. per machine sampling + port = self.dataserver_ports[self.global_rank] + print(f"Connecting to {ipaddr}:{port}") + s.connect((ipaddr, port)) + s.send(self.data_split.encode()) + data = b"" + while True: + chunk = s.recv(4096) + if not chunk: + break + data += chunk + batch_data = pickle.loads(data) + s.close() + print(f"Received {len(batch_data)} samples from the server.") + for data in batch_data: + if self.leave_one_out: + rewards = list() + for i in range(len(data["rewards"])): + rewards.append( + data["rewards"][i] + - np.mean( + [ + data["rewards"][j] + for j in range(len(data["rewards"])) + if j != i + ] + ) + ) + data["raw_rewards"] = data["rewards"] + data["rewards"] = rewards + else: + moving_average = 0 + if len(self.reward_buffer) > 0: + moving_average = np.mean(self.reward_buffer) + self.reward_buffer.append(np.mean(data["rewards"])) + if len(self.reward_buffer) > 100: + self.reward_buffer.pop(0) + # For metrics... + data["raw_rewards"] = data["rewards"] + data["rewards"] = [r - moving_average for r in data["rewards"]] + for i in range(len(data["completions"])): + self.online_batching_data.append( + [ + data["prefix"], + data["completions"][i], + data["rewards"][i], + data["raw_rewards"][i], + ] + ) + + def __getitem__(self, idx): + if len(self.online_batching_data) == 0: + self.update_online_batches() + batch = self.online_batching_data.pop(0) + text = batch[0] + batch[1] + label = [-100 for _ in batch[0]] + batch[1] + # +1 because of causal masking + if len(text) <= self.seq_length: + text = text + [0] * ((self.seq_length + 1) - len(text)) + label = label + [-100] * ((self.seq_length + 1) - len(label)) + return { + "text": np.array(text, dtype=np.int64), + "label": np.array(label, dtype=np.int64), + "reward": np.array([batch[2]], dtype=np.float32), + "raw_reward": np.array([batch[3]], dtype=np.float32), + } diff --git a/megatron/model/weight_server.py b/megatron/model/weight_server.py new file mode 100644 index 000000000..987db3434 --- /dev/null +++ b/megatron/model/weight_server.py @@ -0,0 +1,64 @@ +from typing import Union, List + +import torch +import socket +import pickle + + +def send_tensor(state_dict_key, data, sock, end: bool): + storage = data.storage() + ( + storage_device, + storage_handle, + storage_size_bytes, + storage_offset_bytes, + ref_counter_handle, + ref_counter_offset, + event_handle, + event_sync_required, + ) = storage._share_cuda_() + sock.send( + pickle.dumps( + { + "state_dict_key": state_dict_key, + "dtype": data.dtype, + "tensor_size": data.shape, + "tensor_stride": data.stride(), + "tensor_offset": data.storage_offset(), # !Not sure about this one. + "storage_cls": type(storage), + "storage_device": storage_device, + "storage_handle": storage_handle, + "storage_size_bytes": storage_size_bytes, + "storage_offset_bytes": storage_offset_bytes, + "requires_grad": False, + "ref_counter_handle": ref_counter_handle, + "ref_counter_offset": ref_counter_offset, + "event_handle": event_handle, + "event_sync_required": event_sync_required, + "end": end, + } + ) + ) + + +def send_state_dict(state_dict, sock): + for i, key in enumerate(state_dict.keys()): + print(key) + end = i == len(state_dict.keys()) - 1 + send_tensor(key, state_dict[key], sock, end) + sock.recv(4096) + + +def start_server(model, ports: Union[int, List[int]] = 6000): + global_rank = torch.distributed.get_rank() + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if type(ports) == int: + port = ports + global_rank + else: + port = ports[global_rank] + s.bind(("localhost", port)) + s.listen(1) + conn, addr = s.accept() + state_dict = model.state_dict() + send_state_dict(state_dict, conn) + conn.close() diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index deca060b5..9c8d3635f 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -502,6 +502,28 @@ class NeoXArgsModel(NeoXArgsTemplate): Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + serve_model_weights: bool = False + """ + If true, serve model weight pointers over a socket connection + """ + + weight_server_port: Union[int, List[int]] = 6000 + """ + Port(s) to serve model weights over + If an integer is provided, the port for each GPU will be 6000 + global rank + If a list is provided, the ports will be used in order, e.g. rank0 will be weight_server_port[0] + """ + + online_dataserver_ips: Union[str, List[str]] = "localhost" + """ + ip addresses to connect to for online data serving, defaults to localhost + """ + + online_dataserver_ports: Union[int, List[int]] = 10000 + """ + Port(s) to connect to for online data serving, defaults to 10000 + """ + te_columnparallel: bool = False """ Use TransformerEngine for RowParallelLinear layer. @@ -1132,14 +1154,14 @@ class NeoXArgsTraining(NeoXArgsTemplate): warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets """ - dataset_impl: Literal["gpt2", "pairwise"] = "gpt2" + dataset_impl: Literal["gpt2", "pairwise", "online"] = "gpt2" """ - Dataset implementation, can be one of "gpt2" or "pairwise" + Dataset implementation, can be one of "gpt2", "pairwise", or "online" """ - train_impl: Literal["normal", "dpo", "rm", "kto"] = "normal" + train_impl: Literal["normal", "dpo", "rm", "kto", "reinforce"] = "normal" """ - Training implementation, can be one of "normal", "dpo", "kto", or "rm" + Training implementation, can be one of "normal", "dpo", "kto", "reinforce", or "rm" """ dpo_fp32: bool = True @@ -1184,6 +1206,27 @@ class NeoXArgsTraining(NeoXArgsTemplate): Beta value for KTO """ + fp32_reinforce: bool = True + """ + Whether to cast logits to fp32 for Reinforce loss calculation. + """ + + kl_impl: Literal["abs", "mse", "kl", "full"] = "mse" + """ + KL divergence implementation, can be one of "abs", "mse", "kl", or "full" + """ + + kl_div_beta: float = 0.1 + """ + Beta value for KL divergence in Reinforce loss calculation. + """ + + reinforce_leave_one_out: bool = False + """ + Whether to use reinforce leave one out for training + (from https://arxiv.org/abs/2402.14740 and https://api.semanticscholar.org/CorpusID:198489118) + """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. diff --git a/megatron/training.py b/megatron/training.py index 1965faea8..3def74860 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -62,6 +62,7 @@ get_total_params, CharCounter, ) +from megatron.model.weight_server import start_server from megatron.model.gpt2_model import cross_entropy from megatron.mpu import vocab_parallel_cross_entropy @@ -253,6 +254,13 @@ def pretrain(neox_args): ) timers("model and optimizer").stop() + if neox_args.serve_model_weights: + start_server(model) + # sync... + torch.distributed.barrier() + + # Start data stuff: + # Make and configure iterators timers("train/valid/test data iterators").start() ( @@ -382,7 +390,7 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - if neox_args.train_impl in ["normal", "kto"]: + if neox_args.train_impl in ["normal", "kto", "reinforce"]: keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] elif neox_args.train_impl in ["dpo", "rm"]: keys = ( @@ -427,6 +435,20 @@ def get_batch(neox_args, data_iterator): else None ) return tup + (rw_data, ref_data) + elif neox_args.train_impl == "reinforce": + + tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + rw_data = mpu.broadcast_data(["reward"], data, torch.float)["reward"] + raw_rw_data = mpu.broadcast_data(["raw_reward"], data, torch.float)[ + "raw_reward" + ] + return tup + (rw_data, raw_rw_data) elif neox_args.train_impl in ["dpo", "rm"]: pos_tup = _get_batch( neox_args=neox_args, @@ -604,6 +626,16 @@ def forward_step( rewards, ref_logp, ) = get_batch(neox_args=neox_args, data_iterator=data_iterator) + elif neox_args.train_impl == "reinforce": + ( + tokens, + labels, + loss_mask, + attention_mask, + position_ids, + rewards, + raw_rewards, + ) = get_batch(neox_args=neox_args, data_iterator=data_iterator) if neox_args.train_impl in ["dpo", "rm"]: tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( neox_args=neox_args, data_iterator=data_iterator @@ -841,6 +873,70 @@ def forward_step( # print(loss.shape) loss = loss.mean() # print(loss.shape) + elif neox_args.train_impl == "reinforce": + if reference_model is not None: + with torch.no_grad(): + ref_outputs = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_outputs) is tuple: + ref_outputs, _ = ref_outputs + ref_outputs = ref_outputs + if neox_args.kl_impl == "full": + # Have to do the loss over all tokens... + ref_outputs = gather_from_model_parallel_region(ref_outputs) + if neox_args.fp32_reinforce: + ref_outputs = ref_outputs.float() + ref_logp = ref_outputs.log_softmax(dim=-1).detach() + ref_per_token_logp = torch.gather( + ref_logp.clone(), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + else: + ref_per_token_logp = get_logp( + ref_outputs, labels, neox_args.fp32_reinforce + ) + metrics["ref_logp"] = ref_per_token_logp.clone().detach().mean() + outputs = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(outputs) is tuple: + outputs, _ = outputs + if neox_args.kl_impl == "full": + # Have to do the loss over all tokens... + outputs = gather_from_model_parallel_region(outputs) + if neox_args.fp32_reinforce: + outputs = outputs.float() + logp = outputs.log_softmax(dim=-1) + per_token_logp = torch.gather( + logp.clone(), dim=2, index=labels.unsqueeze(2) + ).squeeze(2) + else: + per_token_logp = get_logp(outputs, labels, neox_args.fp32_reinforce) + with torch.no_grad(): + metrics["logp"] = per_token_logp.clone().detach().mean() + metrics["reward"] = raw_rewards.clone().detach().mean() + metrics["reward_std"] = raw_rewards.clone().detach().std() + loss_mask_sum = loss_mask.sum() + if reference_model is not None: + if neox_args.kl_impl == "full": + # Following along with + # https://github.com/huggingface/trl/blob/104a02d207b63a4a062882aaff68f2d275493399/trl/trainer/ppo_trainer.py#L1109 + kl = F.kl_div(ref_logp, logp, log_target=True, reduction="none").sum(-1) + else: + kl = per_token_logp - ref_per_token_logp + if neox_args.kl_impl == "abs": + kl = kl.abs() + elif neox_args.kl_impl == "mse": + kl = 0.5 * (kl).square() + elif neox_args.kl_impl == "kl": + pass + with torch.no_grad(): + metrics["kl"] = kl.clone().detach().mean() + loss = (-per_token_logp * rewards) + (neox_args.kl_div_beta * kl) + loss = (loss * loss_mask).sum(-1) / loss_mask_sum + loss = loss.mean() + else: + loss = -(rewards * per_token_logp) + loss = (loss * loss_mask).sum(-1) / loss_mask_sum + loss = loss.mean() if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: @@ -1146,10 +1242,17 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): """Setup model and optimizer.""" needs_reference_model = ( - (neox_args.train_impl == "dpo") - and (neox_args.precompute_model_name is None) - and (not neox_args.dpo_reference_free) - ) or ((neox_args.train_impl == "kto") and (neox_args.precompute_model_name is None)) + ( + (neox_args.train_impl == "dpo") + and (neox_args.precompute_model_name is None) + and (not neox_args.dpo_reference_free) + ) + or ( + (neox_args.train_impl == "kto") + and (neox_args.precompute_model_name is None) + ) + or ((neox_args.train_impl == "reinforce") and (neox_args.kl_div_beta > 0.0)) + ) model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) @@ -1281,7 +1384,6 @@ def train_step( reference_model=None, ): """Single training step.""" - # Pipeline parallelism schedules forward/backward/step if neox_args.is_pipe_parallel: reduced_loss = train_step_pipe( diff --git a/post-training/README.md b/post-training/README.md index fb7ac8eb4..940cef428 100644 --- a/post-training/README.md +++ b/post-training/README.md @@ -2,6 +2,8 @@ Examples for running post-training with ultrafeedback data for SFT/DPO/RM training. +For [REINFORCE](https://arxiv.org/abs/2402.14740) style training, see [Online Training](OnlineTraining.MD). + ```bash python tools/ckpts/convert_hf_llama_to_neox.py --tp 4 --model meta-llama/Meta-Llama-3-8B-Instruct --model_path checkpoints/neox_converted/llama3-8b-instruct ``` diff --git a/post-training/configs/llama3-8b-reinforce.yml b/post-training/configs/llama3-8b-reinforce.yml new file mode 100644 index 000000000..8d8e04462 --- /dev/null +++ b/post-training/configs/llama3-8b-reinforce.yml @@ -0,0 +1,119 @@ +{ + "pipe_parallel_size": 0, + "model_parallel_size": 4, + "make_vocab_size_divisible_by": 1, + + # model settings + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 32, + "num_kv_heads": 8, + # llama3 supports more than this but this is just for testing. + "seq_length": 1024, + "max_position_embeddings": 1024, + "pos_emb": "rotary", + "rotary_pct": 1, + "rotary_emb_base": 500000, + "rope_fusion": true, + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + + "attention_config": [[["flash"], 32]], + + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": false, + "use_bias_in_norms": false, + "use_bias_in_attn_linear": false, + "use_bias_in_mlp": false, + "use_flashattn_swiglu": true, + "activation": "swiglu", + "intermediate_size": 14336, + "mlp_multiple_of": 14336, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001, + "betas": [0.9, 0.95], + "eps": 1.0e-8 + } + }, + "min_lr": 0.000001, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 1260000000, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": true, + "cpu_offload": false + }, + + "train_impl": "reinforce", + "dataset_impl": "online", + "reinforce_leave_one_out": true, + "fp32_reinforce": true, + "kl_impl": "abs", + "online_dataserver_ports": [10000, 10001], + "serve_model_weights": true, + "train_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ], + "test_label_data_paths": [ "data/sft/llama3_test_messages_label_document" ], + "valid_label_data_paths": [ "data/sft/llama3_train_messages_label_document" ], + "train_data_paths": [ "data/sft/llama3_train_messages_document" ], + "test_data_paths": [ "data/sft/llama3_test_messages_document" ], + "valid_data_paths": [ "data/sft/llama3_train_messages_document" ], + + "train_micro_batch_size_per_gpu": 8, + "gradient_accumulation_steps": 4, + "data_impl": "mmap", + "pack_impl": "unpacked", + "num_workers": 1, + + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + "precision": "bfloat16", + "fp32_allreduce": true, + "bf16": { + "enabled": true + }, + "data_types": { + "grad_accum_dtype": "fp32" + }, + + "train_iters": 477, + "lr_decay_iters": 477, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.1, + "checkpoint_factor": 1000, + "eval_interval": 100, + "eval_iters": 10, + + "log_interval": 1, + "steps_per_print": 1, + "wall_clock_breakdown": true, + + + "save": "checkpoints/reinforce/llama3/llama3-8b-instruct", + #"load": "", # once run is started, to restart from intermediate ckpt use "load" = "save" + "load": "checkpoints/neox_converted/llama3-8b-instruct", + "vocab-file": "checkpoints/neox_converted/llama3-8b-instruct/tokenizer/tokenizer.json", + "use_wandb": true, + "wandb_group": "llama3-8b-instruct", + "wandb_project": "reinforce-test", + "finetune": true, # set to false once resuming from intermediate finetuning step + "tokenizer_type": "HFTokenizer", +} diff --git a/post-training/online_data_example_llama3.py b/post-training/online_data_example_llama3.py new file mode 100644 index 000000000..bdd902512 --- /dev/null +++ b/post-training/online_data_example_llama3.py @@ -0,0 +1,177 @@ +import socket +import threading +import datasets +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +import requests +import pickle +from collections import defaultdict +import time + + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + + +def http_bot(url, pload): + for i in range(10): + try: + headers = {"User-Agent": "vLLM Client"} + response = requests.post(url, headers=headers, json=pload, stream=True) + data = response.json() + return data + except Exception as e: + # give it a few seconds to recover + time.sleep(5) + print(e) + continue + raise Exception("Failed to connect to server") + + +def threaded_data_gatherer( + prefix, + max_completion_len, + tokenizer, + model_name, + num_completions, + i, + dp_idx, + data_to_send, + rm_pipeline, +): + pload = { + "temperature": 1.0, + "max_tokens": 0, + "stop": "<|eot_id|>", + "stream": False, + "model": model_name, + "prompt": "", + "n": num_completions, + } + # Grab tokens... + prefix_tokens = tokenizer.encode(prefix) + prompt = tokenizer.apply_chat_template( + [ + { + "role": "user", + "content": "Please write a mildly negative movie review starting with " + + prefix, + } + ], + add_generation_prompt=True, + tokenize=False, + ) + prompt_tokens = tokenizer.encode(prompt) + pload["max_tokens"] = max_completion_len - len(prefix_tokens) + pload["prompt"] = prompt + prefix + completions = http_bot(f"http://localhost:{8000+dp_idx}/v1/completions", pload) + completions = [completion["text"].strip() for completion in completions["choices"]] + + def reward_fn(samples, **kwargs): + sentiments = list(map(get_positive_score, rm_pipeline(samples))) + return sentiments + + rewards = reward_fn([prefix + " " + completion for completion in completions]) + if i == 0 and dp_idx == 0: + print(completions) + completions = [ + tokenizer.encode(completion + "<|eot_id|>") for completion in completions + ] + data_to_send.append( + {"prefix": prompt_tokens, "completions": completions, "rewards": rewards} + ) + + +def data_generator( + bs_per_dp, + dataset, + tokenizer, + model_name, + max_prefix_len, + max_completion_len, + num_completions, + dp_idx, + dp_size, + tp_size, + rm_pipeline, +): + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.bind( + ("localhost", 10000 + dp_idx) + ) # only one data loader per data parallel group + split_counter = defaultdict(lambda: dp_idx) + while True: + server.listen(1) + conn, addr = server.accept() + split = conn.recv(4096).decode() + if split == "valid": + split = "unsupervised" + data_to_send = list() + threads = list() + for i in range(bs_per_dp): + prefix = " ".join( + dataset[split][split_counter[split]]["text"].split()[:5] + ) # grab a few words to prompt it... + split_counter[split] = (split_counter[split] + dp_size) % len( + dataset[split] + ) + threads.append( + threading.Thread( + target=threaded_data_gatherer, + args=( + prefix, + max_completion_len, + tokenizer, + model_name, + num_completions, + i, + dp_idx, + data_to_send, + rm_pipeline, + ), + ) + ) + threads[-1].start() + for thread in threads: + thread.join() + conn.send(pickle.dumps(data_to_send)) + conn.close() + print( + f"Sent data to {dp_idx} for {split} split at iter {split_counter[split]}..." + ) + + +if __name__ == "__main__": + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device="cpu", + ) + dataset = datasets.load_dataset("imdb") + threads = list() + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") + for i in range(2): + threads.append( + threading.Thread( + target=data_generator, + args=( + 64, # bs_per_dp + dataset, # dataset + tokenizer, # tokenizer + "meta-llama/Meta-Llama-3-8B-Instruct", # model_name + 128, # max_prefix_len + 256, # max_completion_len + 4, # num_completions + i, # dp_idx + 2, # dp_size + 4, # tp_size + sentiment_fn, # rm_pipeline + ), + ) + ) + threads[-1].start() + for thread in threads: + thread.join() diff --git a/post-training/online_example.sh b/post-training/online_example.sh new file mode 100644 index 000000000..abe601faa --- /dev/null +++ b/post-training/online_example.sh @@ -0,0 +1,7 @@ +# Launch vllm +CUDA_VISIBLE_DEVICES=0,1,2,3 conda run --no-capture-output -n vllm python -m vllm.entrypoints.openai.api_server --model=meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --from-remote-program --tensor-parallel-size=4 --enforce-eager --gpu-memory-utilization=0.2 --port 8000 --max-model-len=1024 --max-num-seqs=512 & + +CUDA_VISIBLE_DEVICES=4,5,6,7 conda run --no-capture-output -n vllm python -m vllm.entrypoints.openai.api_server --model=meta-llama/Meta-Llama-3-8B-Instruct --dtype auto --from-remote-program --tensor-parallel-size=4 --enforce-eager --gpu-memory-utilization=0.2 --port 8001 --max-model-len=1024 --max-num-seqs=512 & + +# Launch training +conda run --no-capture-output -n neox python deepy.py train.py post-training/configs/llama3-8b-reinforce.yml diff --git a/post-training/online_training.md b/post-training/online_training.md new file mode 100644 index 000000000..28f45c7cd --- /dev/null +++ b/post-training/online_training.md @@ -0,0 +1,56 @@ +# Online Training + +## Prerequisites +Want to use [REINFORCE](https://arxiv.org/abs/2402.14740) to train your model? First you'll need to build a custom vllm package. + +[synth-vllm](https://github.com/SynthLabsAI/synth-vllm) is a fork of [vllm](https://github.com/vllm-project/vllm) maintained by [SynthLabs](https://www.synthlabs.ai/) +that has been modified to support using the weights in GPT-NeoX by sharing the GPU memory location of the model weights. + +It currently supports Llama and Pythia models. + +### Building the package + +Here is a reference on how the package has been built before, using conda: +(Note this should be taken as a reference, and may not work as is due to your system configuration) + +```bash +# cd to the synth vllm directory... +conda create -n vllm python=3.10 +conda deactivate +conda activate vllm +conda install -y pytorch pytorch-cuda=12.1 -c pytorch -c nvidia +conda install -y nvidia/label/cuda-12.1.0::cuda-toolkit +conda install -y nvidia/label/cuda-12.1.0::cuda-cudart +conda install -y nvidia/label/cuda-12.1.0::cuda-compiler +conda install -y nvidia/label/cuda-12.1.0::cuda-nvcc +conda install -y nvidia/label/cuda-12.1.0::cuda-profiler-api +conda install -y nvidia/label/cuda-12.1.0::cuda-cudarty +conda install -y -c nvidia cuda-nvprof=12.1 +conda install -y conda-forge::cuda-version=12.1 +conda install -y gcc_linux-64=12.3.0 +conda install -y -c conda-forge gxx_linux-64=12.3.0 +pip install -e . +``` + +## Training + +If you haven't already, run this command to generate a copy of the Llama-3 weights in GPT-NeoX format: +```bash +python tools/ckpts/convert_hf_llama_to_neox.py --tp 4 --model meta-llama/Meta-Llama-3-8B-Instruct --model_path checkpoints/neox_converted/llama3-8b-instruct +``` + +[online_example.sh](online_example.sh), [online_data_example_llama3.py](online_data_example_llama3.py) is an example of +how to train a model using the synth-vllm package on a single node. + +This assumes you are using a conda environment with GPT-NeoX installed under the name `neox`. + +To run the example, execute the following commands: + +```bash +# It may be preferable to run these in two separate terminals +python post-training/online_data_example_llama3.py & +bash post-training/online_example.sh +``` + +This will train a model using the synth-vllm package on the llama3-8b-instruct model. It will optimize a positive reward +from a sentiment classifier.