Skip to content

Commit

Permalink
funasr1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
LauraGPT committed Jan 15, 2024
1 parent 831c48a commit 2a0b2c7
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 34 deletions.
8 changes: 4 additions & 4 deletions funasr/bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def build_model(self, **kwargs):
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer
kwargs["token_list"] = tokenizer.token_list
Expand All @@ -186,13 +186,13 @@ def build_model(self, **kwargs):
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()

# build model
model_class = tables.model_classes.get(kwargs["model"].lower())
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
model.eval()
model.to(device)
Expand Down Expand Up @@ -443,7 +443,7 @@ def __init__(self, **kwargs):
# build frontend
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])

self.frontend = frontend
Expand Down
10 changes: 5 additions & 5 deletions funasr/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,22 @@ def main(**kwargs):

tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
tokenizer_class = tables.tokenizer_classes.get(tokenizer.lower())
tokenizer_class = tables.tokenizer_classes.get(tokenizer)
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
kwargs["tokenizer"] = tokenizer

# build frontend if frontend is none None
frontend = kwargs.get("frontend", None)
if frontend is not None:
frontend_class = tables.frontend_classes.get(frontend.lower())
frontend_class = tables.frontend_classes.get(frontend)
frontend = frontend_class(**kwargs["frontend_conf"])
kwargs["frontend"] = frontend
kwargs["input_size"] = frontend.output_size()

# import pdb;
# pdb.set_trace()
# build model
model_class = tables.model_classes.get(kwargs["model"].lower())
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))


Expand Down Expand Up @@ -141,12 +141,12 @@ def main(**kwargs):
# import pdb;
# pdb.set_trace()
# dataset
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower())
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))

# dataloader
batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler.lower())
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
if batch_sampler is not None:
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
Expand Down
9 changes: 6 additions & 3 deletions funasr/datasets/audio_datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

@tables.register("dataset_classes", "AudioDataset")
class AudioDataset(torch.utils.data.Dataset):
"""
AudioDataset
"""
def __init__(self,
path,
index_ds: str = None,
Expand All @@ -22,16 +25,16 @@ def __init__(self,
float_pad_value: float = 0.0,
**kwargs):
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds.lower())
index_ds_class = tables.index_ds_classes.get(index_ds)
self.index_ds = index_ds_class(path)
preprocessor_speech = kwargs.get("preprocessor_speech", None)
if preprocessor_speech:
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech.lower())
preprocessor_speech_class = tables.preprocessor_speech_classes.get(preprocessor_speech)
preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
self.preprocessor_speech = preprocessor_speech
preprocessor_text = kwargs.get("preprocessor_text", None)
if preprocessor_text:
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text.lower())
preprocessor_text_class = tables.preprocessor_text_classes.get(preprocessor_text)
preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
self.preprocessor_text = preprocessor_text

Expand Down
2 changes: 1 addition & 1 deletion funasr/models/ct_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(


self.embed = nn.Embedding(vocab_size, embed_unit)
encoder_class = tables.encoder_classes.get(encoder.lower())
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)

self.decoder = nn.Linear(att_unit, punc_size)
Expand Down
2 changes: 1 addition & 1 deletion funasr/models/fsmn_vad_streaming/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(self,
super().__init__()
self.vad_opts = VADXOptions(**kwargs)

encoder_class = tables.encoder_classes.get(encoder.lower())
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(**encoder_conf)
self.encoder = encoder

Expand Down
8 changes: 4 additions & 4 deletions funasr/models/monotonic_aligner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def __init__(
super().__init__()

if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug.lower())
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize.lower())
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder.lower())
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
predictor_class = tables.predictor_classes.get(predictor.lower())
predictor_class = tables.predictor_classes.get(predictor)
predictor = predictor_class(**predictor_conf)
self.specaug = specaug
self.normalize = normalize
Expand Down
10 changes: 5 additions & 5 deletions funasr/models/paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,17 @@ def __init__(
super().__init__()

if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug.lower())
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize.lower())
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get(encoder.lower())
encoder_class = tables.encoder_classes.get(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()

if decoder is not None:
decoder_class = tables.decoder_classes.get(decoder.lower())
decoder_class = tables.decoder_classes.get(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
Expand All @@ -104,7 +104,7 @@ def __init__(
odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
)
if predictor is not None:
predictor_class = tables.predictor_classes.get(predictor.lower())
predictor_class = tables.predictor_classes.get(predictor)
predictor = predictor_class(**predictor_conf)

# note that eos is the same as sos (equivalent ID)
Expand Down
2 changes: 1 addition & 1 deletion funasr/models/seaco_paraformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
seaco_decoder = kwargs.get("seaco_decoder", None)
if seaco_decoder is not None:
seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower())
seaco_decoder_class = tables.decoder_classes.get(seaco_decoder)
self.seaco_decoder = seaco_decoder_class(
vocab_size=self.vocab_size,
encoder_output_size=self.inner_dim,
Expand Down
10 changes: 5 additions & 5 deletions funasr/models/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ def __init__(
super().__init__()

if frontend is not None:
frontend_class = tables.frontend_classes.get_class(frontend.lower())
frontend_class = tables.frontend_classes.get_class(frontend)
frontend = frontend_class(**frontend_conf)
if specaug is not None:
specaug_class = tables.specaug_classes.get_class(specaug.lower())
specaug_class = tables.specaug_classes.get_class(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get_class(normalize.lower())
normalize_class = tables.normalize_classes.get_class(normalize)
normalize = normalize_class(**normalize_conf)
encoder_class = tables.encoder_classes.get_class(encoder.lower())
encoder_class = tables.encoder_classes.get_class(encoder)
encoder = encoder_class(input_size=input_size, **encoder_conf)
encoder_output_size = encoder.output_size()
if decoder is not None:
decoder_class = tables.decoder_classes.get_class(decoder.lower())
decoder_class = tables.decoder_classes.get_class(decoder)
decoder = decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
Expand Down
13 changes: 8 additions & 5 deletions funasr/register.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import inspect
from dataclasses import dataclass

import re

@dataclass
class RegisterTables:
Expand Down Expand Up @@ -29,7 +29,7 @@ def print(self, key=None):
flag = key in classes_key
if classes_key.endswith("_meta") and flag:
print(f"----------- ** {classes_key.replace('_meta', '')} ** --------------")
headers = ["class name", "register name", "class location"]
headers = ["class name", "class location"]
metas = []
for register_key, meta in classes_dict.items():
metas.append(meta)
Expand All @@ -51,8 +51,7 @@ def decorator(target_class):

registry = getattr(self, register_tables_key)
registry_key = key if key is not None else target_class.__name__
registry_key = registry_key.lower()
# import pdb; pdb.set_trace()

assert not registry_key in registry, "(key: {} / class: {}) has been registered already,in {}".format(
registry_key, target_class, register_tables_key)

Expand All @@ -63,9 +62,13 @@ def decorator(target_class):
if not hasattr(self, register_tables_key_meta):
setattr(self, register_tables_key_meta, {})
registry_meta = getattr(self, register_tables_key_meta)
# doc = target_class.__doc__
class_file = inspect.getfile(target_class)
class_line = inspect.getsourcelines(target_class)[1]
meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
pattern = r'^.+/funasr/'
class_file = re.sub(pattern, 'funasr/', class_file)
meata_data = [f"{target_class.__name__}", f"{class_file}:{class_line}"]
# meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
registry_meta[registry_key] = meata_data
# print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
return target_class
Expand Down

0 comments on commit 2a0b2c7

Please sign in to comment.