Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Clips providers to be used for validation #40

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
8 changes: 7 additions & 1 deletion microwakeword/audio/augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Augmentation:
min_jitter_s (float, optional): The minimum duration in seconds that the original clip is positioned before the end of the augmented audio. Defaults to 0.0.
max_jitter_s (float, optional): The maximum duration in seconds that the original clip is positioned before the end of the augmented audio. Defaults to 0.0.
truncate_randomly: (bool, option): If true, the clip is truncated to the specified duration randomly. Otherwise, the start of the clip is truncated.
disable_augmentations: (bool, option): If true, augmentations themselves are disabled, only truncation is performed.
"""

def __init__(
Expand Down Expand Up @@ -67,8 +68,8 @@ def __init__(
min_jitter_s: float = 0.0,
max_jitter_s: float = 0.0,
truncate_randomly: bool = False,
disable_augmentations: bool = False,
):
self.truncate_randomly = truncate_randomly
############################################
# Configure audio duration and positioning #
############################################
Expand All @@ -88,6 +89,8 @@ def __init__(
#######################
# Setup augmentations #
#######################
self.disable_augmentations = disable_augmentations
self.truncate_randomly = truncate_randomly

# If either the background_paths or impulse_paths are not specified, use an identity transform instead
def identity_transform(samples, sample_rate):
Expand Down Expand Up @@ -220,6 +223,9 @@ def augment_clip(self, input_audio: np.ndarray):
Returns:
numpy.ndarray: The augmented audio of fixed duration.
"""
if self.disable_augmentations:
return self.create_fixed_size_clip(input_audio)

input_audio = self.add_jitter(input_audio)
input_audio = self.create_fixed_size_clip(input_audio)

Expand Down
151 changes: 97 additions & 54 deletions microwakeword/audio/clips.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import audio_metadata
import datasets
import math
Expand All @@ -32,27 +34,33 @@ class Clips:

Args:
input_directory (str): Path to audio clip files.
file_pattern (str): File glob pattern for selecting audio clip files.
file_pattern (str | list[str]): File glob pattern(s) for selecting audio clip files.
min_clip_duration_s (float | None, optional): The minimum clip duration (in seconds). Set to None to disable filtering by minimum clip duration. Defaults to None.
max_clip_duration_s (float | None, optional): The maximum clip duration (in seconds). Set to None to disable filtering by maximum clip duration. Defaults to None.
repeat_clip_min_duration_s (float | None, optional): If a clip is shorter than this duration, then it is repeated until it is longer than this duration. Set to None to disable repeating the clip. Defaults to None.
remove_silence (bool, optional): Use webrtcvad to trim non-voice activity in the clip. Defaults to False.
random_splits (dict, optional): Specifies how the clips are split into different sets. Only takes effect if `random_split_seed` is set.
random_split_seed (int | None, optional): The random seed used to split the clips into different sets. Set to None to disable splitting the clips. Defaults to None.
split_count (int | float, optional): The percentage/count of clips to be included in the testing and validation sets. Defaults to 0.1.
trimmed_clip_duration_s: (float | None, optional): The duration of the clips to trim the end of long clips. Set to None to disable trimming. Defaults to None.
trim_zerios: (bool, optional): If true, any leading and trailling zeros are removed. Defaults to false.
"""

def __init__(
self,
input_directory: str,
file_pattern: str,
file_pattern: str | list[str],
min_clip_duration_s: float | None = None,
max_clip_duration_s: float | None = None,
repeat_clip_min_duration_s: float | None = None,
remove_silence: bool = False,
random_splits: dict[str, float] = {
"training": 0.8,
"testing": 0.1,
"validation": 0.1,
"testing_ambient": 0,
"validation_ambient": 0,
},
random_split_seed: int | None = None,
split_count: int | float = 0.1,
trimmed_clip_duration_s: float | None = None,
trim_zeros: bool = False,
):
Expand All @@ -75,17 +83,24 @@ def __init__(
self.repeat_clip_min_duration_s = 0.0

self.remove_silence = remove_silence

self.remove_silence_function = remove_silence_webrtc

paths_to_clips = [str(i) for i in Path(input_directory).glob(file_pattern)]
self.input_directory = input_directory

if isinstance(file_pattern, str):
file_pattern = [file_pattern]

paths_to_clips = []

for pattern in file_pattern:
paths_to_clips.extend([str(i) for i in Path(input_directory).glob(pattern)])

if (self.min_clip_duration_s == 0) and (math.isinf(self.max_clip_duration_s)):
# No durations specified, so do not filter by length
filtered_paths = paths_to_clips
else:
# Filter audio clips by length
if file_pattern.endswith("wav"):
if file_pattern[0].endswith("wav"):
# If it is a wave file, assume all wave files have the same parameters and filter by file size.
# Based on openWakeWord's estimate_clip_duration and filter_audio_paths in data.py, accessed March 2, 2024.
with wave.open(paths_to_clips[0], "rb") as input_wav:
Expand Down Expand Up @@ -142,21 +157,72 @@ def __init__(
"audio", datasets.Audio(sampling_rate=16000)
)

if random_split_seed is not None:
train_testvalid = audio_dataset.train_test_split(
test_size=2 * split_count, seed=random_split_seed
dataset_splits = {
"training": [],
"testing": [],
"validation": [],
"testing_ambient": [],
"validation_ambient": [],
}

assigned_splits = [(k, v) for k, v in random_splits.items() if v > 0]
assert abs(sum(dict(assigned_splits).values()) - 1.0) < 1e-6

self.single_split = None

if len(assigned_splits) == 1:
# With a single class, we don't split
self.single_split = assigned_splits[0]
dataset_splits[self.single_split] = audio_dataset
elif random_split_seed is None:
raise ValueError("Random split seed must be set to split the dataset")

if len(assigned_splits) == 2:
# With two classes, it's simple
split1, split2 = assigned_splits
split_dataset = audio_dataset.train_test_split(
train_size=split1[1], seed=random_split_seed
)
test_valid = train_testvalid["test"].train_test_split(test_size=0.5)
split_dataset = datasets.DatasetDict(
{
"train": train_testvalid["train"],
"test": test_valid["test"],
"validation": test_valid["train"],
}
dataset_splits[split1[0]] = split_dataset["train"]
dataset_splits[split2[0]] = split_dataset["test"]
elif len(assigned_splits) == 3:
# Three classes requires two splits
split1, split2, split3 = assigned_splits
split_dataset1 = audio_dataset.train_test_split(
train_size=split1[1] + split2[1], seed=random_split_seed
)
self.split_clips = split_dataset
split_dataset2 = split_dataset1["train"].train_test_split(
train_size=split1[1] / (split1[1] + split2[1]), seed=random_split_seed
)
dataset_splits[split3[0]] = split_dataset1["test"]
dataset_splits[split1[0]] = split_dataset2["train"]
dataset_splits[split2[0]] = split_dataset2["test"]
else:
raise ValueError(f"Only up to three dataset splits are supported: {assigned_splits}")

self.split_clips = datasets.DatasetDict(dataset_splits)

def _process_clip(self, clip_audio):
if self.remove_silence:
clip_audio = self.remove_silence_function(clip_audio)

if self.trim_zeros:
clip_audio = np.trim_zeros(clip_audio)

if self.trimmed_clip_duration_s:
total_samples = int(self.trimmed_clip_duration_s * 16000)
clip_audio = clip_audio[:total_samples]

return self.repeat_clip(clip_audio)

def _get_clips_from_split(self, split: str | None = None):
if split is None:
if self.single_split is None:
raise ValueError("`split` must be provided for multi-class Clips")

split = self.single_split

self.clips = audio_dataset
return self.split_clips[split]

def audio_generator(self, split: str | None = None, repeat: int = 1):
"""A Python generator that retrieves all loaded audio clips.
Expand All @@ -168,50 +234,24 @@ def audio_generator(self, split: str | None = None, repeat: int = 1):
Yields:
numpy.ndarray: Array with the audio clip's samples.
"""
if split is None:
clip_list = self.clips
else:
clip_list = self.split_clips[split]
clip_list = self._get_clips_from_split(split)

for _ in range(repeat):
for clip in clip_list:
clip_audio = clip["audio"]["array"]
yield self._process_clip(clip["audio"]["array"])

if self.remove_silence:
clip_audio = self.remove_silence_function(clip_audio)

if self.trim_zeros:
clip_audio = np.trim_zeros(clip_audio)

if self.trimmed_clip_duration_s:
total_samples = int(self.trimmed_clip_duration_s * 16000)
clip_audio = clip_audio[:total_samples]

clip_audio = self.repeat_clip(clip_audio)
yield clip_audio

def get_random_clip(self):
def get_random_clip(self, split: str | None = None):
"""Retrieves a random audio clip.

Returns:
numpy.ndarray: Array with the audio clip's samples.
"""
rand_audio_entry = random.choice(self.clips)
clip_audio = rand_audio_entry["audio"]["array"]

if self.remove_silence:
clip_audio = self.remove_silence_function(clip_audio)

if self.trim_zeros:
clip_audio = np.trim_zeros(clip_audio)
clip_list = self._get_clips_from_split(split)
rand_audio_entry = random.choice(clip_list)

if self.trimmed_clip_duration_s:
total_samples = int(self.trimmed_clip_duration_s * 16000)
clip_audio = clip_audio[:total_samples]
return self._process_clip(rand_audio_entry["audio"]["array"])

clip_audio = self.repeat_clip(clip_audio)
return clip_audio

def random_audio_generator(self, max_clips: int = math.inf):
def random_audio_generator(self, split: str | None = None, max_clips: int = math.inf):
"""A Python generator that retrieves random audio clips.

Args:
Expand All @@ -220,10 +260,13 @@ def random_audio_generator(self, max_clips: int = math.inf):
Yields:
numpy.ndarray: Array with the random audio clip's samples.
"""
clip_list = self._get_clips_from_split(split)

while max_clips > 0:
max_clips -= 1

yield self.get_random_clip()
# TODO: Sampling with replacement isn't good for small datasets
yield self.get_random_clip(split=split)

def repeat_clip(self, audio_samples: np.array):
"""Repeats the audio clip until its duration exceeds the minimum specified in the class.
Expand Down
12 changes: 6 additions & 6 deletions microwakeword/audio/spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ def __init__(
self.split_spectrogram_duration_s = split_spectrogram_duration_s
self.slide_frames = slide_frames

def get_random_spectrogram(self):
def get_random_spectrogram(self, split: str | None = None):
"""Retrieves a random audio clip's spectrogram that is optionally augmented.

Returns:
numpy.ndarry: 2D spectrogram array for the random (augmented) audio clip.
"""
clip = self.clips.get_random_clip()
clip = self.clips.get_random_clip(split=split)
if self.augmenter is not None:
clip = self.augmenter.augment_clip(clip)

return generate_features_for_clip(clip, self.step_ms)

def spectrogram_generator(self, random=False, max_clips=None, **kwargs):
def spectrogram_generator(self, split: str | None = None, random=False, max_clips=None, **kwargs):
"""A Python generator that retrieves (augmented) spectrograms.

Args:
Expand All @@ -70,11 +70,11 @@ def spectrogram_generator(self, random=False, max_clips=None, **kwargs):
"""
if random:
if max_clips is not None:
clip_generator = self.clips.random_audio_generator(max_clips=max_clips)
clip_generator = self.clips.random_audio_generator(split=split, max_clips=max_clips)
else:
clip_generator = self.clips.random_audio_generator()
clip_generator = self.clips.random_audio_generator(split=split)
else:
clip_generator = self.clips.audio_generator(**kwargs)
clip_generator = self.clips.audio_generator(split=split, **kwargs)

if self.augmenter is not None:
augmented_generator = self.augmenter.augment_generator(clip_generator)
Expand Down
Loading