From 9a9c3b75b5b3359701844a91a9fae6d2979866cd Mon Sep 17 00:00:00 2001 From: zhifu gao Date: Wed, 17 Jan 2024 18:28:28 +0800 Subject: [PATCH] Funasr1.0 (#1261) * funasr1.0 funetine * funasr1.0 pbar * update with main (#1260) * Update websocket_protocol_zh.md * update --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --------- Co-authored-by: Yabin Li Co-authored-by: shixian.shi --- .../paraformer/finetune.sh | 4 +- .../seaco_paraformer/demo.py | 2 +- funasr/auto/auto_model.py | 34 ++- funasr/bin/train.py | 17 +- funasr/datasets/audio_datasets/index_ds.py | 6 +- funasr/datasets/audio_datasets/samplers.py | 3 +- funasr/models/paraformer/model.py | 1 + funasr/models/paraformer/template.yaml | 1 + funasr/train_utils/average_nbest_models.py | 266 +++++++++++------- funasr/train_utils/trainer.py | 107 +++++-- 10 files changed, 296 insertions(+), 145 deletions(-) diff --git a/examples/industrial_data_pretraining/paraformer/finetune.sh b/examples/industrial_data_pretraining/paraformer/finetune.sh index 93cce73f3..7d8987602 100644 --- a/examples/industrial_data_pretraining/paraformer/finetune.sh +++ b/examples/industrial_data_pretraining/paraformer/finetune.sh @@ -9,9 +9,11 @@ python funasr/bin/train.py \ +model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ +model_revision="v2.0.2" \ -+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \ ++train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ ++valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \ ++dataset_conf.batch_size=2 \ ++dataset_conf.batch_type="example" \ +++train_conf.max_epoch=2 \ +output_dir="outputs/debug/ckpt/funasr2/exp2" \ +device="cpu" \ +debug="true" \ No newline at end of file diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 5f17252f9..19ad1c9c5 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -15,6 +15,6 @@ spk_model_revision="v2.0.2", ) -res = model.generate(input=f"{model.model_path}/example/asr_example.wav", +res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", hotword='达摩院 魔搭') print(res) \ No newline at end of file diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py index 740614c74..bedc17d16 100644 --- a/funasr/auto/auto_model.py +++ b/funasr/auto/auto_model.py @@ -221,7 +221,8 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** speed_stats = {} asr_result_list = [] num_samples = len(data_list) - pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) + disable_pbar = kwargs.get("disable_pbar", False) + pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None time_speech_total = 0.0 time_escape_total = 0.0 for beg_idx in range(0, num_samples, batch_size): @@ -239,8 +240,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** time2 = time.perf_counter() asr_result_list.extend(results) - pbar.update(1) - + # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item() batch_data_time = meta_data.get("batch_data_time", -1) time_escape = time2 - time1 @@ -252,12 +252,15 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, ** description = ( f"{speed_stats}, " ) - pbar.set_description(description) + if pbar: + pbar.update(1) + pbar.set_description(description) time_speech_total += batch_data_time time_escape_total += time_escape - - pbar.update(1) - pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") + + if pbar: + pbar.update(1) + pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}") torch.cuda.empty_cache() return asr_result_list @@ -309,8 +312,11 @@ def inference_with_vad(self, input, input_len=None, **cfg): time_speech_total_per_sample = speech_lengths/16000 time_speech_total_all_samples += time_speech_total_per_sample + pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True) + all_segments = [] for j, _ in enumerate(range(0, n)): + pbar_sample.update(1) batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0]) if j < n - 1 and ( batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and ( @@ -319,13 +325,14 @@ def inference_with_vad(self, input, input_len=None, **cfg): batch_size_ms_cum = 0 end_idx = j + 1 speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx]) - results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg) + results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg) if self.spk_model is not None: - + + # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]] for _b in range(len(speech_j)): - vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \ - sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \ + vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, + sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, speech_j[_b]]] segments = sv_chunk(vad_segments) all_segments.extend(segments) @@ -338,12 +345,13 @@ def inference_with_vad(self, input, input_len=None, **cfg): results_sorted.extend(results) - pbar_total.update(1) + end_asr_total = time.time() time_escape_total_per_sample = end_asr_total - beg_asr_total - pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " + pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, " f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, " f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}") + restored_data = [0] * n for j in range(n): diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 7ae687ef9..0334006c5 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -141,30 +141,37 @@ def main(**kwargs): scheduler_class = scheduler_classes.get(scheduler) scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) - # import pdb; - # pdb.set_trace() + # dataset dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf")) + dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, + **kwargs.get("dataset_conf")) # dataloader batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler") - batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) + batch_sampler_val = None if batch_sampler is not None: + batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) + batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, batch_sampler=batch_sampler, num_workers=kwargs.get("dataset_conf").get("num_workers", 4), pin_memory=True) - + dataloader_val = torch.utils.data.DataLoader(dataset_val, + collate_fn=dataset_val.collator, + batch_sampler=batch_sampler_val, + num_workers=kwargs.get("dataset_conf").get("num_workers", 4), + pin_memory=True) trainer = Trainer( model=model, optim=optim, scheduler=scheduler, dataloader_train=dataloader_tr, - dataloader_val=None, + dataloader_val=dataloader_val, local_rank=local_rank, use_ddp=use_ddp, use_fsdp=use_fsdp, diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py index 8e5b05cf3..c94d20961 100644 --- a/funasr/datasets/audio_datasets/index_ds.py +++ b/funasr/datasets/audio_datasets/index_ds.py @@ -54,7 +54,11 @@ def __len__(self): return len(self.contents) def __getitem__(self, index): - return self.contents[index] + try: + data = self.contents[index] + except: + print(index) + return data def get_source_len(self, data_dict): return data_dict["source_len"] diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 4af35e9cd..e170c681b 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -13,6 +13,7 @@ def __init__(self, dataset, buffer_size: int = 30, drop_last: bool = False, shuffle: bool = True, + is_training: bool = True, **kwargs): self.drop_last = drop_last @@ -24,7 +25,7 @@ def __init__(self, dataset, self.buffer_size = buffer_size self.max_token_length = kwargs.get("max_token_length", 5000) self.shuffle_idx = np.arange(self.total_samples) - self.shuffle = shuffle + self.shuffle = shuffle and is_training def __len__(self): return self.total_samples diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py index f92441d39..9f3c3f3b6 100644 --- a/funasr/models/paraformer/model.py +++ b/funasr/models/paraformer/model.py @@ -164,6 +164,7 @@ def __init__( self.use_1st_decoder_loss = use_1st_decoder_loss self.length_normalized_loss = length_normalized_loss self.beam_search = None + self.error_calculator = None def forward( self, diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml index 94eebf7bd..3972caaa1 100644 --- a/funasr/models/paraformer/template.yaml +++ b/funasr/models/paraformer/template.yaml @@ -95,6 +95,7 @@ train_conf: - acc - max keep_nbest_models: 10 + avg_nbest_model: 5 log_interval: 50 optim: adam diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py index 96e138428..f117804f3 100644 --- a/funasr/train_utils/average_nbest_models.py +++ b/funasr/train_utils/average_nbest_models.py @@ -9,117 +9,173 @@ import torch from typing import Collection +import os +import torch +import re +from collections import OrderedDict +from functools import cmp_to_key -from funasr.train.reporter import Reporter +# @torch.no_grad() +# def average_nbest_models( +# output_dir: Path, +# best_model_criterion: Sequence[Sequence[str]], +# nbest: Union[Collection[int], int], +# suffix: Optional[str] = None, +# oss_bucket=None, +# pai_output_dir=None, +# ) -> None: +# """Generate averaged model from n-best models +# +# Args: +# output_dir: The directory contains the model file for each epoch +# reporter: Reporter instance +# best_model_criterion: Give criterions to decide the best model. +# e.g. [("valid", "loss", "min"), ("train", "acc", "max")] +# nbest: Number of best model files to be averaged +# suffix: A suffix added to the averaged model file name +# """ +# if isinstance(nbest, int): +# nbests = [nbest] +# else: +# nbests = list(nbest) +# if len(nbests) == 0: +# warnings.warn("At least 1 nbest values are required") +# nbests = [1] +# if suffix is not None: +# suffix = suffix + "." +# else: +# suffix = "" +# +# # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] +# nbest_epochs = [ +# (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) +# for ph, k, m in best_model_criterion +# if reporter.has(ph, k) +# ] +# +# _loaded = {} +# for ph, cr, epoch_and_values in nbest_epochs: +# _nbests = [i for i in nbests if i <= len(epoch_and_values)] +# if len(_nbests) == 0: +# _nbests = [1] +# +# for n in _nbests: +# if n == 0: +# continue +# elif n == 1: +# # The averaged model is same as the best model +# e, _ = epoch_and_values[0] +# op = output_dir / f"{e}epoch.pb" +# sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" +# if sym_op.is_symlink() or sym_op.exists(): +# sym_op.unlink() +# sym_op.symlink_to(op.name) +# else: +# op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" +# logging.info( +# f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' +# ) +# +# avg = None +# # 2.a. Averaging model +# for e, _ in epoch_and_values[:n]: +# if e not in _loaded: +# if oss_bucket is None: +# _loaded[e] = torch.load( +# output_dir / f"{e}epoch.pb", +# map_location="cpu", +# ) +# else: +# buffer = BytesIO( +# oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) +# _loaded[e] = torch.load(buffer) +# states = _loaded[e] +# +# if avg is None: +# avg = states +# else: +# # Accumulated +# for k in avg: +# avg[k] = avg[k] + states[k] +# for k in avg: +# if str(avg[k].dtype).startswith("torch.int"): +# # For int type, not averaged, but only accumulated. +# # e.g. BatchNorm.num_batches_tracked +# # (If there are any cases that requires averaging +# # or the other reducing method, e.g. max/min, for integer type, +# # please report.) +# pass +# else: +# avg[k] = avg[k] / n +# +# # 2.b. Save the ave model and create a symlink +# if oss_bucket is None: +# torch.save(avg, op) +# else: +# buffer = BytesIO() +# torch.save(avg, buffer) +# oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), +# buffer.getvalue()) +# +# # 3. *.*.ave.pb is a symlink to the max ave model +# if oss_bucket is None: +# op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" +# sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" +# if sym_op.is_symlink() or sym_op.exists(): +# sym_op.unlink() +# sym_op.symlink_to(op.name) -@torch.no_grad() -def average_nbest_models( - output_dir: Path, - reporter: Reporter, - best_model_criterion: Sequence[Sequence[str]], - nbest: Union[Collection[int], int], - suffix: Optional[str] = None, - oss_bucket=None, - pai_output_dir=None, -) -> None: - """Generate averaged model from n-best models - Args: - output_dir: The directory contains the model file for each epoch - reporter: Reporter instance - best_model_criterion: Give criterions to decide the best model. - e.g. [("valid", "loss", "min"), ("train", "acc", "max")] - nbest: Number of best model files to be averaged - suffix: A suffix added to the averaged model file name +def _get_checkpoint_paths(output_dir: str, last_n: int=5): """ - if isinstance(nbest, int): - nbests = [nbest] - else: - nbests = list(nbest) - if len(nbests) == 0: - warnings.warn("At least 1 nbest values are required") - nbests = [1] - if suffix is not None: - suffix = suffix + "." - else: - suffix = "" - - # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] - nbest_epochs = [ - (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) - for ph, k, m in best_model_criterion - if reporter.has(ph, k) - ] - - _loaded = {} - for ph, cr, epoch_and_values in nbest_epochs: - _nbests = [i for i in nbests if i <= len(epoch_and_values)] - if len(_nbests) == 0: - _nbests = [1] - - for n in _nbests: - if n == 0: - continue - elif n == 1: - # The averaged model is same as the best model - e, _ = epoch_and_values[0] - op = output_dir / f"{e}epoch.pb" - sym_op = output_dir / f"{ph}.{cr}.ave_1best.{suffix}pb" - if sym_op.is_symlink() or sym_op.exists(): - sym_op.unlink() - sym_op.symlink_to(op.name) - else: - op = output_dir / f"{ph}.{cr}.ave_{n}best.{suffix}pb" - logging.info( - f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' - ) + Get the paths of the last 'last_n' checkpoints by parsing filenames + in the output directory. + """ + # List all files in the output directory + files = os.listdir(output_dir) + # Filter out checkpoint files and extract epoch numbers + checkpoint_files = [f for f in files if f.startswith("model.pt.e")] + # Sort files by epoch number in descending order + checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True) + # Get the last 'last_n' checkpoint paths + checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]] + return checkpoint_paths - avg = None - # 2.a. Averaging model - for e, _ in epoch_and_values[:n]: - if e not in _loaded: - if oss_bucket is None: - _loaded[e] = torch.load( - output_dir / f"{e}epoch.pb", - map_location="cpu", - ) - else: - buffer = BytesIO( - oss_bucket.get_object(os.path.join(pai_output_dir, f"{e}epoch.pb")).read()) - _loaded[e] = torch.load(buffer) - states = _loaded[e] +@torch.no_grad() +def average_checkpoints(output_dir: str, last_n: int=5): + """ + Average the last 'last_n' checkpoints' model state_dicts. + If a tensor is of type torch.int, perform sum instead of average. + """ + checkpoint_paths = _get_checkpoint_paths(output_dir, last_n) + state_dicts = [] - if avg is None: - avg = states - else: - # Accumulated - for k in avg: - avg[k] = avg[k] + states[k] - for k in avg: - if str(avg[k].dtype).startswith("torch.int"): - # For int type, not averaged, but only accumulated. - # e.g. BatchNorm.num_batches_tracked - # (If there are any cases that requires averaging - # or the other reducing method, e.g. max/min, for integer type, - # please report.) - pass - else: - avg[k] = avg[k] / n + # Load state_dicts from checkpoints + for path in checkpoint_paths: + if os.path.isfile(path): + state_dicts.append(torch.load(path, map_location='cpu')['state_dict']) + else: + print(f"Checkpoint file {path} not found.") + continue - # 2.b. Save the ave model and create a symlink - if oss_bucket is None: - torch.save(avg, op) - else: - buffer = BytesIO() - torch.save(avg, buffer) - oss_bucket.put_object(os.path.join(pai_output_dir, f"{ph}.{cr}.ave_{n}best.{suffix}pb"), - buffer.getvalue()) + # Check if we have any state_dicts to average + if not state_dicts: + raise RuntimeError("No checkpoints found for averaging.") - # 3. *.*.ave.pb is a symlink to the max ave model - if oss_bucket is None: - op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.{suffix}pb" - sym_op = output_dir / f"{ph}.{cr}.ave.{suffix}pb" - if sym_op.is_symlink() or sym_op.exists(): - sym_op.unlink() - sym_op.symlink_to(op.name) + # Average or sum weights + avg_state_dict = OrderedDict() + for key in state_dicts[0].keys(): + tensors = [state_dict[key].cpu() for state_dict in state_dicts] + # Check the type of the tensor + if str(tensors[0].dtype).startswith("torch.int"): + # Perform sum for integer tensors + summed_tensor = sum(tensors) + avg_state_dict[key] = summed_tensor + else: + # Perform average for other types of tensors + stacked_tensors = torch.stack(tensors) + avg_state_dict[key] = torch.mean(stacked_tensors, dim=0) + + torch.save({'state_dict': avg_state_dict}, os.path.join(output_dir, f"model.pt.avg{last_n}")) + return avg_state_dict \ No newline at end of file diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index da346c39c..91b30b0a8 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -7,10 +7,11 @@ from contextlib import nullcontext # from torch.utils.tensorboard import SummaryWriter from tensorboardX import SummaryWriter +from pathlib import Path from funasr.train_utils.device_funcs import to_device from funasr.train_utils.recursive_op import recursive_average - +from funasr.train_utils.average_nbest_models import average_checkpoints class Trainer: """ @@ -66,10 +67,9 @@ def __init__(self, model, self.use_ddp = use_ddp self.use_fsdp = use_fsdp self.device = next(model.parameters()).device + self.avg_nbest_model = kwargs.get("avg_nbest_model", 5) self.kwargs = kwargs - if self.resume: - self._resume_checkpoint(self.resume) try: rank = dist.get_rank() @@ -102,9 +102,17 @@ def _save_checkpoint(self, epoch): } # Create output directory if it does not exist os.makedirs(self.output_dir, exist_ok=True) - filename = os.path.join(self.output_dir, f'model.e{epoch}.pb') + filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') torch.save(state, filename) + print(f'Checkpoint saved to {filename}') + latest = Path(os.path.join(self.output_dir, f'model.pt')) + try: + latest.unlink() + except: + pass + + latest.symlink_to(filename) def _resume_checkpoint(self, resume_path): """ @@ -114,29 +122,50 @@ def _resume_checkpoint(self, resume_path): Args: resume_path (str): The file path to the checkpoint to resume from. """ - if os.path.isfile(resume_path): - checkpoint = torch.load(resume_path) + ckpt = os.path.join(resume_path, "model.pt") + if os.path.isfile(ckpt): + checkpoint = torch.load(ckpt) self.start_epoch = checkpoint['epoch'] + 1 self.model.load_state_dict(checkpoint['state_dict']) self.optim.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) - print(f"Checkpoint loaded successfully from '{resume_path}' at (epoch {checkpoint['epoch']})") + print(f"Checkpoint loaded successfully from '{ckpt}'") else: - print(f"No checkpoint found at '{resume_path}', starting from scratch") + print(f"No checkpoint found at '{ckpt}', starting from scratch") + + if self.use_ddp or self.use_fsdp: + dist.barrier() def run(self): """ Starts the training process, iterating over epochs, training the model, and saving checkpoints at the end of each epoch. """ + if self.resume: + self._resume_checkpoint(self.output_dir) + for epoch in range(self.start_epoch, self.max_epoch + 1): + self._train_epoch(epoch) - # self._validate_epoch(epoch) + + self._validate_epoch(epoch) + if self.rank == 0: self._save_checkpoint(epoch) + + if self.use_ddp or self.use_fsdp: + dist.barrier() + self.scheduler.step() + + + if self.rank == 0: + average_checkpoints(self.output_dir, self.avg_nbest_model) + if self.use_ddp or self.use_fsdp: + dist.barrier() self.writer.close() + def _train_epoch(self, epoch): """ @@ -157,8 +186,7 @@ def _train_epoch(self, epoch): for batch_idx, batch in enumerate(self.dataloader_train): time1 = time.perf_counter() speed_stats["data_load"] = f"{time1-time5:0.3f}" - # import pdb; - # pdb.set_trace() + batch = to_device(batch, self.device) my_context = self.model.no_sync if batch_idx % accum_grad != 0 else nullcontext @@ -211,13 +239,12 @@ def _train_epoch(self, epoch): speed_stats["optim_time"] = f"{time5 - time4:0.3f}" speed_stats["total_time"] = total_time - - # import pdb; - # pdb.set_trace() + + pbar.update(1) if self.local_rank == 0: description = ( - f"Epoch: {epoch + 1}/{self.max_epoch}, " + f"Epoch: {epoch}/{self.max_epoch}, " f"step {batch_idx}/{len(self.dataloader_train)}, " f"{speed_stats}, " f"(loss: {loss.detach().cpu().item():.3f}), " @@ -248,6 +275,50 @@ def _validate_epoch(self, epoch): """ self.model.eval() with torch.no_grad(): - for data, target in self.dataloader_val: - # Implement the model validation steps here - pass + pbar = tqdm(colour="red", desc=f"Training Epoch: {epoch + 1}", total=len(self.dataloader_val), + dynamic_ncols=True) + speed_stats = {} + time5 = time.perf_counter() + for batch_idx, batch in enumerate(self.dataloader_val): + time1 = time.perf_counter() + speed_stats["data_load"] = f"{time1 - time5:0.3f}" + batch = to_device(batch, self.device) + time2 = time.perf_counter() + retval = self.model(**batch) + time3 = time.perf_counter() + speed_stats["forward_time"] = f"{time3 - time2:0.3f}" + loss, stats, weight = retval + stats = {k: v for k, v in stats.items() if v is not None} + if self.use_ddp or self.use_fsdp: + # Apply weighted averaging for loss and stats + loss = (loss * weight.type(loss.dtype)).sum() + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed=True) + # Now weight is summation over all workers + loss /= weight + # Multiply world_size because DistributedDataParallel + # automatically normalizes the gradient by world_size. + loss *= self.world_size + # Scale the loss since we're not updating for every mini-batch + loss = loss + time4 = time.perf_counter() + + pbar.update(1) + if self.local_rank == 0: + description = ( + f"validation: \nEpoch: {epoch}/{self.max_epoch}, " + f"step {batch_idx}/{len(self.dataloader_train)}, " + f"{speed_stats}, " + f"(loss: {loss.detach().cpu().item():.3f}), " + f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" + ) + pbar.set_description(description) + if self.writer: + self.writer.add_scalar('Loss/val', loss.item(), + epoch*len(self.dataloader_train) + batch_idx) + for key, var in stats.items(): + self.writer.add_scalar(f'{key}/val', var.item(), + epoch * len(self.dataloader_train) + batch_idx) + for key, var in speed_stats.items(): + self.writer.add_scalar(f'{key}/val', eval(var), + epoch * len(self.dataloader_train) + batch_idx) \ No newline at end of file