From fce5fa1f8a6ba5562f1706d98181869a0b7895a6 Mon Sep 17 00:00:00 2001 From: aldo Date: Thu, 20 Jun 2024 19:04:30 -0400 Subject: [PATCH 1/7] async logger first commit --- requirements.txt | 2 +- src/instructlab/training/async_logger.py | 45 ++++++++++++++++++++++++ src/instructlab/training/main_ds.py | 4 ++- 3 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 src/instructlab/training/async_logger.py diff --git a/requirements.txt b/requirements.txt index 1bbec7e4..981f7fc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,4 @@ pydantic>=2.7.0 # deepspeed needs to be at the bottom or it'll break during installation deepspeed>=0.14.3 - +aiofiles>=23.2.1 \ No newline at end of file diff --git a/src/instructlab/training/async_logger.py b/src/instructlab/training/async_logger.py new file mode 100644 index 00000000..72841fae --- /dev/null +++ b/src/instructlab/training/async_logger.py @@ -0,0 +1,45 @@ +# File: async_logger.py + +import json +import asyncio +from datetime import datetime +import aiofiles +import threading + +class AsyncStructuredLogger: + def __init__(self, file_name='training_log.json'): + self.file_name = file_name + self.logs = [] + self.loop = asyncio.new_event_loop() + t = threading.Thread(target=self._run_event_loop, args=(self.loop,), daemon=True) + t.start() + asyncio.run_coroutine_threadsafe(self._initialize_log_file(), self.loop) + + def _run_event_loop(self, loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + async def _initialize_log_file(self): + try: + async with aiofiles.open(self.file_name, 'r') as f: + self.logs = json.loads(await f.read()) + except FileNotFoundError: + async with aiofiles.open(self.file_name, 'w') as f: + await f.write(json.dumps(self.logs, indent=4)) + + async def log(self, data): + if not isinstance(data, dict): + raise ValueError("Logged data must be a dictionary") + data['timestamp'] = datetime.now().isoformat() + self.logs.append(data) + await self._write_logs_to_file() + + async def _write_logs_to_file(self): + async with aiofiles.open(self.file_name, 'w') as f: + await f.write(json.dumps(self.logs, indent=4)) + + def log_sync(self, data: dict): + asyncio.run_coroutine_threadsafe(self.log(data), self.loop) + + def __repr__(self): + return f"" diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 1eb21a68..9ac5706c 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -30,6 +30,7 @@ ) from instructlab.training.token_dataset import setup_dataloader, setup_dataset from instructlab.training.tokenizer_utils import setup_tokenizer +from instructlab.training.async_logger import AsyncStructuredLogger from instructlab.training.utils import ( StreamablePopen, add_noisy_embeddings, @@ -434,9 +435,10 @@ def train(args, model, tokenizer, train_loader, grad_accum): def main(args): # Third Party import yaml - + metric_logger = AsyncStructuredLogger(args.output_dir + "/training_params_and_metrics.json") if os.environ["LOCAL_RANK"] == "0": print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m") + metric_logger.log_sync({'script_params': vars(args)}) setup_logger(args.log_level) CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path) From 8a5e9f403908b68a147d37a4e927ecd9002e1023 Mon Sep 17 00:00:00 2001 From: aldo Date: Fri, 21 Jun 2024 02:34:04 -0400 Subject: [PATCH 2/7] jsonl not being populated --- requirements.txt | 2 +- src/instructlab/training/async_logger.py | 28 ++++++++++----- src/instructlab/training/main_ds.py | 44 +++++++++++++++++++++--- 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/requirements.txt b/requirements.txt index 981f7fc2..a380cd8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ torch>=2.3.0a0 transformers>=4.41.2 datasets>=2.15.0 numba -numpy +numpy==1.26.4 rich dolomite-engine @ git+https://github.com/ibm-granite/dolomite-engine.git@main trl>=0.9.4 diff --git a/src/instructlab/training/async_logger.py b/src/instructlab/training/async_logger.py index 72841fae..ac1cd430 100644 --- a/src/instructlab/training/async_logger.py +++ b/src/instructlab/training/async_logger.py @@ -5,9 +5,10 @@ from datetime import datetime import aiofiles import threading +import os class AsyncStructuredLogger: - def __init__(self, file_name='training_log.json'): + def __init__(self, file_name='training_log.jsonl'): self.file_name = file_name self.logs = [] self.loop = asyncio.new_event_loop() @@ -20,12 +21,15 @@ def _run_event_loop(self, loop): loop.run_forever() async def _initialize_log_file(self): - try: - async with aiofiles.open(self.file_name, 'r') as f: - self.logs = json.loads(await f.read()) - except FileNotFoundError: - async with aiofiles.open(self.file_name, 'w') as f: - await f.write(json.dumps(self.logs, indent=4)) + self.logs = [] + try: + async with aiofiles.open(self.file_name, 'r') as f: + async for line in f: + if line.strip(): # Avoid empty lines + self.logs.append(json.loads(line.strip())) + except FileNotFoundError: + pass + async def log(self, data): if not isinstance(data, dict): @@ -35,8 +39,14 @@ async def log(self, data): await self._write_logs_to_file() async def _write_logs_to_file(self): - async with aiofiles.open(self.file_name, 'w') as f: - await f.write(json.dumps(self.logs, indent=4)) + temp_file_name = f"{self.file_name}.tmp" + async with aiofiles.open(temp_file_name, 'w') as temp_file: + await temp_file.write(json.dumps(self.logs[-1], indent=None) + '\n') + await temp_file.flush() # Flush the file buffer + os.fsync(temp_file.fileno()) # Sync the file with the storage device + + # Rename the temporary file to the main file name + os.replace(temp_file_name, self.file_name) def log_sync(self, data: dict): asyncio.run_coroutine_threadsafe(self.log(data), self.loop) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 9ac5706c..b1815c5a 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -317,7 +317,7 @@ def maybe_resume_training(args, model): return model -def train(args, model, tokenizer, train_loader, grad_accum): +def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): model.train() global_step = 1 @@ -395,6 +395,11 @@ def train(args, model, tokenizer, train_loader, grad_accum): current_lr = model.lr_scheduler.get_last_lr()[0] cuda_mem_allocated = torch.cuda.memory_allocated() / (1024**3) cuda_malloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + global_grad_norm = model.get_global_grad_norm() + global_grad_norm = ( + float(global_grad_norm) if global_grad_norm is not None else None + ) + weight_norm = float(model.optimizer.single_partition_of_fp32_groups[0].norm()) print( f"throughput: {overall_throughput} " @@ -404,9 +409,29 @@ def train(args, model, tokenizer, train_loader, grad_accum): f"cuda_malloc_retries: {cuda_malloc_retries} " f"num_loss_counted_tokens: {num_loss_counted_tokens} " f"batch_size: {aggregated_values[1]} " - f"total loss: {aggregated_values[2]/num_loss_counted_tokens}" + f"total loss: {aggregated_values[2]/num_loss_counted_tokens} " + f"gradnorm: {global_grad_norm} " + f"weight_norm: {weight_norm}" + ) + metric_logger.log_sync( + { + "epoch": epoch, + "step": global_step, + "rank": torch.distributed.get_rank(), + "loss": loss.item(), + "overall_throughput": overall_throughput, + "lr": current_lr, + "cuda_mem_allocated": cuda_mem_allocated, + "cuda_malloc_retries": cuda_malloc_retries, + "num_loss_counted_tokens": int(num_loss_counted_tokens), + "batch_size": int(aggregated_values[1]), + "total_loss": float(aggregated_values[2] / num_loss_counted_tokens), + "gradnorm": global_grad_norm if global_grad_norm is not None else None, + "weight_norm": weight_norm, + } ) + if global_step * batch_size % args.save_samples == 0: save_hf_format_ds( args, @@ -435,7 +460,7 @@ def train(args, model, tokenizer, train_loader, grad_accum): def main(args): # Third Party import yaml - metric_logger = AsyncStructuredLogger(args.output_dir + "/training_params_and_metrics.json") + metric_logger = AsyncStructuredLogger(args.output_dir + "/training_params_and_metrics.jsonl") if os.environ["LOCAL_RANK"] == "0": print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m") metric_logger.log_sync({'script_params': vars(args)}) @@ -496,11 +521,22 @@ def main(args): f"avg_samples_per_batch: {len(dataset)/len(train_loader)}\n" f"samples_per_gpu: {args.samples_per_gpu}\033[0m" ) + metric_logger.log_sync({ + 'num_gpus': torch.distributed.get_world_size(), + 'avg_sample_len': dataset.get_lengths().mean(), + 'effective_batch_size': args.effective_batch_size, + 'max_batch_len_per_gpu': args.max_batch_len, + 'packing_max_batch_len': packing_max_batch_len, + 'grad_accum': grad_accum, + 'num_batches': len(train_loader), + 'avg_samples_per_batch': len(dataset)/len(train_loader), + 'samples_per_gpu': args.samples_per_gpu + }) model = setup_model(args, tokenizer, train_loader, grad_accum) model = maybe_resume_training(args, model) - train(args, model, tokenizer, train_loader, grad_accum) + train(args, model, tokenizer, train_loader, grad_accum, metric_logger) torch.distributed.barrier() torch.distributed.destroy_process_group() From 2731161365ca173ada56a5aab45664e637e20918 Mon Sep 17 00:00:00 2001 From: aldo Date: Fri, 21 Jun 2024 15:33:42 -0400 Subject: [PATCH 3/7] async logger fix bug --- src/instructlab/training/async_logger.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/instructlab/training/async_logger.py b/src/instructlab/training/async_logger.py index ac1cd430..752b7a9c 100644 --- a/src/instructlab/training/async_logger.py +++ b/src/instructlab/training/async_logger.py @@ -1,5 +1,3 @@ -# File: async_logger.py - import json import asyncio from datetime import datetime @@ -17,7 +15,7 @@ def __init__(self, file_name='training_log.jsonl'): asyncio.run_coroutine_threadsafe(self._initialize_log_file(), self.loop) def _run_event_loop(self, loop): - asyncio.set_event_loop(loop) + asyncio.set_event_loop(loop) # loop.run_forever() async def _initialize_log_file(self): @@ -28,28 +26,26 @@ async def _initialize_log_file(self): if line.strip(): # Avoid empty lines self.logs.append(json.loads(line.strip())) except FileNotFoundError: + # File does not exist but the first log will create it. pass async def log(self, data): + '''logs a dictionary as a new line in a jsonl file with a timestamp''' if not isinstance(data, dict): raise ValueError("Logged data must be a dictionary") data['timestamp'] = datetime.now().isoformat() self.logs.append(data) - await self._write_logs_to_file() - - async def _write_logs_to_file(self): - temp_file_name = f"{self.file_name}.tmp" - async with aiofiles.open(temp_file_name, 'w') as temp_file: - await temp_file.write(json.dumps(self.logs[-1], indent=None) + '\n') - await temp_file.flush() # Flush the file buffer - os.fsync(temp_file.fileno()) # Sync the file with the storage device + await self._write_logs_to_file(data) - # Rename the temporary file to the main file name - os.replace(temp_file_name, self.file_name) + async def _write_logs_to_file(self, data): + '''appends to the log instead of writing the whole log each time''' + async with aiofiles.open(self.file_name, 'a') as f: + await f.write(json.dumps(data, indent=None) + '\n') def log_sync(self, data: dict): + '''runs the log coroutine non-blocking''' asyncio.run_coroutine_threadsafe(self.log(data), self.loop) def __repr__(self): - return f"" + return f"" \ No newline at end of file From d4342a217d9d743bca97a3bfb254423bc6680c53 Mon Sep 17 00:00:00 2001 From: aldo Date: Fri, 21 Jun 2024 15:39:56 -0400 Subject: [PATCH 4/7] async add print logic as well --- src/instructlab/training/async_logger.py | 1 + src/instructlab/training/main_ds.py | 46 ++++++++++++------------ 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/instructlab/training/async_logger.py b/src/instructlab/training/async_logger.py index 752b7a9c..96fc3067 100644 --- a/src/instructlab/training/async_logger.py +++ b/src/instructlab/training/async_logger.py @@ -37,6 +37,7 @@ async def log(self, data): data['timestamp'] = datetime.now().isoformat() self.logs.append(data) await self._write_logs_to_file(data) + {{ print(f"\033[92m{json.dumps(data, indent=4)}\033[0m") }} async def _write_logs_to_file(self, data): '''appends to the log instead of writing the whole log each time''' diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index b1815c5a..50148476 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -401,18 +401,18 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): ) weight_norm = float(model.optimizer.single_partition_of_fp32_groups[0].norm()) - print( - f"throughput: {overall_throughput} " - f"samples/s, lr: {current_lr}, " - f"loss: {loss.item()} " - f"cuda_mem_allocated: {cuda_mem_allocated} GB " - f"cuda_malloc_retries: {cuda_malloc_retries} " - f"num_loss_counted_tokens: {num_loss_counted_tokens} " - f"batch_size: {aggregated_values[1]} " - f"total loss: {aggregated_values[2]/num_loss_counted_tokens} " - f"gradnorm: {global_grad_norm} " - f"weight_norm: {weight_norm}" - ) + # print( + # f"throughput: {overall_throughput} " + # f"samples/s, lr: {current_lr}, " + # f"loss: {loss.item()} " + # f"cuda_mem_allocated: {cuda_mem_allocated} GB " + # f"cuda_malloc_retries: {cuda_malloc_retries} " + # f"num_loss_counted_tokens: {num_loss_counted_tokens} " + # f"batch_size: {aggregated_values[1]} " + # f"total loss: {aggregated_values[2]/num_loss_counted_tokens} " + # f"gradnorm: {global_grad_norm} " + # f"weight_norm: {weight_norm}" + # ) metric_logger.log_sync( { "epoch": epoch, @@ -510,17 +510,17 @@ def main(args): ) if args.local_rank == 0: - print( - f"\033[96mnum_gpus: {torch.distributed.get_world_size()}\n" - f"avg_sample_len: {dataset.get_lengths().mean()}\n" - f"effective_batch_size: {args.effective_batch_size}\n" - f"max_batch_len_per_gpu: {args.max_batch_len}\n" - f"packing_max_batch_len: {packing_max_batch_len}\n" - f"grad_accum: {grad_accum}\n" - f"num batches: {len(train_loader)}\n" - f"avg_samples_per_batch: {len(dataset)/len(train_loader)}\n" - f"samples_per_gpu: {args.samples_per_gpu}\033[0m" - ) + # print( + # f"\033[96mnum_gpus: {torch.distributed.get_world_size()}\n" + # f"avg_sample_len: {dataset.get_lengths().mean()}\n" + # f"effective_batch_size: {args.effective_batch_size}\n" + # f"max_batch_len_per_gpu: {args.max_batch_len}\n" + # f"packing_max_batch_len: {packing_max_batch_len}\n" + # f"grad_accum: {grad_accum}\n" + # f"num batches: {len(train_loader)}\n" + # f"avg_samples_per_batch: {len(dataset)/len(train_loader)}\n" + # f"samples_per_gpu: {args.samples_per_gpu}\033[0m" + # ) metric_logger.log_sync({ 'num_gpus': torch.distributed.get_world_size(), 'avg_sample_len': dataset.get_lengths().mean(), From d82dd0fc2f5112e9c44b3ddeb4589116f3e865de Mon Sep 17 00:00:00 2001 From: aldo Date: Fri, 21 Jun 2024 15:50:22 -0400 Subject: [PATCH 5/7] removed prints --- src/instructlab/training/main_ds.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 50148476..40af2969 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -401,18 +401,6 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): ) weight_norm = float(model.optimizer.single_partition_of_fp32_groups[0].norm()) - # print( - # f"throughput: {overall_throughput} " - # f"samples/s, lr: {current_lr}, " - # f"loss: {loss.item()} " - # f"cuda_mem_allocated: {cuda_mem_allocated} GB " - # f"cuda_malloc_retries: {cuda_malloc_retries} " - # f"num_loss_counted_tokens: {num_loss_counted_tokens} " - # f"batch_size: {aggregated_values[1]} " - # f"total loss: {aggregated_values[2]/num_loss_counted_tokens} " - # f"gradnorm: {global_grad_norm} " - # f"weight_norm: {weight_norm}" - # ) metric_logger.log_sync( { "epoch": epoch, @@ -510,17 +498,6 @@ def main(args): ) if args.local_rank == 0: - # print( - # f"\033[96mnum_gpus: {torch.distributed.get_world_size()}\n" - # f"avg_sample_len: {dataset.get_lengths().mean()}\n" - # f"effective_batch_size: {args.effective_batch_size}\n" - # f"max_batch_len_per_gpu: {args.max_batch_len}\n" - # f"packing_max_batch_len: {packing_max_batch_len}\n" - # f"grad_accum: {grad_accum}\n" - # f"num batches: {len(train_loader)}\n" - # f"avg_samples_per_batch: {len(dataset)/len(train_loader)}\n" - # f"samples_per_gpu: {args.samples_per_gpu}\033[0m" - # ) metric_logger.log_sync({ 'num_gpus': torch.distributed.get_world_size(), 'avg_sample_len': dataset.get_lengths().mean(), From 7e3f0117f6ca183c821543fd59e82287511ebd53 Mon Sep 17 00:00:00 2001 From: aldo Date: Fri, 21 Jun 2024 15:51:24 -0400 Subject: [PATCH 6/7] black formatting --- src/instructlab/training/async_logger.py | 44 +++++++++++++----------- src/instructlab/training/main_ds.py | 44 +++++++++++++++--------- 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/src/instructlab/training/async_logger.py b/src/instructlab/training/async_logger.py index 96fc3067..6c66a452 100644 --- a/src/instructlab/training/async_logger.py +++ b/src/instructlab/training/async_logger.py @@ -5,48 +5,50 @@ import threading import os + class AsyncStructuredLogger: - def __init__(self, file_name='training_log.jsonl'): + def __init__(self, file_name="training_log.jsonl"): self.file_name = file_name self.logs = [] self.loop = asyncio.new_event_loop() - t = threading.Thread(target=self._run_event_loop, args=(self.loop,), daemon=True) + t = threading.Thread( + target=self._run_event_loop, args=(self.loop,), daemon=True + ) t.start() asyncio.run_coroutine_threadsafe(self._initialize_log_file(), self.loop) def _run_event_loop(self, loop): - asyncio.set_event_loop(loop) # + asyncio.set_event_loop(loop) # loop.run_forever() async def _initialize_log_file(self): - self.logs = [] - try: - async with aiofiles.open(self.file_name, 'r') as f: - async for line in f: - if line.strip(): # Avoid empty lines - self.logs.append(json.loads(line.strip())) - except FileNotFoundError: - # File does not exist but the first log will create it. - pass - + self.logs = [] + try: + async with aiofiles.open(self.file_name, "r") as f: + async for line in f: + if line.strip(): # Avoid empty lines + self.logs.append(json.loads(line.strip())) + except FileNotFoundError: + # File does not exist but the first log will create it. + pass async def log(self, data): - '''logs a dictionary as a new line in a jsonl file with a timestamp''' + """logs a dictionary as a new line in a jsonl file with a timestamp""" if not isinstance(data, dict): raise ValueError("Logged data must be a dictionary") - data['timestamp'] = datetime.now().isoformat() + data["timestamp"] = datetime.now().isoformat() self.logs.append(data) await self._write_logs_to_file(data) - {{ print(f"\033[92m{json.dumps(data, indent=4)}\033[0m") }} + {{print(f"\033[92m{json.dumps(data, indent=4)}\033[0m")}} async def _write_logs_to_file(self, data): - '''appends to the log instead of writing the whole log each time''' - async with aiofiles.open(self.file_name, 'a') as f: - await f.write(json.dumps(data, indent=None) + '\n') + """appends to the log instead of writing the whole log each time""" + async with aiofiles.open(self.file_name, "a") as f: + await f.write(json.dumps(data, indent=None) + "\n") def log_sync(self, data: dict): - '''runs the log coroutine non-blocking''' + """runs the log coroutine non-blocking""" asyncio.run_coroutine_threadsafe(self.log(data), self.loop) def __repr__(self): - return f"" \ No newline at end of file + return f"" diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 40af2969..97ec95cb 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -399,7 +399,9 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): global_grad_norm = ( float(global_grad_norm) if global_grad_norm is not None else None ) - weight_norm = float(model.optimizer.single_partition_of_fp32_groups[0].norm()) + weight_norm = float( + model.optimizer.single_partition_of_fp32_groups[0].norm() + ) metric_logger.log_sync( { @@ -413,13 +415,16 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): "cuda_malloc_retries": cuda_malloc_retries, "num_loss_counted_tokens": int(num_loss_counted_tokens), "batch_size": int(aggregated_values[1]), - "total_loss": float(aggregated_values[2] / num_loss_counted_tokens), - "gradnorm": global_grad_norm if global_grad_norm is not None else None, + "total_loss": float( + aggregated_values[2] / num_loss_counted_tokens + ), + "gradnorm": ( + global_grad_norm if global_grad_norm is not None else None + ), "weight_norm": weight_norm, } ) - if global_step * batch_size % args.save_samples == 0: save_hf_format_ds( args, @@ -448,10 +453,13 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): def main(args): # Third Party import yaml - metric_logger = AsyncStructuredLogger(args.output_dir + "/training_params_and_metrics.jsonl") + + metric_logger = AsyncStructuredLogger( + args.output_dir + "/training_params_and_metrics.jsonl" + ) if os.environ["LOCAL_RANK"] == "0": print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m") - metric_logger.log_sync({'script_params': vars(args)}) + metric_logger.log_sync({"script_params": vars(args)}) setup_logger(args.log_level) CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path) @@ -498,17 +506,19 @@ def main(args): ) if args.local_rank == 0: - metric_logger.log_sync({ - 'num_gpus': torch.distributed.get_world_size(), - 'avg_sample_len': dataset.get_lengths().mean(), - 'effective_batch_size': args.effective_batch_size, - 'max_batch_len_per_gpu': args.max_batch_len, - 'packing_max_batch_len': packing_max_batch_len, - 'grad_accum': grad_accum, - 'num_batches': len(train_loader), - 'avg_samples_per_batch': len(dataset)/len(train_loader), - 'samples_per_gpu': args.samples_per_gpu - }) + metric_logger.log_sync( + { + "num_gpus": torch.distributed.get_world_size(), + "avg_sample_len": dataset.get_lengths().mean(), + "effective_batch_size": args.effective_batch_size, + "max_batch_len_per_gpu": args.max_batch_len, + "packing_max_batch_len": packing_max_batch_len, + "grad_accum": grad_accum, + "num_batches": len(train_loader), + "avg_samples_per_batch": len(dataset) / len(train_loader), + "samples_per_gpu": args.samples_per_gpu, + } + ) model = setup_model(args, tokenizer, train_loader, grad_accum) model = maybe_resume_training(args, model) From bf61b2644927130d4c2d38886aec0109c16b02f3 Mon Sep 17 00:00:00 2001 From: aldo Date: Fri, 21 Jun 2024 15:52:43 -0400 Subject: [PATCH 7/7] redundant calc --- src/instructlab/training/main_ds.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 97ec95cb..5783700e 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -418,9 +418,7 @@ def train(args, model, tokenizer, train_loader, grad_accum, metric_logger): "total_loss": float( aggregated_values[2] / num_loss_counted_tokens ), - "gradnorm": ( - global_grad_norm if global_grad_norm is not None else None - ), + "gradnorm": global_grad_norm, "weight_norm": weight_norm, } )