Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Funasr1.0 #1275

Merged
merged 15 commits into from
Jan 19, 2024
4 changes: 2 additions & 2 deletions examples/industrial_data_pretraining/paraformer/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ python funasr/bin/train.py \
+model_revision="v2.0.2" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
++dataset_conf.batch_size=2 \
++dataset_conf.batch_size=64 \
++dataset_conf.batch_type="example" \
++train_conf.max_epoch=2 \
++dataset_conf.num_workers=4 \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
+device="cpu" \
+debug="true"
2 changes: 1 addition & 1 deletion funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, **kwargs):
self.punc_kwargs = punc_kwargs
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.model_path = kwargs["model_path"]
self.model_path = kwargs.get("model_path", "./")


def build_model(self, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion funasr/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main_hydra(kwargs: DictConfig):


def main(**kwargs):

print(kwargs)
# set random seed
tables.print()
set_all_random_seed(kwargs.get("seed", 0))
Expand Down
2 changes: 1 addition & 1 deletion funasr/datasets/audio_datasets/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, dataset,
self.shuffle = shuffle and is_training

def __len__(self):
return self.total_samples
return (self.total_samples-1) // self.batch_size + 1

def set_epoch(self, epoch):
np.random.seed(epoch)
Expand Down
2 changes: 0 additions & 2 deletions funasr/models/fsmn_vad_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def __init__(self,
self.waveform = None
self.last_drop_frames = 0


@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
"""
Expand Down Expand Up @@ -500,7 +499,6 @@ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
# # reset class variables and clear the dict for the next query
# self.AllResetDetection()
return segments


def init_cache(self, cache: dict = {}, **kwargs):

Expand Down
20 changes: 16 additions & 4 deletions funasr/train_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,17 @@ def run(self):
for epoch in range(self.start_epoch, self.max_epoch + 1):

self._train_epoch(epoch)


if self.use_ddp or self.use_fsdp:
dist.barrier()

self._validate_epoch(epoch)


if self.use_ddp or self.use_fsdp:
dist.barrier()


if self.rank == 0:
self._save_checkpoint(epoch)

Expand All @@ -164,7 +172,9 @@ def run(self):

if self.use_ddp or self.use_fsdp:
dist.barrier()
self.writer.close()

if self.writer:
self.writer.close()


def _train_epoch(self, epoch):
Expand Down Expand Up @@ -230,6 +240,8 @@ def _train_epoch(self, epoch):
continue

# Execute an optimization step (update model parameters)
if self.use_ddp or self.use_fsdp:
dist.barrier()
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
Expand All @@ -244,7 +256,7 @@ def _train_epoch(self, epoch):
pbar.update(1)
if self.local_rank == 0:
description = (
f"Epoch: {epoch}/{self.max_epoch}, "
f"Train epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
Expand Down Expand Up @@ -306,7 +318,7 @@ def _validate_epoch(self, epoch):
pbar.update(1)
if self.local_rank == 0:
description = (
f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
f"validation epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
Expand Down
Loading