diff --git a/configs/150M/H100.toml b/configs/150M/H100.toml index a6339181..94ba83cf 100644 --- a/configs/150M/H100.toml +++ b/configs/150M/H100.toml @@ -1,16 +1,21 @@ name_model = "150M" -project = "debug_150m_zero_band" +project = "adam_sweep" type_model = "llama2" [train] -micro_bs = 64 # change this base on the gpu +micro_bs = 4 # change this base on the gpu reshard_after_forward = true [optim] -batch_size = 512 +batch_size = 128 warmup_steps = 1000 -total_steps = 88_000 +total_steps = 8192 +optim.lr = 4e-4 -[optim.optim] -lr = 4e-4 +[data] +seq_length = 8192 +num_workers = 2 +dataset_name_or_paths = "/home/ubuntu/prime/datasets/fineweb-edu" +split_by_data_rank = true +[acco] diff --git a/configs/1B/H100.toml b/configs/1B/H100.toml index de9cef75..41a3db65 100644 --- a/configs/1B/H100.toml +++ b/configs/1B/H100.toml @@ -1,15 +1,22 @@ name_model = "1B" -project = "debug_1B_zero_band" +project = "adam_sweep" type_model = "llama2" [train] -micro_bs = 32 +micro_bs = 2 reshard_after_forward = true [optim] -batch_size = 1024 +batch_size = 128 warmup_steps = 1000 total_steps = 8192 -[optim.optim] -lr = 7e-4 +optim.lr = 4e-4 + +[data] +seq_length = 8192 +num_workers = 2 +dataset_name_or_paths = "/home/ubuntu/prime/datasets/fineweb-edu" +split_by_data_rank = true + +[acco] diff --git a/llama-debug/config.json b/llama-debug/config.json new file mode 100644 index 00000000..c3f7712c --- /dev/null +++ b/llama-debug/config.json @@ -0,0 +1,35 @@ +{ +"architectures": [ + "LlamaForCausalLM" +], +"attention_bias": false, +"attention_dropout": 0.0, +"bos_token_id": 128000, +"eos_token_id": 128001, +"head_dim": 64, +"hidden_act": "silu", +"hidden_size": 1024, +"initializer_range": 0.02, +"intermediate_size": 4096, +"max_position_embeddings": 1024, +"mlp_bias": false, +"model_type": "llama", +"num_attention_heads": 16, +"num_hidden_layers": 5, +"num_key_value_heads": 8, +"pretraining_tp": 1, +"rms_norm_eps": 1e-05, +"rope_scaling": { + "factor": 32.0, + "high_freq_factor": 4.0, + "low_freq_factor": 1.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3" +}, +"rope_theta": 500000.0, +"tie_word_embeddings": true, +"torch_dtype": "bfloat16", +"transformers_version": "4.45.0.dev0", +"use_cache": true, +"vocab_size": 128256 +} \ No newline at end of file diff --git a/src/zeroband/collectives.py b/src/zeroband/collectives.py index efdb3ea1..ea87575c 100644 --- a/src/zeroband/collectives.py +++ b/src/zeroband/collectives.py @@ -12,7 +12,8 @@ def gloo_all_reduce( tensor: torch.Tensor, op: dist.ReduceOp = dist.ReduceOp.SUM, group: Optional[dist.ProcessGroup] = None, -) -> None: + async_op: bool = False, +) -> None | dist.Work: """Wrap gloo all reduce""" if group is None: group = dist.distributed_c10d._get_default_group() @@ -24,7 +25,7 @@ def gloo_all_reduce( # todo check numerical stability of doing post or pre div tensor.div_(group.size()) - dist.all_reduce(tensor, op, group=group) + return dist.all_reduce(tensor, op, group=group, async_op=async_op) class Compression(Enum): diff --git a/src/zeroband/config.py b/src/zeroband/config.py index 07d4b7e0..a82a043a 100644 --- a/src/zeroband/config.py +++ b/src/zeroband/config.py @@ -8,6 +8,7 @@ from zeroband.diloco import DilocoConfig from zeroband.models.llama.model import AttnFnType from zeroband.optimizers import OptimizersConfig, AdamConfig +from zeroband.dpu import ACCOConfig class OptimConfig(BaseConfig): @@ -68,6 +69,7 @@ class Config(BaseConfig): # sub config diloco: DilocoConfig | None = None + acco: ACCOConfig | None = None data: DataConfig = DataConfig() optim: OptimConfig = OptimConfig() train: TrainConfig @@ -88,4 +90,3 @@ def validate_live_recovery_rank_src(self): if self.ckpt is not None and self.ckpt.live_recovery_rank_src is not None and self.diloco is None: raise ValueError("live_recovery_rank_src is only supported with diloco") return self - diff --git a/src/zeroband/dpu.py b/src/zeroband/dpu.py new file mode 100644 index 00000000..d6c5161a --- /dev/null +++ b/src/zeroband/dpu.py @@ -0,0 +1,5 @@ +from typing import Optional +from pydantic_config import BaseConfig + +class ACCOConfig(BaseConfig): + theta_t_device: Optional[str] = None diff --git a/src/zeroband/global_ddp.py b/src/zeroband/global_ddp.py new file mode 100644 index 00000000..4bf4a049 --- /dev/null +++ b/src/zeroband/global_ddp.py @@ -0,0 +1,161 @@ +import time +from typing import Generator, NamedTuple +from pydantic import model_validator +from pydantic_config import BaseConfig +import torch +import torch.nn as nn +from zeroband.comms import ElasticDeviceMesh +import torch.distributed as dist +from zeroband.collectives import Compression, gloo_all_reduce +from torch.distributed._tensor.api import DTensor +from zeroband.utils.logging import get_logger +from zeroband.utils.world_info import get_world_info + +from torch.distributed import Work + +logger = get_logger(__name__) + + +class GlobalDDPConfig(BaseConfig): + # retry_all_reduce: int = 3 + compression: Compression = Compression.NO + dpu: bool = False + enable: bool = True + + @model_validator(mode="after") + def validate_compression(self): + if self.compression != Compression.NO: + raise NotImplementedError("Compression is not implemented yet") + return self + + +def offload_grad_generator(model: nn.Module) -> Generator: + for param in model.parameters(): + if param.grad is not None: + if isinstance(param.grad, DTensor): + yield param.grad.to_local().to("cpu") + else: + yield param.grad.to("cpu") + + +def apply_staling_grad(model: nn.Module, tensors: list[torch.Tensor]): + for param, tensor in zip(model.parameters(), tensors): + if isinstance(param.grad, DTensor): + param.grad.to_local().copy_(tensor) + else: + param.grad.copy_(tensor) + + +def maybe_unwrap_dtensor(tensor: torch.Tensor | DTensor): + if isinstance(tensor, DTensor): + return tensor.to_local() + else: + return tensor + + +class AllReduceGradWork(NamedTuple): + grad: torch.Tensor + work: Work + + +def async_all_reduce(model: nn.Module, elastic_device_mesh: ElasticDeviceMesh, flag: str) -> list[AllReduceGradWork]: + """ + Triggered all reduce operation on a list of tensors in a async manner. + Return a list of async jobs that can be waited on. + """ + + elastic_device_mesh.maybe_reinit_global_pg(admit_joiners=False) + world_size = elastic_device_mesh.global_pg.size() + + global_pg = elastic_device_mesh.global_pg + elastic_device_mesh.monitored_barrier(flag) + logger.debug("Beginning all reduce") + + async_job = [] + + for param in offload_grad_generator(model): # TODO: do we need to offload when doing blocking all reduce ? + grad = maybe_unwrap_dtensor(param) + + grad.div_(world_size) + + # all_reduce(self.config.compression, grad, dist.ReduceOp.SUM, global_pg) # doing gloo all reduce direclty because of async op + + async_job.append(AllReduceGradWork(grad, gloo_all_reduce(grad, dist.ReduceOp.SUM, global_pg, True))) + + return async_job + + +class GlobalDDP: + """ + This class implements DDP over internet. It + + :Args: + model: The model to be trained + config: The configuration for the global DDP + elastic_device_mesh: The elastic device mesh to be used + + Example usage: + + ``` + config = GlobalDDPConfig(dpu=False) + global_ddp = GlobalDDP(model, config, elastic_device_mesh) + + for step in range(num_steps): + for micro_bs in range(num_micro_bs): + loss = model(batch) + loss.backward() + + global_ddp.all_reduce() + optimizer.step() + optimizer.zero_grad() + ``` + + """ + + flag: str = "global_ddp" + + def __init__( + self, + model: nn.Module, + config: GlobalDDPConfig, + elastic_device_mesh: ElasticDeviceMesh, + ): + self.elastic_device_mesh = elastic_device_mesh + self.config = config + + self.world_info = get_world_info() + self._logger = get_logger() + + self.model = model + + self._stalling_grad_work: list[AllReduceGradWork] | None = None + + def all_reduce(self): + if not self.config.dpu: + self._blocking_all_reduce(self.model) + else: + new_staling_grad_work = async_all_reduce(self.model, self.elastic_device_mesh, self.flag) + + if self._stalling_grad_work is None: + # if it is the first step we just store the work for the next call to this function and return + self._stalling_grad_work = new_staling_grad_work + else: + # otherwise we wait for the current staling grad work to finish + start_time = time.time() + [all_reduce_grad_work.work.wait() for all_reduce_grad_work in self._stalling_grad_work] + self._logger.debug(f"Time to wait for staling grads: {time.time() - start_time}") + # and apply the staling grads to the model + apply_staling_grad( + self.model, [all_reduce_grad_work.grad for all_reduce_grad_work in self._stalling_grad_work] + ) + # and store the new staling grad work for the next call to this function + self._stalling_grad_work = new_staling_grad_work + + def _blocking_all_reduce(self, tensor: list[torch.Tensor]): + """ + Triggered all reduce operation on a list of tensors in a blocking manner. + """ + [ + all_reduce_grad_work.work.wait() + for all_reduce_grad_work in async_all_reduce(tensor, self.elastic_device_mesh, self.flag) + ] diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 87f4635d..2c91a9dc 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -6,6 +6,7 @@ from pydantic_config import parse_argv from einops import rearrange from torch.nn import functional as F +from torch import nn from transformers import AutoTokenizer @@ -18,7 +19,7 @@ from zeroband.loss import cross_entropy_max_z_loss from zeroband.models.llama.model import create_block_mask_from_seqlens -from zeroband.config import Config #, MemoryProfilerConfig +from zeroband.config import Config # , MemoryProfilerConfig from zeroband.optimizers import get_optimizer from zeroband.utils import ( @@ -39,6 +40,7 @@ from zeroband.checkpoint import CkptManager, TrainingProgress from zeroband.lr_scheduler import get_scheduler + def log_hash_training_state( config: Config, model: torch.nn.Module, @@ -76,6 +78,62 @@ def log_hash_training_state( metric_logger.log(metrics) +def compute_loss( + model: nn.Module, + gradient_accumulation_steps: int, + train_dataloader_iterator: iter, + local_pg: dist.ProcessGroup, + loss_scaling: float = 1.0, + enable_z_loss: bool = False, +): + loss_batch = 0 + z_loss_batch = 0 + + for grad_acc_step in range(gradient_accumulation_steps): + is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 + # no sync if we are accumulating gradients + model.set_requires_gradient_sync(not is_accumulating) + + batch = next(train_dataloader_iterator) + input_ids = batch["input_ids"].to("cuda") + labels = batch["labels"].to("cuda") + if config.train.sequence_packing: + seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] + block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None + else: + block_mask = None + + logits = model(tokens=input_ids, block_mask=block_mask).contiguous() + flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") + flatten_labels = rearrange(labels, "b seq -> (b seq)") + + if enable_z_loss: + ce_loss, z_loss = cross_entropy_max_z_loss(flatten_logits, flatten_labels, config.optim.z_loss_weight) + ce_loss /= gradient_accumulation_steps * loss_scaling + z_loss /= gradient_accumulation_steps * loss_scaling + + del logits + loss = ce_loss + z_loss + loss.backward() + + else: + loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps * loss_scaling + del logits + loss.backward() + + if config.optim.z_loss: + loss_batch += ce_loss.clone().detach() + z_loss_batch += z_loss.clone().detach() + else: + loss_batch += loss.clone().detach() + + dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=local_pg) + if config.optim.z_loss: + dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=local_pg) + + return loss_batch, z_loss_batch + + def train(config: Config): # batch_size is the total batch size for all GPUs assert config.optim.batch_size % world_info.local_world_size == 0 @@ -83,6 +141,9 @@ def train(config: Config): assert batch_size % config.train.micro_bs == 0 gradient_accumulation_steps = batch_size // config.train.micro_bs + if config.acco: + assert gradient_accumulation_steps % 2 == 0, "ACCO requires gradient accumulation steps to be even" + gradient_accumulation_steps //= 2 if config.ckpt is not None and config.ckpt.interval is not None and config.diloco is not None: assert ( @@ -137,7 +198,8 @@ def train(config: Config): apply_ac_ckpt(model, num) elastic_device_mesh = ElasticDeviceMesh( - enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src + enable=config.diloco is not None or config.acco is not None, + live_recovery_rank_src=config.ckpt.live_recovery_rank_src, ) mp_policy = MixedPrecisionPolicy( @@ -165,6 +227,12 @@ def train(config: Config): # Setup optimizers inner_optimizer = get_optimizer(model.parameters(), config.optim.optim) + if config.acco is not None: + first_step = True + reduce_work = [] + theta_t = [p.detach().clone() for p in model.parameters() if p.requires_grad] + if config.acco.theta_t_device is not None: + theta_t = [p.to(config.acco.theta_t_device) for p in theta_t] diloco = Diloco(config.diloco, model, elastic_device_mesh) if config.diloco is not None else None @@ -284,57 +352,77 @@ def train(config: Config): monitor.set_stage("inner_loop") for inner_step in range(num_inner_steps): - loss_batch = 0 - z_loss_batch = 0 - - for grad_acc_step in range(gradient_accumulation_steps): - is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 - # no sync if we are accumulating gradients - model.set_requires_gradient_sync(not is_accumulating) - - batch = next(train_dataloader_iterator) - input_ids = batch["input_ids"].to("cuda") - labels = batch["labels"].to("cuda") - if config.train.sequence_packing: - seqlens = [seqlen.to("cuda") for seqlen in batch["seqlens"]] - block_mask = create_block_mask_from_seqlens(seqlens) if seqlens is not None else None - else: - block_mask = None - - logits = model(tokens=input_ids, block_mask=block_mask).contiguous() - flatten_logits = rearrange(logits, "b seq vocab -> (b seq) vocab") - flatten_labels = rearrange(labels, "b seq -> (b seq)") - - if config.optim.z_loss: - ce_loss, z_loss = cross_entropy_max_z_loss( - flatten_logits, flatten_labels, config.optim.z_loss_weight + loss_batch, z_loss_batch = compute_loss( + model, + gradient_accumulation_steps, + train_dataloader_iterator, + elastic_device_mesh.local_pg, + 0.5 if config.acco is not None else 1.0, + config.optim.z_loss, + ) + print(loss_batch, z_loss_batch) + + if config.acco is not None: + new_g_tilde = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + + if not first_step: + # Copy in theta_t and consume g_t + for opt_param, cpu_param, _g_t, _g_tilde in zip(model.parameters(), theta_t, g_t, g_tilde): # noqa + opt_param.data.copy_(cpu_param.data, non_blocking=True) + opt_param.grad.copy_(_g_t + _g_tilde, non_blocking=True) + opt_param.grad /= batch_size * elastic_device_mesh.global_pg.size() + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() + # Update theta_t + for param, cpu_param in zip(model.parameters(), theta_t): + cpu_param.data.copy_(param.data, non_blocking=True) + first_step = False + + g_tilde = new_g_tilde + reduce_work = [ + dist.all_reduce( + _g_tilde.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True ) - ce_loss /= gradient_accumulation_steps - z_loss /= gradient_accumulation_steps - - del logits - loss = ce_loss + z_loss - loss.backward() - - else: - loss = F.cross_entropy(flatten_logits, flatten_labels) / gradient_accumulation_steps - del logits - loss.backward() + for _g_tilde in g_tilde + ] - if config.optim.z_loss: - loss_batch += ce_loss.clone().detach() - z_loss_batch += z_loss.clone().detach() - else: - loss_batch += loss.clone().detach() + loss_batch_1, z_loss_batch_1 = compute_loss( + model, + gradient_accumulation_steps, + train_dataloader_iterator, + elastic_device_mesh.local_pg, + 0.5 if config.acco is not None else 1.0, + config.optim.z_loss, + ) + loss_batch += loss_batch_1 + z_loss_batch += z_loss_batch_1 + + print(loss_batch, z_loss_batch) + g_t = [p.grad.detach().cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + # reduce_work = [elastic_device_mesh.global_pg.allreduce([_g_t], op=dist.ReduceOp.SUM) for _g_t in g_t] + reduce_work = [ + dist.all_reduce( + _g_t.to_local(), dist.ReduceOp.SUM, group=elastic_device_mesh.global_pg, async_op=True + ) + for _g_t in g_t + ] - dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) - if config.optim.z_loss: - dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg) + for opt_param, _g_tilde in zip(model.parameters(), g_tilde): + opt_param.grad.copy_(_g_tilde, non_blocking=True) + opt_param.grad /= batch_size // 2 * elastic_device_mesh.global_pg.size() + inner_optimizer.step() + inner_optimizer.zero_grad() - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - inner_optimizer.step() - scheduler.step() - inner_optimizer.zero_grad() + else: + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() # logging training_progress.step += 1 diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 7b33a620..8a26eb34 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -36,7 +36,7 @@ def gpus_to_use(num_nodes, num_gpu, rank): return ",".join(map(str, range(rank * num_gpu, (rank + 1) * num_gpu))) -def _test_multi_gpu(num_gpus, config, extra_args=[], diloco=False): +def _test_multi_gpu(num_gpus, config, extra_args=[], multi_nodes=False): num_nodes, num_gpu = num_gpus[0], num_gpus[1] processes = [] @@ -55,7 +55,7 @@ def _test_multi_gpu(num_gpus, config, extra_args=[], diloco=False): env = copy.deepcopy(os.environ) - if diloco: + if multi_nodes: new_env = { "GLOBAL_RANK": str(i), "GLOBAL_UNIQUE_ID": str(i), @@ -85,7 +85,7 @@ def test_multi_gpu(num_gpus): @pytest.mark.parametrize("num_gpus", [[2, 1], [2, 2]] if num_gpu >= 4 else [[2, 1]]) def test_multi_gpu_diloco(num_gpus): - _test_multi_gpu(num_gpus, "debug/diloco.toml", diloco=True) + _test_multi_gpu(num_gpus, "debug/diloco.toml", multi_nodes=True) def test_act_ckpt(): @@ -101,7 +101,7 @@ def test_act_ckpt_num(): @pytest.mark.parametrize("backend", [Compression.NO, Compression.UINT8]) def test_all_reduce_diloco(backend: Compression): num_gpus = [2, 1] - _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--diloco.compression", backend.value], diloco=True) + _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--diloco.compression", backend.value], multi_nodes=True) def test_z_loss(): @@ -116,6 +116,13 @@ def test_packing(packing: bool): _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) +@pytest.mark.parametrize("dpu", [True, False]) +def test_global_ddp(dpu: bool): + num_gpus = [2, 1] + dpu_arg = "--global_ddp.dpu" if dpu else "--no-global_ddp.dpu" + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[dpu_arg], multi_nodes=True) + + @pytest.mark.parametrize("diloco", [False, True]) def test_soap(diloco: bool): num_gpus = [1, 2] if diloco else [2, 1] diff --git a/train_dpu.py b/train_dpu.py new file mode 100644 index 00000000..ef6e2953 --- /dev/null +++ b/train_dpu.py @@ -0,0 +1,152 @@ +import torch.distributed as dist +from typing import List +import torch +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +from datasets import load_dataset +from torch.optim import AdamW, Optimizer +import wandb +import psutil +from tqdm import tqdm + + +# Loss function +def compute_loss(model: torch.nn.Module, inputs: List[str], tokenizer) -> torch.Tensor: + """ + Compute the loss for a batch of input text using a causal language modeling objective. + + Args: + model (torch.nn.Module): The pre-trained model (e.g., Llama). + inputs (List[str]): A batch of input text strings. + tokenizer: The tokenizer associated with the model. + + Returns: + torch.Tensor: The computed loss value. + """ + # Tokenize input text and prepare for model input + input_ids = tokenizer( + inputs, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings + ).input_ids + input_ids = input_ids.to(model.device) + labels = input_ids.clone() + + # Compute the loss + outputs = model(input_ids, labels=labels) + return outputs.loss + +def print_memory_usage(): + # Get CPU memory usage + memory_info = psutil.virtual_memory() + cpu_memory_used = memory_info.used / (1024 ** 2) + cpu_memory_total = memory_info.total / (1024 ** 2) + + print(f"CPU Memory Usage:") + print(f"Used: {cpu_memory_used:.2f} MB") + print(f"Total: {cpu_memory_total:.2f} MB") + print(f"Percentage: {memory_info.percent}%\n") + + # Check if CUDA is available + if torch.cuda.is_available(): + # Get current device + device = torch.device('cuda') + gpu_memory_used = torch.cuda.memory_allocated(device=device) + gpu_memory_reserved = torch.cuda.memory_reserved(device=device) + gpu_memory_total = torch.cuda.get_device_properties(device).total_memory + + print(f"GPU Memory Usage (Device: {torch.cuda.get_device_name(device)}):") + print(f"Allocated: {gpu_memory_used / (1024 ** 2):.2f} MB") + print(f"Reserved: {gpu_memory_reserved / (1024 ** 2):.2f} MB") + print(f"Total: {gpu_memory_total / (1024 ** 2):.2f} MB\n") + else: + print("CUDA is not available.") + +# Main function +def main(): + batch_size = 8 + # Load dataset + dataset = load_dataset("/root/prime/prime/datasets/fineweb-edu", split="train", streaming=True) + data_loader = DataLoader(dataset, batch_size=8, shuffle=False) + + # Load model and tokenizer + model_name = "llama-debug" # Replace with actual Llama model if available + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + tokenizer.pad_token = tokenizer.eos_token + config = AutoConfig.from_pretrained(model_name) + model = AutoModelForCausalLM.from_config(config) + print(f"Model params: {sum(p.numel() for p in model.parameters()):,}, Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") + print_memory_usage() + theta_t = [p.detach().clone() for p in model.parameters() if p.requires_grad] + optimizer_copy = [p.detach().clone() for p in model.parameters() if p.requires_grad] + reduce_work = [] + model.to("cuda") + + # Define optimizer + optimizer = AdamW(optimizer_copy, lr=1e-4) + + # Run ACCO algorithm + num_steps = 100 + + first_step = True + print("Post Init") + print_memory_usage() + for step, batch in tqdm(enumerate(data_loader), total=num_steps): + if step >= num_steps: + break + + # Split the batch into two halves + batch_text = batch["text"] + mid_point = len(batch_text) // 2 + first_half, second_half = batch_text[:mid_point], batch_text[mid_point:] + + # Stage 1: Compute gradients g_tilde and theta + for p in model.parameters(): + p.grad = None + loss = compute_loss(model, first_half, tokenizer) + loss.backward() # Compute gradients for g_t + for work in reduce_work: + work.wait() + + if not first_step: + for opt_param, cpu_param, _g_t, _g_tilde in zip(optimizer_copy, theta_t, g_t, g_tilde): + opt_param.data = cpu_param.data + opt_param.grad = (_g_t + _g_tilde) / (batch_size * dist.get_world_size()) + optimizer.step() + for param, cpu_param, opt_param in zip(model.parameters(), theta_t, optimizer_copy): + param.data.copy_(opt_param.data, non_blocking=True) + cpu_param.data.copy_(opt_param.data, non_blocking=True) + first_step = False + + g_tilde = [p.grad.cpu() for p in model.parameters() if p.requires_grad] + reduce_work = [dist.all_reduce(_g_tilde, op=dist.ReduceOp.SUM, async_op=True) for _g_tilde in g_tilde] + + # Stage 2: Compute g_t and theta_tilde + for p in model.parameters(): + p.grad = None + loss = compute_loss(model, second_half, tokenizer) + loss.backward() + g_t = [p.grad.cpu() for p in model.parameters() if p.requires_grad] + for work in reduce_work: + work.wait() + reduce_work = [dist.all_reduce(_g_t, op=dist.ReduceOp.SUM, async_op=True) for _g_t in g_t] + + # theta_tilde + for param, _g_tilde in zip(optimizer_copy, g_tilde): + ## TODO: Weight by seen by batches + param.grad = _g_tilde / (batch_size // 2 * dist.get_world_size()) + optimizer.step() + for param, param_tilde in zip(model.parameters(), optimizer_copy): + param.data.copy_(param_tilde, non_blocking=True) + + print(f"Step {step + 1}/{num_steps}: Loss = {loss.item()}") + wandb.log({"loss": loss.item()}) + print(f"End of step {step}") + print_memory_usage() + +# Entry point +if __name__ == "__main__": + dist.init_process_group(backend="cpu:gloo,cuda:nccl") + torch.cuda.set_device(dist.get_rank()) + wandb.init() + main() + wandb.finish() + dist.destroy_process_group()