From 1a3b439d698978f9a94c25af02315bb7c0fe77af Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Dec 2024 05:26:18 +0000 Subject: [PATCH 1/3] use shared function for log modelhash --- src/zeroband/train.py | 96 ++++++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 47 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 7ab7cb8d..d54391e3 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -32,7 +32,7 @@ ) from zeroband.utils.activation_ckpt import apply_ac_ckpt from zeroband.data import TEST_VOCAB_SIZE, get_dataloader, DataConfig -from zeroband.utils.metric_logger import WandbMetricLogger, DummyMetricLogger +from zeroband.utils.metric_logger import MetricLogger, WandbMetricLogger, DummyMetricLogger from zeroband.utils.monitor import HttpMonitor from zeroband.models.llama import get_model from zeroband.utils.profiler import MemoryProfiler @@ -127,6 +127,40 @@ def validate_live_recovery_rank_src(self): return self +def log_hash_training_state( + config: Config, + model: torch.nn.Module, + inner_optimizer: torch.optim.Optimizer, + diloco: Diloco | None, + metric_logger: MetricLogger, + id: str = "", +): + """Log the hash of the model and optimizer. This function is slow""" + if config.train.log_model_hash: + inner_model_hash = get_module_signature(model) + inner_optimizer_hash = get_optimizer_signature(inner_optimizer) + + logger.debug(f"inner diloco model {id} : {inner_model_hash}") + logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}") + + if world_info.rank == 0: + metric_logger.log( + {"inner_model_hash_{id}": inner_model_hash, "inner_optimizer_hash_{id}": inner_optimizer_hash} + ) + + if config.diloco is not None and diloco is not None: + outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer) + outer_model_hash = get_tensor_list_signature(diloco.param_list_cpu) + + logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") + logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") + + if world_info.rank == 0: + metric_logger.log( + {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} + ) + + def train(config: Config): # batch_size is the total batch size for all GPUs assert config.optim.batch_size % world_info.local_world_size == 0 @@ -244,6 +278,16 @@ def train(config: Config): diloco_offloaded_param_list=diloco.param_list_cpu if config.diloco is not None else None, ) + if world_info.rank == 0: + logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger + metric_logger = logger_cls( + project=config.project, + config={"config": config.model_dump(), "world_info": world_info.json()}, + resume=config.wandb_resume, + ) + else: + metric_logger = None + if config.train.torch_compile: # we need to compile AFTER creating the CKPT manager, DON'T ASK ME WHY model = torch.compile(model) @@ -256,21 +300,7 @@ def train(config: Config): skip_dataloader=config.ckpt.skip_dataloader, data_path=config.ckpt.data_path, ) - if config.train.log_model_hash: - logger.info(f"model hash: {get_module_signature(model)}") - logger.info(f"optimizer hash: {get_optimizer_signature(inner_optimizer)}") - - if config.diloco is not None: - logger.info(f"outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}") - logger.info(f"outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") - - if world_info.rank == 0: - logger_cls = WandbMetricLogger if config.metric_logger_type == "wandb" else DummyMetricLogger - metric_logger = logger_cls( - project=config.project, - config={"config": config.model_dump(), "world_info": world_info.json()}, - resume=config.wandb_resume, - ) + log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="resume") if config.train.memory_monitor: gpu_mem_monitor = GPUMemoryMonitor() @@ -305,15 +335,6 @@ def train(config: Config): maybe_dest_rank = elastic_device_mesh.live_recovery.should_send_ckpt_to() if maybe_dest_rank is not None: logger.info(f"Start live recovery to rank {maybe_dest_rank}") - if config.train.log_model_hash: - logger.info( - f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" - ) - logger.info( - f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}" - ) - logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") - ckpt_manager.send_ckpt_to_peer(elastic_device_mesh.global_pg, maybe_dest_rank, blocking=True) elastic_device_mesh.live_recovery.reset() @@ -328,13 +349,7 @@ def train(config: Config): ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) - if config.train.log_model_hash: - logger.info( - f"live recovery outer optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}" - ) - logger.info(f"live recovery outer model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") - logger.info(f"inner optimizer hash: {get_optimizer_signature(inner_optimizer)}") - + log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="live_reco_recv") need_live_recovery = False if config.ckpt.remote_data_load: @@ -461,9 +476,6 @@ def train(config: Config): memory_profiler.step() if config.diloco is not None: - if config.train.log_model_hash: - logger.debug("Pre diloco model: %s", get_module_signature(model)) - if world_info.rank == 0 and config.monitor is not None: monitor.set_stage("outer_loop") @@ -471,11 +483,7 @@ def train(config: Config): diloco.step(model=model, flag=training_progress.outer_step) diloco_time = time.perf_counter() - time_start_inner - if config.train.log_model_hash: - logger.debug("inner diloco model: %s", get_module_signature(model)) - logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}") - logger.debug(f"outer diloco optimizer hash: {get_optimizer_signature(diloco.outer_optimizer)}") - logger.debug(f"outer diloco model hash: {get_tensor_list_signature(diloco.param_list_cpu)}") + log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="outer_step") training_progress.outer_step += 1 @@ -488,13 +496,7 @@ def train(config: Config): do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 ckpt_manager.save(remote=do_remote) - if config.train.log_model_hash: - logger.debug("Post saved model: %s", get_module_signature(model)) - logger.debug("Post saved optimizer: %s", get_optimizer_signature(inner_optimizer)) - - if config.diloco is not None: - logger.debug("Post saved outer model: %s", get_tensor_list_signature(diloco.param_list_cpu)) - logger.debug("optimizer hash: %s", get_optimizer_signature(diloco.outer_optimizer)) + log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="ckpt save") if config.diloco: tokens_per_second = ( From 5a39b9add313f397df925a4f0d827b0e89f4eb3f Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Dec 2024 07:46:59 +0000 Subject: [PATCH 2/3] add ckpt tests --- src/zeroband/data.py | 9 +- src/zeroband/models/llama/__init__.py | 2 + src/zeroband/models/llama/model.py | 9 +- src/zeroband/train.py | 44 +++++--- tests/test_torchrun/test_train.py | 138 ++++++++++++++++++++++++++ 5 files changed, 187 insertions(+), 15 deletions(-) diff --git a/src/zeroband/data.py b/src/zeroband/data.py index ed90f61c..9dc7c42d 100644 --- a/src/zeroband/data.py +++ b/src/zeroband/data.py @@ -46,18 +46,23 @@ def __init__(self, seq_len: int, vocab_size: int): self.seq_len = seq_len self.vocab_size = vocab_size assert vocab_size > 3, "Vocab size must be greater than 3" + self.step = 0 def __iter__(self) -> Generator[dict[str, Any], Any, None]: while True: len_ = random.randint(1, self.seq_len) input_ids = torch.randint(3, self.vocab_size, (len_,)).tolist() + self.step += 1 yield {"input_ids": input_ids} def state_dict(self): - return {} + return {"step": self.step} def load_state_dict(self, state_dict): - pass + self.step = state_dict["step"] + itera = iter(self) + for _ in range(self.step): + next(itera) class BatchOutput(TypedDict): diff --git a/src/zeroband/models/llama/__init__.py b/src/zeroband/models/llama/__init__.py index 30b54963..ff142d26 100644 --- a/src/zeroband/models/llama/__init__.py +++ b/src/zeroband/models/llama/__init__.py @@ -85,6 +85,7 @@ def get_model( type_model: str, vocab_size: int, seq_length: int, + math_attn: bool, ) -> tuple[Transformer, ModelArgs]: """get the transformer model""" @@ -97,5 +98,6 @@ def get_model( config.vocab_size = vocab_size config.max_seq_len = seq_length + config.math_attn = math_attn return Transformer(config), config diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index 234b8fee..c1a63403 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -11,6 +11,7 @@ # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import contextlib from dataclasses import dataclass from typing import Optional, Tuple @@ -20,6 +21,7 @@ from zeroband.models.norms import build_norm from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask, _DEFAULT_SPARSE_BLOCK_SIZE +from torch.nn.attention import SDPBackend, sdpa_kernel _flex_attention_compiled = torch.compile(flex_attention, dynamic=False) @@ -58,6 +60,8 @@ class ModelArgs: depth_init: bool = True norm_type: str = "fused_rmsnorm" + math_attn: bool = False # slow for testing + def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """ @@ -222,6 +226,8 @@ def __init__(self, model_args: ModelArgs): self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(model_args.n_heads * self.head_dim, model_args.dim, bias=False) + self.math_attn = model_args.math_attn + def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) @@ -271,7 +277,8 @@ def forward( return self.wo(output) def _sdpa_attention(self, xq, xk, xv) -> torch.Tensor: - output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + with sdpa_kernel(SDPBackend.MATH) if self.math_attn else contextlib.nullcontext(): + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) return output diff --git a/src/zeroband/train.py b/src/zeroband/train.py index d54391e3..15d0cd93 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -79,6 +79,8 @@ class TrainConfig(BaseConfig): sequence_packing: bool = True attn_fn: Literal["flash", "sdpa"] | None = None + math_attn: bool = False # slow + @model_validator(mode="after") def validate_attn_fn(self): if self.attn_fn is not None: @@ -133,6 +135,7 @@ def log_hash_training_state( inner_optimizer: torch.optim.Optimizer, diloco: Diloco | None, metric_logger: MetricLogger, + step: int, id: str = "", ): """Log the hash of the model and optimizer. This function is slow""" @@ -143,10 +146,11 @@ def log_hash_training_state( logger.debug(f"inner diloco model {id} : {inner_model_hash}") logger.debug(f"inner optimizer hash {id} : {inner_optimizer_hash}") - if world_info.rank == 0: - metric_logger.log( - {"inner_model_hash_{id}": inner_model_hash, "inner_optimizer_hash_{id}": inner_optimizer_hash} - ) + metrics = { + "step": step, + f"inner_model_hash_{id}": inner_model_hash, + f"inner_optimizer_hash_{id}": inner_optimizer_hash, + } if config.diloco is not None and diloco is not None: outer_optimizer_hash = get_optimizer_signature(diloco.outer_optimizer) @@ -155,10 +159,11 @@ def log_hash_training_state( logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}") logger.debug(f"outer diloco model hash {id} : {outer_model_hash}") - if world_info.rank == 0: - metric_logger.log( - {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} - ) + metrics.update( + {f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash} + ) + if world_info.rank == 0: + metric_logger.log(metrics) def train(config: Config): @@ -198,6 +203,7 @@ def train(config: Config): config.type_model, vocab_size=len(tokenizer) if config.name_model != "debugmodel" or not config.data.fake else TEST_VOCAB_SIZE, seq_length=config.data.seq_length, + math_attn=config.train.math_attn, ) model = model.to(world_info.local_rank) @@ -300,7 +306,9 @@ def train(config: Config): skip_dataloader=config.ckpt.skip_dataloader, data_path=config.ckpt.data_path, ) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="resume") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" + ) if config.train.memory_monitor: gpu_mem_monitor = GPUMemoryMonitor() @@ -349,7 +357,15 @@ def train(config: Config): ckpt_manager.recv_ckpt_from_peer(elastic_device_mesh.global_pg) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="live_reco_recv") + log_hash_training_state( + config, + model, + inner_optimizer, + diloco, + metric_logger, + step=training_progress.step, + id="live_reco_recv", + ) need_live_recovery = False if config.ckpt.remote_data_load: @@ -483,7 +499,9 @@ def train(config: Config): diloco.step(model=model, flag=training_progress.outer_step) diloco_time = time.perf_counter() - time_start_inner - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="outer_step") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="outer_step" + ) training_progress.outer_step += 1 @@ -496,7 +514,9 @@ def train(config: Config): do_remote = config.ckpt.remote is not None and training_progress.step % config.ckpt.remote.interval == 0 ckpt_manager.save(remote=do_remote) - log_hash_training_state(config, model, inner_optimizer, diloco, metric_logger, id="ckpt save") + log_hash_training_state( + config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="save" + ) if config.diloco: tokens_per_second = ( diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index e5703fe3..dc36701a 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -1,5 +1,7 @@ import copy import os +from pathlib import Path +import pickle import subprocess import pytest import socket @@ -112,3 +114,139 @@ def test_packing(packing: bool): num_gpus = [2, 1] packing_arg = "--train.sequence_packing" if packing else "--no-train.sequence_packing" _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=[packing_arg]) + + +def test_ckpt(tmp_path: Path): + num_gpus = [1, 2] + v1_file = tmp_path / "v1.log" + v2_file = tmp_path / "v2.log" + # v3_file = tmp_path / "v3.log" + + v1_ckpt = tmp_path / "v1_ckpt" + v2_ckpt = tmp_path / "v2_ckpt" + # v3_ckpt = tmp_path / "v3_ckpt" + + os.mkdir(v1_ckpt) + os.mkdir(v2_ckpt) + # os.mkdir(v3_ckpt) + + _test_multi_gpu( + num_gpus, + "debug/diloco.toml", + extra_args=[ + "--project", + str(v1_file), + "--ckpt.path", + str(v1_ckpt), + "--ckpt.interval", + "5", + "--optim.total_steps", + "20", + "--train.log_model_hash", + "--no-train.sequence_packing", + "--train.math_attn", + ], + diloco=True, + ) + _test_multi_gpu( + num_gpus, + "debug/diloco.toml", + extra_args=[ + "--project", + str(v2_file), + "--ckpt.path", + str(v2_ckpt), + "--ckpt.interval", + "5", + "--ckpt.resume", + str(v1_ckpt / "step_5"), + "--optim.total_steps", + "20", + "--train.log_model_hash", + "--no-train.sequence_packing", + "--train.math_attn", + ], + diloco=True, + ) + # _test_multi_gpu( + # num_gpus, + # "debug/diloco.toml", + # extra_args=[ + # "--project", + # str(v3_file), + # "--ckpt.path", + # str(v3_ckpt), + # "--ckpt.interval", + # "5", + # "--ckpt.resume", + # str(v2_ckpt / "step_10"), + # "--optim.total_steps", + # "20", + # "--train.log_model_hash", + # "--no-train.sequence_packing", + # "--train.math_attn", + # ], + # diloco=True, + # ) + + key_to_round = ["Perplexity", "Loss"] + digit_to_round = [0, 3] + + def read_logs(path: Path): + with path.open("rb") as f: + data = pickle.load(f) + + filtered_data = {} + for entry in data: + step = entry.pop("step") + + # Round perplexity and loss + for key, digit in zip(key_to_round, digit_to_round): + if key in entry: + entry[key] = round(entry[key], digit) + + if step in filtered_data: + filtered_data[step].update(entry) + else: + filtered_data[step] = entry + + return filtered_data + + v1_data = read_logs(v1_file) + v2_data = read_logs(v2_file) + # v3_data = read_logs(v3_file) + + ## check that loading from v1 to v2 worked + + # first check that the hash of saving is the same as the hash of loading + assert v1_data[5]["inner_model_hash_save"] == v2_data[5]["inner_model_hash_resume"] + assert v1_data[5]["inner_optimizer_hash_save"] == v2_data[5]["inner_optimizer_hash_resume"] + assert v1_data[5]["outer_optimizer_hash_save"] == v2_data[5]["outer_optimizer_hash_resume"] + assert v1_data[5]["outer_model_hash_save"] == v2_data[5]["outer_model_hash_resume"] + + # then we check that the loss and lr value are the same after loading the ckpt + for step, data_v2 in v2_data.items(): + if step == 5: + continue # not testing 5 as ts the one were we restarted from + + data_v1 = v1_data[step] + assert data_v1["Loss"] == data_v2["Loss"] + assert data_v1["inner_lr"] == data_v2["inner_lr"] + assert data_v1["total_tokens"] == data_v2["total_tokens"] + + # ## check that the second loading is working + # ## why ? We had bugs where ckpt was working but not when the training was resuming + + # assert v2_data[10]["inner_model_hash_save"] == v3_data[10]["inner_model_hash_resume"] + # assert v2_data[10]["inner_optimizer_hash_save"] == v3_data[10]["inner_optimizer_hash_resume"] + # assert v2_data[10]["outer_optimizer_hash_save"] == v3_data[10]["outer_optimizer_hash_resume"] + # assert v2_data[10]["outer_model_hash_save"] == v3_data[10]["outer_model_hash_resume"] + + # for step, data_v3 in v3_data.items(): + # if step == 10: + # continue # not testing 10 as ts the one were we restarted from + + # data_v2 = v2_data[step] + # assert data_v2["Loss"] == data_v3["Loss"] + # assert data_v2["inner_lr"] == data_v3["inner_lr"] + # assert data_v2["total_tokens"] == data_v3["total_tokens"] From dbddd9f6fd5c6bd5a14407b21dc21f343012e0d0 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 19 Dec 2024 09:49:04 +0000 Subject: [PATCH 3/3] remove useless memory stuff --- src/zeroband/train.py | 14 ---------- src/zeroband/utils/__init__.py | 50 ---------------------------------- 2 files changed, 64 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 15d0cd93..a7f6d486 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -2,7 +2,6 @@ from typing import Literal import time import warnings -import psutil from pydantic import model_validator from multiprocessing.process import _children @@ -24,7 +23,6 @@ from zeroband.utils import ( FakeTokenizer, - GPUMemoryMonitor, PerfCounter, get_module_signature, get_optimizer_signature, @@ -73,7 +71,6 @@ class TrainConfig(BaseConfig): log_model_hash: bool = False - memory_monitor: bool = False memory_profiler: MemoryProfilerConfig | None = None sequence_packing: bool = True @@ -310,8 +307,6 @@ def train(config: Config): config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume" ) - if config.train.memory_monitor: - gpu_mem_monitor = GPUMemoryMonitor() if config.train.memory_profiler is not None: memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir) @@ -447,7 +442,6 @@ def train(config: Config): # we count the total tokens with respect to all diloco workers # might need to tweak this as some worker might fail to join the all reduce later training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size() - remaining_cpu_ram = psutil.virtual_memory().available / (1024 * 1024 * 1024) metrics = { "Loss": loss_batch.item(), @@ -456,16 +450,11 @@ def train(config: Config): "Perplexity": torch.exp(loss_batch).item(), "total_tokens": training_progress.total_tokens, "time": time.time(), - "remaining_cpu_ram": remaining_cpu_ram, } if config.optim.z_loss: metrics["z_loss"] = z_loss_batch.item() - if config.train.memory_monitor: - peak_gpu_stats = gpu_mem_monitor.get_peak_stats() - metrics.update(peak_gpu_stats) - log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}" tokens_per_second = perf_counter.get_tokens_per_second() @@ -539,9 +528,6 @@ def train(config: Config): } ) - if config.train.memory_monitor: - logger.info(f"outer step peak gpu stats: {gpu_mem_monitor.format_peak_states()}") - if training_progress.step >= config.optim.total_steps: # we only allow to break outisde of the inner loop. # This avoid ending the training in the middle of a the inner loop diff --git a/src/zeroband/utils/__init__.py b/src/zeroband/utils/__init__.py index 1bb454fb..c0ea3699 100644 --- a/src/zeroband/utils/__init__.py +++ b/src/zeroband/utils/__init__.py @@ -1,13 +1,10 @@ import hashlib import socket import time -from typing import Any import torch from torch.distributed.fsdp import ShardingStrategy from torch.distributed._tensor.api import DTensor -from zeroband.utils.logging import get_logger - __all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"] @@ -165,53 +162,6 @@ def get_tensor_list_signature(tensor_list: list[torch.Tensor]) -> str: return hashlib.md5(str(tensors).encode("utf-8")).hexdigest() -class GPUMemoryMonitor: - # inspired from https://github.com/pytorch/torchtitan/blob/eef8bb2b1b6f0875ab0581079e1511d51654910e/torchtitan/metrics.py#L32 - def __init__(self, device: str = "cuda"): - self.device = torch.device(device) # device object - self.device_capacity = torch.cuda.get_device_properties(self.device).total_memory - self.device_capacity_gib = self._to_gib(self.device_capacity) - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - - self._logger = get_logger() - - def _to_gib(self, memory_in_bytes): - # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 - _gib_in_bytes = 1024 * 1024 * 1024 - memory_in_gib = memory_in_bytes / _gib_in_bytes - return memory_in_gib - - def _to_pct(self, memory): - return 100 * memory / self.device_capacity - - def get_peak_stats(self) -> dict[str, Any]: - cuda_info = torch.cuda.memory_stats(self.device) - - max_active = cuda_info["active_bytes.all.peak"] - max_active_gib = self._to_gib(max_active) - max_active_pct = self._to_pct(max_active) - - max_reserved = cuda_info["reserved_bytes.all.peak"] - max_reserved_gib = self._to_gib(max_reserved) - max_reserved_pct = self._to_pct(max_reserved) - - return { - "gpu_max_active_gib": max_active_gib, - "gpu_max_active_pct": max_active_pct, - "gpu_max_reserved_gib": max_reserved_gib, - "gpu_max_reserved_pct": max_reserved_pct, - } - - def reset_peak_stats(self): - torch.cuda.reset_peak_memory_stats() - - def format_peak_states(self, peak_stats: dict[str, Any] | None = None) -> str: - if peak_stats is None: - peak_stats = self.get_peak_stats() - return f"Active {peak_stats['gpu_max_active_gib']:.2f} GiB ({peak_stats['gpu_max_active_pct']:.2f}%), Reserved {peak_stats['gpu_max_reserved_gib']:.2f} GiB ({peak_stats['gpu_max_reserved_pct']:.2f}%)" - - def get_random_available_port_list(num_port): # https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number ports = []