Skip to content

Commit

Permalink
streaming bugfix (#1271)
Browse files Browse the repository at this point in the history
* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <[email protected]>
Co-authored-by: shixian.shi <[email protected]>

* update with main (#1264)

* Funasr1.0 (#1261)

* funasr1.0 funetine

* funasr1.0 pbar

* update with main (#1260)

* Update websocket_protocol_zh.md

* update

---------

Co-authored-by: Yabin Li <[email protected]>
Co-authored-by: shixian.shi <[email protected]>

---------

Co-authored-by: Yabin Li <[email protected]>
Co-authored-by: shixian.shi <[email protected]>

* bug fix

---------

Co-authored-by: Yabin Li <[email protected]>
Co-authored-by: shixian.shi <[email protected]>

* funasr1.0 sanm scama

* funasr1.0 infer_after_finetune

* funasr1.0 fsmn-vad bug fix

* funasr1.0 fsmn-vad bug fix

* funasr1.0 fsmn-vad bug fix

---------

Co-authored-by: Yabin Li <[email protected]>
Co-authored-by: shixian.shi <[email protected]>
  • Loading branch information
3 people authored Jan 18, 2024
1 parent b28f3c9 commit 12496e5
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cross-attention

model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revision="v2.0.2")
cache = {}
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
chunk_size=chunk_size,
encoder_chunk_look_back=encoder_chunk_look_back,
Expand Down
6 changes: 4 additions & 2 deletions funasr/models/fsmn_vad_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,9 @@ def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
# self.AllResetDetection()
return segments


def init_cache(self, cache: dict = {}, **kwargs):

cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
Expand All @@ -528,7 +530,7 @@ def inference(self,
cache: dict = {},
**kwargs,
):

if len(cache) == 0:
self.init_cache(cache, **kwargs)

Expand Down Expand Up @@ -583,7 +585,7 @@ def inference(self,

cache["prev_samples"] = audio_sample[:-m]
if _is_final:
cache = {}
self.init_cache(cache)

ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
Expand Down
3 changes: 1 addition & 2 deletions funasr/models/paraformer_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,7 @@ def inference(self,
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)



if len(cache) == 0:
self.init_cache(cache, **kwargs)

Expand Down

0 comments on commit 12496e5

Please sign in to comment.