From ae4dceecf03c643be770e3274a070078e45fc9c7 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Tue, 23 Jan 2024 11:34:03 +0800 Subject: [PATCH] bug fix for punc and umap --- funasr/models/campplus/cluster_backend.py | 2 ++ funasr/models/ct_transformer/model.py | 4 ++-- setup.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/funasr/models/campplus/cluster_backend.py b/funasr/models/campplus/cluster_backend.py index 3bac0a02c..e33a14d6a 100644 --- a/funasr/models/campplus/cluster_backend.py +++ b/funasr/models/campplus/cluster_backend.py @@ -119,6 +119,7 @@ def __init__(self, self.metric = metric def __call__(self, X): + from umap.umap_ import UMAP umap_X = umap.UMAP( n_neighbors=self.n_neighbors, min_dist=0.0, @@ -156,6 +157,7 @@ def forward(self, X, **params): if X.shape[0] < 20: return np.zeros(X.shape[0], dtype='int') if X.shape[0] < 2048 or k is not None: + # unexpected corner case labels = self.spectral_cluster(X, k) else: labels = self.umap_hdbscan_cluster(X) diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py index 330d7e554..1e53aa3d4 100644 --- a/funasr/models/ct_transformer/model.py +++ b/funasr/models/ct_transformer/model.py @@ -336,10 +336,11 @@ def inference(self, elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())!=1: new_mini_sentence_out = new_mini_sentence + "。" new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] + if len(punctuations): punctuations[-1] = 2 elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: new_mini_sentence_out = new_mini_sentence + "." new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] - + if len(punctuations): punctuations[-1] = 2 # keep a punctuations array for punc segment if punc_array is None: punc_array = punctuations @@ -347,6 +348,5 @@ def inference(self, punc_array = torch.cat([punc_array, punctuations], dim=0) result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} results.append(result_i) - return results, meta_data diff --git a/setup.py b/setup.py index f4439b01e..561dea2b6 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ # "protobuf", "tqdm", "hdbscan", - "umap", + "umap_learn", "jaconv", "hydra-core>=1.3.2", ],