Skip to content

Commit

Permalink
remove useless memory stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 19, 2024
1 parent 5a39b9a commit dbddd9f
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 64 deletions.
14 changes: 0 additions & 14 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Literal
import time
import warnings
import psutil
from pydantic import model_validator
from multiprocessing.process import _children

Expand All @@ -24,7 +23,6 @@

from zeroband.utils import (
FakeTokenizer,
GPUMemoryMonitor,
PerfCounter,
get_module_signature,
get_optimizer_signature,
Expand Down Expand Up @@ -73,7 +71,6 @@ class TrainConfig(BaseConfig):

log_model_hash: bool = False

memory_monitor: bool = False
memory_profiler: MemoryProfilerConfig | None = None

sequence_packing: bool = True
Expand Down Expand Up @@ -310,8 +307,6 @@ def train(config: Config):
config, model, inner_optimizer, diloco, metric_logger, step=training_progress.step, id="resume"
)

if config.train.memory_monitor:
gpu_mem_monitor = GPUMemoryMonitor()
if config.train.memory_profiler is not None:
memory_profiler = MemoryProfiler(config.train.memory_profiler.freq, config.train.memory_profiler.snapshot_dir)

Expand Down Expand Up @@ -447,7 +442,6 @@ def train(config: Config):
# we count the total tokens with respect to all diloco workers
# might need to tweak this as some worker might fail to join the all reduce later
training_progress.total_tokens += new_tokens * elastic_device_mesh.global_pg.size()
remaining_cpu_ram = psutil.virtual_memory().available / (1024 * 1024 * 1024)

metrics = {
"Loss": loss_batch.item(),
Expand All @@ -456,16 +450,11 @@ def train(config: Config):
"Perplexity": torch.exp(loss_batch).item(),
"total_tokens": training_progress.total_tokens,
"time": time.time(),
"remaining_cpu_ram": remaining_cpu_ram,
}

if config.optim.z_loss:
metrics["z_loss"] = z_loss_batch.item()

if config.train.memory_monitor:
peak_gpu_stats = gpu_mem_monitor.get_peak_stats()
metrics.update(peak_gpu_stats)

log = f"step: {training_progress.step}, loss: {loss_batch.item():.4f}"

tokens_per_second = perf_counter.get_tokens_per_second()
Expand Down Expand Up @@ -539,9 +528,6 @@ def train(config: Config):
}
)

if config.train.memory_monitor:
logger.info(f"outer step peak gpu stats: {gpu_mem_monitor.format_peak_states()}")

if training_progress.step >= config.optim.total_steps:
# we only allow to break outisde of the inner loop.
# This avoid ending the training in the middle of a the inner loop
Expand Down
50 changes: 0 additions & 50 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import hashlib
import socket
import time
from typing import Any
import torch
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed._tensor.api import DTensor

from zeroband.utils.logging import get_logger


__all__ = ["get_sharding_strategy", "get_peak_flops", "get_num_flop_per_token", "get_num_params"]

Expand Down Expand Up @@ -165,53 +162,6 @@ def get_tensor_list_signature(tensor_list: list[torch.Tensor]) -> str:
return hashlib.md5(str(tensors).encode("utf-8")).hexdigest()


class GPUMemoryMonitor:
# inspired from https://github.com/pytorch/torchtitan/blob/eef8bb2b1b6f0875ab0581079e1511d51654910e/torchtitan/metrics.py#L32
def __init__(self, device: str = "cuda"):
self.device = torch.device(device) # device object
self.device_capacity = torch.cuda.get_device_properties(self.device).total_memory
self.device_capacity_gib = self._to_gib(self.device_capacity)
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

self._logger = get_logger()

def _to_gib(self, memory_in_bytes):
# NOTE: GiB (gibibyte) is 1024, vs GB is 1000
_gib_in_bytes = 1024 * 1024 * 1024
memory_in_gib = memory_in_bytes / _gib_in_bytes
return memory_in_gib

def _to_pct(self, memory):
return 100 * memory / self.device_capacity

def get_peak_stats(self) -> dict[str, Any]:
cuda_info = torch.cuda.memory_stats(self.device)

max_active = cuda_info["active_bytes.all.peak"]
max_active_gib = self._to_gib(max_active)
max_active_pct = self._to_pct(max_active)

max_reserved = cuda_info["reserved_bytes.all.peak"]
max_reserved_gib = self._to_gib(max_reserved)
max_reserved_pct = self._to_pct(max_reserved)

return {
"gpu_max_active_gib": max_active_gib,
"gpu_max_active_pct": max_active_pct,
"gpu_max_reserved_gib": max_reserved_gib,
"gpu_max_reserved_pct": max_reserved_pct,
}

def reset_peak_stats(self):
torch.cuda.reset_peak_memory_stats()

def format_peak_states(self, peak_stats: dict[str, Any] | None = None) -> str:
if peak_stats is None:
peak_stats = self.get_peak_stats()
return f"Active {peak_stats['gpu_max_active_gib']:.2f} GiB ({peak_stats['gpu_max_active_pct']:.2f}%), Reserved {peak_stats['gpu_max_reserved_gib']:.2f} GiB ({peak_stats['gpu_max_reserved_pct']:.2f}%)"


def get_random_available_port_list(num_port):
# https://stackoverflow.com/questions/1365265/on-localhost-how-do-i-pick-a-free-port-number
ports = []
Expand Down

0 comments on commit dbddd9f

Please sign in to comment.