From a6a80896d5a491639d8443429ed906a486fa3038 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 23 Oct 2024 06:03:12 -0700 Subject: [PATCH] add sliding window --- ..._from_hubert_base.py => extract_kmeans.py} | 104 ++++++++++++++---- egs/librilight/SSL/prepare.sh | 6 +- 2 files changed, 85 insertions(+), 25 deletions(-) rename egs/librilight/SSL/local/{extract_kmeans_from_hubert_base.py => extract_kmeans.py} (72%) diff --git a/egs/librilight/SSL/local/extract_kmeans_from_hubert_base.py b/egs/librilight/SSL/local/extract_kmeans.py similarity index 72% rename from egs/librilight/SSL/local/extract_kmeans_from_hubert_base.py rename to egs/librilight/SSL/local/extract_kmeans.py index a3b01b2311..b8f6fcba9e 100755 --- a/egs/librilight/SSL/local/extract_kmeans_from_hubert_base.py +++ b/egs/librilight/SSL/local/extract_kmeans.py @@ -17,6 +17,7 @@ import argparse import logging +import math from pathlib import Path from typing import Optional @@ -98,44 +99,90 @@ def get_args(): help="Stop processing pieces until this number (exclusive).", ) + parser.add_argument( + "--window-duration", + type=float, + default=300.0, + ) + + parser.add_argument( + "--shift-duration", + type=float, + default=250.0, + ) + return parser.parse_args() +@torch.no_grad() def extract_and_save_one_cuts( - raw_cuts_path, cuts_path, model, apply_kmeans, do_normalize, device + raw_cuts_path, + cuts_path, + model, + apply_kmeans, + do_normalize, + window_duration, + shift_duration, ): logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) logging.info("Extracting kmeans") cuts = [] + + assert window_duration >= shift_duration + window_size = int(window_duration * 16000) + shift_size = int(shift_duration * 16000) + overlap_size = window_size - shift_size + out_overlap_size = get_out_length(overlap_size) + for cut in tqdm(cut_set): assert cut.sampling_rate == 16000, f"Sampling rate: {cut.sampling_rate}" + audio = cut.load_audio() - offsets = 0 - if True: - x = torch.from_numpy(audio).float().to(device) + T = audio.shape[1] + start = 0 + kmeans = [] + while start < T: + real_window_size = min(window_size, T - start) + audio_window = audio[:, start : start + real_window_size] + + x = ( + torch.from_numpy(audio_window) + .float() + .to(next(model.parameters()).device) + ) + if do_normalize: + x = torch.nn.functional.layer_norm(x, x.shape) + + feature, _ = model.extract_features( + source=x, + padding_mask=None, + mask=False, + output_layer=9, + ) + feature = feature.squeeze(0) - with torch.no_grad(): - if do_normalize: - x = torch.nn.functional.layer_norm(x, x.shape) + current_kmeans = apply_kmeans(feature).tolist() - feature, _ = model.extract_features( - source=x, - padding_mask=None, - mask=False, - output_layer=9, - ) - feature = feature.squeeze(0) + if start == 0: + kmeans.extend(current_kmeans) + else: + kmeans.extend(current_kmeans[out_overlap_size:]) - kmeans = " ".join(map(str, apply_kmeans(feature).tolist())) + if T - start <= window_size: + break - cut_with_kmeans = fastcopy( - cut, - custom={"kmeans": kmeans}, - ) - cuts.append(cut_with_kmeans) + start += shift_size + + kmeans = " ".join(map(str, kmeans)) + + cut_with_kmeans = fastcopy( + cut, + custom={"kmeans": kmeans}, + ) + cuts.append(cut_with_kmeans) cuts = CutSet(cuts) @@ -166,6 +213,9 @@ def extract_kmeans(args): model = model[0].eval().to(device) do_normalize = task.cfg.normalize + window_duration = args.window_duration + shift_duration = args.shift_duration + if args.subset == "small": cuts_path = output_dir / f"{prefix}_cuts_{args.subset}.jsonl.gz" if cuts_path.is_file(): @@ -183,7 +233,8 @@ def extract_kmeans(args): model, apply_kmeans, do_normalize, - device, + window_duration, + shift_duration, ) else: num_digits = 8 # num_digits is fixed by lhotse split-lazy @@ -213,10 +264,19 @@ def extract_kmeans(args): model, apply_kmeans, do_normalize, - device, + window_duration, + shift_duration, ) +def get_out_length(T): + conv_layers = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + for i, (out_channels, kernel_size, stride) in enumerate(conv_layers): + T = math.floor((T - kernel_size) / stride) + 1 + + return max(0, T) + + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librilight/SSL/prepare.sh b/egs/librilight/SSL/prepare.sh index 24bb9b8dee..e0a293b2c9 100755 --- a/egs/librilight/SSL/prepare.sh +++ b/egs/librilight/SSL/prepare.sh @@ -86,15 +86,15 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then wget https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin -P download fi if [ ! -e data/kmeans/.extract_small.done ]; then - ./local/extract_kmeans_from_hubert_base.py --subset small + ./local/extract_kmeans.py --subset small touch data/kmeans/.extract_small.done fi if [ ! -e data/kmeans/.extract_medium.done ]; then - ./local/extract_kmeans_from_hubert_base.py --subset medium + ./local/extract_kmeans.py --subset medium touch data/kmeans/.extract_medium.done fi if [ ! -e data/kmeans/.extract_large.done ]; then - ./local/extract_kmeans_from_hubert_base.py --subset large + ./local/extract_kmeans.py --subset large touch data/kmeans/.extract_large.done fi fi