Skip to content

Commit

Permalink
Funasr1.0 (#1261)
Browse files Browse the repository at this point in the history
* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <[email protected]>
Co-authored-by: shixian.shi <[email protected]>

---------

Co-authored-by: Yabin Li <[email protected]>
Co-authored-by: shixian.shi <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2024
1 parent b185783 commit 9a9c3b7
Show file tree
Hide file tree
Showing 10 changed files with 296 additions and 145 deletions.
4 changes: 3 additions & 1 deletion examples/industrial_data_pretraining/paraformer/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
python funasr/bin/train.py \
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+model_revision="v2.0.2" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
+train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
+valid_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
++dataset_conf.batch_size=2 \
++dataset_conf.batch_type="example" \
++train_conf.max_epoch=2 \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
+device="cpu" \
+debug="true"
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
spk_model_revision="v2.0.2",
)

res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 魔搭')
print(res)
34 changes: 21 additions & 13 deletions funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
disable_pbar = kwargs.get("disable_pbar", False)
pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None
time_speech_total = 0.0
time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size):
Expand All @@ -239,8 +240,7 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **
time2 = time.perf_counter()

asr_result_list.extend(results)
pbar.update(1)


# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1)
time_escape = time2 - time1
Expand All @@ -252,12 +252,15 @@ def inference(self, input, input_len=None, model=None, kwargs=None, key=None, **
description = (
f"{speed_stats}, "
)
pbar.set_description(description)
if pbar:
pbar.update(1)
pbar.set_description(description)
time_speech_total += batch_data_time
time_escape_total += time_escape

pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")

if pbar:
pbar.update(1)
pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list

Expand Down Expand Up @@ -309,8 +312,11 @@ def inference_with_vad(self, input, input_len=None, **cfg):
time_speech_total_per_sample = speech_lengths/16000
time_speech_total_all_samples += time_speech_total_per_sample

pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True)

all_segments = []
for j, _ in enumerate(range(0, n)):
pbar_sample.update(1)
batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
if j < n - 1 and (
batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
Expand All @@ -319,13 +325,14 @@ def inference_with_vad(self, input, input_len=None, **cfg):
batch_size_ms_cum = 0
end_idx = j + 1
speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
results = self.inference(speech_j, input_len=None, model=model, kwargs=kwargs, disable_pbar=True, **cfg)
if self.spk_model is not None:



# compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
for _b in range(len(speech_j)):
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0,
sorted_data[beg_idx:end_idx][_b][0][1]/1000.0,
speech_j[_b]]]
segments = sv_chunk(vad_segments)
all_segments.extend(segments)
Expand All @@ -338,12 +345,13 @@ def inference_with_vad(self, input, input_len=None, **cfg):
results_sorted.extend(results)


pbar_total.update(1)

end_asr_total = time.time()
time_escape_total_per_sample = end_asr_total - beg_asr_total
pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")


restored_data = [0] * n
for j in range(n):
Expand Down
17 changes: 12 additions & 5 deletions funasr/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,30 +141,37 @@ def main(**kwargs):
scheduler_class = scheduler_classes.get(scheduler)
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))

# import pdb;
# pdb.set_trace()

# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer,
**kwargs.get("dataset_conf"))

# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler_val = None
if batch_sampler is not None:
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
collate_fn=dataset_tr.collator,
batch_sampler=batch_sampler,
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)


dataloader_val = torch.utils.data.DataLoader(dataset_val,
collate_fn=dataset_val.collator,
batch_sampler=batch_sampler_val,
num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
pin_memory=True)
trainer = Trainer(
model=model,
optim=optim,
scheduler=scheduler,
dataloader_train=dataloader_tr,
dataloader_val=None,
dataloader_val=dataloader_val,
local_rank=local_rank,
use_ddp=use_ddp,
use_fsdp=use_fsdp,
Expand Down
6 changes: 5 additions & 1 deletion funasr/datasets/audio_datasets/index_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def __len__(self):
return len(self.contents)

def __getitem__(self, index):
return self.contents[index]
try:
data = self.contents[index]
except:
print(index)
return data

def get_source_len(self, data_dict):
return data_dict["source_len"]
Expand Down
3 changes: 2 additions & 1 deletion funasr/datasets/audio_datasets/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, dataset,
buffer_size: int = 30,
drop_last: bool = False,
shuffle: bool = True,
is_training: bool = True,
**kwargs):

self.drop_last = drop_last
Expand All @@ -24,7 +25,7 @@ def __init__(self, dataset,
self.buffer_size = buffer_size
self.max_token_length = kwargs.get("max_token_length", 5000)
self.shuffle_idx = np.arange(self.total_samples)
self.shuffle = shuffle
self.shuffle = shuffle and is_training

def __len__(self):
return self.total_samples
Expand Down
1 change: 1 addition & 0 deletions funasr/models/paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def __init__(
self.use_1st_decoder_loss = use_1st_decoder_loss
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.error_calculator = None

def forward(
self,
Expand Down
1 change: 1 addition & 0 deletions funasr/models/paraformer/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ train_conf:
- acc
- max
keep_nbest_models: 10
avg_nbest_model: 5
log_interval: 50

optim: adam
Expand Down
Loading

0 comments on commit 9a9c3b7

Please sign in to comment.