Skip to content

Commit

Permalink
funasr1.0 update
Browse files Browse the repository at this point in the history
  • Loading branch information
LauraGPT committed Jan 21, 2024
2 parents cfff084 + 2cca810 commit 453d118
Show file tree
Hide file tree
Showing 5 changed files with 454 additions and 119 deletions.
1 change: 1 addition & 0 deletions funasr/auto/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(self, **kwargs):
self.spk_model = spk_model
self.spk_kwargs = spk_kwargs
self.model_path = kwargs.get("model_path")



def build_model(self, **kwargs):
Expand Down
76 changes: 46 additions & 30 deletions funasr/models/fsmn_vad_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@
from typing import List, Tuple, Dict, Any, Optional

from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank


class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3


class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0


# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
Expand All @@ -37,16 +39,19 @@ class AudioChangeState(Enum):
kChangeStateNoBegin = 4
kChangeStateInvalid = 5


class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1


class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""

def __init__(
self,
sample_rate: int = 16000,
Expand Down Expand Up @@ -117,6 +122,7 @@ class E2EVadSpeechBufWithDoa(object):
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""

def __init__(self):
self.start_ms = 0
self.end_ms = 0
Expand All @@ -140,6 +146,7 @@ class E2EVadFrameProb(object):
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""

def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
Expand All @@ -154,6 +161,7 @@ class WindowDetector(object):
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""

def __init__(self, window_size_ms: int,
sil_to_speech_time: int,
speech_to_sil_time: int,
Expand Down Expand Up @@ -190,7 +198,7 @@ def Reset(self) -> None:
def GetWinSize(self) -> int:
return int(self.win_size_frame)

def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict = {}) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
Expand Down Expand Up @@ -220,13 +228,13 @@ def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)


class Stats(object):
def __init__(self,
sil_pdf_ids,
max_end_sil_frame_cnt_thresh,
speech_noise_thres,
):

self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
Expand Down Expand Up @@ -263,6 +271,7 @@ class FsmnVADStreaming(nn.Module):
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""

def __init__(self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
Expand All @@ -276,7 +285,6 @@ def __init__(self,
encoder = encoder_class(**encoder_conf)
self.encoder = encoder


def ResetDetection(self, cache: dict = {}):
cache["stats"].continous_silence_frame_count = 0
cache["stats"].latest_confirmed_speech_frame = 0
Expand All @@ -293,15 +301,17 @@ def ResetDetection(self, cache: dict = {}):
drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
real_drop_frames = drop_frames - cache["stats"].last_drop_frames
cache["stats"].last_drop_frames = drop_frames
cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]

def ComputeDecibel(self, cache: dict = {}) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if cache["stats"].data_buf_all is None:
cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
cache["stats"].data_buf_all = cache["stats"].waveform[
0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
cache["stats"].data_buf = cache["stats"].data_buf_all
else:
cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
Expand All @@ -320,15 +330,16 @@ def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
else:
cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)

def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again
def PopDataBufTillFrame(self, frame_idx: int, cache: dict = {}) -> None: # need check again
while cache["stats"].data_buf_start_frame < frame_idx:
if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
cache["stats"].data_buf_start_frame += 1
cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
cache["stats"].data_buf = cache["stats"].data_buf_all[
(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]

def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict = {}) -> None:
self.PopDataBufTillFrame(start_frm, cache=cache)
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
if last_frm_is_end_point:
Expand Down Expand Up @@ -380,14 +391,15 @@ def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
cache["stats"].lastest_confirmed_silence_frame = valid_frame
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame, cache=cache)
# silence_detected_callback_
# pass

def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
# silence_detected_callback_
# pass

def OnVoiceDetected(self, valid_frame: int, cache: dict = {}) -> None:
cache["stats"].latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)

def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache: dict = {}) -> None:
if self.vad_opts.do_start_point_detection:
pass
if cache["stats"].confirmed_start_frame != -1:
Expand All @@ -398,7 +410,7 @@ def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={
if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)

def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache: dict = {}) -> None:
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t, cache=cache)
if self.vad_opts.do_end_point_detection:
Expand Down Expand Up @@ -488,7 +500,8 @@ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
segment_batch = []
if len(cache["stats"].output_data_buf) > 0:
for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not
cache["stats"].output_data_buf[
i].contain_seg_end_point):
continue
segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
Expand All @@ -501,7 +514,8 @@ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
# self.AllResetDetection()
return segments

def init_cache(self, cache: dict = {}):
def init_cache(self, cache: dict = {}, **kwargs):

cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
Expand All @@ -528,12 +542,12 @@ def inference(self,
cache: dict = {},
**kwargs,
):
# cache = kwargs.get("cache", {})

if len(cache) == 0:
self.init_cache(cache)
self.init_cache(cache, **kwargs)

meta_data = {}
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)

time1 = time.perf_counter()
Expand Down Expand Up @@ -580,7 +594,6 @@ def inference(self,
if len(segments_i) > 0:
segments.extend(*segments_i)


cache["prev_samples"] = audio_sample[:-m]
if _is_final:
self.init_cache(cache)
Expand All @@ -600,16 +613,15 @@ def inference(self,
if ibest_writer is not None:
ibest_writer["text"][key[0]] = segments


return results, meta_data


def DetectCommonFrames(self, cache: dict = {}) -> int:
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
cache=cache)
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)

return 0
Expand All @@ -619,15 +631,17 @@ def DetectLastFrames(self, cache: dict = {}) -> int:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames,
cache=cache)
if i != 0:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
else:
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)

return 0

def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool,
cache: dict = {}) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
Expand All @@ -644,7 +658,8 @@ def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_f
cache["stats"].pre_end_silence_detected = False
start_frame = 0
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
start_frame = max(cache["stats"].data_buf_start_frame,
cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
self.OnVoiceStart(start_frame, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
for t in range(start_frame + 1, cur_frm_idx + 1):
Expand Down Expand Up @@ -696,7 +711,8 @@ def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_f
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
# silence timeout, return zero length decision
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
cache[
"stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
or (is_final_frame and cache["stats"].number_end_time_detected == 0):
for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t, cache=cache)
Expand All @@ -707,7 +723,8 @@ def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_f
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache[
"stats"].max_end_sil_frame_cnt_thresh:
lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
if self.vad_opts.do_extend:
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
Expand All @@ -733,4 +750,3 @@ def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_f
self.ResetDetection(cache=cache)



Loading

0 comments on commit 453d118

Please sign in to comment.