diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py index b303418bd..966cc9449 100644 --- a/funasr/datasets/preprocessor.py +++ b/funasr/datasets/preprocessor.py @@ -705,55 +705,73 @@ def join_chinese_and_english(cls, input_list): return line @classmethod - def split_words_jieba(cls, text: str): - input_list = text.split() - token_list_all = [] - langauge_list = [] - token_list_tmp = [] - language_flag = None - for token in input_list: - if cls.isEnglish(token) and language_flag == 'Chinese': - token_list_all.append(token_list_tmp) - langauge_list.append('Chinese') - token_list_tmp = [] - elif not cls.isEnglish(token) and language_flag == 'English': - token_list_all.append(token_list_tmp) - langauge_list.append('English') - token_list_tmp = [] + def split_words(cls, text: str , seg_jieba: bool): + if seg_jieba == True: + input_list = text.split() + token_list_all = [] + langauge_list = [] + token_list_tmp = [] + language_flag = None + for token in input_list: + if cls.isEnglish(token) and language_flag == 'Chinese': + token_list_all.append(token_list_tmp) + langauge_list.append('Chinese') + token_list_tmp = [] + elif not cls.isEnglish(token) and language_flag == 'English': + token_list_all.append(token_list_tmp) + langauge_list.append('English') + token_list_tmp = [] + + token_list_tmp.append(token) + + if cls.isEnglish(token): + language_flag = 'English' + else: + language_flag = 'Chinese' - token_list_tmp.append(token) + if token_list_tmp: + token_list_all.append(token_list_tmp) + langauge_list.append(language_flag) - if cls.isEnglish(token): - language_flag = 'English' - else: - language_flag = 'Chinese' + result_list = [] + for token_list_tmp, language_flag in zip(token_list_all, langauge_list): + if language_flag == 'English': + result_list.extend(token_list_tmp) + else: + seg_list = jieba.cut(cls.join_chinese_and_english(token_list_tmp), HMM=False) + result_list.extend(seg_list) - if token_list_tmp: - token_list_all.append(token_list_tmp) - langauge_list.append(language_flag) + return result_list - result_list = [] - for token_list_tmp, language_flag in zip(token_list_all, langauge_list): - if language_flag == 'English': - result_list.extend(token_list_tmp) - else: - seg_list = jieba.cut(cls.join_chinese_and_english(token_list_tmp), HMM=False) - result_list.extend(seg_list) + else: + words = [] + segs = text.split() + for seg in segs: + # There is no space in seg. + current_word = "" + for c in seg: + if len(c.encode()) == 1: + # This is an ASCII char. + current_word += c + else: + # This is a Chinese char. + if len(current_word) > 0: + words.append(current_word) + current_word = "" + words.append(c) + if len(current_word) > 0: + words.append(current_word) + return words - return result_list def __call__( self, uid: str, data: Dict[str, Union[list, str, np.ndarray]] ) -> Dict[str, Union[list, np.ndarray]]: # Split words. - if isinstance(data[self.text_name], str): - if self.seg_jieba: - # jieba.load_userdict(seg_dict_file) - split_text = self.split_words_jieba(data[self.text_name]) - else: - split_text = self.split_words(data[self.text_name]) - else: - split_text = data[self.text_name] + data_in = data[self.text_name] + if isinstance(data[self.text_name], list): + data_in = " ".join(data[self.text_name]) + split_text = self.split_words(data_in, self.seg_jieba) data[self.text_name] = " ".join(split_text) data = self._speech_process(data) data = self._text_process(data)