Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD Whisper #82

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added audio_processing/whisper/demo.wav
Binary file not shown.
151 changes: 151 additions & 0 deletions audio_processing/whisper/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!pip install transformers
#!pip install datasets

# %%
import tensorflow as tf

from datasets import load_dataset
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer

generate_saved_model = False
generate_tflite_model = False

quantize = True

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny.en")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en", predict_timestamps=True)
processor = WhisperProcessor(feature_extractor, tokenizer)
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
# Loading dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

inputs = feature_extractor(
ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features

# Generating Transcription
if generate_saved_model:
generated_ids = model.generate(input_features=input_features)
print(generated_ids)
transcription = processor.tokenizer.decode(generated_ids[0])
print(transcription)
model.save('./content/tf_whisper_saved')

# %% [markdown]
# ##Convert saved model to TFLite model

# %%
import tensorflow as tf

saved_model_dir = './content/tf_whisper_saved'
tflite_model_path = 'whisper.tflite'

# Convert the model
if generate_saved_model:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

# %% [markdown]
# ## Create generation-enabled TF Lite model
#
# The solution consists in defining a model whose serving function is the generation call. Here's an example of how to do it:

# %%
class GenerateModel(tf.Module):
def __init__(self, model):
super(GenerateModel, self).__init__()
self.model = model

@tf.function(
# shouldn't need static batch size, but throws exception without it (needs to be fixed)
input_signature=[
tf.TensorSpec((1, 80, 3000), tf.float32, name="input_features"),
],
)
def serving(self, input_features):
outputs = self.model.generate(
input_features,
max_new_tokens=450, #change as needed
return_dict_in_generate=True,
)
return {"sequences": outputs["sequences"]}

saved_model_dir = './content/tf_whisper_saved'

if generate_saved_model:
generate_model = GenerateModel(model=model)
tf.saved_model.save(generate_model, saved_model_dir, signatures={"serving_default": generate_model.serving})

def representative_dataset():
num_datasets = 1 # max 73
for i in range(num_datasets):#Change this to 100 and provide 100 different audio files from known dataset
inputs = feature_extractor(
ds[i]["audio"]["array"], sampling_rate=ds[i]["audio"]["sampling_rate"], return_tensors="tf"
)
input_features = inputs.input_features
attention = tf.constant(0, shape=(1, 1), dtype=tf.int32)
input_ids = tf.constant(0, shape=(1, 1), dtype=tf.int32)
yield [attention, input_ids, input_features]

# Convert the model
if generate_tflite_model:
if not quantize:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
else:
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8, tf.lite.OpsSet.SELECT_TF_OPS]
converter.inference_input_type = tf.float32
converter.inference_output_type = tf.float32 # int32 can not selected for int8
tflite_model = converter.convert()

# Save the model
if quantize:
tflite_model_path = 'whisper-tiny-en-int8.tflite'
else:
tflite_model_path = 'whisper-tiny-en-float.tflite'

if generate_tflite_model:
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)

# %%
# loaded model... now with generate!
interpreter = tf.lite.Interpreter(tflite_model_path)

if quantize:
interpreter.allocate_tensors()
for i in range(450):
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
attention = tf.constant(0, shape=(1, 1), dtype=tf.int32)
input_ids = tf.constant(0, shape=(1, 1), dtype=tf.int32)
interpreter.set_tensor(input_details[0]['index'], attention)
interpreter.set_tensor(input_details[1]['index'], input_ids)
interpreter.set_tensor(input_details[2]['index'], input_features)
interpreter.invoke()
generated_ids = interpreter.get_tensor(output_details[0]['index'])
print(generated_ids)
else:
tflite_generate = interpreter.get_signature_runner()
generated_ids = tflite_generate(input_features=input_features)["sequences"]
print(generated_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(transcription)


118 changes: 118 additions & 0 deletions audio_processing/whisper/whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import enum
import sys
import time

import librosa


# import original modules
sys.path.append('../../util')
from utils import get_base_parser, update_parser, get_savepath # noqa: E402
from model_utils import check_and_download_models, format_input_tensor, get_output_tensor # noqa: E402
from image_utils import load_image # noqa: E402
from classifier_utils import plot_results, print_results, write_predictions # noqa: E402
import webcamera_utils # noqa: E402


# ======================
# Parameters 1
# ======================
INPUT_PATH = 'demo.wav'


# ======================
# Argument Parser Config
# ======================
parser = get_base_parser(
'Whisper Speech To Text', INPUT_PATH, None
)
args = update_parser(parser)

if args.tflite:
import tensorflow as tf
else:
import ailia_tflite


# ======================
# Parameters 2
# ======================
MODEL_PATH = f'whisper-tiny-en.tflite'
REMOTE_PATH = f'https://storage.googleapis.com/ailia-models-tflite/whisper/'


# ======================
# Main functions
# ======================
def recognize_from_audio():
#from datasets import load_dataset
from transformers import WhisperProcessor, WhisperFeatureExtractor, TFWhisperForConditionalGeneration, WhisperTokenizer

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny.en")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en", predict_timestamps=True)
processor = WhisperProcessor(feature_extractor, tokenizer)
model = TFWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
# Loading dataset
#ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

audio, sr = librosa.load("demo.wav", sr=16000)
inputs = feature_extractor(
audio, sampling_rate=sr, return_tensors="tf"
)
input_features = inputs.input_features

# %%
# loaded model... now with generate!
tflite_model_path = 'whisper-tiny-en.tflite'
if args.tflite:
interpreter = tf.lite.Interpreter(tflite_model_path)
else:
interpreter = ailia_tflite.Interpreter(tflite_model_path)

interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print(input_details)
print(output_details)

if args.benchmark:
print('BENCHMARK mode')
average_time = 0
for i in range(args.benchmark_count):
start = int(round(time.time() * 1000))
interpreter.set_tensor(input_details[0]['index'], input_features)
interpreter.invoke()
end = int(round(time.time() * 1000))
average_time = average_time + (end - start)
print(f'\tailia processing time {end - start} ms')
print(f'\taverage time {average_time / args.benchmark_count} ms')
else:
interpreter.set_tensor(input_details[0]['index'], input_features)
interpreter.invoke()

generated_ids = interpreter.get_tensor(output_details[0]['index'])

print(generated_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(transcription)

print('Script finished successfully.')


def main():
# model files check and download
check_and_download_models(MODEL_PATH, REMOTE_PATH)

recognize_from_audio()


if __name__ == '__main__':
main()