diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..2724da0 --- /dev/null +++ b/.flake8 @@ -0,0 +1,42 @@ +[flake8] +max-line-length = 120 +# Following 4 for black compatibility +# E501: line too long +# W503: Line break occurred before a binary operator +# E203: Whitespace before ':' +# D202 No blank lines allowed after function docstring + +# TODO fix flake8 +# D100 Missing docstring in public module +# D101 Missing docstring in public class +# D102 Missing docstring in public method +# D103 Missing docstring in public function +# D104 Missing docstring in public package +# D105 Missing docstring in magic method +# D107 Missing docstring in __init__ +# D200 One-line docstring should fit on one line with quotes +# D205 1 blank line required between summary line and description +# D209 Multi-line docstring closing quotes should be on a separate line +# D400 First line should end with a period +# D401 First line should be in imperative mood + +ignore = + E501, + W503, + E203, + D202, + + D100, + D101, + D102, + D103, + D104, + D105, + D107, + D200, + D205, + D209, + D400, + D401, + +exclude = api_pb2.py diff --git a/.gitignore b/.gitignore index cf3f811..4c43851 100644 --- a/.gitignore +++ b/.gitignore @@ -142,11 +142,16 @@ sdkconfig.* data_training/ trained_models/ +notebooks/*/ -*.npy +training_parameters.yaml +*.npy +*.pb *.ninja +*.tar +*.zip *.flac *.wav -*.tar \ No newline at end of file +*.mp3 \ No newline at end of file diff --git a/README.md b/README.md index c436e29..0bec6e3 100644 --- a/README.md +++ b/README.md @@ -1,58 +1,20 @@ -# microWakeWord +![microWakeWord logo](etc/logo.png) microWakeWord is an open-source wakeword library for detecting custom wake words on low power devices. It produces models that are suitable for using [TensorFlow Lite for Microcontrollers](https://www.tensorflow.org/lite/microcontrollers). The models are suitable for real-world usage with low false accept and false reject rates. -**microWakeword is currently available as a very early release. microWakeWord can generate features and train models. It does not include sample generation or audio augmentations. The training process produces usable models if you manually fine-tune penalty weights.** - - -## Benchmarks - -Benchmarking and comparing wake word models is challenging. It is hard to account for all the different operating environments. [Picovoice](https://github.com/Picovoice/wake-word-benchmark) has provided one benchmark for at least one point of comparison. For a more rigorous false acceptance metric, we also test on the [Dinner Party Corpus](https://www.amazon.science/publications/dipco-dinner-party-corpus) dataset. - -### Okay Nabu - -The following graph depicts the false-accept/false-reject rate for the "Okay Nabu" model. Note that the test clips used in the benchmark are created with Piper sample generator, not real voice samples. -![FPR/FRR curve for "Okay Nabu" pre-trained model](benchmarks/okay_nabu_roc_curve.png) - -The default parameters (probablity cutoff of 0.5 and average window size of 10) has a false rejection rate of 2% and 0.122 false accepts per hour with the Picovoice benchmark dataset. There are 0.187 false accepts per hour on the Dinner Party Corpus with these settings. - -### Hey Jarvis - -The following graph depicts the false-accept/false-reject rate for the "Hey Jarvis" model. Note that the test clips used in the benchmark are created with Piper sample generator, not real voice samples. -![FPR/FRR curve for "Hey Jarvis" pre-trained model](benchmarks/hey_jarvis_roc_curve.png) - -The default parameters (probablity cutoff of 0.5 and average window size of 10) has a false rejection rate of 0.67% and 0.081 false accepts per hour with the Picovoice benchmark dataset. There are 0.375 false accepts per hour on the Dinner Party Corpus with these settings. - -### Alexa - -The following graph depicts the false-accept/false-reject rate for the "Alexa" model. The positive samples are real recordings sources from the Picovoice repository. -![FPR/FRR curve for "Alexa" pre-trained model](benchmarks/alexa_roc_curve.png) - -The default parameters (probability cutoff of 0.66 and average window size of 10) has a false rejection rate of 3.49% and 0.486 false accepts per hour with the Picovoice benchmark dataset. There are 0.187 false accepts per hour on the Dinner Party Corpus with these settings. - +**microWakeword is currently available as an early release. Training new models is intended for advanced users. Training a model that works well is still very difficult, as it typically requires experimentation with hyperparameters and sample generation settings. Please share any insights you find for training a good model!** ## Detection Process -We detect the wake word in two stages. Raw audio data is processed into 40 features every 20 ms. These features construct a spectrogram. The streaming inference model uses the newest slice of feature data as input and returns a probability that the wake word is said. If the model consistently predicts the wake word over multiple windows, then we predict that the wake word has been said. - -The first stage processes the raw monochannel audio data at a sample rate of 16 kHz via the [micro_speech preprocessor](https://github.com/tensorflow/tflite-micro/tree/main/tensorflow/lite/micro/examples/micro_speech). The preprocessor generates 40 features over 30 ms (the window duration) of audio data. The preprocessor generates these features every 20 ms (the stride duration), so the first 10 ms of audio data is part of the previous window. This process is similar to calculating a Mel spectrogram for the audio data, but it is lightweight for devices with limited processing power. See the linked TFLite Micro example for full details on how the audio is processed. - -The streaming model performs inferences every 20 ms on the newest audio stride. The model is based on an [inception neural network](https://towardsdatascience.com/a-simple-guide-to-the-versions-of-the-inception-network-7fc52b863202?gi=6bc760f44aef) converted for streaming. Streaming and training the model uses heavily modified open-sourced code from [Google Research](https://github.com/google-research/google-research/tree/master/kws_streaming) found in the paper [Streaming Keyword Spotting on Mobile Devices](https://arxiv.org/pdf/2005.06720.pdf) by Rykabov, Kononenko, Subrahmanya, Visontai, and Laurenzo. +We detect the wake word in two stages. Raw audio data is processed into 40 spectrogram features every 10 ms. The streaming inference model uses the newest slice of feature data as input and returns a probability that the wake word is said. If the model consistently predicts the wake word over multiple windows, then we predict that the wake word has been said. -## Model and Training Design Notes +The first stage processes the raw monochannel audio data at a sample rate of 16 kHz via the [micro_speech preprocessor](https://github.com/tensorflow/tflite-micro/tree/main/tensorflow/lite/micro/examples/micro_speech). The preprocessor generates 40 features over 30 ms (the window duration) of audio data. The preprocessor generates these features every 10 ms (the stride duration), so the first 20 ms of audio data is part of the previous window. This process is similar to calculating a Mel spectrogram for the audio data, but it includes noise supression and automatic gain control. This makes it suitable for devices with limited processing power. See the linked TFLite Micro example for full details on how the audio is processed. -### Inception Based Model -- We apply [SubSpectral Normalization](https://arxiv.org/abs/2103.13620) after the initial convolution layer -- Temporal dilatations for later convolutions greatly improve accuracy; these operations are currently not optimized in the TFLite Micro for Espressif's chip, so by default are not configured -- The model doesn't use a Global Average Pooling layer, but rather a larger Fully Connected layer. This improves accuracy, and it is much faster on ESP32 devices. -- Some wake word phrases may not need as large of a model. Adjusting ``cnn1_filters``, ``cnn2_filters1``, and ``cnn2_filters2`` can increase or decrease the model size and latency. -- All convolutions have no padding. The training process ensures the last layer has features representing exactly ``clip_duration_ms``. +The streaming model performs inferences every 30 ms, where the initial convolution layer strides over three 10 ms slices of audio. The model is a neural network using [MixConv](https://arxiv.org/abs/1907.09595) mixed depthwise convolutions suitable for streaming. Streaming and training the model uses heavily modified open-sourced code from [Google Research](https://github.com/google-research/google-research/tree/master/kws_streaming) found in the paper [Streaming Keyword Spotting on Mobile Devices](https://arxiv.org/pdf/2005.06720.pdf) by Rykabov, Kononenko, Subrahmanya, Visontai, and Laurenzo. ### Training Process - We augment the spectrograms in several possible ways during training: - [SpecAugment](https://arxiv.org/pdf/1904.08779.pdf) masks time and frequency features - - [MixUp](https://openreview.net/forum?id=r1Ddp1-Rb) averages two spectrograms and their labels - - [FreqMix](https://arxiv.org/pdf/2204.11479.pdf) combines two spectrograms and their labels using a low-pass and high-pass filter. - The best weights are chosen as a two-step process: 1. The top priority is minimizing a specific metric like the false accepts per hour on ambient background noise first. 2. If the specified minimization target metric is met, then we maximize a different specified metric like accuracy. @@ -61,166 +23,25 @@ The streaming model performs inferences every 20 ms on the newest audio stride. 2. The ``validation_ambient`` and ``testing_ambient`` sets are all negative samples representing real-world background sounds; e.g., music, random household noises, and general speech/conversations. - Generated spectrograms are stored as [Ragged Mmap](https://github.com/hristo-vrigazov/mmap.ninja/tree/master) folders for quick loading from the disk while training. - Each feature set is configured with a ``sampling_weight`` and ``penalty_weight``. The ``sampling_weight`` parameter controls oversampling and ``penalty_weight`` controls the weight of incorrect predictions. -- Class weights are also adjustable with the ``positive_class_weight`` and ``negative_class_weight`` parameters. It is useful to increase the ``negative_class_weight`` near the end of the training process to reduce the amount of false accepts. -- We train the model in a non-streaming mode; i.e., it trains on the entire spectrogram. When finished, this is converted to a streaming model that updates every 20 ms. +- Class weights are also adjustable with the ``positive_class_weight`` and ``negative_class_weight`` parameters. It is useful to increase the ``negative_class_weight`` to reduce the amount of false acceptances. +- We train the model in a non-streaming mode; i.e., it trains on the entire spectrogram. When finished, this is converted to a streaming model that updates on only the newest spectrogram features. - Not padding the convolutions ensures the non-streaming and streaming models have nearly identical prediction behaviors. - We estimate the false accepts per hour metric during training by splitting long-duration ambient clips into appropriate-sized spectrograms with a 100 ms stride to simulate the streaming model. This is not a perfect estimate of the streaming model's real-world false accepts per hour, but it is sufficient for determining the best weights. -- We should generate spectrogram features over a longer time period than needed for training the model. The preprocessor model applies PCAN and noise reduction, and generating features over a longer time period results in models that are better to generalize. _This is not currently automatically implemented in microWakeWord._ -- We quantize the streaming models to increase performance on low-power devices. This has a small performance penalty that varies from model to model, but it typically lowers accuracy on the test dataset by around 0.05%. +- We should generate spectrogram features over a longer time period than needed for training the model. The preprocessor model applies PCAN and noise reduction, and generating features over a longer time period results in models that are better to generalize. +- We quantize the streaming models to increase performance on low-power devices. This has a small performance penalty that varies from model to model, but there is typically no reduction in accuracy. ## Model Training Process -We generate positive and negative samples using [openWakeWord](https://github.com/dscripka/openWakeWord), which relies on [Piper sample generator](https://github.com/rhasspy/piper-sample-generator). We also use openWakeWord's data tools to augment the positive and negative samples. Additional data sources are used for negative data. Currently, microWakeWord does support these steps directly. - -Audio samples are converted to features are stored as Ragged Mmaps. Currently, only converting wav audio files are supported and no direct audio augmentations are applied. - -```python -from microwakeword.feature_generation import generate_features_for_folder - -generate_features_for_folder(path_to_audio='audio_samples/training', features_output_dir='audio_features/training', set_name='audio_samples') -``` - -Training configuration options are stored in a yaml file. - -```python -# Save a yaml config that controls the training process -import yaml -import os - -config = {} - -config['train_dir'] = 'trained_models/alexa' - -# Each feature_dir should have at least one of the following folders with this structure: -# training/ -# ragged_mmap_folders_ending_in_mmap -# testing/ -# ragged_mmap_folders_ending_in_mmap -# testing_ambient/ -# ragged_mmap_folders_ending_in_mmap -# validation/ -# ragged_mmap_folders_ending_in_mmap -# validation_ambient/ -# ragged_mmap_folders_ending_in_mmap -# -# sampling_weight: Weight for choosing a spectrogram from this set in the batch -# penalty_weight: Penalizing weight for incorrect predictions from this set -# truth: Boolean whether this set has positive samples or negative samples -# truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated -# - random: choose a random portion of the entire spectrogram - useful for long negative samples -# - truncate_start: remove the start of the spectrogram -# - truncate_end: remove the end of the spectrogram -# - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets - -config['features'] = [ - { - 'features_dir': '/Volumes/MachineLearning/training_data/alexa_4990ms_spectrogram/generated_positive', - 'sampling_weight': 0.25, - 'penalty_weight': 1, - 'truth': True, - 'truncation_strategy': 'truncate_start' - }, - { - 'features_dir': '/Volumes/MachineLearning/training_data/alexa_4990ms_spectrogram/generated_negative', - 'sampling_weight': 0.25, - 'penalty_weight': 1, - 'truth': False, - 'truncation_strategy': 'truncate_start' - }, - { - 'features_dir': '/Volumes/MachineLearning/training_data/english_speech_background_1970ms', - 'sampling_weight': 0.2, - 'penalty_weight': 3, - 'truth': False, - 'truncation_strategy': 'random' - }, - { - 'features_dir': '/Volumes/MachineLearning/training_data/cv_corpus_background', - 'sampling_weight': 0.10, - 'penalty_weight': 2, - 'truth': False, - 'truncation_strategy': 'random' - }, - { - 'features_dir': '/Volumes/MachineLearning/training_data/no_speech_background_1970ms', - 'sampling_weight': 0.2, - 'penalty_weight': 3, - 'truth': False, - 'truncation_strategy': 'random' - }, - { - 'features_dir': '/Volumes/MachineLearning/training_data/ambient_background', - 'sampling_weight': 0.2, - 'penalty_weight': 2, - 'truth': False, - 'truncation_strategy': 'split' - }, - ] - -# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps -config['training_steps'] = [20000, 20000, 20000] - -# Penalizing weight for incorrect class predictions - lists that correspond to training steps -config["positive_class_weight"] = [1] -config["negative_class_weight"] = [1] -config['learning_rates'] = [0.001, 0.0005, 0.00025] # Learning rates for Adam optimizer - list that corresponds to training steps -config['batch_size'] = 100 - -config['mix_up_augmentation_prob'] = [0] # Probability of applying MixUp augmentation - list that corresponds to training steps -config['freq_mix_augmentation_prob'] = [0] # Probability of applying FreqMix augmentation - list that corresponds to training steps -config['time_mask_max_size'] = [5] # SpecAugment - list that corresponds to training steps -config['time_mask_count'] = [2] # SpecAugment - list that corresponds to training steps -config['freq_mask_max_size'] = [5] # SpecAugment - list that corresponds to training steps -config['freq_mask_count'] = [2] # SpecAugment - list that corresponds to training steps -config['eval_step_interval'] = 500 # Test the validation sets after every this many steps - -config['clip_duration_ms'] = 1490 # Maximum length of wake word that the streaming model will accept -config['window_stride_ms'] = 20 # Fixed setting for default feature generator -config['window_size_ms'] = 30 # Fixed setting for default feature generator -config['sample_rate'] = 16000 # Fixed setting for default feature generator - -# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization -# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize -# Available metrics: -# - "loss" - cross entropy error on validation set -# - "accuracy" - accuracy of validation set -# - "recall" - recall of validation set -# - "precision" - precision of validation set -# - "false_positive_rate" - false positive rate of validation set -# - "false_negative_rate" - false negative rate of validation set -# - "ambient_false_positives" - count of false positives from the split validation_ambient set -# - "ambient_false_positives_per_hour" - estimated number of false positives per hour on the split validation_ambient set -config['minimization_metric'] = 'ambient_false_positives_per_hour' # Set to N -config['target_minimization'] = 0.5 -config['maximization_metric'] = 'accuracy' - -with open(os.path.join('training_parameters.yaml'), 'w') as file: - documents = yaml.dump(config, file) -``` - -The model's hyperparameters are specified when calling the training script. - -```python -!python -m microwakeword.model_train_eval \ ---training_config='training_parameters.yaml' \ ---train 1 \ ---restore_checkpoint 1 \ ---test_tf_nonstreaming 0 \ ---test_tflite_nonstreaming 0 \ ---test_tflite_streaming 0 \ ---test_tflite_streaming_quantized 1 \ -inception \ ---cnn1_filters '32' \ ---cnn1_kernel_sizes '5' \ ---cnn1_subspectral_groups '4' \ ---cnn2_filters1 '24,24,24' \ ---cnn2_filters2 '32,64,96' \ ---cnn2_kernel_sizes '3,5,5' \ ---cnn2_subspectral_groups '1,1,1' \ ---cnn2_dilation '1,1,1' \ ---dropout 0.8 -``` +We generate samples using [Piper sample generator](https://github.com/rhasspy/piper-sample-generator). + +The generated samples are augmented before or during training to increase variability. There are pre-generated spectrogram features for various negative datasets available on [Hugging Face](https://huggingface.co/datasets/kahrendt/microwakeword). + +Please see the ``basic_training_notebook.ipynb`` notebook to see how a model is trained. This notebook will produce a model, but it will most likely not be usable! Training a usable model requires a lot of experimentation, and that notebook is meant to serve only as a starting point for advanced users. + +## Models + +See https://github.com/esphome/micro-wake-word-models to download the currently available models. ## Acknowledgements @@ -232,4 +53,5 @@ I am very thankful for many people's support to help improve this! Thank you, in - [kbx81](https://github.com/kbx81) - [synesthesiam](https://github.com/synesthesiam) - [ESPHome](https://github.com/esphome) - - [Nabu Casa](https://github.com/NabuCasa) \ No newline at end of file + - [Nabu Casa](https://github.com/NabuCasa) + - [Open Home Foundation](https://www.openhomefoundation.org/) \ No newline at end of file diff --git a/etc/logo.png b/etc/logo.png new file mode 100644 index 0000000..4b89598 Binary files /dev/null and b/etc/logo.png differ diff --git a/etc/logo.svg b/etc/logo.svg new file mode 100644 index 0000000..3a19c97 --- /dev/null +++ b/etc/logo.svg @@ -0,0 +1,48 @@ + + + + + + + + + + + + + \ No newline at end of file diff --git a/microwakeword/audio/audio_utils.py b/microwakeword/audio/audio_utils.py new file mode 100644 index 0000000..7862a9e --- /dev/null +++ b/microwakeword/audio/audio_utils.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2024 Kevin Ahrendt. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import tensorflow as tf +import webrtcvad + +from tensorflow.lite.experimental.microfrontend.python.ops import ( + audio_microfrontend_op as frontend_op, +) +from scipy.io import wavfile + +from pymicro_features import MicroFrontend + + +def generate_features_for_clip( + audio_samples: np.ndarray, step_ms: int = 20, use_c: bool = True +): + """Generates spectrogram features for the given audio data. + + Args: + audio_samples (numpy.ndarray): The clip's audio samples. + step_ms (int, optional): The window step size in ms. Defaults to 20. + use_c (bool, optional): Whether to use the C implementation of the microfrontend via pymicro-features. Defaults to True. + + Raises: + ValueError: If the provided audio data is not a 16-bit integer array. + + + Returns: + numpy.ndarray: The spectrogram features for the provided audio clip. + """ + + # Convert any float formatted audio data to an int16 array + if audio_samples.dtype in (np.float32, np.float64): + audio_samples = np.clip((audio_samples * 32768), -32768, 32767).astype(np.int16) + + if use_c: + audio_samples = audio_samples.tobytes() + micro_frontend = MicroFrontend() + features = [] + audio_idx = 0 + num_audio_bytes = len(audio_samples) + while audio_idx + 160 * 2 < num_audio_bytes: + frontend_result = micro_frontend.ProcessSamples( + audio_samples[audio_idx : audio_idx + 160 * 2] + ) + audio_idx += frontend_result.samples_read * 2 + if frontend_result.features: + features.append(frontend_result.features) + + return np.array(features).astype(np.float32) + + with tf.device("/cpu:0"): + # The default settings match the TFLM preprocessor settings. + # Preproccesor model is available from the tflite-micro repository, accessed December 2023. + micro_frontend = frontend_op.audio_microfrontend( + tf.convert_to_tensor(audio_samples), + sample_rate=16000, + window_size=30, + window_step=step_ms, + num_channels=40, + upper_band_limit=7500, + lower_band_limit=125, + enable_pcan=True, + min_signal_remaining=0.05, + out_scale=1, + out_type=tf.uint16, + ) + + spectrogram = micro_frontend.numpy() + return spectrogram + + +def save_clip(audio_samples: np.ndarray, output_file: str) -> None: + """Saves an audio clip's sample as a wave file. + + Args: + audio_samples (numpy.ndarray): The clip's audio samples. + output_file (str): Path to the desired output file. + """ + if audio_samples.dtype in (np.float32, np.float64): + audio_samples = (audio_samples * 32767).astype(np.int16) + wavfile.write(output_file, 16000, audio_samples) + + +def remove_silence_webrtc( + audio_data: np.ndarray, + frame_duration: float = 0.030, + sample_rate: int = 16000, + min_start: int = 2000, +) -> np.ndarray: + """Uses webrtc voice activity detection to remove silence from the clips + + Args: + audio_data (numpy.ndarray): The input clip's audio samples. + frame_duration (float): The frame_duration for webrtcvad. Defaults to 0.03. + sample_rate (int): The audio's sample rate. Defaults to 16000. + min_start: (int): The number of audio samples from the start of the clip to always include. Defaults to 2000. + + Returns: + numpy.ndarray: Array with the trimmed audio clip's samples. + """ + vad = webrtcvad.Vad(0) + + # webrtcvad expects int16 arrays as input, so convert if audio_data is a float + float_type = audio_data.dtype in (np.float32, np.float64) + if float_type: + audio_data = (audio_data * 32767).astype(np.int16) + + filtered_audio = audio_data[0:min_start].tolist() + + step_size = int(sample_rate * frame_duration) + + for i in range(min_start, audio_data.shape[0] - step_size, step_size): + vad_detected = vad.is_speech( + audio_data[i : i + step_size].tobytes(), sample_rate + ) + if vad_detected: + # If voice activity is detected, add it to filtered_audio + filtered_audio.extend(audio_data[i : i + step_size].tolist()) + + # If the original audio data was a float array, convert back + if float_type: + trimmed_audio = np.array(filtered_audio) + return np.array(trimmed_audio / 32767).astype(np.float32) + + return np.array(filtered_audio).astype(np.int16) diff --git a/microwakeword/audio/augmentation.py b/microwakeword/audio/augmentation.py new file mode 100644 index 0000000..594c665 --- /dev/null +++ b/microwakeword/audio/augmentation.py @@ -0,0 +1,244 @@ +# coding=utf-8 +# Copyright 2024 Kevin Ahrendt. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import audiomentations +import warnings + +import numpy as np + +from typing import List + + +class Augmentation: + """A class that handles applying augmentations to audio clips. + + Args: + augmentation_duration_s (float): The duration of the augmented clip in seconds. + augmentation_probabilities (dict, optional): Dictionary that specifies each augmentation's probability of being applied. Defaults to { "SevenBandParametricEQ": 0.0, "TanhDistortion": 0.0, "PitchShift": 0.0, "BandStopFilter": 0.0, "AddColorNoise": 0.25, "AddBackgroundNoise": 0.75, "Gain": 1.0, "GainTransition": 0.25, "RIR": 0.5, }. + impulse_paths (List[str], optional): List of directory paths that contain room impulse responses that the audio clip is reverberated with. If the list is empty, then reverberation is not applied. Defaults to []. + background_paths (List[str], optional): List of directory paths that contain audio clips to be mixed into the audio clip. If the list is empty, then the background augmentation is not applied. Defaults to []. + background_min_snr_db (int, optional): The minimum signal to noise ratio for mixing in background audio. Defaults to -10. + background_max_snr_db (int, optional): The maximum signal to noise ratio for mixing in background audio. Defaults to 10. + min_gain_db (float, optional): The minimum gain for the gain augmentation. Defaults to -45.0. + max_gain_db (float, optional): The mmaximum gain for the gain augmentation. Defaults to 0.0. + min_gain_transition_db (float, optional): The minimum gain for the gain transition augmentation. Defaults to -10.0. + max_gain_transition_db (float, optional): The mmaximum gain for the gain transition augmentation. Defaults to 10.0. + min_jitter_s (float, optional): The minimum duration in seconds that the original clip is positioned before the end of the augmented audio. Defaults to 0.0. + max_jitter_s (float, optional): The maximum duration in seconds that the original clip is positioned before the end of the augmented audio. Defaults to 0.0. + truncate_randomly: (bool, option): If true, the clip is truncated to the specified duration randomly. Otherwise, the start of the clip is truncated. + """ + + def __init__( + self, + augmentation_duration_s: float | None = None, + augmentation_probabilities: dict = { + "SevenBandParametricEQ": 0.0, + "TanhDistortion": 0.0, + "PitchShift": 0.0, + "BandStopFilter": 0.0, + "AddColorNoise": 0.25, + "AddBackgroundNoise": 0.75, + "Gain": 1.0, + "GainTransition": 0.25, + "RIR": 0.5, + }, + impulse_paths: List[str] = [], + background_paths: List[str] = [], + background_min_snr_db: int = -10, + background_max_snr_db: int = 10, + color_min_snr_db: int = 10, + color_max_snr_db: int = 30, + min_gain_db: float = -45, + max_gain_db: float = 0, + min_gain_transition_db: float = -10, + max_gain_transition_db: float = 10, + min_jitter_s: float = 0.0, + max_jitter_s: float = 0.0, + truncate_randomly: bool = False, + ): + self.truncate_randomly = truncate_randomly + ############################################ + # Configure audio duration and positioning # + ############################################ + + self.min_jitter_samples = int(min_jitter_s * 16000) + self.max_jitter_samples = int(max_jitter_s * 16000) + + if augmentation_duration_s is not None: + self.augmented_samples = int(augmentation_duration_s * 16000) + else: + self.augmented_samples = None + + assert ( + self.min_jitter_samples <= self.max_jitter_samples + ), "Minimum jitter must be less than or equal to maximum jitter." + + ####################### + # Setup augmentations # + ####################### + + # If either the background_paths or impulse_paths are not specified, use an identity transform instead + def identity_transform(samples, sample_rate): + return samples + + background_noise_augment = audiomentations.Lambda( + transform=identity_transform, p=0.0 + ) + reverb_augment = audiomentations.Lambda(transform=identity_transform, p=0.0) + + if len(background_paths): + background_noise_augment = audiomentations.AddBackgroundNoise( + p=augmentation_probabilities.get("AddBackgroundNoise", 0.0), + sounds_path=background_paths, + min_snr_db=background_min_snr_db, + max_snr_db=background_max_snr_db, + ) + + if len(impulse_paths) > 0: + reverb_augment = audiomentations.ApplyImpulseResponse( + p=augmentation_probabilities.get("RIR", 0.0), + ir_path=impulse_paths, + ) + + # Based on openWakeWord's augmentations, accessed on February 23, 2024. + self.augment = audiomentations.Compose( + transforms=[ + audiomentations.SevenBandParametricEQ( + p=augmentation_probabilities.get("SevenBandParametricEQ", 0.0), + min_gain_db=-6, + max_gain_db=6, + ), + audiomentations.TanhDistortion( + p=augmentation_probabilities.get("TanhDistortion", 0.0), + min_distortion=0.0001, + max_distortion=0.10, + ), + audiomentations.PitchShift( + p=augmentation_probabilities.get("PitchShift", 0.0), + min_semitones=-3, + max_semitones=3, + ), + audiomentations.BandStopFilter( + p=augmentation_probabilities.get("BandStopFilter", 0.0), + ), + audiomentations.AddColorNoise( + p=augmentation_probabilities.get("AddColorNoise", 0.0), + min_snr_db=color_min_snr_db, + max_snr_db=color_max_snr_db, + ), + background_noise_augment, + audiomentations.Gain( + p=augmentation_probabilities.get("Gain", 0.0), + min_gain_db=min_gain_db, + max_gain_db=max_gain_db, + ), + audiomentations.GainTransition( + p=augmentation_probabilities.get("GainTransition", 0.0), + min_gain_db=min_gain_transition_db, + max_gain_db=max_gain_transition_db, + ), + reverb_augment, + audiomentations.Compose( + transforms=[ + audiomentations.Normalize( + apply_to="only_too_loud_sounds", p=1.0 + ) + ] + ), # If the audio is clipped, normalize + ], + shuffle=False, + ) + + def add_jitter(self, input_audio: np.ndarray): + """Pads the clip on the right by a random duration between the class's min_jitter_s and max_jitter_s paramters. + + Args: + input_audio (numpy.ndarray): Array containing the audio clip's samples. + + Returns: + numpy.ndarray: Array of audio samples with silence added to the end. + """ + if self.min_jitter_samples < self.max_jitter_samples: + jitter_samples = np.random.randint( + self.min_jitter_samples, self.max_jitter_samples + ) + else: + jitter_samples = self.min_jitter_samples + + # Pad audio on the right by jitter samples + return np.pad(input_audio, (0, jitter_samples)) + + def create_fixed_size_clip(self, input_audio: np.ndarray): + """Ensures the input audio clip has a fixced length. If the duration is too long, the start of the clip is removed. If it is too short, the start of the clip is padded with silence. + + Args: + input_audio (numpy.ndarray): Array containing the audio clip's samples. + + Returns: + numpy.ndarray: Array of audio samples with `augmented_duration_s` length. + """ + if self.augmented_samples is None: + return input_audio + + if self.augmented_samples < input_audio.shape[0]: + # Truncate the too long audio by removing the start of the clip + if self.truncate_randomly: + random_start = np.random.randint( + 0, input_audio.shape[0] - self.augmented_samples + ) + input_audio = input_audio[ + random_start : random_start + self.augmented_samples + ] + else: + input_audio = input_audio[-self.augmented_samples :] + else: + # Pad with zeros at start of too short audio clip + left_padding_samples = self.augmented_samples - input_audio.shape[0] + + input_audio = np.pad(input_audio, (left_padding_samples, 0)) + + return input_audio + + def augment_clip(self, input_audio: np.ndarray): + """Augments the input audio after adding jitter and creating a fixed size clip. + + Args: + input_audio (numpy.ndarray): Array containing the audio clip's samples. + + Returns: + numpy.ndarray: The augmented audio of fixed duration. + """ + input_audio = self.add_jitter(input_audio) + input_audio = self.create_fixed_size_clip(input_audio) + + with warnings.catch_warnings(): + warnings.simplefilter( + "ignore" + ) # Suppresses warning about background clip being too quiet... TODO: find better approach! + output_audio = self.augment(input_audio, sample_rate=16000) + + return output_audio + + def augment_generator(self, audio_generator): + """A Python generator that augments clips retrived from the input audio generator. + + Args: + audio_generator (generator): A Python generator that yields audio clips. + + Yields: + numpy.ndarray: The augmented audio clip's samples. + """ + for audio in audio_generator: + yield self.augment_clip(audio) diff --git a/microwakeword/audio/clips.py b/microwakeword/audio/clips.py new file mode 100644 index 0000000..84feba2 --- /dev/null +++ b/microwakeword/audio/clips.py @@ -0,0 +1,241 @@ +# coding=utf-8 +# Copyright 2024 Kevin Ahrendt. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import audio_metadata +import datasets +import math +import os +import random +import wave + +import numpy as np + +from pathlib import Path + +from microwakeword.audio.audio_utils import remove_silence_webrtc + + +class Clips: + """Class for loading audio clips from the specified directory. The clips can first be filtered by their duration using the `min_clip_duration_s` and `max_clip_duration_s` parameters. Clips are retrieved as numpy float arrays via the `get_random_clip` method or via the `audio_generator` or `random_audio_generator` generators. Before retrieval, the audio clip can trim non-voice activiity. Before retrieval, the audio clip can be repeated until it is longer than a specified minimum duration. + + Args: + input_directory (str): Path to audio clip files. + file_pattern (str): File glob pattern for selecting audio clip files. + min_clip_duration_s (float | None, optional): The minimum clip duration (in seconds). Set to None to disable filtering by minimum clip duration. Defaults to None. + max_clip_duration_s (float | None, optional): The maximum clip duration (in seconds). Set to None to disable filtering by maximum clip duration. Defaults to None. + repeat_clip_min_duration_s (float | None, optional): If a clip is shorter than this duration, then it is repeated until it is longer than this duration. Set to None to disable repeating the clip. Defaults to None. + remove_silence (bool, optional): Use webrtcvad to trim non-voice activity in the clip. Defaults to False. + random_split_seed (int | None, optional): The random seed used to split the clips into different sets. Set to None to disable splitting the clips. Defaults to None. + split_count (int | float, optional): The percentage/count of clips to be included in the testing and validation sets. Defaults to 0.1. + trimmed_clip_duration_s: (float | None, optional): The duration of the clips to trim the end of long clips. Set to None to disable trimming. Defaults to None. + trim_zerios: (bool, optional): If true, any leading and trailling zeros are removed. Defaults to false. + """ + + def __init__( + self, + input_directory: str, + file_pattern: str, + min_clip_duration_s: float | None = None, + max_clip_duration_s: float | None = None, + repeat_clip_min_duration_s: float | None = None, + remove_silence: bool = False, + random_split_seed: int | None = None, + split_count: int | float = 0.1, + trimmed_clip_duration_s: float | None = None, + trim_zeros: bool = False, + ): + self.trim_zeros = trim_zeros + self.trimmed_clip_duration_s = trimmed_clip_duration_s + + if min_clip_duration_s is not None: + self.min_clip_duration_s = min_clip_duration_s + else: + self.min_clip_duration_s = 0.0 + + if max_clip_duration_s is not None: + self.max_clip_duration_s = max_clip_duration_s + else: + self.max_clip_duration_s = math.inf + + if repeat_clip_min_duration_s is not None: + self.repeat_clip_min_duration_s = repeat_clip_min_duration_s + else: + self.repeat_clip_min_duration_s = 0.0 + + self.remove_silence = remove_silence + + self.remove_silence_function = remove_silence_webrtc + + paths_to_clips = [str(i) for i in Path(input_directory).glob(file_pattern)] + + if (self.min_clip_duration_s == 0) and (math.isinf(self.max_clip_duration_s)): + # No durations specified, so do not filter by length + filtered_paths = paths_to_clips + else: + # Filter audio clips by length + if file_pattern.endswith("wav"): + # If it is a wave file, assume all wave files have the same parameters and filter by file size. + # Based on openWakeWord's estimate_clip_duration and filter_audio_paths in data.py, accessed March 2, 2024. + with wave.open(paths_to_clips[0], "rb") as input_wav: + channels = input_wav.getnchannels() + sample_width = input_wav.getsampwidth() + sample_rate = input_wav.getframerate() + frames = input_wav.getnframes() + + sizes = [] + sizes.extend([os.path.getsize(i) for i in paths_to_clips]) + + # Correct for the wav file header bytes. Assumes all files in the directory have same parameters. + header_correction = ( + os.path.getsize(paths_to_clips[0]) + - frames * sample_width * channels + ) + + durations = [] + for size in sizes: + durations.append( + (size - header_correction) + / (sample_rate * sample_width * channels) + ) + + filtered_paths = [ + path_to_clip + for path_to_clip, duration in zip(paths_to_clips, durations) + if (self.min_clip_duration_s < duration) + and (duration < self.max_clip_duration_s) + ] + else: + # If not a wave file, use the audio_metadata package to analyze audio file headers for the duration. + # This is slower! + filtered_paths = [] + + if (self.min_clip_duration_s > 0) or ( + not math.isinf(self.max_clip_duration_s) + ): + for audio_file in paths_to_clips: + metadata = audio_metadata.load(audio_file) + duration = metadata["streaminfo"]["duration"] + if (self.min_clip_duration_s < duration) and ( + duration < self.max_clip_duration_s + ): + filtered_paths.append(audio_file) + + # Load all filtered clips + audio_dataset = datasets.Dataset.from_dict( + {"audio": [str(i) for i in filtered_paths]} + ).cast_column("audio", datasets.Audio()) + + # Convert all clips to 16 kHz sampling rate when accessed + audio_dataset = audio_dataset.cast_column( + "audio", datasets.Audio(sampling_rate=16000) + ) + + if random_split_seed is not None: + train_testvalid = audio_dataset.train_test_split( + test_size=2 * split_count, seed=random_split_seed + ) + test_valid = train_testvalid["test"].train_test_split(test_size=0.5) + split_dataset = datasets.DatasetDict( + { + "train": train_testvalid["train"], + "test": test_valid["test"], + "validation": test_valid["train"], + } + ) + self.split_clips = split_dataset + + self.clips = audio_dataset + + def audio_generator(self, split: str | None = None, repeat: int = 1): + """A Python generator that retrieves all loaded audio clips. + + Args: + split (str | None, optional): Specifies which set the clips are retrieved from. If None, all clips are retrieved. Otherwise, it can be set to `train`, `test`, or `validation`. Defaults to None. + repeat (int, optional): The number of times each audio clip will be yielded. Defaults to 1. + + Yields: + numpy.ndarray: Array with the audio clip's samples. + """ + if split is None: + clip_list = self.clips + else: + clip_list = self.split_clips[split] + for _ in range(repeat): + for clip in clip_list: + clip_audio = clip["audio"]["array"] + + if self.remove_silence: + clip_audio = self.remove_silence_function(clip_audio) + + if self.trim_zeros: + clip_audio = np.trim_zeros(clip_audio) + + if self.trimmed_clip_duration_s: + total_samples = int(self.trimmed_clip_duration_s * 16000) + clip_audio = clip_audio[:total_samples] + + clip_audio = self.repeat_clip(clip_audio) + yield clip_audio + + def get_random_clip(self): + """Retrieves a random audio clip. + + Returns: + numpy.ndarray: Array with the audio clip's samples. + """ + rand_audio_entry = random.choice(self.clips) + clip_audio = rand_audio_entry["audio"]["array"] + + if self.remove_silence: + clip_audio = self.remove_silence_function(clip_audio) + + if self.trim_zeros: + clip_audio = np.trim_zeros(clip_audio) + + if self.trimmed_clip_duration_s: + total_samples = int(self.trimmed_clip_duration_s * 16000) + clip_audio = clip_audio[:total_samples] + + clip_audio = self.repeat_clip(clip_audio) + return clip_audio + + def random_audio_generator(self, max_clips: int = math.inf): + """A Python generator that retrieves random audio clips. + + Args: + max_clips (int, optional): The total number of clips the generator will yield before the StopIteration. Defaults to math.inf. + + Yields: + numpy.ndarray: Array with the random audio clip's samples. + """ + while max_clips > 0: + max_clips -= 1 + + yield self.get_random_clip() + + def repeat_clip(self, audio_samples: np.array): + """Repeats the audio clip until its duration exceeds the minimum specified in the class. + + Args: + audio_samples numpy.ndarray: Original audio clip's samples. + + Returns: + numpy.ndarray: Array with duration exceeding self.repeat_clip_min_duration_s. + """ + original_clip = audio_samples + desired_samples = int(self.repeat_clip_min_duration_s * 16000) + while audio_samples.shape[0] < desired_samples: + audio_samples = np.append(audio_samples, original_clip) + return audio_samples diff --git a/microwakeword/audio/spectrograms.py b/microwakeword/audio/spectrograms.py new file mode 100644 index 0000000..5adb585 --- /dev/null +++ b/microwakeword/audio/spectrograms.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2024 Kevin Ahrendt. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from microwakeword.audio.audio_utils import generate_features_for_clip +from microwakeword.audio.augmentation import Augmentation +from microwakeword.audio.clips import Clips + + +class SpectrogramGeneration: + """A class that handles generating spectrogram features for audio clips. Spectrograms can optionally be split into nonoverlapping segments for faster file loading or they can optionally be strided by dropping the last feature windows to simulate a streaming model's sequential inputs. + + Args: + clips (Clips): Object that retrieves audio clips. + augmenter (Augmentation | None, optional): Object that augments audio clips. If None, no augmentations are applied. Defaults to None. + step_ms (int, optional): The window step size in ms for the spectrogram features. Defaults to 20. + split_spectrogram_duration_s (float | None, optional): Splits generated spectrograms to yield nonoverlapping spectrograms with this duration. If None, the entire spectrogram is yielded. Defaults to None. + slide_frames (int | None, optional): Strides the generated spectrograms to yield `slide_frames` overlapping spectrogram by removing features at the end of the spectrogram. If None, the entire spectrogram is yielded. Defaults to None. + """ + + def __init__( + self, + clips: Clips, + augmenter: Augmentation | None = None, + step_ms: int = 20, + split_spectrogram_duration_s: float | None = None, + slide_frames: int | None = None, + ): + + self.clips = clips + self.augmenter = augmenter + self.step_ms = step_ms + self.split_spectrogram_duration_s = split_spectrogram_duration_s + self.slide_frames = slide_frames + + def get_random_spectrogram(self): + """Retrieves a random audio clip's spectrogram that is optionally augmented. + + Returns: + numpy.ndarry: 2D spectrogram array for the random (augmented) audio clip. + """ + clip = self.clips.get_random_clip() + if self.augmenter is not None: + clip = self.augmenter.augment_clip(clip) + + return generate_features_for_clip(clip, self.step_ms) + + def spectrogram_generator(self, random=False, max_clips=None, **kwargs): + """A Python generator that retrieves (augmented) spectrograms. + + Args: + random (bool, optional): Specifies if the source audio clips should be chosen randomly. Defaults to False. + kwargs: Parameters to pass to the clips audio generator. + + Yields: + numpy.ndarry: 2D spectrogram array for the random (augmented) audio clip. + """ + if random: + if max_clips is not None: + clip_generator = self.clips.random_audio_generator(max_clips=max_clips) + else: + clip_generator = self.clips.random_audio_generator() + else: + clip_generator = self.clips.audio_generator(**kwargs) + + if self.augmenter is not None: + augmented_generator = self.augmenter.augment_generator(clip_generator) + else: + augmented_generator = clip_generator + + for augmented_clip in augmented_generator: + spectrogram = generate_features_for_clip(augmented_clip, self.step_ms) + + if self.split_spectrogram_duration_s is not None: + # Splits the resulting spectrogram into non-overlapping spectrograms. The features from the first 20 feature windows are dropped. + desired_spectrogram_length = int( + self.split_spectrogram_duration_s / (self.step_ms / 1000) + ) + + if spectrogram.shape[0] > desired_spectrogram_length + 20: + slided_spectrograms = np.lib.stride_tricks.sliding_window_view( + spectrogram, + window_shape=(desired_spectrogram_length, spectrogram.shape[1]), + )[20::desired_spectrogram_length, ...] + + for i in range(slided_spectrograms.shape[0]): + yield np.squeeze(slided_spectrograms[i]) + else: + yield spectrogram + elif self.slide_frames is not None: + # Generates self.slide_frames spectrograms by shifting over the already generated spectrogram + spectrogram_length = spectrogram.shape[0] - self.slide_frames + 1 + + slided_spectrograms = np.lib.stride_tricks.sliding_window_view( + spectrogram, window_shape=(spectrogram_length, spectrogram.shape[1]) + ) + for i in range(self.slide_frames): + yield np.squeeze(slided_spectrograms[i]) + else: + yield spectrogram diff --git a/microwakeword/data.py b/microwakeword/data.py index 92982d1..d1b2992 100644 --- a/microwakeword/data.py +++ b/microwakeword/data.py @@ -15,7 +15,6 @@ """Functions and classes for loading/augmenting spectrograms""" -import copy import os import random @@ -25,77 +24,17 @@ from pathlib import Path from mmap_ninja.ragged import RaggedMmap - -def mixup_augment( - spectrogram_1, truth_1, weight_1, spectrogram_2, truth_2, weight_2, mix_ratio -): - """Applies MixUp augment to the input spectrograms. - Based on mixup: BEYOND EMPIRICAL RISK MINIMIZATION by H. Zhang, M. Cisse, Y. Dauphin, D. Lopez-Paz - https://openreview.net/pdf?id=r1Ddp1-Rb - - Args: - spectrogram_1: the first spectrogram. - truth_1: the ground truth of the first spectrogram - weight_1: the penalty weight of the first spectrogram - spectrogram_2: the second spectrogram - truth_2: the ground truth of the second spectrogram - weight_2: the penalty weight of the second spectrogram - - Returns: - spectrogram: the blended spectrogram - truth: the blended ground truth - weight: the blended penalty weight - """ - - combined_spectrogram = spectrogram_1 * mix_ratio + spectrogram_2 * (1 - mix_ratio) - combined_truth = float(truth_1) * mix_ratio + float(truth_2) * (1 - mix_ratio) - combined_weight = weight_1 * mix_ratio + weight_2 * (1 - mix_ratio) - - return combined_spectrogram, combined_truth, combined_weight - - -def freqmix_augment( - spectrogram_1, truth_1, weight_1, spectrogram_2, truth_2, weight_2, mix_ratio -): - """Applies FreqMix augment to the input spectrograms. - Based on END-TO-END AUDIO STRIKES BACK: BOOSTING AUGMENTATIONS TOWARDS AN EFFICIENT AUDIO CLASSIFICATION NETWORK by A. Gazneli, G. Zimerman, T. Ridnik, G. Sharir, A. Noy - https://arxiv.org/pdf/2204.11479v5.pdf - - Args: - spectrogram_1: the first spectrogram - truth_1: the ground truth of the first spectrogram - weight_1: the penalty weight of the first spectrogram - spectrogram_2: the second spectrogram - truth_2: the ground truth of the second spectrogram - weight_2: the penalty weight of the second spectrogram - - Returns: - spectrogram: the blended spectrogram - truth: the blended ground truth - weight: the blended penalty weight - """ - - freq_bin_cutoff = int(mix_ratio * 40) - - combined_spectrogram = np.concatenate( - (spectrogram_1[:, :freq_bin_cutoff], spectrogram_2[:, freq_bin_cutoff:]), axis=1 - ) - combined_truth = float(truth_1) * (freq_bin_cutoff / 40.0) + float(truth_2) * ( - 1 - freq_bin_cutoff / 40.0 - ) - combined_weight = weight_1 * (freq_bin_cutoff / 40.0) + weight_2 * ( - 1 - freq_bin_cutoff / 40.0 - ) - - return combined_spectrogram, combined_truth, combined_weight +from microwakeword.audio.clips import Clips +from microwakeword.audio.augmentation import Augmentation +from microwakeword.audio.spectrograms import SpectrogramGeneration def spec_augment( - spectrogram, - time_mask_max_size=0, - time_mask_count=1, - freq_mask_max_size=0, - freq_mask_count=1, + spectrogram: np.ndarray, + time_mask_max_size: int = 0, + time_mask_count: int = 0, + freq_mask_max_size: int = 0, + freq_mask_count: int = 0, ): """Applies SpecAugment to the input spectrogram. Based on SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition by D. Park, W. Chan, Y. Zhang, C. Chiu, B. Zoph, E Cubuk, Q Le @@ -103,41 +42,47 @@ def spec_augment( Implementation based on https://github.com/pyyush/SpecAugment/tree/master Args: - spectrogram: the input spectrogram - time_mask_max_size: maximum size of time feature masks - time_mask_count: the total number of separate time masks - freq_mask_max_size: maximum size of frequency feature masks - time_mask_count: the total number of separate feature masks + spectrogram (numpy.ndarray): The input spectrogram. + time_mask_max_size (int): The maximum size of time feature masks. Defaults to 0. + time_mask_count (int): The total number of separate time masks. Defaults to 0. + freq_mask_max_size (int): The maximum size of frequency feature masks. Defaults to 0. + time_mask_count (int): The total number of separate feature masks. Defaults to 0. Returns: - masked spectrogram + numpy.ndarray: The masked spectrogram. """ - freq_bins = spectrogram.shape[0] - time_frames = spectrogram.shape[1] + time_frames = spectrogram.shape[0] + freq_bins = spectrogram.shape[1] - for i in range(freq_mask_count): - f = int(np.random.uniform(0, freq_mask_max_size)) - f0 = random.randint(0, freq_bins - f) - spectrogram[f0 : f0 + f, :] = 0 + # Spectrograms yielded from a generator are read only + augmented_spectrogram = np.copy(spectrogram) for i in range(time_mask_count): t = int(np.random.uniform(0, time_mask_max_size)) t0 = random.randint(0, time_frames - t) - spectrogram[:, t0 : t0 + t] = 0 + augmented_spectrogram[t0 : t0 + t, :] = 0 - return spectrogram + for i in range(freq_mask_count): + f = int(np.random.uniform(0, freq_mask_max_size)) + f0 = random.randint(0, freq_bins - f) + augmented_spectrogram[:, f0 : f0 + f] = 0 + + return augmented_spectrogram def fixed_length_spectrogram( - spectrogram, features_length, truncation_strategy="random" + spectrogram: np.ndarray, + features_length: int, + truncation_strategy: str = "random", + right_cutoff: int = 0, ): - """Returns a spectrogram with specified length. Pads with zeros at the start if too short. + """Returns a spectrogram with specified length. Pads with zeros at the start if too short. Removes feature windows following ``truncation_strategy`` if too long. Args: - spectrogram: the spectrogram to truncate or pad - features_length: the desired spectrogram length - truncation_strategy: how to truncate if ``spectrogram`` is longer than ``features_length`` One of: + spectrogram (numpy.ndarray): The spectrogram to truncate or pad. + features_length (int): The desired spectrogram length. + truncation_strategy (str): How to truncate if ``spectrogram`` is longer than ``features_length`` One of: random: choose a random portion of the entire spectrogram - useful for long negative samples truncate_start: remove the start of the spectrogram truncate_end: remove the end of the spectrogram @@ -145,8 +90,9 @@ def fixed_length_spectrogram( Returns: - fixed length spectrogram after truncating or padding + numpy.ndarry: The fixed length spectrogram due to padding or truncation. """ + data_length = spectrogram.shape[0] features_offset = 0 if data_length > features_length: @@ -154,13 +100,13 @@ def fixed_length_spectrogram( features_offset = np.random.randint(0, data_length - features_length) elif truncation_strategy == "none": # return the entire spectrogram - features_offset = 0 features_length = data_length elif truncation_strategy == "truncate_start": features_offset = data_length - features_length elif truncation_strategy == "truncate_end": features_offset = 0 - + elif truncation_strategy == "fixed_right_cutoff": + features_offset = data_length - features_length - right_cutoff else: pad_slices = features_length - data_length @@ -172,64 +118,55 @@ def fixed_length_spectrogram( return spectrogram[features_offset : (features_offset + features_length)] -class FeatureHandler(object): - """Class that handles loading spectrogram features and providing them to the training and testing functions. +class MmapFeatureGenerator(object): + """A class that handles loading spectrograms from Ragged MMaps for training or testing. Args: - config: dictionary containing microWakeWord training configuration + path (str): Input directory to the Ragged MMaps. The Ragged MMap folders should be included in the following file structure: + training/ (spectrograms to use for training the model) + validation/ (spectrograms used to validate the model while training) + testing/ (spectrograms used to test the model after training) + validation_ambient/ (spectrograms of long duration background audio clips that are split and validated while training) + testing_ambient/ (spectrograms of long duration background audio clips to test the model after training) + label (bool): The class each spectrogram represents; i.e., wakeword or not. + sampling_weight (float): The sampling weight for how frequently a spectrogram from this dataset is chosen. + penalty_weight (float): The penalizing weight for incorrect predictions for each spectrogram. + truncation_strategy (str): How to truncate if ``spectrogram`` is too long. + stride (int): The stride in the model's first layer. + step (float): The window step duration (in seconds). + fixed_right_cutoffs (list[int]): List of spectogram slices to cutoff on the right if the truncation strategy is "fixed_right_cutoff". In training mode, its randomly chosen from the list. Otherwise, it yields spectrograms with all cutoffs in the list. """ def __init__( self, - config, + path: str, + label: bool, + sampling_weight: float, + penalty_weight: float, + truncation_strategy: str, + stride: int, + step: float, + fixed_right_cutoffs: list[int] = [0], ): - self.features = [] + self.label = float(label) + self.sampling_weight = sampling_weight + self.penalty_weight = penalty_weight + self.truncation_strategy = truncation_strategy + self.fixed_right_cutoffs = fixed_right_cutoffs - logging.info("Loading and analyzing data sets.") + self.stride = stride + self.step = step - features = copy.deepcopy(config["features"]) + self.stats = {} + self.feature_sets = {} - for feature_set in features: - feature_set["testing"] = [] - feature_set["training"] = [] - feature_set["validation"] = [] - feature_set["validation_ambient"] = [] - feature_set["testing_ambient"] = [] - feature_set["loaded_features"] = [] - feature_set["stats"] = {} + self.feature_sets["testing"] = [] + self.feature_sets["training"] = [] + self.feature_sets["validation"] = [] + self.feature_sets["validation_ambient"] = [] + self.feature_sets["testing_ambient"] = [] - self.prepare_data(feature_set) - self.features.append(feature_set) - - modes = [ - "training", - "validation", - "validation_ambient", - "testing", - "testing_ambient", - ] - - for mode in modes: - logging.info( - "%s mode has %d spectrograms representing %.1f hours of audio", - *(mode, self.get_mode_size(mode), self.get_mode_duration(mode) / 3600.0) - ) - - def prepare_data(self, feature_dict): - """Loads data from a feature's config entry. - - Args: - feature_dict: dictionary with keys for: - features_dir: directory containing diffferent mode folders - sampling_weight: weight for choosing a spectrogram from this set in the batch - penalty_weight: penalizing weight for incorrect predictions from this set - truth: boolean representing whether this set has positive samples or negative samples - truncation_strategy: if a spectrogram is longer than necessary for training, how is it truncated - """ - data_dir = feature_dict["features_dir"] - - if not os.path.exists(data_dir): - print("ERROR:" + str(data_dir) + "directory doesn't exist") + self.loaded_features = [] dirs = [ "testing", @@ -243,7 +180,7 @@ def prepare_data(self, feature_dict): duration = 0.0 count = 0 - search_path_directory = os.path.join(data_dir, set_index) + search_path_directory = os.path.join(path, set_index) search_path = [ str(i) for i in Path(os.path.abspath(search_path_directory)).glob("**/*_mmap/") @@ -252,88 +189,338 @@ def prepare_data(self, feature_dict): for mmap_path in search_path: imported_features = RaggedMmap(mmap_path) - feature_dict["loaded_features"].append(imported_features) - feature_index = len(feature_dict["loaded_features"]) - 1 + self.loaded_features.append(imported_features) + feature_index = len(self.loaded_features) - 1 for i in range(0, len(imported_features)): - feature_dict[set_index].append( + self.feature_sets[set_index].append( { "loaded_feature_index": feature_index, "subindex": i, } ) - duration += ( - 0.02 * imported_features[i].shape[0] - ) # Each feature represents 0.02 seconds of audio + duration += step * imported_features[i].shape[0] count += 1 - random.shuffle(feature_dict[set_index]) + random.shuffle(self.feature_sets[set_index]) - feature_dict["stats"][set_index] = { + self.stats[set_index] = { "spectrogram_count": count, "total_duration": duration, } + def get_mode_duration(self, mode: str): + """Retrieves the total duration of the spectrograms in the mode set. + + Args: + mode (str): Specifies the set. One of "training", "validation", "testing", "validation_ambient", "testing_ambient". + + Returns: + float: The duration in hours. + """ + return self.stats[mode]["total_duration"] + + def get_mode_size(self, mode): + """Retrieves the total count of the spectrograms in the mode set. + + Args: + mode (str): Specifies the set. One of "training", "validation", "testing", "validation_ambient", "testing_ambient". + + Returns: + int: The spectrogram count. + """ + return self.stats[mode]["spectrogram_count"] + + def get_random_spectrogram( + self, mode: str, features_length: int, truncation_strategy: str + ): + """Retrieves a random spectrogram from the specified mode with specified length after truncation. + + Args: + mode (str): Specifies the set. One of "training", "validation", "testing", "validation_ambient", "testing_ambient". + features_length (int): The length of the spectrogram in feature windows. + truncation_strategy (str): How to truncate if ``spectrogram`` is too long. + + Returns: + numpy.ndarray: A random spectrogram of specified length after truncation. + """ + right_cutoff = 0 + if truncation_strategy == "default": + truncation_strategy = self.truncation_strategy + + if truncation_strategy == "fixed_right_cutoff": + right_cutoff = random.choice(self.fixed_right_cutoffs) + + feature = random.choice(self.feature_sets[mode]) + spectrogram = self.loaded_features[feature["loaded_feature_index"]][ + feature["subindex"] + ] + + spectrogram = fixed_length_spectrogram( + spectrogram, + features_length, + truncation_strategy, + right_cutoff, + ) + + # Spectrograms with type np.uint16 haven't been scaled + if np.issubdtype(spectrogram.dtype, np.uint16): + spectrogram = spectrogram.astype(np.float32) * 0.0390625 + + return spectrogram + + def get_feature_generator( + self, + mode, + features_length, + truncation_strategy="default", + ): + """A Python generator that yields spectrograms from the specified mode of specified length after truncation. + + Args: + mode (str): Specifies the set. One of "training", "validation", "testing", "validation_ambient", "testing_ambient". + features_length (int): The length of the spectrogram in feature windows. + truncation_strategy (str): How to truncate if ``spectrogram`` is too long. + + Yields: + numpy.ndarray: A random spectrogram of specified length after truncation. + """ + if truncation_strategy == "default": + truncation_strategy = self.truncation_strategy + + for feature in self.feature_sets[mode]: + spectrogram = self.loaded_features[feature["loaded_feature_index"]][ + feature["subindex"] + ] + + # Spectrograms with type np.uint16 haven't been scaled + if np.issubdtype(spectrogram.dtype, np.uint16): + spectrogram = spectrogram.astype(np.float32) * 0.0390625 + + if truncation_strategy == "split": + for feature_start_index in range( + 0, + spectrogram.shape[0] - features_length, + int(1000 * self.step * self.stride), + ): # 10*2 features corresponds to 200 ms + split_spectrogram = spectrogram[ + feature_start_index : feature_start_index + features_length + ] + + yield split_spectrogram + else: + for cutoff in self.fixed_right_cutoffs: + fixed_spectrogram = fixed_length_spectrogram( + spectrogram, + features_length, + truncation_strategy, + cutoff, + ) + + yield fixed_spectrogram + + +class ClipsHandlerWrapperGenerator(object): + """A class that handles loading spectrograms from audio files on the disk to use while training. This generates spectrograms with random augmentations applied during the training process. + + Args: + spectrogram_generation (SpectrogramGeneration): Object that handles generating spectrograms from audio files. + label (bool): The class each spectrogram represents; i.e., wakeword or not. + sampling_weight (float): The sampling weight for how frequently a spectrogram from this dataset is chosen. + penalty_weight (float): The penalizing weight for incorrect predictions for each spectrogram. + truncation_strategy (str): How to truncate if ``spectrogram`` is too long. + """ + + def __init__( + self, + spectrogram_generation: SpectrogramGeneration, + label: bool, + sampling_weight: float, + penalty_weight: float, + truncation_strategy: str, + ): + self.spectrogram_generation = spectrogram_generation + self.label = label + self.sampling_weight = sampling_weight + self.penalty_weight = penalty_weight + self.truncation_strategy = truncation_strategy + + self.augmented_generator = self.spectrogram_generation.spectrogram_generator( + random=True + ) + def get_mode_duration(self, mode): + """Function to maintain compatability with the MmapFeatureGenerator class.""" + return 0.0 + + def get_mode_size(self, mode): + """Function to maintain compatability with the MmapFeatureGenerator class. This class is intended only for retrieving spectrograms for training.""" + if mode == "training": + return len(self.spectrogram_generation.clips.clips) + else: + return 0 + + def get_random_spectrogram(self, mode, features_length, truncation_strategy): + """Retrieves a random spectrogram from the specified mode with specified length after truncation. + + Args: + mode (str): Specifies the set, but is ignored for this class. It is assumed the spectrograms will be for training. + features_length (int): The length of the spectrogram in feature windows. + truncation_strategy (str): How to truncate if ``spectrogram`` is too long. + + Returns: + numpy.ndarray: A random spectrogram of specified length after truncation. + """ + + if truncation_strategy == "default": + truncation_strategy = self.truncation_strategy + + spectrogram = next(self.augmented_generator) + + spectrogram = fixed_length_spectrogram( + spectrogram, + features_length, + truncation_strategy, + right_cutoff=0, + ) + + # Spectrograms with type np.uint16 haven't been scaled + if np.issubdtype(spectrogram.dtype, np.uint16): + spectrogram = spectrogram.astype(np.float32) * 0.0390625 + + return spectrogram + + def get_feature_generator( + self, + mode, + features_length, + truncation_strategy="default", + ): + """Function to maintain compatability with the MmapFeatureGenerator class.""" + for x in []: + yield x + + +class FeatureHandler(object): + """Class that handles loading spectrogram features and providing them to the training and testing functions. + + Args: + config: dictionary containing microWakeWord training configuration + """ + + def __init__( + self, + config: dict, + ): + self.feature_providers = [] + + logging.info("Loading and analyzing data sets.") + + for feature_set in config["features"]: + if feature_set["type"] == "mmap": + self.feature_providers.append( + MmapFeatureGenerator( + feature_set["features_dir"], + feature_set["truth"], + feature_set["sampling_weight"], + feature_set["penalty_weight"], + feature_set["truncation_strategy"], + stride=config["stride"], + step=config["window_step_ms"] / 1000.0, + fixed_right_cutoffs=feature_set.get("fixed_right_cutoffs", [0]), + ) + ) + elif feature_set["type"] == "clips": + clips_handler = Clips(**feature_set["clips_settings"]) + augmentation_applier = Augmentation( + **feature_set["augmentation_settings"] + ) + spectrogram_generator = SpectrogramGeneration( + clips_handler, + augmentation_applier, + **feature_set["spectrogram_generation_settings"], + ) + self.feature_providers.append( + ClipsHandlerWrapperGenerator( + spectrogram_generator, + feature_set["truth"], + feature_set["sampling_weight"], + feature_set["penalty_weight"], + feature_set["truncation_strategy"], + ) + ) + set_modes = [ + "training", + "validation", + "testing", + "validation_ambient", + "testing_ambient", + ] + total_spectrograms = 0 + for set in set_modes: + total_spectrograms += self.feature_providers[-1].get_mode_size(set) + + if total_spectrograms == 0: + logging.warning("No spectrograms found in a configured feature set:") + logging.warning(feature_set) + + def get_mode_duration(self, mode: str): """Returns the durations of all spectrogram features in the given mode. Args: - mode: which training set to compute duration over. One of `training`, `testing`, `testing_ambient`, `validation`, or `validation_ambient` + mode (str): which training set to compute duration over. One of `training`, `testing`, `testing_ambient`, `validation`, or `validation_ambient` Returns: duration, in seconds, of all spectrograms in this mode """ sample_duration = 0 - for feature_set in self.features: - # sample_count += len(feature_set[mode]) - sample_duration += feature_set["stats"][mode]["total_duration"] + for provider in self.feature_providers: + sample_duration += provider.get_mode_duration(mode) return sample_duration - def get_mode_size(self, mode): + def get_mode_size(self, mode: str): """Returns the count of all spectrogram features in the given mode. Args: - mode: which training set to count the spectrograms. One of `training`, `testing`, `testing_ambient`, `validation`, or `validation_ambient` + mode (str): which training set to count the spectrograms. One of `training`, `testing`, `testing_ambient`, `validation`, or `validation_ambient` Returns: count of spectrograms in given mode """ sample_count = 0 - for feature_set in self.features: - sample_count += feature_set["stats"][mode]["spectrogram_count"] + for provider in self.feature_providers: + sample_count += provider.get_mode_size(mode) return sample_count def get_data( self, - mode, - batch_size, - features_length, - truncation_strategy="default", - augmentation_policy={ - "mix_up_prob": 0.0, + mode: str, + batch_size: int, + features_length: int, + truncation_strategy: str = "default", + augmentation_policy: dict = { "freq_mix_prob": 0.0, - "time_mask_max_size": 10, - "time_mask_count": 1, - "freq_mask_max_size": 1, - "freq_mask_count": 3, + "time_mask_max_size": 0, + "time_mask_count": 0, + "freq_mask_max_size": 0, + "freq_mask_count": 0, }, ): """Gets spectrograms from the appropriate mode. Ensures spectrograms are the approriate length and optionally applies augmentation. Args: - mode: which training set to count the spectrograms. One of `training`, `testing`, `testing_ambient`, `validation`, or `validation_ambient` - batch_size: number of spectrograms in the sample for training mode - features_length: the length of the spectrograms - truncation_strategy: how to truncate spectrograms longer than `features_length` - augmentation_policy: dictionary that specifies augmentation settings. It has the following keys: - mix_up_prob: probability that MixUp is applied + mode (str): which training set to count the spectrograms. One of `training`, `testing`, `testing_ambient`, `validation`, or `validation_ambient` + batch_size (int): number of spectrograms in the sample for training mode + features_length (int): the length of the spectrograms + truncation_strategy (str): how to truncate spectrograms longer than `features_length` + augmentation_policy (dict): dictionary that specifies augmentation settings. It has the following keys: freq_mix_prob: probability that FreqMix is applied - time_mask_max_size: maximum size of time feature masks for SpecAugment + time_mask_max_size: maximum size of time masks for SpecAugment time_mask_count: the total number of separate time masks applied for SpecAugment freq_mask_max_size: maximum size of frequency feature masks for SpecAugment - time_mask_count: the total number of separate feature masks applied for SpecAugment + freq_mask_count: the total number of separate feature masks applied for SpecAugment Returns: data: spectrograms in a NumPy array (or as a list if in mode is `*_ambient`) @@ -341,216 +528,70 @@ def get_data( weights: penalizing weight for incorrect predictions for each spectrogram """ - combination_augments = (augmentation_policy["mix_up_prob"] > 0) or ( - augmentation_policy["freq_mix_prob"] > 0 - ) - if mode == "training": sample_count = batch_size elif (mode == "validation") or (mode == "testing"): sample_count = self.get_mode_size(mode) - else: - sample_count = 0 - for feature_set in self.features: - features_in_this_mode = feature_set[mode] - - if len(features_in_this_mode) > 0: - if truncation_strategy != "none": - for features in features_in_this_mode: - spectrogram = feature_set["loaded_features"][ - features["loaded_feature_index"] - ][features["subindex"]] - sample_count += ( - spectrogram.shape[0] - features_length - ) // 10 + 1 - else: - sample_count += len(features_in_this_mode) - - spectrogram_shape = (features_length, 40) - - data = np.zeros((sample_count,) + spectrogram_shape) - labels = np.full(sample_count, 0.0) - weights = np.ones(sample_count) - - if mode.endswith("ambient") and truncation_strategy == "none": - # Use a list instead of a numpy array; ambient set's spectrograms may have different lengths - data = [] + + data = [] + labels = [] + weights = [] if mode == "training": - random_feature_sets = random.choices( + random_feature_providers = random.choices( [ - feature_set - for feature_set in self.features - if len(feature_set["training"]) + provider + for provider in self.feature_providers + if provider.get_mode_size("training") ], [ - feature_set["sampling_weight"] - for feature_set in self.features - if len(feature_set["training"]) - ], - k=sample_count, - ) - random_feature_sets2 = random.choices( - [ - feature_set - for feature_set in self.features - if len(feature_set["training"]) - ], - [ - feature_set["sampling_weight"] - for feature_set in self.features - if len(feature_set["training"]) + provider.sampling_weight + for provider in self.feature_providers + if provider.get_mode_size("training") ], k=sample_count, ) - for i in range(sample_count): - feature_set_1 = random_feature_sets[i] - feature_1 = random.choice(feature_set_1["training"]) - spectrogram_1 = feature_set_1["loaded_features"][ - feature_1["loaded_feature_index"] - ][feature_1["subindex"]] - - if truncation_strategy == "default": - truncation_strategy_1 = feature_set_1["truncation_strategy"] - else: - truncation_strategy_1 = truncation_strategy - - spectrogram_1 = fixed_length_spectrogram( - spectrogram_1, - features_length, - truncation_strategy=truncation_strategy_1, + for provider in random_feature_providers: + spectrogram = provider.get_random_spectrogram( + "training", features_length, truncation_strategy ) - - if combination_augments: - feature_set_2 = random_feature_sets2[i] - feature_2 = random.choice(feature_set_2["training"]) - spectrogram_2 = feature_set_2["loaded_features"][ - feature_2["loaded_feature_index"] - ][feature_2["subindex"]] - - if truncation_strategy == "default": - truncation_strategy_2 = feature_set_2["truncation_strategy"] - else: - truncation_strategy_2 = truncation_strategy - - spectrogram_2 = fixed_length_spectrogram( - spectrogram_2, - features_length, - truncation_strategy=truncation_strategy_2, - ) - - data[i] = spectrogram_1 - labels[i] = float(feature_set_1["truth"]) - weights[i] = float(feature_set_1["penalty_weight"]) - - if combination_augments and ( - np.random.rand() - < ( - augmentation_policy["mix_up_prob"] - + augmentation_policy["freq_mix_prob"] - ) - ): - mix_ratio = np.random.rand() - - which_augment = random.choices( - [0, 1], - [ - augmentation_policy["mix_up_prob"], - augmentation_policy["freq_mix_prob"], - ], - k=1, - ) - - if which_augment[0] == 0: - data[i], labels[i], weights[i] = mixup_augment( - spectrogram_1, - feature_set_1["truth"], - feature_set_1["penalty_weight"], - spectrogram_2, - feature_set_2["truth"], - feature_set_2["penalty_weight"], - mix_ratio, - ) - else: - data[i], labels[i], weights[i] = freqmix_augment( - spectrogram_1, - feature_set_1["truth"], - feature_set_1["penalty_weight"], - spectrogram_2, - feature_set_2["truth"], - feature_set_2["penalty_weight"], - mix_ratio, - ) - - data[i] = spec_augment( - data[i], + spectrogram = spec_augment( + spectrogram, augmentation_policy["time_mask_max_size"], augmentation_policy["time_mask_count"], augmentation_policy["freq_mask_max_size"], augmentation_policy["freq_mask_count"], ) - elif (mode == "validation") or (mode == "testing"): - index = 0 - for feature_set in self.features: - for feature_index in feature_set[mode]: - spectrogram = feature_set["loaded_features"][ - feature_index["loaded_feature_index"] - ][feature_index["subindex"]] - if truncation_strategy == "default": - truncation_strategy = feature_set["truncation_strategy"] + data.append(spectrogram) + labels.append(float(provider.label)) + weights.append(float(provider.penalty_weight)) + else: + for provider in self.feature_providers: + generator = provider.get_feature_generator( + mode, features_length, truncation_strategy + ) - data[index] = fixed_length_spectrogram( - spectrogram, - features_length, - truncation_strategy=truncation_strategy, - ) - labels[index] = feature_set["truth"] - weights[index] = feature_set["penalty_weight"] + for spectrogram in generator: + data.append(spectrogram) + labels.append(provider.label) + weights.append(provider.penalty_weight) + + if truncation_strategy != "none": + # Spectrograms are all the same length, convert to numpy array + data = np.array(data) + labels = np.array(labels) + weights = np.array(weights) - index += 1 + if truncation_strategy == "none": + # Spectrograms may be of different length + return data, np.array(labels), np.array(weights) - # Randomize the order of the testing and validation sets - indices = np.arange(data.shape[0]) + indices = np.arange(labels.shape[0]) + + if mode == "testing" or "validation": + # Randomize the order of the data, weights, and labels np.random.shuffle(indices) - data = data[indices] - labels = labels[indices] - weights = weights[indices] - else: - # ambient testing, split the long spectrograms into overlapping chunks - index = 0 - for feature_set in self.features: - features_in_this_mode = feature_set[mode] - - if len(features_in_this_mode) > 0: - for features in features_in_this_mode: - spectrogram = feature_set["loaded_features"][ - features["loaded_feature_index"] - ][features["subindex"]] - - if truncation_strategy == "default": - truncation_strategy = feature_set["truncation_strategy"] - - if truncation_strategy == "split": - for subset_feature_index in range( - 0, spectrogram.shape[0], 10 - ): - if ( - subset_feature_index + features_length - < spectrogram.shape[0] - ): - data[index] = spectrogram[ - subset_feature_index : subset_feature_index - + features_length - ] - labels[index] = 0.0 - weights[index] = 1.0 - index += 1 - else: - data.append(spectrogram) - labels[index] = 0.0 - weights[index] = 1.0 - - return data, labels, weights + return data[indices], labels[indices], weights[indices] diff --git a/microwakeword/feature_generation.py b/microwakeword/feature_generation.py deleted file mode 100644 index c56919f..0000000 --- a/microwakeword/feature_generation.py +++ /dev/null @@ -1,403 +0,0 @@ -# coding=utf-8 -# Copyright 2024 Kevin Ahrendt. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Function to generate spectrograms and a class for loading and augmenting audio clips""" - -import audiomentations -import audio_metadata -import datasets -import math -import os -import random -import wave - -import numpy as np -import tensorflow as tf - -from mmap_ninja.ragged import RaggedMmap -from pathlib import Path -from tensorflow.lite.experimental.microfrontend.python.ops import ( - audio_microfrontend_op as frontend_op, -) - - -def generate_features_for_clip(audio, desired_spectrogram_length=None): - """Generates spectrogram features for the given audio data. - - Args: - clip (ndarray): audio data with sample rate 16 kHz and 16-bit samples - desired_spectrogram_length (int, optional): Number of time features to include in the spectrogram. - Truncates earlier time features. Set to None to disable. - - Returns: - (ndarray): spectrogram audio features - """ - with tf.device("/cpu:0"): - # The default settings match the TFLM preprocessor settings. - # Preproccesor model is available from the tflite-micro repository, accessed December 2023. - micro_frontend = frontend_op.audio_microfrontend( - tf.convert_to_tensor(audio), - sample_rate=16000, - window_size=30, - window_step=20, - num_channels=40, - upper_band_limit=7500, - lower_band_limit=125, - enable_pcan=True, - min_signal_remaining=0.05, - out_scale=1, - out_type=tf.float32, - ) - output = tf.multiply(micro_frontend, 0.0390625) - - spectrogram = output.numpy() - if desired_spectrogram_length is not None: - return spectrogram[ - -desired_spectrogram_length: - ] # truncate to match desired spectrogram size - return spectrogram - - -class ClipsHandler: - """ClipsHandler object that loads audio files from the disk, optionally filters by length, augments clips, and - generates spectrogram features for augmented clips. - - Default augmentation settings and probabilities are borrowed from openWakeWord's data.py, accessed on February 23, 2024. - - Args: - input_path (str): The path to audio files to be augmented. - input_glob (str): The glob to choose audio files in `input_path`. Most audio types are supported, as clips - will automatically be converted to the appropriate format. - impulse_paths (List[str], optional): The paths to room impulse response files. - Set to None to disable the room impulse response augmentation. - background_paths (List[str], optional): The paths to background audio files. - Set to None to disable the background noise augmentaion. - augmentation_probabilities (dict, optional): The individual probabilities of each augmentation. If all probabilities - are zero, the input audio files will simply be padded with silence. The - default values are: - { - "SevenBandParametricEQ": 0.25, - "TanhDistortion": 0.25, - "PitchShift": 0.25, - "BandStopFilter": 0.25, - "AddBackgroundNoise": 0.75, - "Gain": 1.0, - "RIR": 0.5 - } - min_clip_duration_s (float, optional): The minimum clip duration (in seconds) of the input audio clips. - Set to None to not filter clips. - max_clip_duration_s (float, optional): The maximum clip duration (in seconds) of the input audio clips. - Set to None to not filter clips. - augmented_duration_s (float, optional): The final duration (in seconds) of the augmented file. - Set to None to let spectrogram represent the clips actual duration. - max_start_time_from_right_s (float, optional): The maximum time (in seconds) that the clip should start from the right. - Only used if augmented_duration_s is set. - """ - - def __init__( - self, - input_path, - input_glob, - impulse_paths=None, - background_paths=None, - augmentation_probabilities: dict = { - "SevenBandParametricEQ": 0.25, - "TanhDistortion": 0.25, - "PitchShift": 0.25, - "BandStopFilter": 0.25, - "AddBackgroundNoise": 0.75, - "Gain": 1.0, - "RIR": 0.5, - }, - min_clip_duration_s=None, - max_clip_duration_s=None, - augmented_duration_s=None, - max_start_time_from_right_s=None, - ): - ####################### - # Setup augmentations # - ####################### - - # If either the the background_paths or impulse_paths are not specified, use an identity transform instead - def identity_transform(samples, sample_rate): - return samples - - background_noise_augment = audiomentations.Lambda( - transform=identity_transform, p=0.0 - ) - reverb_augment = audiomentations.Lambda(transform=identity_transform, p=0.0) - - if background_paths is not None: - background_noise_augment = audiomentations.AddBackgroundNoise( - p=augmentation_probabilities["AddBackgroundNoise"], - sounds_path=background_paths, - min_snr_in_db=-10, - max_snr_in_db=15, - ) - - if impulse_paths is not None: - reverb_augment = audiomentations.ApplyImpulseResponse( - p=augmentation_probabilities["RIR"], - ir_path=impulse_paths, - ) - - # Based on openWakeWord's augmentations, accessed on February 23, 2024. - self.augment = audiomentations.Compose( - transforms=[ - audiomentations.SevenBandParametricEQ( - p=augmentation_probabilities["SevenBandParametricEQ"], - min_gain_db=-6, - max_gain_db=6, - ), - audiomentations.TanhDistortion( - p=augmentation_probabilities["TanhDistortion"], - min_distortion=0.0001, - max_distortion=0.10, - ), - audiomentations.PitchShift( - p=augmentation_probabilities["PitchShift"], - min_semitones=-3, - max_semitones=3, - ), - audiomentations.BandStopFilter( - p=augmentation_probabilities["BandStopFilter"] - ), - background_noise_augment, - audiomentations.Gain( - p=augmentation_probabilities["Gain"], - min_gain_in_db=-12, - max_gain_in_db=0, - ), - reverb_augment, - ] - ) - - ##################################################### - # Load clips and optionally filter them by duration # - ##################################################### - if min_clip_duration_s is None: - min_clip_duration_s = 0 - - if max_clip_duration_s is None: - max_clip_duration_s = float("inf") - - if max_start_time_from_right_s is not None: - max_clip_duration_s = min( - max_clip_duration_s, max_start_time_from_right_s - ) - - if augmented_duration_s is not None: - max_clip_duration_s = min(max_clip_duration_s, augmented_duration_s) - - self.max_start_time_from_right_s = max_start_time_from_right_s - self.augmented_duration_s = augmented_duration_s - - if (self.max_start_time_from_right_s is not None) and ( - self.augmented_duration_s is None - ): - raise ValueError( - "max_start_time_from_right_s cannot be specified if augmented_duration_s is not configured." - ) - - if ( - (self.max_start_time_from_right_s is not None) - and (self.augmented_duration_s is not None) - and (self.max_start_time_from_right_s > self.augmented_duration_s) - ): - raise ValueError( - "max_start_time_from_right_s cannot be greater than augmented_duration_s." - ) - - if self.augmented_duration_s is not None: - self.desired_samples = int(augmented_duration_s * 16000) - else: - self.desired_samples = None - - paths_to_clips = [str(i) for i in Path(input_path).glob(input_glob)] - - # Filter audio clips by length - if input_glob.endswith("wav"): - # If it is a wave file, assume all wave files have the same parameters and filter by file size. - # Based on openWakeWord's estimate_clip_duration and filter_audio_paths in data.py, accessed March 2, 2024. - with wave.open(paths_to_clips[0], "rb") as input_wav: - channels = input_wav.getnchannels() - sample_width = input_wav.getsampwidth() - sample_rate = input_wav.getframerate() - frames = input_wav.getnframes() - - if (min_clip_duration_s > 0) or (not math.isinf(max_clip_duration_s)): - sizes = [] - sizes.extend([os.path.getsize(i) for i in paths_to_clips]) - - # Correct for the wav file header bytes. Assumes all files in the directory have same parameters. - header_correction = ( - os.path.getsize(paths_to_clips[0]) - - frames * sample_width * channels - ) - - durations = [] - for size in sizes: - durations.append( - (size - header_correction) - / (sample_rate * sample_width * channels) - ) - - filtered_paths = [ - path_to_clip - for path_to_clip, duration in zip(paths_to_clips, durations) - if (min_clip_duration_s < duration) - and (duration < max_clip_duration_s) - ] - else: - filtered_paths = paths_to_clips - else: - # If not a wave file, use the audio_metadata package to analyze audio file headers for the duration. - # This is slower! - filtered_paths = [] - - if (min_clip_duration_s > 0) or (not math.isinf(max_clip_duration_s)): - for audio_file in paths_to_clips: - metadata = audio_metadata.load(audio_file) - duration = metadata["streaminfo"]["duration"] - if (min_clip_duration_s < duration) and ( - duration < max_clip_duration_s - ): - filtered_paths.append(audio_file) - else: - filtered_paths = paths_to_clips - - # Load all filtered clips - audio_dataset = datasets.Dataset.from_dict( - {"audio": [str(i) for i in filtered_paths]} - ).cast_column("audio", datasets.Audio()) - - # Convert all clips to 16 kHz sampling rate when accessed - audio_dataset = audio_dataset.cast_column( - "audio", datasets.Audio(sampling_rate=16000) - ) - self.clips = audio_dataset - - def augment_clip(self, input_audio): - """Augments the input audio, optionally creating a fixed sized clip first. - - Args: - input_audio (ndarray): audio data with sample rate 16 kHz and 16-bit samples - - Returns: - (ndarray): the augmented audio with sample rate 16 kHz and 16-bit samples - """ - if self.augmented_duration_s is not None: - input_audio = self.create_fixed_size_clip(input_audio) - output_audio = self.augment(input_audio, sample_rate=16000) - - return (output_audio * 32767).astype(np.int16) - - def augment_random_clip(self): - """Augments a random loaded clip. - - Returns: - (ndarray): a random clip's augmented audio with sample rate 16 kHz and 16-bit samples - """ - rand_audio = random.choice(self.clips) - return self.augment_clip(rand_audio["audio"]["array"]) - - def save_random_augmented_clip(self, output_file): - """Saves a random augmented clip. - - Args: - output_file (str): file name to save augmented clip to with sample rate 16 kHz and 16-bit samples - """ - augmented_audio = self.augment_random_clip() - with wave.open(output_file, "wb") as output_wav_file: - output_wav_file.setframerate(16000) - output_wav_file.setsampwidth(2) - output_wav_file.setnchannels(1) - output_wav_file.writeframes(augmented_audio) - - def generate_augmented_spectrogram(self, input_audio): - """Generates the spectrogram of the input audio after augmenting. - - Args: - input_audio (ndarray): audio data with sample rate 16 kHz and 16-bit samples - - Returns: - (ndarray): the spectrogram of the augmented audio - """ - augmented_audio = self.augment_clip(input_audio) - return generate_features_for_clip(augmented_audio) - - def generate_random_augmented_feature(self): - """Generates the spectrogram of a random audio clip after augmenting. - - Returns: - (ndarray): the spectrogram of the augmented audio from a random clip - """ - rand_augmented_clip = self.augment_random_clip() - return self.generate_augmented_feature(rand_augmented_clip) - - def augmented_features_generator(self): - """Generator function for augmenting all loaded clips and computing their spectrograms - - Yields: - (ndarray): the spectrogram of an augmented audio clip - """ - for clip in self.clips: - audio = clip["audio"]["array"] - - yield self.generate_augmented_spectrogram(audio) - - def save_augmented_features(self, mmap_output_dir): - """Saves all augmented features in a RaggedMmap format - - Args: - mmap_output_dir (str): Path to saved the RaggedMmap data - """ - RaggedMmap.from_generator( - out_dir=mmap_output_dir, - sample_generator=self.augmented_features_generator(), - batch_size=10, - verbose=True, - ) - - def create_fixed_size_clip(self, x, sr=16000): - """Create a fixed-length clip with self.desired_samples samples. - - If self.augmented_duration_s and self.max_start_time_from_right_s are specified, the entire clip is - inserted randomly up to self.max_start_time_from_right_s duration from the right. - - Based on openWakeWord's data.py create_fixed_size_clip function, accessed on February 23, 2024 - - Args: - x (ndarray): The input audio to pad to a fixed size - sr (int): The sample rate of the audio - - Returns: - ndarray: A new array of audio data of the specified length - """ - - dat = np.zeros(self.desired_samples) - - if self.max_start_time_from_right_s is not None: - max_samples_from_end = int(self.max_start_time_from_right_s * sr) - else: - max_samples_from_end = self.desired_samples - - assert max_samples_from_end > len(x) - - samples_from_end = np.random.randint(len(x), max_samples_from_end) + 1 - - dat[-samples_from_end : -samples_from_end + len(x)] = x - - return dat diff --git a/microwakeword/inception.py b/microwakeword/inception.py index 57a0992..e25b40b 100644 --- a/microwakeword/inception.py +++ b/microwakeword/inception.py @@ -250,7 +250,7 @@ def model(flags, shape, batch_size): net = input_audio # [batch, time, feature] - net = tf.keras.backend.expand_dims(net, axis=2) + net = tf.keras.ops.expand_dims(net, axis=2) # [batch, time, 1, feature] for filters, kernel_size, subgroups in zip( diff --git a/microwakeword/inference.py b/microwakeword/inference.py index 91bd349..3a9c910 100644 --- a/microwakeword/inference.py +++ b/microwakeword/inference.py @@ -19,7 +19,7 @@ # imports import numpy as np import tensorflow as tf -from microwakeword.feature_generation import generate_features_for_clip +from microwakeword.audio.audio_utils import generate_features_for_clip class Model: @@ -27,10 +27,11 @@ class Model: Class for loading and running tflite microwakeword models Args: - tflite_model_path (str): path to tflite model file + tflite_model_path (str): Path to tflite model file. + stride (int | None, optional): Time dimension's stride. If None, then the stride is the input tensor's time dimension. Defaults to None. """ - def __init__(self, tflite_model_path): + def __init__(self, tflite_model_path: str, stride: int | None = None): # Load tflite model interpreter = tf.lite.Interpreter( model_path=tflite_model_path, @@ -43,6 +44,11 @@ def __init__(self, tflite_model_path): self.is_quantized_model = self.input_details[0]["dtype"] == np.int8 self.input_feature_slices = self.input_details[0]["shape"][1] + if stride is None: + self.stride = self.input_feature_slices + else: + self.stride = stride + for s in range(len(self.input_details)): if self.is_quantized_model: interpreter.set_tensor( @@ -57,42 +63,51 @@ def __init__(self, tflite_model_path): self.model = interpreter - def predict_clip(self, data): + def predict_clip(self, data: np.ndarray, step_ms: int = 20): """Run the model on a single clip of audio data Args: - data (np.ndarray): input data for the model (16 khz, 16-bit PCM audio data) + data (numpy.ndarray): input data for the model (16 khz, 16-bit PCM audio data) + step_ms (int): The window step sized used for generating the spectrogram in ms. Defaults to 20. Returns: list: model predictions for the input audio data """ # Get the spectrogram - spec = generate_features_for_clip(data) + spectrogram = generate_features_for_clip(data, stride_ms=step_ms) - return self.predict_spectrogram(spec) + return self.predict_spectrogram(spectrogram) - def predict_spectrogram(self, spec): - """Run the model on a single clip of audio data + def predict_spectrogram(self, spectrogram: np.ndarray): + """Run the model on a single spectrogram Args: - spec (np.ndarray): input spectrogram + spectrogram (numpy.ndarray): Input spectrogram. Returns: list: model predictions for the input audio data """ + # Spectrograms with type np.uint16 haven't been scaled + if np.issubdtype(spectrogram.dtype, np.uint16): + spectrogram = spectrogram.astype(np.float32) * 0.0390625 + elif np.issubdtype(spectrogram.dtype, np.float64): + spectrogram = spectrogram.astype(np.float32) + # Slice the input data into the required number of chunks chunks = [] - for i in range(0, len(spec), self.input_feature_slices): - chunk = spec[i : i + self.input_feature_slices] + for last_index in range( + self.input_feature_slices, len(spectrogram) + 1, self.stride + ): + chunk = spectrogram[last_index - self.input_feature_slices : last_index] if len(chunk) == self.input_feature_slices: chunks.append(chunk) # Get the prediction for each chunk predictions = [] for chunk in chunks: - if self.is_quantized_model: + if self.is_quantized_model and spectrogram.dtype != np.int8: chunk = self.quantize_input_data(chunk, self.input_details[0]) self.model.set_tensor( @@ -109,15 +124,15 @@ def predict_spectrogram(self, spec): return predictions - def quantize_input_data(self, data, input_details) -> np.ndarray: + def quantize_input_data(self, data: np.ndarray, input_details: dict) -> np.ndarray: """quantize the input data using scale and zero point Args: - data (np.array in float): input data for the interpreter - input_details : output of get_input_details from the tflm interpreter. + data (numpy.array in float): input data for the interpreter + input_details (dict): output of get_input_details from the tflm interpreter. Returns: - np.ndarray: quantized data as int8 dtype + numpy.ndarray: quantized data as int8 dtype """ # Get input quantization parameters data_type = input_details["dtype"] @@ -137,18 +152,19 @@ def dequantize_output_data( """Dequantize the model output Args: - data: integer data to be dequantized - output_details: TFLM interpreter model output details + data (numpy.ndarray): integer data to be dequantized + output_details (dict): TFLM interpreter model output details Returns: - np.ndarray: dequantized data as float32 dtype + numpy.ndarray: dequantized data as float32 dtype """ output_quantization_parameters = output_details["quantization_parameters"] - output_scale = output_quantization_parameters["scales"][0] + output_scale = 255.0 # assume (u)int8 quantization output_zero_point = output_quantization_parameters["zero_points"][0] # Caveat: tflm_output_quant need to be converted to float to avoid integer # overflow during dequantization # e.g., (tflm_output_quant -output_zero_point) and # (tflm_output_quant + (-output_zero_point)) # can produce different results (int8 calculation) - return output_scale * (data.astype(np.float32) - output_zero_point) + # return output_scale * (data.astype(np.float32) - output_zero_point) + return 1 / output_scale * (data.astype(np.float32) - output_zero_point) diff --git a/microwakeword/layers/delay.py b/microwakeword/layers/delay.py index 25d7377..9cfaea5 100644 --- a/microwakeword/layers/delay.py +++ b/microwakeword/layers/delay.py @@ -112,7 +112,7 @@ def get_config(self): return config def _streaming_internal_state(self, inputs): - memory = tf.keras.backend.concatenate([self.states, inputs], 1) + memory = tf.keras.layers.concatenate([self.states, inputs], 1) outputs = memory[:, : inputs.shape.as_list()[1]] new_memory = memory[:, -self.delay :] assign_states = self.states.assign(new_memory) @@ -121,7 +121,7 @@ def _streaming_internal_state(self, inputs): return tf.identity(outputs) def _streaming_external_state(self, inputs, states): - memory = tf.keras.backend.concatenate([states, inputs], 1) + memory = tf.keras.layers.concatenate([states, inputs], 1) outputs = memory[:, : inputs.shape.as_list()[1]] new_memory = memory[:, -self.delay :] return outputs, new_memory diff --git a/microwakeword/layers/modes.py b/microwakeword/layers/modes.py index 7e74206..6cb8433 100644 --- a/microwakeword/layers/modes.py +++ b/microwakeword/layers/modes.py @@ -59,5 +59,6 @@ def get_input_data_shape(config, mode): if mode in (Modes.TRAINING, Modes.NON_STREAM_INFERENCE): data_shape = (config["spectrogram_length"], 40) else: - data_shape = (1, 40) + stride = config['stride'] + data_shape = (stride, 40) return data_shape diff --git a/microwakeword/layers/stream.py b/microwakeword/layers/stream.py index 4d82670..37b7770 100644 --- a/microwakeword/layers/stream.py +++ b/microwakeword/layers/stream.py @@ -559,9 +559,9 @@ def _streaming_internal_state(self, inputs): with tf.control_dependencies([assign_states]): if self.transposed_conv_crop_output: - return tf.identity(outputs[:, 0 : self.output_time_dim, :]) + return tf.keras.layers.Identity()(outputs[:, 0 : self.output_time_dim, :]) else: - return tf.identity(outputs) + return tf.keras.layers.Identity()(outputs) else: if self.use_one_step: # The time dimenstion always has to equal 1 in streaming mode. @@ -572,7 +572,7 @@ def _streaming_internal_state(self, inputs): memory = self.states[:, 1 : self.ring_buffer_size_in_time_dim, :] # add new row [batch_size, memory_size, feature_dim, channel] - memory = tf.keras.backend.concatenate([memory, inputs], 1) + memory = tf.keras.layers.concatenate([memory, inputs], 1) assign_states = self.states.assign(memory) @@ -581,7 +581,7 @@ def _streaming_internal_state(self, inputs): else: # add new row [batch_size, memory_size, feature_dim, channel] if self.ring_buffer_size_in_time_dim: - memory = tf.keras.backend.concatenate([self.states, inputs], 1) + memory = tf.keras.layers.concatenate([self.states, inputs], 1) state_update = memory[ :, -self.ring_buffer_size_in_time_dim :, : @@ -636,13 +636,13 @@ def _streaming_external_state(self, inputs, state): memory = state[:, 1 : self.ring_buffer_size_in_time_dim, :] # add new row [batch_size, memory_size, feature_dim, channel] - memory = tf.keras.backend.concatenate([memory, inputs], 1) + memory = tf.keras.layers.concatenate([memory, inputs], 1) output = self.cell(memory) return output, memory else: # add new row [batch_size, memory_size, feature_dim, channel] - memory = tf.keras.backend.concatenate([state, inputs], 1) + memory = tf.keras.layers.concatenate([state, inputs], 1) state_update = memory[ :, -self.ring_buffer_size_in_time_dim :, : diff --git a/microwakeword/layers/strided_drop.py b/microwakeword/layers/strided_drop.py index f3b3145..a0dbc1e 100644 --- a/microwakeword/layers/strided_drop.py +++ b/microwakeword/layers/strided_drop.py @@ -17,6 +17,7 @@ from microwakeword.layers import modes + class StridedDrop(tf.keras.layers.Layer): """StridedDrop @@ -34,6 +35,7 @@ def __init__( super(StridedDrop, self).__init__(**kwargs) self.time_slices_to_drop = time_slices_to_drop self.mode = mode + self.state_shape = [] def call(self, inputs): if self.mode == modes.Modes.NON_STREAM_INFERENCE: @@ -48,3 +50,49 @@ def get_config(self): } base_config = super(StridedDrop, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + def get_input_state(self): + return [] + + def get_output_state(self): + return [] + + +class StridedKeep(tf.keras.layers.Layer): + """StridedKeep + + Keeps the specified audio feature slices in streaming mode only. + Used for splitting a single streaming ring buffer into multiple branches with minimal overhead. + + Attributes: + time_sclices_to_keep: number of audio feature slices to keep + mode: inference mode; e.g., non-streaming, internal streaming + """ + + def __init__( + self, time_slices_to_keep, mode=modes.Modes.NON_STREAM_INFERENCE, **kwargs + ): + super(StridedKeep, self).__init__(**kwargs) + self.time_slices_to_keep = max(time_slices_to_keep, 1) + self.mode = mode + self.state_shape = [] + + def call(self, inputs): + if self.mode != modes.Modes.NON_STREAM_INFERENCE: + return inputs[:, -self.time_slices_to_keep :, :, :] + + return inputs + + def get_config(self): + config = { + "time_slices_to_keep": self.time_slices_to_keep, + "mode": self.mode, + } + base_config = super(StridedKeep, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def get_input_state(self): + return [] + + def get_output_state(self): + return [] diff --git a/microwakeword/mixednet.py b/microwakeword/mixednet.py new file mode 100644 index 0000000..076d467 --- /dev/null +++ b/microwakeword/mixednet.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Copyright 2024 Kevin Ahrendt. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model based on 1D depthwise MixedConvs and 1x1 convolutions in time + residual.""" + +from microwakeword.layers import stream +from microwakeword.layers import strided_drop + +import ast +import tensorflow as tf + + +def parse(text): + """Parse model parameters. + + Args: + text: string with layer parameters: '128,128' or "'relu','relu'". + + Returns: + list of parsed parameters + """ + if not text: + return [] + res = ast.literal_eval(text) + if isinstance(res, tuple): + return res + else: + return [res] + + +def model_parameters(parser_nn): + """MixedNet model parameters.""" + + parser_nn.add_argument( + "--pointwise_filters", + type=str, + default="48, 48, 48, 48", + help="Number of filters in every MixConv block's pointwise convolution", + ) + parser_nn.add_argument( + "--residual_connection", + type=str, + default="0,0,0,0,0", + help="Use a residual connection in each MixConv block", + ) + parser_nn.add_argument( + "--repeat_in_block", + type=str, + default="1,1,1,1", + help="Number of repeating conv blocks inside of residual block", + ) + parser_nn.add_argument( + "--mixconv_kernel_sizes", + type=str, + default="[5], [9], [13], [21]", + help="Kernel size lists for DepthwiseConv1D in time dim for every MixConv block", + ) + parser_nn.add_argument( + "--max_pool", + type=int, + default=0, + help="apply max pool instead of average pool before final convolution and sigmoid activation", + ) + parser_nn.add_argument( + "--first_conv_filters", + type=int, + default=32, + help="Number of filters on initial convolution layer. Set to 0 to disable.", + ) + parser_nn.add_argument( + "--first_conv_kernel_size", + type=int, + default="3", + help="Temporal kernel size for the initial convolution layer.", + ) + parser_nn.add_argument( + "--spatial_attention", + type=int, + default=0, + help="Add a spatial attention layer before the final pooling layer", + ) + parser_nn.add_argument( + "--pooled", + type=int, + default=0, + help="Pool the temporal dimension before the final fully connected layer. Uses average pooling or max pooling depending on the max_pool argument", + ) + parser_nn.add_argument( + "--stride", + type=int, + default=1, + help="Striding in the time dimension of the initial convolution layer", + ) + + +def spectrogram_slices_dropped(flags): + """Computes the number of spectrogram slices dropped due to valid padding. + + Args: + flags: data/model parameters + + Returns: + int: number of spectrogram slices dropped + """ + spectrogram_slices_dropped = 0 + + if flags.first_conv_filters > 0: + spectrogram_slices_dropped += flags.first_conv_kernel_size - 1 + + for repeat, ksize in zip( + parse(flags.repeat_in_block), + parse(flags.mixconv_kernel_sizes), + ): + spectrogram_slices_dropped += (repeat * (max(ksize) - 1)) * flags.stride + + # spectrogram_slices_dropped *= flags.stride + return spectrogram_slices_dropped + + +def _split_channels(total_filters, num_groups): + """Helper for MixConv""" + split = [total_filters // num_groups for _ in range(num_groups)] + split[0] += total_filters - sum(split) + return split + + +def _get_shape_value(maybe_v2_shape): + """Helper for MixConv""" + if maybe_v2_shape is None: + return None + elif isinstance(maybe_v2_shape, int): + return maybe_v2_shape + else: + return maybe_v2_shape.value + + +class ChannelSplit(tf.keras.layers.Layer): + def __init__(self, splits, axis=-1, **kwargs): + super().__init__(**kwargs) + self.splits = splits + self.axis = axis + + def call(self, inputs): + return tf.split(inputs, self.splits, axis=self.axis) + + def compute_output_shape(self, input_shape): + output_shapes = [] + for split in self.splits: + new_shape = list(input_shape) + new_shape[self.axis] = split + output_shapes.append(tuple(new_shape)) + return output_shapes + + + +class MixConv: + """MixConv with mixed depthwise convolutional kernels. + + MDConv is an improved depthwise convolution that mixes multiple kernels (e.g. + 3x1, 5x1, etc). Right now, we use an naive implementation that split channels + into multiple groups and perform different kernels for each group. + + See Mixnet paper for more details. + """ + + def __init__(self, kernel_size, **kwargs): + """Initialize the layer. + + Most of args are the same as tf.keras.layers.DepthwiseConv2D. + + Args: + kernel_size: An integer or a list. If it is a single integer, then it is + same as the original tf.keras.layers.DepthwiseConv2D. If it is a list, + then we split the channels and perform different kernel for each group. + strides: An integer or tuple/list of 2 integers, specifying the strides of + the convolution along the height and width. + **kwargs: other parameters passed to the original depthwise_conv layer. + """ + self._channel_axis = -1 + + self.ring_buffer_length = max(kernel_size) - 1 + + self.kernel_sizes = kernel_size + + def __call__(self, inputs): + # We manually handle the streaming ring buffer for each layer + # - There is some latency overhead on the esp devices for loading each ring buffer's data + # - This avoids variable's holding redundant information + # - Reduces the necessary size of the tensor arena + net = stream.Stream( + cell=tf.keras.layers.Identity(), + ring_buffer_size_in_time_dim=self.ring_buffer_length, + use_one_step=False, + )(inputs) + + if len(self.kernel_sizes) == 1: + return tf.keras.layers.DepthwiseConv2D( + (self.kernel_sizes[0], 1), strides=1, padding="valid" + )(net) + + filters = _get_shape_value(net.shape[self._channel_axis]) + splits = _split_channels(filters, len(self.kernel_sizes)) + x_splits = ChannelSplit(splits, axis=self._channel_axis)(net) + + x_outputs = [] + for x, ks in zip(x_splits, self.kernel_sizes): + fit = strided_drop.StridedKeep(ks)(x) + x_outputs.append( + tf.keras.layers.DepthwiseConv2D((ks, 1), strides=1, padding="valid")( + fit + ) + ) + + for i, output in enumerate(x_outputs): + features_drop = output.shape[1] - x_outputs[-1].shape[1] + x_outputs[i] = strided_drop.StridedDrop(features_drop)(output) + + x = tf.keras.layers.concatenate(x_outputs, axis=self._channel_axis) + return x + + +class SpatialAttention(tf.keras.layers.Layer): + """Spatial Attention Layer based on CBAM: Convolutional Block Attention Module + https://arxiv.org/pdf/1807.06521v2 + + Args: + object (_type_): _description_ + """ + + def __init__(self, kernel_size, ring_buffer_size, **kwargs): + super().__init__(**kwargs) + + self.kernel_size = kernel_size + self.ring_buffer_size = ring_buffer_size + + def call(self, inputs): + tranposed = tf.transpose(inputs, perm=[0, 1, 3, 2]) + channel_avg = tf.keras.layers.AveragePooling2D( + pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2]) + )(tranposed) + channel_max = tf.keras.layers.MaxPooling2D( + pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2]) + )(tranposed) + pooled = tf.keras.layers.Concatenate(axis=-1)([channel_avg, channel_max]) + attention = stream.Stream( + cell=tf.keras.layers.Conv2D( + 1, + (self.kernel_size, 1), + strides=(1, 1), + padding="valid", + use_bias=False, + activation="sigmoid", + ), + use_one_step=False, + )(pooled) + + net = stream.Stream( + cell=tf.keras.layers.Identity(), + ring_buffer_size_in_time_dim=self.ring_buffer_size, + use_one_step=False, + )(inputs) + net = net[:, -attention.shape[1] :, :, :] + + return net * attention + + def get_config(self): + return { + "kernel_size": self.kernel_size, + "ring_buffer_size": self.ring_buffer_size, + } + + +def model(flags, shape, batch_size): + """MixedNet model. + + It is based on the paper + MixConv: Mixed Depthwise Convolutional Kernels + https://arxiv.org/abs/1907.09595 + Args: + flags: data/model parameters + shape: shape of the input vector + config: dictionary containing microWakeWord training configuration + + Returns: + Keras model for training + """ + + pointwise_filters = parse(flags.pointwise_filters) + repeat_in_block = parse(flags.repeat_in_block) + mixconv_kernel_sizes = parse(flags.mixconv_kernel_sizes) + residual_connections = parse(flags.residual_connection) + + for list in ( + pointwise_filters, + repeat_in_block, + mixconv_kernel_sizes, + residual_connections, + ): + if len(pointwise_filters) != len(list): + raise ValueError("all input lists have to be the same length") + + input_audio = tf.keras.layers.Input( + shape=shape, + batch_size=batch_size, + ) + net = input_audio + + # make it [batch, time, 1, feature] + net = tf.keras.ops.expand_dims(net, axis=2) + + # Streaming Conv2D with 'valid' padding + if flags.first_conv_filters > 0: + net = stream.Stream( + cell=tf.keras.layers.Conv2D( + flags.first_conv_filters, + (flags.first_conv_kernel_size, 1), + strides=(flags.stride, 1), + padding="valid", + use_bias=False, + ), + use_one_step=False, + pad_time_dim=None, + pad_freq_dim="valid", + )(net) + + net = tf.keras.layers.Activation("relu")(net) + + # encoder + for filters, repeat, ksize, res in zip( + pointwise_filters, + repeat_in_block, + mixconv_kernel_sizes, + residual_connections, + ): + if res: + residual = tf.keras.layers.Conv2D( + filters=filters, kernel_size=1, use_bias=False, padding="same" + )(net) + residual = tf.keras.layers.BatchNormalization()(residual) + + for _ in range(repeat): + if max(ksize) > 1: + net = MixConv(kernel_size=ksize)(net) + net = tf.keras.layers.Conv2D( + filters=filters, kernel_size=1, use_bias=False, padding="same" + )(net) + net = tf.keras.layers.BatchNormalization()(net) + + if res: + residual = strided_drop.StridedDrop(residual.shape[1] - net.shape[1])( + residual + ) + net = net + residual + + net = tf.keras.layers.Activation("relu")(net) + + if net.shape[1] > 1: + if flags.spatial_attention: + net = SpatialAttention(4, net.shape[1] - 1)(net) + else: + net = stream.Stream( + cell=tf.keras.layers.Identity(), + ring_buffer_size_in_time_dim=net.shape[1] - 1, + use_one_step=False, + )(net) + + if flags.pooled: + # We want to use either Global Max Pooling or Global Average Pooling, but the esp-nn operator optimizations only benefit regular pooling operations + + if flags.max_pool: + net = tf.keras.layers.MaxPooling2D(pool_size=(net.shape[1], 1))(net) + else: + net = tf.keras.layers.AveragePooling2D(pool_size=(net.shape[1], 1))(net) + + net = tf.keras.layers.Flatten()(net) + net = tf.keras.layers.Dense(1, activation="sigmoid")(net) + + return tf.keras.Model(input_audio, net) diff --git a/microwakeword/model_train_eval.py b/microwakeword/model_train_eval.py index 9d61c8b..d359508 100644 --- a/microwakeword/model_train_eval.py +++ b/microwakeword/model_train_eval.py @@ -16,15 +16,28 @@ import argparse import os +import sys import yaml +import platform from absl import logging +import tensorflow as tf + +# Disable GPU by default on ARM Macs, it's slower than just using the CPU +if os.environ.get("CUDA_VISIBLE_DEVICES") == "-1" or ( + sys.platform == "darwin" + and platform.processor() == "arm" + and "CUDA_VISIBLE_DEVICES" not in os.environ +): + tf.config.set_visible_devices([], "GPU") + import microwakeword.data as input_data import microwakeword.train as train import microwakeword.test as test import microwakeword.utils as utils import microwakeword.inception as inception +import microwakeword.mixednet as mixednet from microwakeword.layers import modes @@ -40,23 +53,25 @@ def load_config(flags, model_module): dict: dictionary containing training configuration """ config_filename = flags.training_config + config = yaml.load(open(config_filename, "r").read(), yaml.Loader) + + config["summaries_dir"] = os.path.join(config["train_dir"], "logs/") + + config["stride"] = flags.__dict__.get("stride", 1) + config["window_step_ms"] = config.get("window_step_ms", 20) # Default preprocessor settings preprocessor_sample_rate = 16000 # Hz preprocessor_window_size = 30 # ms - preprocessor_window_stride = 20 # ms - - config = yaml.load(open(config_filename, "r").read(), yaml.Loader) - - config["summaries_dir"] = os.path.join(config["train_dir"], "logs/") + preprocessor_window_step = config["window_step_ms"] # ms desired_samples = int(preprocessor_sample_rate * config["clip_duration_ms"] / 1000) window_size_samples = int( preprocessor_sample_rate * preprocessor_window_size / 1000 ) - window_stride_samples = int( - preprocessor_sample_rate * preprocessor_window_stride / 1000 + window_step_samples = int( + config["stride"] * preprocessor_sample_rate * preprocessor_window_step / 1000 ) length_minus_window = desired_samples - window_size_samples @@ -65,7 +80,7 @@ def load_config(flags, model_module): config["spectrogram_length_final_layer"] = 0 else: config["spectrogram_length_final_layer"] = 1 + int( - length_minus_window / window_stride_samples + length_minus_window / window_step_samples ) config["spectrogram_length"] = config[ @@ -96,7 +111,7 @@ def train_model(config, model, data_processor, restore_checkpoint): try: os.makedirs(config["train_dir"]) os.mkdir(config["summaries_dir"]) - except OSError as e: + except OSError: if restore_checkpoint: pass else: @@ -119,6 +134,7 @@ def evaluate_model( data_processor, test_tf_nonstreaming, test_tflite_nonstreaming, + test_tflite_nonstreaming_quantized, test_tflite_streaming, test_tflite_streaming_quantized, ): @@ -132,19 +148,25 @@ def evaluate_model( model (Keras model): model (with loaded weights) to test data_processor (FeatureHandler): feature handler that loads spectrogram data test_tf_nonstreaming (bool): Evaluate the nonstreaming SavedModel + test_tflite_nonstreaming_quantized (bool): Convert and evaluate quantized nonstreaming TFLite model test_tflite_nonstreaming (bool): Convert and evaluate nonstreaming TFLite model test_tflite_streaming (bool): Convert and evaluate streaming TFLite model test_tflite_streaming_quantized (bool): Convert and evaluate quantized streaming TFLite model """ - if test_tf_nonstreaming or test_tflite_nonstreaming: + + if ( + test_tf_nonstreaming + or test_tflite_nonstreaming + or test_tflite_nonstreaming_quantized + ): # Save the nonstreaming model to disk logging.info("Saving nonstreaming model") utils.convert_model_saved( model, config, - "non_stream", - modes.Modes.NON_STREAM_INFERENCE, + folder="non_stream", + mode=modes.Modes.NON_STREAM_INFERENCE, ) if test_tflite_streaming or test_tflite_streaming_quantized: @@ -154,8 +176,8 @@ def evaluate_model( utils.convert_model_saved( model, config, - "stream_state_internal", - modes.Modes.STREAM_INTERNAL_STATE_INFERENCE, + folder="stream_state_internal", + mode=modes.Modes.STREAM_INTERNAL_STATE_INFERENCE, ) if test_tf_nonstreaming: @@ -170,75 +192,86 @@ def evaluate_model( accuracy_name="testing_set_metrics.txt", ) - tflite_log_strings = [] - tflite_source_folders = [] - tflite_output_folders = [] - tflite_filenames = [] - tflite_testing_datasets = [] - tflite_quantize = [] + tflite_configs = [] if test_tflite_nonstreaming: - tflite_log_strings.append("nonstreaming model") - tflite_source_folders.append("non_stream") - tflite_output_folders.append("tflite_non_stream") - tflite_filenames.append("non_stream.tflite") - tflite_testing_datasets.append(["testing"]) - tflite_quantize.append(False) + tflite_configs.append( + { + "log_string": "nonstreaming model", + "source_folder": "non_stream", + "output_folder": "tflite_non_stream", + "filename": "non_stream.tflite", + "testing_dataset": "testing", + "testing_ambient_dataset": "testing_ambient", + "quantize": False, + } + ) + + if test_tflite_nonstreaming_quantized: + tflite_configs.append( + { + "log_string": "quantized nonstreaming model", + "source_folder": "non_stream", + "output_folder": "tflite_non_stream_quant", + "filename": "non_stream_quant.tflite", + "testing_dataset": "testing", + "testing_ambient_dataset": "testing_ambient", + "quantize": True, + } + ) if test_tflite_streaming: - tflite_log_strings.append("streaming model") - tflite_source_folders.append("stream_state_internal") - tflite_output_folders.append("tflite_stream_state_internal") - tflite_filenames.append("stream_state_internal.tflite") - tflite_testing_datasets.append(["testing", "testing_ambient"]) - tflite_quantize.append(False) + tflite_configs.append( + { + "log_string": "streaming model", + "source_folder": "stream_state_internal", + "output_folder": "tflite_stream_state_internal", + "filename": "stream_state_internal.tflite", + "testing_dataset": "testing", + "testing_ambient_dataset": "testing_ambient", + "quantize": False, + } + ) if test_tflite_streaming_quantized: - tflite_log_strings.append("quantized streaming model") - tflite_source_folders.append("stream_state_internal") - tflite_output_folders.append("tflite_stream_state_internal_quant") - tflite_filenames.append("stream_state_internal_quant.tflite") - tflite_testing_datasets.append(["testing", "testing_ambient"]) - tflite_quantize.append(True) - - for ( - log_string, - source_folder, - output_folder, - filename, - testing_datasets, - quantize, - ) in zip( - tflite_log_strings, - tflite_source_folders, - tflite_output_folders, - tflite_filenames, - tflite_testing_datasets, - tflite_quantize, - ): - logging.info("Converting " + log_string + " to TFLite") + tflite_configs.append( + { + "log_string": "quantized streaming model", + "source_folder": "stream_state_internal", + "output_folder": "tflite_stream_state_internal_quant", + "filename": "stream_state_internal_quant.tflite", + "testing_dataset": "testing", + "testing_ambient_dataset": "testing_ambient", + "quantize": True, + } + ) + + for tflite_config in tflite_configs: + logging.info("Converting %s to TFLite", tflite_config["log_string"]) utils.convert_saved_model_to_tflite( config, - data_processor, - os.path.join(config["train_dir"], source_folder), - os.path.join(config["train_dir"], output_folder), - filename, - quantize=quantize, + audio_processor=data_processor, + path_to_model=os.path.join(config["train_dir"], tflite_config["source_folder"]), + folder=os.path.join(config["train_dir"], tflite_config["output_folder"]), + fname=tflite_config["filename"], + quantize=tflite_config["quantize"], ) - for dataset in testing_datasets: - logging.info( - "Testing the TFLite " + log_string + " on the " + dataset + " set" - ) - test.tflite_model_accuracy( - config, - output_folder, - data_processor, - data_set=dataset, - tflite_model_name=filename, - accuracy_name=dataset + "_set_metrics.txt", - ) + logging.info( + "Testing the TFLite %s false accept per hour and false rejection rates at various cutoffs.", + tflite_config["log_string"], + ) + + test.tflite_streaming_model_roc( + config, + tflite_config["output_folder"], + data_processor, + data_set=tflite_config["testing_dataset"], + ambient_set=tflite_config["testing_ambient_dataset"], + tflite_model_name=tflite_config["filename"], + accuracy_name="tflite_streaming_roc.txt", + ) if __name__ == "__main__": @@ -269,6 +302,12 @@ def evaluate_model( default=0, help="Save the TFLite nonstreaming model and test on the test datasets", ) + parser.add_argument( + "--test_tflite_nonstreaming_quantized", + type=int, + default=0, + help="Save the TFLite quantized nonstreaming model and test on the test datasets", + ) parser.add_argument( "--test_tflite_streaming", type=int, @@ -341,17 +380,23 @@ def verbosity_arg(value): parser_inception = subparsers.add_parser("inception") inception.model_parameters(parser_inception) + # mixednet model settings + parser_mixednet = subparsers.add_parser("mixednet") + mixednet.model_parameters(parser_mixednet) + flags, unparsed = parser.parse_known_args() if unparsed: raise ValueError("Unknown argument: {}".format(unparsed)) - logging.set_verbosity(flags.verbosity) - if flags.model_name == "inception": model_module = inception + elif flags.model_name == "mixednet": + model_module = mixednet else: raise ValueError("Unknown model type: {}".format(flags.model_name)) + logging.set_verbosity(flags.verbosity) + config = load_config(flags, model_module) data_processor = input_data.FeatureHandler(config) @@ -377,8 +422,10 @@ def verbosity_arg(value): ) model.load_weights( - os.path.join(config["train_dir"], flags.use_weights) - ).expect_partial() + os.path.join(config["train_dir"], flags.use_weights) + ".weights.h5" + ) + + logging.info(model.summary()) evaluate_model( config, @@ -386,6 +433,7 @@ def verbosity_arg(value): data_processor, flags.test_tf_nonstreaming, flags.test_tflite_nonstreaming, + flags.test_tflite_nonstreaming_quantized, flags.test_tflite_streaming, flags.test_tflite_streaming_quantized, ) diff --git a/microwakeword/test.py b/microwakeword/test.py index 8417ed9..7052986 100644 --- a/microwakeword/test.py +++ b/microwakeword/test.py @@ -17,11 +17,14 @@ """Test utility functions for accuracy evaluation.""" import os -from absl import logging + import numpy as np import tensorflow as tf +from absl import logging +from typing import List from microwakeword.inference import Model +from numpy.lib.stride_tricks import sliding_window_view def compute_metrics(true_positives, true_negatives, false_positives, false_negatives): @@ -88,6 +91,119 @@ def metrics_to_string(metrics): ) +def compute_false_accepts_per_hour( + streaming_probabilities_list: List[np.ndarray], + cutoffs: np.array, + ignore_slices_after_accept: int = 75, + stride: int = 1, + step_s: float = 0.02, +): + """Computes the false accept per hour rates at various cutoffs given a list of streaming probabilities. + + Args: + streaming_probabilities_list (List[numpy.ndarray]): A list containing streaming probabilities from negative audio clips + cutoffs (numpy.array): An array of cutoffs/thresholds to test the false accpet rate at. + ignore_slices_after_accept (int, optional): The number of probabililities slices to ignore after a false accept. Defaults to 75. + stride (int, optional): The stride of the input layer. Defaults to 1. + step_s (float, optional): The duration between each probabilitiy in seconds. Defaults to 0.02. + + Returns: + numpy.ndarray: The false accepts per hour corresponding to thresholds in `cutoffs`. + """ + cutoffs_count = cutoffs.shape[0] + + false_accepts_at_cutoffs = np.zeros(cutoffs_count) + probabilities_duration_h = 0 + + for track_probabilities in streaming_probabilities_list: + probabilities_duration_h += len(track_probabilities) * stride * step_s / 3600.0 + + cooldown_at_cutoffs = np.ones(cutoffs_count) * ignore_slices_after_accept + + for wakeword_probability in track_probabilities: + # Decrease the cooldown cutoff by 1 with a minimum value of 0 + cooldown_at_cutoffs = np.maximum( + cooldown_at_cutoffs - 1, np.zeros(cutoffs_count) + ) + detection_boolean = ( + wakeword_probability > cutoffs + ) # a list of detection states at each cutoff + + for index in range(cutoffs_count): + if cooldown_at_cutoffs[index] == 0 and detection_boolean[index]: + false_accepts_at_cutoffs[index] += 1 + cooldown_at_cutoffs[index] = ignore_slices_after_accept + + return false_accepts_at_cutoffs / probabilities_duration_h + + +def generate_roc_curve( + false_accepts_per_hour: np.ndarray, + false_rejections: np.ndarray, + # positive_samples_probabilities: np.ndarray, + cutoffs: np.ndarray, + max_faph: float = 2.0, +): + """Generates the coordinates for an ROC curve plotting false accepts per hour vs false rejections. Computes the false rejection rate at the specifiied cutoffs. + + Args: + false_accepts_per_hour (numpy.ndarray): False accepts per hour rates for each threshold in `cutoffs`. + false_rejections (numpy.ndarray): False rejection rates for each threshold in `cutoffs`. + cutoffs (numpy.ndarray): Thresholds used for `false_ccepts_per_hour` + max_faph (float, optional): The maximum false accept per hour rate to include in curve's coordinates. Defaults to 2.0. + + Returns: + (numpy.ndarray, numpy.ndarray, numpy.ndarray): (false accept per hour coordinates, false rejection rate coordinates, cutoffs for each coordinate) + """ + + if false_accepts_per_hour[0] > max_faph: + # Use linear interpolation to estimate false negative rate at max_faph + + # Increase the index until we find a faph less than max_faph + index_of_first_viable = 1 + while false_accepts_per_hour[index_of_first_viable] > max_faph: + index_of_first_viable += 1 + + x0 = false_accepts_per_hour[index_of_first_viable - 1] + y0 = false_rejections[index_of_first_viable - 1] + x1 = false_accepts_per_hour[index_of_first_viable] + y1 = false_rejections[index_of_first_viable - 1] + + fnr_at_max_faph = (y0 * (x1 - 2.0) + y1 * (2.0 - x0)) / (x1 - x0) + cutoff_at_max_faph = ( + cutoffs[index_of_first_viable] + cutoffs[index_of_first_viable - 1] + ) / 2.0 + else: + # Smallest faph is less than max_faph, so assume the false negative rate is constant + index_of_first_viable = 0 + fnr_at_max_faph = false_rejections[index_of_first_viable] + cutoff_at_max_faph = cutoffs[index_of_first_viable] + + horizontal_coordinates = [max_faph] + vertical_coordinates = [fnr_at_max_faph] + cutoffs_at_coordinate = [cutoff_at_max_faph] + + for index in range(index_of_first_viable, len(false_rejections)): + if false_accepts_per_hour[index] != horizontal_coordinates[-1]: + # Only add a point if it is a new faph + # This ensures if a faph rate is repeated, we use the small false negative rate + horizontal_coordinates.append(false_accepts_per_hour[index]) + vertical_coordinates.append(false_rejections[index]) + cutoffs_at_coordinate.append(cutoffs[index]) + + if horizontal_coordinates[-1] > 0: + # If there isn't a cutoff with 0 faph, then add a coordinate at (0,1) + horizontal_coordinates.append(0.0) + vertical_coordinates.append(1.0) + cutoffs_at_coordinate.append(0.0) + + # The points on the curve are listed in descending order, flip them before returning + horizontal_coordinates = np.flip(horizontal_coordinates) + vertical_coordinates = np.flip(vertical_coordinates) + cutoffs_at_coordinate = np.flip(cutoffs_at_coordinate) + return horizontal_coordinates, vertical_coordinates, cutoffs_at_coordinate + + def tf_model_accuracy( config, folder, @@ -97,6 +213,8 @@ def tf_model_accuracy( ): """Function to test a TF model on a specified data set. + NOTE: This assumes the wakeword is at the end of the spectrogram. The ``tflite_streaming_model_roc`` method does not make this assumption, and you may get vastly different results depending on how word is positioned in the spectrogram in the data set. + Arguments: config: dictionary containing microWakeWord training configuration folder: folder containing the TF model @@ -172,6 +290,119 @@ def tf_model_accuracy( return metrics +def tflite_streaming_model_roc( + config, + folder, + audio_processor, + data_set="testing", + ambient_set="testing_ambient", + tflite_model_name="stream_state_internal.tflite", + accuracy_name="tflite_streaming_roc.txt", + sliding_window_length=5, + ignore_slices_after_accept=25, +): + """Function to test a tflite model false accepts per hour and false rejection rates. + + Model can be streaming or nonstreaming. Nonstreaming models are strided by 1 spectrogram feature in the time dimension. + + Args: + config (dict): dictionary containing microWakeWord training configuration + folder (str): folder containing the TFLite model + audio_processor (FeatureHandler): microWakeWord FeatureHandler object for retrieving spectrograms + data_set (str, optional): Dataset for testing recall. Defaults to "testing". + ambient_set (str, optional): Dataset for testing false accepts per hour. Defaults to "testing_ambient". + tflite_model_name (str, optional): filename of the TFLite model. Defaults to "stream_state_internal.tflite". + accuracy_name (str, optional): filename to save metrics at various cutoffs. Defaults to "tflite_streaming_roc.txt". + sliding_window_length (int, optional): the length of the sliding window for computing average probabilities. Defaults to 1. + + Returns: + float: The Area under the false accept per hour vs. false rejection curve. + """ + stride = config["stride"] + model = Model( + os.path.join(config["train_dir"], folder, tflite_model_name), stride=stride + ) + + test_ambient_fingerprints, _, _ = audio_processor.get_data( + ambient_set, + batch_size=config["batch_size"], + features_length=config["spectrogram_length"], + truncation_strategy="none", + ) + + logging.info("Testing the " + ambient_set + " set.") + ambient_streaming_probabilities = [] + for spectrogram_track in test_ambient_fingerprints: + streaming_probabilities = model.predict_spectrogram(spectrogram_track) + sliding_window_probabilities = sliding_window_view( + streaming_probabilities, sliding_window_length + ) + moving_average = sliding_window_probabilities.mean(axis=-1) + ambient_streaming_probabilities.append(moving_average) + + cutoffs = np.arange(0, 1.01, 0.01) + # ignore_slices_after_accept = 25 + + faph = compute_false_accepts_per_hour( + ambient_streaming_probabilities, + cutoffs, + ignore_slices_after_accept, + stride=config["stride"], + step_s=config["window_step_ms"] / 1000, + ) + + test_fingerprints, test_ground_truth, _ = audio_processor.get_data( + data_set, + batch_size=config["batch_size"], + features_length=config["spectrogram_length"], + truncation_strategy="none", + ) + + logging.info("Testing the " + data_set + " set.") + + positive_sample_streaming_probabilities = [] + for i in range(len(test_fingerprints)): + if test_ground_truth[i]: + # Only test positive samples + streaming_probabilities = model.predict_spectrogram(test_fingerprints[i]) + sliding_window_probabilities = sliding_window_view( + streaming_probabilities[ignore_slices_after_accept:], + sliding_window_length, + ) + moving_average = sliding_window_probabilities.mean(axis=-1) + positive_sample_streaming_probabilities.append(np.max(moving_average)) + + # Compute the false negative rates at each cutoff + false_negative_rate_at_cutoffs = [] + for cutoff in cutoffs: + true_accepts = sum(i > cutoff for i in positive_sample_streaming_probabilities) + false_negative_rate_at_cutoffs.append( + 1 - true_accepts / len(positive_sample_streaming_probabilities) + ) + + x_coordinates, y_coordinates, cutoffs_at_points = generate_roc_curve( + false_accepts_per_hour=faph, + false_rejections=false_negative_rate_at_cutoffs, + cutoffs=cutoffs, + ) + + path = os.path.join(config["train_dir"], folder) + with open(os.path.join(path, accuracy_name), "wt") as fd: + auc = np.trapz(y_coordinates, x_coordinates) + auc_string = "AUC {:.5f}".format(auc) + logging.info(auc_string) + fd.write(auc_string + "\n") + + for i in range(0, x_coordinates.shape[0]): + cutoff_string = "Cutoff {:.2f}: frr={:.4f}; faph={:.3f}".format( + cutoffs_at_points[i], y_coordinates[i], x_coordinates[i] + ) + logging.info(cutoff_string) + fd.write(cutoff_string + "\n") + + return auc + + def tflite_model_accuracy( config, folder, @@ -182,6 +413,8 @@ def tflite_model_accuracy( ): """Function to test a TFLite model on a specified data set. + NOTE: This assumes the wakeword is at the end of the spectrogram. The ``tflite_streaming_model_roc`` method does not make this assumption, and you may get vastly different results depending on how word is positioned in the spectrogram in the data set. + Model can be streaming or nonstreaming. If tested on an "_ambient" set, it detects a false accept if the previous probability was less than 0.5 and the current probability is greater than 0.5. diff --git a/microwakeword/train.py b/microwakeword/train.py index 5ae8d73..b1d42d3 100644 --- a/microwakeword/train.py +++ b/microwakeword/train.py @@ -15,14 +15,27 @@ # limitations under the License. import os +import platform +import contextlib from absl import logging -from collections import deque import numpy as np import tensorflow as tf -import microwakeword.test as test +from tensorflow.python.util import tf_decorator + + +@contextlib.contextmanager +def swap_attribute(obj, attr, temp_value): + """Temporarily swap an attribute of an object.""" + original_value = getattr(obj, attr) + setattr(obj, attr, temp_value) + + try: + yield + finally: + setattr(obj, attr, original_value) def validate_nonstreaming(config, data_processor, model, test_set): @@ -32,30 +45,32 @@ def validate_nonstreaming(config, data_processor, model, test_set): features_length=config["spectrogram_length"], truncation_strategy="truncate_start", ) + testing_ground_truth = testing_ground_truth.reshape(-1, 1) - test_batch_size = 1000 - - for i in range(0, len(testing_fingerprints), test_batch_size): - result = model.test_on_batch( - testing_fingerprints[i : i + test_batch_size], - testing_ground_truth[i : i + test_batch_size], - reset_metrics=(i == 0), - ) - - true_positives = result[4] - false_positives = result[5] - true_negatives = result[6] - false_negatives = result[7] + model.reset_metrics() - metrics = test.compute_metrics( - true_positives, true_negatives, false_positives, false_negatives + result = model.evaluate( + testing_fingerprints, + testing_ground_truth, + batch_size=1024, + return_dict=True, + verbose=0, ) - metrics["loss"] = result[0] - metrics["auc"] = result[8] + metrics = {} + metrics["accuracy"] = result["accuracy"] + metrics["recall"] = result["recall"] + metrics["precision"] = result["precision"] - ambient_false_positives = 0 # float("nan") - estimated_ambient_false_positives_per_hour = 0 # float("nan") + metrics["auc"] = result["auc"] + metrics["loss"] = result["loss"] + metrics["recall_at_no_faph"] = 0 + metrics["cutoff_for_no_faph"] = 0 + metrics["ambient_false_positives"] = 0 + metrics["ambient_false_positives_per_hour"] = 0 + metrics["average_viable_recall"] = 0 + + test_set_fp = result["fp"].numpy() if data_processor.get_mode_size("validation_ambient") > 0: ( @@ -68,30 +83,87 @@ def validate_nonstreaming(config, data_processor, model, test_set): features_length=config["spectrogram_length"], truncation_strategy="split", ) - - for i in range(0, len(ambient_testing_fingerprints), test_batch_size): - ambient_result = model.test_on_batch( - ambient_testing_fingerprints[i : i + test_batch_size], - ambient_testing_ground_truth[i : i + test_batch_size], - reset_metrics=(i == 0), + ambient_testing_ground_truth = ambient_testing_ground_truth.reshape(-1, 1) + + # XXX: tf no longer provides a way to evaluate a model without updating metrics + with swap_attribute(model, "reset_metrics", lambda: None): + ambient_predictions = model.evaluate( + ambient_testing_fingerprints, + ambient_testing_ground_truth, + batch_size=1024, + return_dict=True, + verbose=0, ) - ambient_false_positives = ambient_result[5] - - estimated_ambient_false_positives_per_hour = ambient_false_positives / ( + duration_of_ambient_set = ( data_processor.get_mode_duration("validation_ambient") / 3600.0 ) - metrics["ambient_false_positives"] = ambient_false_positives - metrics["ambient_false_positives_per_hour"] = ( - estimated_ambient_false_positives_per_hour - ) + # Other than the false positive rate, all other metrics are accumulated across + # both test sets + all_true_positives = ambient_predictions["tp"].numpy() + ambient_false_positives = ambient_predictions["fp"].numpy() - test_set_fp + all_false_negatives = ambient_predictions["fn"].numpy() + + metrics["auc"] = ambient_predictions["auc"] + metrics["loss"] = ambient_predictions["loss"] + + recall_at_cutoffs = ( + all_true_positives / (all_true_positives + all_false_negatives) + ) + faph_at_cutoffs = ambient_false_positives / duration_of_ambient_set + + target_faph_cutoff_probability = 1.0 + for index, cutoff in enumerate(np.linspace(0.0, 1.0, 101)): + if faph_at_cutoffs[index] == 0: + target_faph_cutoff_probability = cutoff + recall_at_no_faph = recall_at_cutoffs[index] + break + + if faph_at_cutoffs[0] > 2: + # Use linear interpolation to estimate recall at 2 faph + + # Increase index until we find a faph less than 2 + index_of_first_viable = 1 + while faph_at_cutoffs[index_of_first_viable] > 2: + index_of_first_viable += 1 + + x0 = faph_at_cutoffs[index_of_first_viable - 1] + y0 = recall_at_cutoffs[index_of_first_viable - 1] + x1 = faph_at_cutoffs[index_of_first_viable] + y1 = recall_at_cutoffs[index_of_first_viable] + + recall_at_2faph = (y0 * (x1 - 2.0) + y1 * (2.0 - x0)) / (x1 - x0) + else: + # Lowest faph is already under 2, assume the recall is constant before this + index_of_first_viable = 0 + recall_at_2faph = recall_at_cutoffs[0] + + x_coordinates = [2.0] + y_coordinates = [recall_at_2faph] + + for index in range(index_of_first_viable, len(recall_at_cutoffs)): + if faph_at_cutoffs[index] != x_coordinates[-1]: + # Only add a point if it is a new faph + # This ensures if a faph rate is repeated, we use the highest recall + x_coordinates.append(faph_at_cutoffs[index]) + y_coordinates.append(recall_at_cutoffs[index]) + + # Use trapezoid rule to estimate the area under the curve, then divide by 2.0 to get the average recall + average_viable_recall = ( + np.trapz(np.flip(y_coordinates), np.flip(x_coordinates)) / 2.0 + ) + + metrics["recall_at_no_faph"] = recall_at_no_faph + metrics["cutoff_for_no_faph"] = target_faph_cutoff_probability + metrics["ambient_false_positives"] = ambient_false_positives[50] + metrics["ambient_false_positives_per_hour"] = faph_at_cutoffs[50] + metrics["average_viable_recall"] = average_viable_recall return metrics def train(model, config, data_processor): - # Assign default training settings if not set in the configuration yaml if not (training_steps_list := config.get("training_steps")): training_steps_list = [20000] @@ -132,21 +204,28 @@ def pad_list_with_last_entry(list_to_pad, desired_length): pad_list_with_last_entry(negative_class_weight_list, training_step_iterations) loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) - optimizer = tf.keras.optimizers.legacy.Adam() + optimizer = tf.keras.optimizers.Adam() + + cutoffs = np.linspace(0.0, 1.0, 101).tolist() metrics = [ tf.keras.metrics.BinaryAccuracy(name="accuracy"), tf.keras.metrics.Recall(name="recall"), tf.keras.metrics.Precision(name="precision"), - tf.keras.metrics.TruePositives(name="tp"), - tf.keras.metrics.FalsePositives(name="fp"), - tf.keras.metrics.TrueNegatives(name="tn"), - tf.keras.metrics.FalseNegatives(name="fn"), + tf.keras.metrics.TruePositives(name="tp", thresholds=cutoffs), + tf.keras.metrics.FalsePositives(name="fp", thresholds=cutoffs), + tf.keras.metrics.TrueNegatives(name="tn", thresholds=cutoffs), + tf.keras.metrics.FalseNegatives(name="fn", thresholds=cutoffs), tf.keras.metrics.AUC(name="auc"), + tf.keras.metrics.BinaryCrossentropy(name="loss"), ] model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + # We un-decorate the `tf.function`, it's very slow to manually run training batches + model.make_train_function() + _, model.train_function = tf_decorator.unwrap(model.train_function) + # Configure checkpointer and restore if available checkpoint_directory = os.path.join(config["train_dir"], "restore/") checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") @@ -165,8 +244,7 @@ def pad_list_with_last_entry(list_to_pad, desired_length): best_minimization_quantity = 10000 best_maximization_quantity = 0.0 - - results_deque = deque([]) + best_no_faph_cutoff = 1.0 for training_step in range(1, training_steps_max + 1): training_steps_sum = 0 @@ -184,7 +262,7 @@ def pad_list_with_last_entry(list_to_pad, desired_length): negative_class_weight = negative_class_weight_list[i] break - tf.keras.backend.set_value(model.optimizer.lr, learning_rate) + model.optimizer.learning_rate.assign(learning_rate) augmentation_policy = { "mix_up_prob": mix_up_prob, @@ -207,82 +285,76 @@ def pad_list_with_last_entry(list_to_pad, desired_length): augmentation_policy=augmentation_policy, ) + train_ground_truth = train_ground_truth.reshape(-1, 1) + class_weights = {0: negative_class_weight, 1: positive_class_weight} + combined_weights = train_sample_weights * np.vectorize(class_weights.get)( + train_ground_truth + ) result = model.train_on_batch( train_fingerprints, train_ground_truth, - sample_weight=train_sample_weights, - class_weight=class_weights, + sample_weight=combined_weights, ) - with train_writer.as_default(): - metrics = test.compute_metrics( - true_positives=result[4], - false_positives=result[5], - true_negatives=result[6], - false_negatives=result[7], - ) - - tf.summary.scalar("loss", result[0], step=training_step) - tf.summary.scalar("accuracy", result[1], step=training_step) - tf.summary.scalar("recall", result[2], step=training_step) - tf.summary.scalar("precision", result[3], step=training_step) - tf.summary.scalar("fpr", metrics["false_positive_rate"], step=training_step) - tf.summary.scalar("fnr", metrics["false_negative_rate"], step=training_step) - tf.summary.scalar("auc", result[8], step=training_step) - - if not training_step % 25: - train_writer.flush() - - if len(results_deque) >= 5: - results_deque.popleft() - - results_deque.append(result) - - if not training_step % 5: - loss = 0.0 - accuracy = 0.0 - recall = 0.0 - precision = 0.0 - for i in range(0, 5): - loss += results_deque[i][0] - accuracy += results_deque[i][1] - recall += results_deque[i][2] - precision += results_deque[i][3] + # Print the running statistics in the current validation epoch + print( + "Validation Batch #{:d}: Accuracy = {:.3f}; Recall = {:.3f}; Precision = {:.3f}; Loss = {:.4f}; Mini-Batch #{:d}".format( + (training_step // config["eval_step_interval"] + 1), + result[1], + result[2], + result[3], + result[9], + (training_step % config["eval_step_interval"]), + ), + end="\r", + ) + is_last_step = training_step == training_steps_max + if (training_step % config["eval_step_interval"]) == 0 or is_last_step: logging.info( "Step #%d: rate %f, accuracy %.2f%%, recall %.2f%%, precision %.2f%%, cross entropy %f", *( training_step, learning_rate, - accuracy / 5.0 * 100, - recall / 5.0 * 100, - precision / 5.0 * 100, - loss / 5.0, + result[1] * 100, + result[2] * 100, + result[3] * 100, + result[9], ), ) - is_last_step = training_step == training_steps_max - if (training_step % config["eval_step_interval"]) == 0 or is_last_step: - model.save_weights(os.path.join(config["train_dir"], "last_weights")) + with train_writer.as_default(): + tf.summary.scalar("loss", result[9], step=training_step) + tf.summary.scalar("accuracy", result[1], step=training_step) + tf.summary.scalar("recall", result[2], step=training_step) + tf.summary.scalar("precision", result[3], step=training_step) + tf.summary.scalar("auc", result[8], step=training_step) + train_writer.flush() + + model.save_weights( + os.path.join(config["train_dir"], "last_weights.weights.h5") + ) nonstreaming_metrics = validate_nonstreaming( config, data_processor, model, "validation" ) + model.reset_metrics() # reset metrics for next validation epoch of training logging.info( - "Step %d (nonstreaming): Validation accuracy = %.2f%%, recall = %.2f%%, precision = %.2f%%, fpr = %.2f%%, fnr = %.2f%%, ambient false positives = %d, estimated false positives per hour = %.5f, loss = %.5f, auc = %.5f,", + "Step %d (nonstreaming): Validation: recall at no faph = %.3f with cutoff %.2f, accuracy = %.2f%%, recall = %.2f%%, precision = %.2f%%, ambient false positives = %d, estimated false positives per hour = %.5f, loss = %.5f, auc = %.5f, average viable recall = %.9f", *( training_step, + nonstreaming_metrics["recall_at_no_faph"] * 100, + nonstreaming_metrics["cutoff_for_no_faph"], nonstreaming_metrics["accuracy"] * 100, nonstreaming_metrics["recall"] * 100, nonstreaming_metrics["precision"] * 100, - nonstreaming_metrics["false_positive_rate"] * 100, - nonstreaming_metrics["false_negative_rate"] * 100, nonstreaming_metrics["ambient_false_positives"], nonstreaming_metrics["ambient_false_positives_per_hour"], nonstreaming_metrics["loss"], nonstreaming_metrics["auc"], + nonstreaming_metrics["average_viable_recall"], ), ) @@ -300,34 +372,29 @@ def pad_list_with_last_entry(list_to_pad, desired_length): "precision", nonstreaming_metrics["precision"], step=training_step ) tf.summary.scalar( - "fpr", - nonstreaming_metrics["false_positive_rate"], + "recall_at_no_faph", + nonstreaming_metrics["recall_at_no_faph"], step=training_step, ) tf.summary.scalar( - "fnr", - nonstreaming_metrics["false_negative_rate"], - step=training_step, - ) - tf.summary.scalar( - "faph", - nonstreaming_metrics["ambient_false_positives_per_hour"], + "auc", + nonstreaming_metrics["auc"], step=training_step, ) tf.summary.scalar( - "auc", - nonstreaming_metrics["auc"], + "average_viable_recall", + nonstreaming_metrics["average_viable_recall"], step=training_step, ) validation_writer.flush() + os.makedirs(os.path.join(config["train_dir"], "train"), exist_ok=True) + model.save_weights( os.path.join( config["train_dir"], - "train/", - str(int(best_minimization_quantity * 10000)) - + "weights_" - + str(training_step), + "train", + f"{int(best_minimization_quantity * 10000)}_weights_{training_step}.weights.h5", ) ) @@ -339,6 +406,7 @@ def pad_list_with_last_entry(list_to_pad, desired_length): current_maximization_quantity = nonstreaming_metrics[ config["maximization_metric"] ] + current_no_faph_cutoff = nonstreaming_metrics["cutoff_for_no_faph"] # Save model weights if this is a new best model if ( @@ -374,46 +442,21 @@ def pad_list_with_last_entry(list_to_pad, desired_length): ): best_minimization_quantity = current_minimization_quantity best_maximization_quantity = current_maximization_quantity + best_no_faph_cutoff = current_no_faph_cutoff # overwrite the best model weights - model.save_weights(os.path.join(config["train_dir"], "best_weights")) + model.save_weights( + os.path.join(config["train_dir"], "best_weights.weights.h5") + ) checkpoint.save(file_prefix=checkpoint_prefix) logging.info( - "So far the best minimization quantity is %.3f with best maximization quantity of %.5f%%", + "So far the best minimization quantity is %.3f with best maximization quantity of %.5f%%; no faph cutoff is %.2f", best_minimization_quantity, (best_maximization_quantity * 100), + best_no_faph_cutoff, ) # Save checkpoint after training checkpoint.save(file_prefix=checkpoint_prefix) - - testing_fingerprints, testing_ground_truth, _ = data_processor.get_data( - "testing", - batch_size=config["batch_size"], - features_length=config["spectrogram_length"], - truncation_strategy="truncate_start", - ) - - for i in range(0, len(testing_fingerprints), config["batch_size"]): - result = model.test_on_batch( - testing_fingerprints[i : i + config["batch_size"]], - testing_ground_truth[i : i + config["batch_size"]], - reset_metrics=(i == 0), - ) - - true_positives = result[4] - false_positives = result[5] - true_negatives = result[6] - false_negatives = result[7] - - metrics = test.compute_metrics( - true_positives, true_negatives, false_positives, false_negatives - ) - metrics_string = test.metrics_to_string(metrics) - - logging.info("Last weights on testing set: " + metrics_string) - - with open(os.path.join(config["train_dir"], "metrics_last.txt"), "wt") as fd: - fd.write(metrics_string) - model.save_weights(os.path.join(config["train_dir"], "last_weights")) + model.save_weights(os.path.join(config["train_dir"], "last_weights.weights.h5")) diff --git a/microwakeword/utils.py b/microwakeword/utils.py index f783ff2..154cd0f 100644 --- a/microwakeword/utils.py +++ b/microwakeword/utils.py @@ -20,9 +20,8 @@ import tensorflow as tf from absl import logging -from typing import Sequence -from microwakeword.layers import modes +from microwakeword.layers import modes, stream, strided_drop def _set_mode(model, mode): @@ -35,8 +34,12 @@ def _recursive_set_layer_mode(layer, mode): config = layer.get_config() # for every layer set mode, if it has it if "mode" in config: + assert isinstance( + layer, + (stream.Stream, strided_drop.StridedDrop, strided_drop.StridedKeep), + ) layer.mode = mode - # with any mode of inference - training is False + # with any mode of inference - training is False if "training" in config: layer.training = False if mode == modes.Modes.NON_STREAM_INFERENCE: @@ -48,23 +51,6 @@ def _recursive_set_layer_mode(layer, mode): return model -def _get_input_output_states(model): - """Get input/output states of model with external states.""" - input_states = [] - output_states = [] - for i in range(len(model.layers)): - config = model.layers[i].get_config() - # input output states exist only in layers with property 'mode' - if "mode" in config: - input_state = model.layers[i].get_input_state() - if input_state not in ([], [None]): - input_states.append(model.layers[i].get_input_state()) - output_state = model.layers[i].get_output_state() - if output_state not in ([], [None]): - output_states.append(output_state) - return input_states, output_states - - def _copy_weights(new_model, model): """Copy weights of trained model to an inference one.""" @@ -142,21 +128,6 @@ def _same_weights(weight, new_weight): return new_model -def _flatten_nested_sequence(sequence): - """Returns a flattened list of sequence's elements.""" - if not isinstance(sequence, Sequence): - return [sequence] - result = [] - for value in sequence: - result.extend(_flatten_nested_sequence(value)) - return result - - -def _get_state_shapes(model_states): - """Converts a nested list of states in to a flat list of their shapes.""" - return [state.shape for state in _flatten_nested_sequence(model_states)] - - def save_model_summary(model, path, file_name="model_summary.txt"): """Saves model topology/summary in text format. @@ -197,6 +168,7 @@ def convert_to_inference_model(model, input_tensors, mode): # scope is introduced for simplifiyng access to weights by names scope_name = "streaming" + with tf.name_scope(scope_name): if not isinstance(model, tf.keras.Model): raise ValueError( @@ -209,34 +181,11 @@ def convert_to_inference_model(model, input_tensors, mode): "got a `Sequential` instance instead:", model, ) - # pylint: disable=protected-access - if not model._is_graph_network: - raise ValueError( - "Expected `model` argument " - "to be a functional `Model` instance, " - "but got a subclass model instead." - ) - # pylint: enable=protected-access model = _set_mode(model, mode) - new_model = tf.keras.models.clone_model( - model, input_tensors - ) # _clone_model(model, input_tensors) + new_model = tf.keras.models.clone_model(model, input_tensors) if mode == modes.Modes.STREAM_INTERNAL_STATE_INFERENCE: return _copy_weights(new_model, model) - elif mode == modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE: - input_states, output_states = _get_input_output_states(new_model) - all_inputs = new_model.inputs + input_states - all_outputs = new_model.outputs + output_states - new_streaming_model = tf.keras.Model(all_inputs, all_outputs) - new_streaming_model.input_shapes = _get_state_shapes(all_inputs) - new_streaming_model.output_shapes = _get_state_shapes(all_outputs) - - # inference streaming model with external states - # has the same number of weights with - # non streaming model so we can use set_weights directly - new_streaming_model.set_weights(model.get_weights()) - return new_streaming_model elif mode == modes.Modes.NON_STREAM_INFERENCE: new_model.set_weights(model.get_weights()) return new_model @@ -250,8 +199,8 @@ def to_streaming_inference(model_non_stream, config, mode): Args: model_non_stream: trained Keras model non streamable config: dictionary containing microWakeWord training configuration - mode: it supports Non streaming inference, Streaming inference with internal - states, Streaming inference with external states + mode: it supports Non streaming inference or Streaming inference with internal + states Returns: Keras inference model of inference_type @@ -265,6 +214,7 @@ def to_streaming_inference(model_non_stream, config, mode): else: dtype = model_non_stream.input.dtype + # For streaming, set the batch size to 1 input_tensors = [ tf.keras.layers.Input( shape=input_data_shape, batch_size=1, dtype=dtype, name="input_audio" @@ -280,6 +230,7 @@ def to_streaming_inference(model_non_stream, config, mode): "Maximum number of inputs supported is 2 (input_audio and " "cond_features), but got %d inputs" % len(model_non_stream.input) ) + input_tensors.append( tf.keras.layers.Input( shape=config["cond_shape"], @@ -289,14 +240,22 @@ def to_streaming_inference(model_non_stream, config, mode): ) ) - model_inference = convert_to_inference_model(model_non_stream, input_tensors, mode) + # Input tensors must have the same shape as the original + if isinstance(model_non_stream.input, (tuple, list)): + model_inference = convert_to_inference_model( + model_non_stream, input_tensors, mode + ) + else: + model_inference = convert_to_inference_model( + model_non_stream, input_tensors[0], mode + ) + return model_inference def model_to_saved( model_non_stream, config, - save_model_path, mode=modes.Modes.STREAM_INTERNAL_STATE_INFERENCE, ): """Convert Keras model to SavedModel. @@ -308,8 +267,7 @@ def model_to_saved( Args: model_non_stream: Keras non streamable model config: dictionary containing microWakeWord training configuration - save_model_path: path where saved model representation with be stored - mode: inference mode it can be streaming with external state or non + mode: inference mode it can be streaming with internal state or non streaming """ @@ -325,8 +283,7 @@ def model_to_saved( # convert non streaming Keras model to Keras streaming model, internal state model = to_streaming_inference(model_non_stream, config, mode) - save_model_summary(model, save_model_path) - model.save(save_model_path, include_optimizer=False, save_format="tf") + return model def convert_saved_model_to_tflite( @@ -342,40 +299,48 @@ def convert_saved_model_to_tflite( fname: output filename for TFLite file quantize: boolean selecting whether to quantize the model """ - if not os.path.exists(folder): - os.makedirs(folder) - sample_fingerprints, _, _ = audio_processor.get_data( - "training", 500, features_length=config["spectrogram_length"] - ) + def representative_dataset_gen(): + sample_fingerprints, _, _ = audio_processor.get_data( + "training", 500, features_length=config["spectrogram_length"] + ) - sample_fingerprints[0][0, 0] = 0.0 # guarantee one pixel is the preprocessor min - sample_fingerprints[0][0, 1] = 26.0 # guarantee one pixel is the preprocessor max + sample_fingerprints[0][ + 0, 0 + ] = 0.0 # guarantee one pixel is the preprocessor min + sample_fingerprints[0][ + 0, 1 + ] = 26.0 # guarantee one pixel is the preprocessor max + + # for spectrogram in sample_fingerprints: + # yield spectrogram + + stride = config["stride"] - def representative_dataset_gen(): for spectrogram in sample_fingerprints: - for i in range(spectrogram.shape[0]): - yield [spectrogram[i, :].astype(np.float32)] + assert spectrogram.shape[0] % stride == 0 - converter = tf.compat.v2.lite.TFLiteConverter.from_saved_model(path_to_model) - converter.experimental_new_quantizer = True - converter.experimental_enable_resource_variables = True - converter.experimental_new_converter = True - converter._experimental_variable_quantization = True - converter.optimizations = [tf.lite.Optimize.DEFAULT] + for i in range(0, spectrogram.shape[0] - stride, stride): + sample = spectrogram[i : i + stride, :].astype(np.float32) + yield [sample] - if quantize: - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter = tf.lite.TFLiteConverter.from_saved_model(path_to_model) + converter.optimizations = {tf.lite.Optimize.DEFAULT} - converter.inference_type = tf.int8 + if quantize: + converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8} converter.inference_input_type = tf.int8 converter.inference_output_type = tf.uint8 + converter.representative_dataset = tf.lite.RepresentativeDataset( + representative_dataset_gen + ) - converter.representative_dataset = representative_dataset_gen + if not os.path.exists(folder): + os.makedirs(folder) - tflite_model = converter.convert() - path_to_output = os.path.join(folder, fname) - open(path_to_output, "wb").write(tflite_model) + with open(os.path.join(folder, fname), "wb") as f: + tflite_model = converter.convert() + f.write(tflite_model) def convert_model_saved(model, config, folder, mode): @@ -391,10 +356,24 @@ def convert_model_saved(model, config, folder, mode): path_model = os.path.join(config["train_dir"], folder) if not os.path.exists(path_model): os.makedirs(path_model) - try: - # Convert trained model to SavedModel - model_to_saved(model, config, path_model, mode) - except IOError as e: - logging.warning("FAILED to write file: %s", e) - except (ValueError, AttributeError, RuntimeError, TypeError, AssertionError) as e: - logging.warning("WARNING: failed to convert to SavedModel: %s", e) + + # Convert trained model to SavedModel + converted_model = model_to_saved(model, config, mode) + converted_model.summary() + + assert converted_model.input.shape[0] is not None + + # XXX: Using `converted_model.export(path_model)` results in obscure errors during + # quantization, we create an export archive directly instead. + export_archive = tf.keras.export.ExportArchive() + export_archive.track(converted_model) + export_archive.add_endpoint( + name="serve", + fn=converted_model.call, + input_signature=[tf.TensorSpec(shape=converted_model.input.shape, dtype=tf.float32)], + ) + export_archive.write_out(path_model) + + save_model_summary(converted_model, path_model) + + return converted_model diff --git a/models/alexa.json b/models/alexa.json deleted file mode 100644 index 7842da1..0000000 --- a/models/alexa.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "type": "micro", - "wake_word": "alexa", - "author": "Kevin Ahrendt", - "website": "https://www.kevinahrendt.com/", - "model": "./alexa.tflite", - "version": 1, - "micro": { - "probability_cutoff": 0.66, - "sliding_window_average_size": 10 - } -} diff --git a/models/alexa.tflite b/models/alexa.tflite deleted file mode 100644 index 3dd7438..0000000 Binary files a/models/alexa.tflite and /dev/null differ diff --git a/models/hey_jarvis.json b/models/hey_jarvis.json deleted file mode 100644 index ec064a9..0000000 --- a/models/hey_jarvis.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "type": "micro", - "wake_word": "hey jarvis", - "author": "Kevin Ahrendt", - "website": "https://www.kevinahrendt.com/", - "model": "./hey_jarvis.tflite", - "version": 1, - "micro": { - "probability_cutoff": 0.5, - "sliding_window_average_size": 10 - } -} diff --git a/models/hey_jarvis.tflite b/models/hey_jarvis.tflite deleted file mode 100644 index fedaf60..0000000 Binary files a/models/hey_jarvis.tflite and /dev/null differ diff --git a/models/okay_nabu.json b/models/okay_nabu.json deleted file mode 100644 index a215971..0000000 --- a/models/okay_nabu.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "type": "micro", - "wake_word": "okay nabu", - "author": "Kevin Ahrendt", - "website": "https://www.kevinahrendt.com/", - "model": "okay_nabu.tflite", - "version": 1, - "micro": { - "probability_cutoff": 0.5, - "sliding_window_average_size": 10 - } -} diff --git a/models/okay_nabu.tflite b/models/okay_nabu.tflite deleted file mode 100644 index 5f91d01..0000000 Binary files a/models/okay_nabu.tflite and /dev/null differ diff --git a/notebooks/basic_training_notebook.ipynb b/notebooks/basic_training_notebook.ipynb new file mode 100644 index 0000000..253e106 --- /dev/null +++ b/notebooks/basic_training_notebook.ipynb @@ -0,0 +1,559 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "r11cNiLqvWC6" + }, + "source": [ + "# Training a microWakeWord Model\n", + "\n", + "This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.\n", + "\n", + "**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**\n", + "\n", + "In the comment at the start of certain blocks, I note some specific settings to consider modifying.\n", + "\n", + "This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!\n", + "\n", + "At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "BFf6511E65ff" + }, + "outputs": [], + "source": [ + "# Installs microWakeWord. Be sure to restart the session after this is finished.\n", + "import platform\n", + "\n", + "if platform.system() == \"Darwin\":\n", + " # `pymicro-features` is installed from a fork to support building on macOS\n", + " !pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'\n", + "\n", + "# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n", + "!pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'\n", + "\n", + "!git clone -b november-update https://github.com/kahrendt/microWakeWord\n", + "!pip install -e ./microWakeWord" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dEluu7nL7ywd" + }, + "outputs": [], + "source": [ + "# Generates 1 sample of the target word for manual verification.\n", + "\n", + "target_word = 'khum_puter' # Phonetic spellings may produce better samples\n", + "\n", + "import os\n", + "import sys\n", + "import platform\n", + "\n", + "from IPython.display import Audio\n", + "\n", + "if not os.path.exists(\"./piper-sample-generator\"):\n", + " if platform.system() == \"Darwin\":\n", + " !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n", + " else:\n", + " !git clone https://github.com/rhasspy/piper-sample-generator\n", + "\n", + " !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n", + "\n", + " # Install system dependencies\n", + " !pip install torch torchaudio piper-phonemize-cross==1.2.1\n", + "\n", + " if \"piper-sample-generator/\" not in sys.path:\n", + " sys.path.append(\"piper-sample-generator/\")\n", + "\n", + "!python3 piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", + "--max-samples 1 \\\n", + "--batch-size 1 \\\n", + "--output-dir generated_samples\n", + "\n", + "Audio(\"generated_samples/0.wav\", autoplay=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-SvGtCCM9akR" + }, + "outputs": [], + "source": [ + "# Generates a larger amount of wake word samples.\n", + "# Start here when trying to improve your model.\n", + "# See https://github.com/rhasspy/piper-sample-generator for the full set of\n", + "# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n", + "# generating negative samples similar to the wake word, and generating many more\n", + "# wake word samples, possibly with different phonetic pronunciations.\n", + "\n", + "!python3 piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", + "--max-samples 1000 \\\n", + "--batch-size 100 \\\n", + "--output-dir generated_samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YJRG4Qvo9nXG" + }, + "outputs": [], + "source": [ + "# Downloads audio data for augmentation. This can be slow!\n", + "# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n", + "#\n", + "# **Important note!** The data downloaded here has a mixture of difference\n", + "# licenses and usage restrictions. As such, any custom models trained with this\n", + "# data should be considered as appropriate for **non-commercial** personal use only.\n", + "\n", + "\n", + "import datasets\n", + "import scipy\n", + "import os\n", + "\n", + "import numpy as np\n", + "\n", + "from pathlib import Path\n", + "from tqdm import tqdm\n", + "\n", + "## Download MIR RIR data\n", + "\n", + "output_dir = \"./mit_rirs\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", + " # Save clips to 16-bit PCM wav files\n", + " for row in tqdm(rir_dataset):\n", + " name = row['audio']['path'].split('/')[-1]\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", + "\n", + "## Download noise and background audio\n", + "\n", + "# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n", + "# Download one part of the audioset .tar files, extract, and convert to 16khz\n", + "# For full-scale training, it's recommended to download the entire dataset from\n", + "# https://huggingface.co/datasets/agkphysics/AudioSet, and\n", + "# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n", + "\n", + "if not os.path.exists(\"audioset\"):\n", + " os.mkdir(\"audioset\")\n", + "\n", + " fname = \"bal_train09.tar\"\n", + " out_dir = f\"audioset/{fname}\"\n", + " link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/\" + fname\n", + " !wget -O {out_dir} {link}\n", + " !cd audioset && tar -xf bal_train09.tar\n", + "\n", + " output_dir = \"./audioset_16k\"\n", + " if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + "\n", + " # Save clips to 16-bit PCM wav files\n", + " audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n", + " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", + " for row in tqdm(audioset_dataset):\n", + " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", + "\n", + "# Free Music Archive dataset\n", + "# https://github.com/mdeff/fma\n", + "# (Third-party mchl914 extra small set)\n", + "\n", + "output_dir = \"./fma\"\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " fname = \"fma_xs.zip\"\n", + " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", + " out_dir = f\"fma/{fname}\"\n", + " !wget -O {out_dir} {link}\n", + " !cd {output_dir} && unzip -q {fname}\n", + "\n", + " output_dir = \"./fma_16k\"\n", + " if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + "\n", + " # Save clips to 16-bit PCM wav files\n", + " fma_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"fma/fma_small\").glob(\"**/*.mp3\")]})\n", + " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", + " for row in tqdm(audioset_dataset):\n", + " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", + " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XW3bmbI5-JAz" + }, + "outputs": [], + "source": [ + "# Sets up the augmentations.\n", + "# To improve your model, experiment with these settings and use more sources of\n", + "# background clips.\n", + "\n", + "from microwakeword.audio.augmentation import Augmentation\n", + "from microwakeword.audio.clips import Clips\n", + "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", + "\n", + "clips = Clips(input_directory='generated_samples',\n", + " file_pattern='*.wav',\n", + " max_clip_duration_s=None,\n", + " remove_silence=False,\n", + " random_split_seed=10,\n", + " split_count=0.1,\n", + " )\n", + "augmenter = Augmentation(augmentation_duration_s=3.2,\n", + " augmentation_probabilities = {\n", + " \"SevenBandParametricEQ\": 0.1,\n", + " \"TanhDistortion\": 0.1,\n", + " \"PitchShift\": 0.1,\n", + " \"BandStopFilter\": 0.1,\n", + " \"AddColorNoise\": 0.1,\n", + " \"AddBackgroundNoise\": 0.75,\n", + " \"Gain\": 1.0,\n", + " \"RIR\": 0.5,\n", + " },\n", + " impulse_paths = ['mit_rirs'],\n", + " background_paths = ['fma_16k', 'audioset_16k'],\n", + " background_min_snr_db = -5,\n", + " background_max_snr_db = 10,\n", + " min_jitter_s = 0.195,\n", + " max_jitter_s = 0.205,\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V5UsJfKKD1k9" + }, + "outputs": [], + "source": [ + "# Augment a random clip and play it back to verify it works well\n", + "\n", + "from IPython.display import Audio\n", + "from microwakeword.audio.audio_utils import save_clip\n", + "\n", + "random_clip = clips.get_random_clip()\n", + "augmented_clip = augmenter.augment_clip(random_clip)\n", + "save_clip(augmented_clip, 'augmented_clip.wav')\n", + "\n", + "Audio(\"augmented_clip.wav\", autoplay=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "D7BHcY1mEGbK" + }, + "outputs": [], + "source": [ + "# Augment samples and save the training, validation, and testing sets.\n", + "# Validating and testing samples generated the same way can make the model\n", + "# benchmark better than it performs in real-word use. Use real samples or TTS\n", + "# samples generated with a different TTS engine to potentially get more accurate\n", + "# benchmarks.\n", + "\n", + "import os\n", + "from mmap_ninja.ragged import RaggedMmap\n", + "\n", + "output_dir = 'generated_augmented_features'\n", + "\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + "\n", + "splits = [\"training\", \"validation\", \"testing\"]\n", + "for split in splits:\n", + " out_dir = os.path.join(output_dir, split)\n", + " if not os.path.exists(out_dir):\n", + " os.mkdir(out_dir)\n", + "\n", + "\n", + " split_name = \"train\"\n", + " repetition = 2\n", + "\n", + " spectrograms = SpectrogramGeneration(clips=clips,\n", + " augmenter=augmenter,\n", + " slide_frames=10, # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.\n", + " step_ms=10,\n", + " )\n", + " if split == \"validation\":\n", + " split_name = \"validation\"\n", + " repetition = 1\n", + " elif split == \"testing\":\n", + " split_name = \"test\"\n", + " repetition = 1\n", + " spectrograms = SpectrogramGeneration(clips=clips,\n", + " augmenter=augmenter,\n", + " slide_frames=1, # The testing set uses the streaming version of the model, so no artificial repetition is necessary\n", + " step_ms=10,\n", + " )\n", + "\n", + " RaggedMmap.from_generator(\n", + " out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n", + " sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),\n", + " batch_size=100,\n", + " verbose=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1pGuJDPyp3ax" + }, + "outputs": [], + "source": [ + "# Downloads pre-generated spectrogram features (made for microWakeWord in\n", + "# particular) for various negative datasets. This can be slow!\n", + "\n", + "output_dir = './negative_datasets'\n", + "if not os.path.exists(output_dir):\n", + " os.mkdir(output_dir)\n", + " link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n", + " filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n", + " for fname in filenames:\n", + " link = link_root + fname\n", + "\n", + " zip_path = f\"negative_datasets/{fname}\"\n", + " !wget -O {zip_path} {link}\n", + " !unzip -q {zip_path} -d {output_dir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ii1A14GsGVQT" + }, + "outputs": [], + "source": [ + "# Save a yaml config that controls the training process\n", + "# These hyperparamters can make a huge different in model quality.\n", + "# Experiment with sampling and penalty weights and increasing the number of\n", + "# training steps.\n", + "\n", + "import yaml\n", + "import os\n", + "\n", + "config = {}\n", + "\n", + "config[\"window_step_ms\"] = 10\n", + "\n", + "config[\"train_dir\"] = (\n", + " \"trained_models/wakeword\"\n", + ")\n", + "\n", + "\n", + "# Each feature_dir should have at least one of the following folders with this structure:\n", + "# training/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# testing/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# testing_ambient/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# validation/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "# validation_ambient/\n", + "# ragged_mmap_folders_ending_in_mmap\n", + "#\n", + "# sampling_weight: Weight for choosing a spectrogram from this set in the batch\n", + "# penalty_weight: Penalizing weight for incorrect predictions from this set\n", + "# truth: Boolean whether this set has positive samples or negative samples\n", + "# truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated\n", + "# - random: choose a random portion of the entire spectrogram - useful for long negative samples\n", + "# - truncate_start: remove the start of the spectrogram\n", + "# - truncate_end: remove the end of the spectrogram\n", + "# - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets\n", + "\n", + "config[\"features\"] = [\n", + " {\n", + " \"features_dir\": \"generated_augmented_features\",\n", + " \"sampling_weight\": 2.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": True,\n", + " \"truncation_strategy\": \"truncate_start\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " {\n", + " \"features_dir\": \"negative_datasets/speech\",\n", + " \"sampling_weight\": 10.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"random\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " {\n", + " \"features_dir\": \"negative_datasets/dinner_party\",\n", + " \"sampling_weight\": 10.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"random\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " {\n", + " \"features_dir\": \"negative_datasets/no_speech\",\n", + " \"sampling_weight\": 5.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"random\",\n", + " \"type\": \"mmap\",\n", + " },\n", + " { # Only used for validation and testing\n", + " \"features_dir\": \"negative_datasets/dinner_party_eval\",\n", + " \"sampling_weight\": 0.0,\n", + " \"penalty_weight\": 1.0,\n", + " \"truth\": False,\n", + " \"truncation_strategy\": \"split\",\n", + " \"type\": \"mmap\",\n", + " },\n", + "]\n", + "\n", + "# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps\n", + "config[\"training_steps\"] = [10000]\n", + "\n", + "# Penalizing weight for incorrect class predictions - lists that correspond to training steps\n", + "config[\"positive_class_weight\"] = [1]\n", + "config[\"negative_class_weight\"] = [20]\n", + "\n", + "config[\"learning_rates\"] = [\n", + " 0.001,\n", + "] # Learning rates for Adam optimizer - list that corresponds to training steps\n", + "config[\"batch_size\"] = 128\n", + "\n", + "config[\"time_mask_max_size\"] = [\n", + " 0\n", + "] # SpecAugment - list that corresponds to training steps\n", + "config[\"time_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", + "config[\"freq_mask_max_size\"] = [\n", + " 0\n", + "] # SpecAugment - list that corresponds to training steps\n", + "config[\"freq_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", + "\n", + "config[\"eval_step_interval\"] = (\n", + " 500 # Test the validation sets after every this many steps\n", + ")\n", + "config[\"clip_duration_ms\"] = (\n", + " 1500 # Maximum length of wake word that the streaming model will accept\n", + ")\n", + "\n", + "# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization\n", + "# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize\n", + "# Available metrics:\n", + "# - \"loss\" - cross entropy error on validation set\n", + "# - \"accuracy\" - accuracy of validation set\n", + "# - \"recall\" - recall of validation set\n", + "# - \"precision\" - precision of validation set\n", + "# - \"false_positive_rate\" - false positive rate of validation set\n", + "# - \"false_negative_rate\" - false negative rate of validation set\n", + "# - \"ambient_false_positives\" - count of false positives from the split validation_ambient set\n", + "# - \"ambient_false_positives_per_hour\" - estimated number of false positives per hour on the split validation_ambient set\n", + "config[\"target_minimization\"] = 0.9\n", + "config[\"minimization_metric\"] = None # Set to None to disable\n", + "\n", + "config[\"maximization_metric\"] = \"average_viable_recall\"\n", + "\n", + "with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n", + " documents = yaml.dump(config, file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WoEXJBaiC9mf" + }, + "outputs": [], + "source": [ + "# Trains a model. When finished, it will quantize and convert the model to a\n", + "# streaming version suitable for on-device detection.\n", + "# It will resume if stopped, but it will start over at the configured training\n", + "# steps in the yaml file.\n", + "# Change --train 0 to only convert and test the best-weighted model.\n", + "# On Google colab, it doesn't print the mini-batch results, so it may appear\n", + "# stuck for several minutes! Additionally, it is very slow compared to training\n", + "# on a local GPU.\n", + "\n", + "!python -m microwakeword.model_train_eval \\\n", + "--training_config='training_parameters.yaml' \\\n", + "--train 1 \\\n", + "--restore_checkpoint 1 \\\n", + "--test_tf_nonstreaming 0 \\\n", + "--test_tflite_nonstreaming 0 \\\n", + "--test_tflite_nonstreaming_quantized 0 \\\n", + "--test_tflite_streaming 0 \\\n", + "--test_tflite_streaming_quantized 1 \\\n", + "--use_weights \"best_weights\" \\\n", + "mixednet \\\n", + "--pointwise_filters \"64,64,64,64\" \\\n", + "--repeat_in_block \"1, 1, 1, 1\" \\\n", + "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", + "--residual_connection \"0,0,0,0\" \\\n", + "--first_conv_filters 32 \\\n", + "--first_conv_kernel_size 5 \\\n", + "--stride 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ex_UIWvwtjAN" + }, + "outputs": [], + "source": [ + "# Downloads the tflite model file. To use on the device, you need to write a\n", + "# Model JSON file. See https://esphome.io/components/micro_wake_word for the\n", + "# documentation and\n", + "# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for\n", + "# examples. Adjust the probability threshold based on the test results obtained\n", + "# after training is finished. You may also need to increase the Tensor arena\n", + "# model size if the model fails to load.\n", + "\n", + "from google.colab import files\n", + "\n", + "files.download(f\"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\")" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/notebooks/feature_generation.ipynb b/notebooks/feature_generation.ipynb deleted file mode 100644 index ab906ee..0000000 --- a/notebooks/feature_generation.ipynb +++ /dev/null @@ -1,183 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Downloads audio data for augmentation\n", - "# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n", - "\n", - "import datasets\n", - "import scipy\n", - "import os\n", - "\n", - "import numpy as np\n", - "\n", - "from pathlib import Path\n", - "from tqdm import tqdm\n", - "\n", - "## Download MIR RIR data\n", - "\n", - "output_dir = \"./mit_rirs\"\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", - " # Save clips to 16-bit PCM wav files\n", - " for row in tqdm(rir_dataset):\n", - " name = row['audio']['path'].split('/')[-1]\n", - " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", - "\n", - "## Download noise and background audio\n", - "\n", - "# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n", - "# Download one part of the audioset .tar files, extract, and convert to 16khz\n", - "# For full-scale training, it's recommended to download the entire dataset from\n", - "# https://huggingface.co/datasets/agkphysics/AudioSet, and\n", - "# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n", - "\n", - "if not os.path.exists(\"audioset\"):\n", - " os.mkdir(\"audioset\")\n", - "\n", - " fname = \"bal_train09.tar\"\n", - " out_dir = f\"audioset/{fname}\"\n", - " link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/\" + fname\n", - " !wget -O {out_dir} {link}\n", - " !cd audioset && tar -xvf bal_train09.tar\n", - "\n", - " output_dir = \"./audioset_16k\"\n", - " if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - "\n", - " # Save clips to 16-bit PCM wav files\n", - " audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n", - " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", - " for row in tqdm(audioset_dataset):\n", - " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", - " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", - "\n", - "# Free Music Archive dataset\n", - "# https://github.com/mdeff/fma\n", - "\n", - "output_dir = \"./fma\"\n", - "if not os.path.exists(output_dir):\n", - " os.mkdir(output_dir)\n", - " fma_dataset = datasets.load_dataset(\"rudraml/fma\", name=\"small\", split=\"train\", streaming=True)\n", - " fma_dataset = iter(fma_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000)))\n", - "\n", - " # Save clips to 16-bit PCM wav files\n", - " n_hours = 1 # use only 1 hour of clips for this example notebook, recommend increasing for full-scale training\n", - " for i in tqdm(range(n_hours*3600//30)): # this works because the FMA dataset is all 30 second clips\n", - " row = next(fma_dataset)\n", - " name = row['audio']['path'].split('/')[-1].replace(\".mp3\", \".wav\")\n", - " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", - " i += 1\n", - " if i == n_hours*3600//30:\n", - " break\n", - " \n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Specify parameters for augmentation\n", - "\n", - "audio_config = {}\n", - "\n", - "audio_config['features_output_dir'] = 'augmented_features_mmap'\n", - "\n", - "audio_config['input_path'] = 'generated_samples'\n", - "audio_config['input_glob'] = '**/*.wav'\n", - "audio_config['impulse_paths'] = ['mit_rirs']\n", - "audio_config['background_paths'] = ['fma', 'audioset_16k']\n", - "audio_config['min_clip_duration_s'] = None\n", - "audio_config['max_clip_duration_s'] = 1.39\n", - "audio_config['max_start_time_from_right_s'] = 1.49\n", - "audio_config['augmented_duration_s'] = 3.0\n", - "\n", - "from microwakeword.feature_generation import ClipsHandler" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load audio clips and prepare them for augmentation\n", - "\n", - "clips_handler = ClipsHandler(\n", - " input_path=audio_config['input_path'],\n", - " input_glob=audio_config['input_glob'],\n", - " impulse_paths=audio_config['impulse_paths'], \n", - " background_paths=audio_config['background_paths'], \n", - " augmentation_probabilities = {\n", - " \"SevenBandParametricEQ\": 0.25,\n", - " \"TanhDistortion\": 0.25,\n", - " \"PitchShift\": 0.25,\n", - " \"BandStopFilter\": 0.25,\n", - " \"AddBackgroundNoise\": 0.75,\n", - " \"Gain\": 1.0,\n", - " \"RIR\": 0.5,\n", - " },\n", - " augmented_duration_s = audio_config['augmented_duration_s'],\n", - " max_start_time_from_right_s = audio_config['max_start_time_from_right_s'],\n", - " max_clip_duration_s = audio_config['max_clip_duration_s'], \n", - " min_clip_duration_s = audio_config['min_clip_duration_s'], \n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Test by playing a randomly augmented clip\n", - "\n", - "import IPython\n", - "\n", - "clips_handler.save_random_augmented_clip(\"augmented_clip.wav\")\n", - "\n", - "IPython.display.display(IPython.display.Audio(\"augmented_clip.wav\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Save all the clip's augmented features in a Ragged Mmap\n", - "\n", - "clips_handler.save_augmented_features(audio_config['features_output_dir'])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/training_notebook.ipynb b/notebooks/training_notebook.ipynb deleted file mode 100644 index bb574f2..0000000 --- a/notebooks/training_notebook.ipynb +++ /dev/null @@ -1,202 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Install necessary packages\n", - "!pip install tensorflow\n", - "!pip install mmap_ninja\n", - "!pip install pyyaml\n", - "!pip install datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Save a yaml config that controls the training process\n", - "import yaml\n", - "import os\n", - "\n", - "config = {}\n", - "\n", - "config['train_dir'] = 'trained_models/alexa_space_in_phonemes'\n", - "\n", - "# Each feature_dir should have at least one of the following folders with this structure:\n", - "# training/\n", - "# ragged_mmap_folders_ending_in_mmap\n", - "# testing/\n", - "# ragged_mmap_folders_ending_in_mmap\n", - "# testing_ambient/\n", - "# ragged_mmap_folders_ending_in_mmap\n", - "# validation/\n", - "# ragged_mmap_folders_ending_in_mmap\n", - "# validation_ambient/\n", - "# ragged_mmap_folders_ending_in_mmap\n", - "#\n", - "# sampling_weight: Weight for choosing a spectrogram from this set in the batch\n", - "# penalty_weight: Penalizing weight for incorrect predictions from this set\n", - "# truth: Boolean whether this set has positive samples or negative samples\n", - "# truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated\n", - "# - random: choose a random portion of the entire spectrogram - useful for long negative samples\n", - "# - truncate_start: remove the start of the spectrogram\n", - "# - truncate_end: remove the end of the spectrogram\n", - "# - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets\n", - "\n", - "config['features'] = [\n", - " {\n", - " 'features_dir': '/Volumes/MachineLearning/training_data/alexa_phonetic/generated_positive',\n", - " 'sampling_weight': 1,\n", - " 'penalty_weight': 1,\n", - " 'truth': True,\n", - " 'truncation_strategy': 'truncate_start',\n", - " 'type': \"mmap\",\n", - " },\n", - " {\n", - " 'features_dir': '/Volumes/MachineLearning/training_data/alexa_phonetic/generated_negative',\n", - " 'sampling_weight': 1,\n", - " 'penalty_weight': 1,\n", - " 'truth': False,\n", - " 'truncation_strategy': 'truncate_start',\n", - " 'type': \"mmap\", \n", - " },\n", - " {\n", - " 'features_dir': '/Volumes/MachineLearning/training_data/negative_datasets/english_speech_background',\n", - " 'sampling_weight': 2,\n", - " 'penalty_weight': 1,\n", - " 'truth': False,\n", - " 'truncation_strategy': 'random',\n", - " 'type': \"mmap\", \n", - " },\n", - " {\n", - " 'features_dir': '/Volumes/MachineLearning/training_data/negative_datasets/non_english_speech_background',\n", - " 'sampling_weight': 1,\n", - " 'penalty_weight': 1,\n", - " 'truth': False,\n", - " 'truncation_strategy': 'random',\n", - " 'type': \"mmap\", \n", - " },\n", - " {\n", - " 'features_dir': '/Volumes/MachineLearning/training_data/negative_datasets/dinner_party_background',\n", - " 'sampling_weight': 2,\n", - " 'penalty_weight': 3,\n", - " 'truth': False,\n", - " 'truncation_strategy': 'random',\n", - " 'type': \"mmap\", \n", - " },\n", - " {\n", - " 'features_dir': '/Volumes/MachineLearning/training_data/negative_datasets/no_speech_background',\n", - " 'sampling_weight': 1,\n", - " 'penalty_weight': 1,\n", - " 'truth': False,\n", - " 'truncation_strategy': 'random',\n", - " 'type': \"mmap\", \n", - " },\n", - " {\n", - " 'features_dir': '/Volumes/MachineLearning/training_data/negative_datasets/ambient_background',\n", - " 'sampling_weight': 0.0,\n", - " 'penalty_weight': 1,\n", - " 'truth': False,\n", - " 'truncation_strategy': 'split',\n", - " 'type': \"mmap\", \n", - " },\n", - " ]\n", - "\n", - "# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps\n", - "config['training_steps'] = [30000, 30000,20000,20000] \n", - "\n", - "# Penalizing weight for incorrect class predictions - lists that correspond to training steps\n", - "config[\"positive_class_weight\"] = [1] \n", - "config[\"negative_class_weight\"] = [1]\n", - "\n", - "config['learning_rates'] = [0.001, 0.0005,0.0002,0.0001] # Learning rates for Adam optimizer - list that corresponds to training steps\n", - "config['batch_size'] = 128\n", - "\n", - "config['mix_up_augmentation_prob'] = [0] # Probability of applying MixUp augmentation - list that corresponds to training steps\n", - "config['freq_mix_augmentation_prob'] = [0] # Probability of applying FreqMix augmentation - list that corresponds to training steps\n", - "config['time_mask_max_size'] = [5] # SpecAugment - list that corresponds to training steps\n", - "config['time_mask_count'] = [2] # SpecAugment - list that corresponds to training steps\n", - "config['freq_mask_max_size'] = [5] # SpecAugment - list that corresponds to training steps\n", - "config['freq_mask_count'] = [2] # SpecAugment - list that corresponds to training steps\n", - "config['eval_step_interval'] = 500 # Test the validation sets after every this many steps\n", - "\n", - "config['clip_duration_ms'] = 1490 # Maximum length of wake word that the streaming model will accept\n", - "config['window_stride_ms'] = 20 # Fixed setting for default feature generator\n", - "config['window_size_ms'] = 30 # Fixed setting for default feature generator\n", - "config['sample_rate'] = 16000 # Fixed setting for default feature generator\n", - "\n", - "# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization\n", - "# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize\n", - "# Available metrics:\n", - "# - \"loss\" - cross entropy error on validation set\n", - "# - \"accuracy\" - accuracy of validation set\n", - "# - \"recall\" - recall of validation set\n", - "# - \"precision\" - precision of validation set\n", - "# - \"false_positive_rate\" - false positive rate of validation set\n", - "# - \"false_negative_rate\" - false negative rate of validation set\n", - "# - \"ambient_false_positives\" - count of false positives from the split validation_ambient set\n", - "# - \"ambient_false_positives_per_hour\" - estimated number of false positives per hour on the split validation_ambient set\n", - "config['minimization_metric'] = 'ambient_false_positives_per_hour' # Set to None to disable\n", - "config['target_minimization'] = 0.5\n", - "config['maximization_metric'] = 'recall'\n", - "config['binary_classification'] = False\n", - "\n", - "with open(os.path.join('training_parameters.yaml'), 'w') as file:\n", - " documents = yaml.dump(config, file)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "!python -m microwakeword.model_train_eval \\\n", - "--training_config='training_parameters.yaml' \\\n", - "--train 0 \\\n", - "--restore_checkpoint 1 \\\n", - "--test_tf_nonstreaming 0 \\\n", - "--test_tflite_nonstreaming 0 \\\n", - "--test_tflite_streaming 0 \\\n", - "--test_tflite_streaming_quantized 1 \\\n", - "--use_weights \"last_weights\" \\\n", - "inception \\\n", - "--cnn1_filters '32' \\\n", - "--cnn1_kernel_sizes '5' \\\n", - "--cnn1_subspectral_groups '1' \\\n", - "--cnn2_filters1 '24,24,24' \\\n", - "--cnn2_filters2 '32,64,96' \\\n", - "--cnn2_kernel_sizes '3,5,5' \\\n", - "--cnn2_subspectral_groups '1,1,1' \\\n", - "--cnn2_dilation '1,1,1' \\\n", - "--dropout 0.8\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index 7f08048..0955125 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,13 +4,13 @@ build-backend = "setuptools.build_meta" [project] name = "microwakeword" -version = "0.0.1" +version = "0.1.0" authors = [ { name="Kevin Ahrendt", email="kahrendt@gmail.com" }, ] description = "A TensorFlow based wake word detection training framework using synthetic sample generation suitable for certain microcontrollers." readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.10, <3.11" dynamic = ["dependencies"] classifiers = [ "Programming Language :: Python :: 3", @@ -20,4 +20,8 @@ classifiers = [ [project.urls] Homepage = "https://github.com/kahrendt/microWakeWord" -Issues = "https://github.com/kahrendt/microWakeWord/issues" \ No newline at end of file +Issues = "https://github.com/kahrendt/microWakeWord/issues" + +[tool.black] +target-version = ["py310"] +exclude = 'generated' \ No newline at end of file diff --git a/setup.py b/setup.py index 7ce1e01..932ceed 100644 --- a/setup.py +++ b/setup.py @@ -5,15 +5,17 @@ setuptools.setup( name="microwakeword", - version="0.0.1", + version="0.1.0", install_requires=[ "audiomentations", "audio_metadata", "datasets", "mmap_ninja", "numpy", + "pymicro-features", "pyyaml", - "tensorflow>=2.14", + "tensorflow>=2.16", + "webrtcvad", ], author="Kevin Ahrendt", author_email="kahrendt@gmail.com", @@ -31,5 +33,5 @@ ], packages=setuptools.find_packages(), include_package_data=True, - python_requires=">=3.8", + python_requires=">=3.10, <3.11", )