From a686c696d73aa3b35897f9340ba364dee01d7fc5 Mon Sep 17 00:00:00 2001 From: Chau Date: Wed, 9 Nov 2022 06:17:35 +0100 Subject: [PATCH] removed yt interfacing, removed dependencies --- requirements.txt | 6 +- setup.py | 2 +- simple_diarizer/cluster.py | 82 +++++----- simple_diarizer/diarizer.py | 296 +++++++++++++++++------------------- simple_diarizer/utils.py | 181 ++++++---------------- 5 files changed, 229 insertions(+), 338 deletions(-) diff --git a/requirements.txt b/requirements.txt index be9759a..2dd02a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,5 @@ -beautifulsoup4>=4.10.0 matplotlib>=3.5.1 pandas>=1.3.5 -pytube>=11.0.2 scikit-learn>=1.0.2 speechbrain>=0.5.11 -torchaudio>=0.10.1 -validators>=0.18.2 -youtube-dl>=2021.12.17 +torchaudio>=0.10.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 0398a87..ddf9dc5 100644 --- a/setup.py +++ b/setup.py @@ -12,5 +12,5 @@ version=__version__, install_requires=install_requires, long_description=long_description, - long_description_content_type="text/markdown" + long_description_content_type="text/markdown", ) diff --git a/simple_diarizer/cluster.py b/simple_diarizer/cluster.py index e28ad73..737e5a7 100644 --- a/simple_diarizer/cluster.py +++ b/simple_diarizer/cluster.py @@ -6,11 +6,12 @@ from sklearn.cluster import AgglomerativeClustering, KMeans, SpectralClustering from sklearn.metrics import pairwise_distances -def similarity_matrix(embeds, metric='cosine'): + +def similarity_matrix(embeds, metric="cosine"): return pairwise_distances(embeds, metric=metric) -def cluster_AHC(embeds, n_clusters=None, threshold=None, - metric='cosine', **kwargs): + +def cluster_AHC(embeds, n_clusters=None, threshold=None, metric="cosine", **kwargs): """ Cluster embeds using Agglomerative Hierarchical Clustering """ @@ -18,31 +19,33 @@ def cluster_AHC(embeds, n_clusters=None, threshold=None, assert threshold, "If num_clusters is not defined, threshold must be defined" S = similarity_matrix(embeds, metric=metric) - + if n_clusters is None: - cluster_model = AgglomerativeClustering(n_clusters=None, - affinity='precomputed', - linkage='average', - compute_full_tree=True, - distance_threshold=threshold) + cluster_model = AgglomerativeClustering( + n_clusters=None, + affinity="precomputed", + linkage="average", + compute_full_tree=True, + distance_threshold=threshold, + ) return cluster_model.fit_predict(S) else: - cluster_model = AgglomerativeClustering(n_clusters=n_clusters, - affinity='precomputed', - linkage='average') + cluster_model = AgglomerativeClustering( + n_clusters=n_clusters, affinity="precomputed", linkage="average" + ) return cluster_model.fit_predict(S) ########################################## # Spectral clustering -# A lot of these methods are lifted from +# A lot of these methods are lifted from # https://github.com/wq2012/SpectralCluster ########################################## -def cluster_SC(embeds, n_clusters=None, threshold=None, - enhance_sim=True, **kwargs): + +def cluster_SC(embeds, n_clusters=None, threshold=None, enhance_sim=True, **kwargs): """ Cluster embeds using Spectral Clustering """ @@ -52,7 +55,7 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, S = compute_affinity_matrix(embeds) if enhance_sim: S = sim_enhancement(S) - + if n_clusters is None: (eigenvalues, eigenvectors) = compute_sorted_eigenvectors(S) # Get number of clusters. @@ -67,18 +70,18 @@ def cluster_SC(embeds, n_clusters=None, threshold=None, # This implemention from scikit-learn does NOT, which is inconsistent # with the paper. kmeans_clusterer = KMeans( - n_clusters=k, - init="k-means++", - max_iter=300, - random_state=0) + n_clusters=k, init="k-means++", max_iter=300, random_state=0 + ) labels = kmeans_clusterer.fit_predict(spectral_embeddings) return labels else: - cluster_model = SpectralClustering(n_clusters=n_clusters, - affinity='precomputed') + cluster_model = SpectralClustering( + n_clusters=n_clusters, affinity="precomputed" + ) return cluster_model.fit_predict(S) + def diagonal_fill(A): """ Sets the diagonal elemnts of the matrix to the max of each row @@ -87,54 +90,61 @@ def diagonal_fill(A): A[np.diag_indices(A.shape[0])] = np.max(A, axis=1) return A + def gaussian_blur(A, sigma=1.0): """ Does a gaussian blur on the affinity matrix """ return gaussian_filter(A, sigma=sigma) + def row_threshold_mult(A, p=0.95, mult=0.01): """ For each row multiply elements smaller than the row's p'th percentile by mult """ - percentiles = np.percentile(A, p*100, axis=1) - mask = A < percentiles[:,np.newaxis] + percentiles = np.percentile(A, p * 100, axis=1) + mask = A < percentiles[:, np.newaxis] - A = (mask * mult * A) + (~mask * A) + A = (mask * mult * A) + (~mask * A) return A + def symmetrization(A): """ Symmeterization: Y_{i,j} = max(S_{ij}, S_{ji}) """ return np.maximum(A, A.T) + def diffusion(A): """ Diffusion: Y <- YY^T """ return np.dot(A, A.T) + def row_max_norm(A): """ Row-wise max normalization: S_{ij} = Y_{ij} / max_k(Y_{ik}) """ maxes = np.amax(A, axis=1) - return A/maxes + return A / maxes + def sim_enhancement(A): func_order = [ - diagonal_fill, - gaussian_blur, - row_threshold_mult, - symmetrization, - diffusion, - row_max_norm - ] + diagonal_fill, + gaussian_blur, + row_threshold_mult, + symmetrization, + diffusion, + row_max_norm, + ] for f in func_order: A = f(A) return A + def compute_affinity_matrix(X): """Compute the affinity matrix from data. Note that the range of affinity is [0,1]. @@ -175,8 +185,7 @@ def compute_sorted_eigenvectors(A): return w, v -def compute_number_of_clusters( - eigenvalues, max_clusters=None, stop_eigenvalue=1e-2): +def compute_number_of_clusters(eigenvalues, max_clusters=None, stop_eigenvalue=1e-2): """Compute number of clusters using EigenGap principle. Args: eigenvalues: sorted eigenvalues of the affinity matrix @@ -198,6 +207,3 @@ def compute_number_of_clusters( max_delta = delta max_delta_index = i return max_delta_index - - - diff --git a/simple_diarizer/diarizer.py b/simple_diarizer/diarizer.py index 8960322..f9a6de7 100644 --- a/simple_diarizer/diarizer.py +++ b/simple_diarizer/diarizer.py @@ -1,5 +1,4 @@ import os -import subprocess import sys from copy import deepcopy @@ -7,54 +6,58 @@ import pandas as pd import torch import torchaudio -from sklearn.cluster import AgglomerativeClustering from speechbrain.pretrained import EncoderClassifier from tqdm.autonotebook import tqdm from .cluster import cluster_AHC, cluster_SC -from .utils import (check_wav_16khz_mono, convert_wavfile, - download_youtube_files, download_youtube_ttml, - download_youtube_wav, get_youtube_id, parse_ttml) +from .utils import check_wav_16khz_mono, convert_wavfile class Diarizer: - - def __init__(self, - embed_model='xvec', - cluster_method='sc', - window=1.5, - period=0.75): + def __init__( + self, embed_model="xvec", cluster_method="sc", window=1.5, period=0.75 + ): assert embed_model in [ - 'xvec', 'ecapa'], "Only xvec and ecapa are supported options" + "xvec", + "ecapa", + ], "Only xvec and ecapa are supported options" assert cluster_method in [ - 'ahc', 'sc'], "Only ahc and sc in the supported clustering options" + "ahc", + "sc", + ], "Only ahc and sc in the supported clustering options" - if cluster_method == 'ahc': + if cluster_method == "ahc": self.cluster = cluster_AHC - if cluster_method == 'sc': + if cluster_method == "sc": self.cluster = cluster_SC self.vad_model, self.get_speech_ts = self.setup_VAD() - self.run_opts = {"device": "cuda:0"} if torch.cuda.is_available() else { - "device": "cpu"} - - if embed_model == 'xvec': - self.embed_model = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb", - savedir="pretrained_models/spkrec-xvect-voxceleb", - run_opts=self.run_opts) - if embed_model == 'ecapa': - self.embed_model = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", - savedir="pretrained_models/spkrec-ecapa-voxceleb", - run_opts=self.run_opts) + self.run_opts = ( + {"device": "cuda:0"} if torch.cuda.is_available() else {"device": "cpu"} + ) + + if embed_model == "xvec": + self.embed_model = EncoderClassifier.from_hparams( + source="speechbrain/spkrec-xvect-voxceleb", + savedir="pretrained_models/spkrec-xvect-voxceleb", + run_opts=self.run_opts, + ) + if embed_model == "ecapa": + self.embed_model = EncoderClassifier.from_hparams( + source="speechbrain/spkrec-ecapa-voxceleb", + savedir="pretrained_models/spkrec-ecapa-voxceleb", + run_opts=self.run_opts, + ) self.window = window self.period = period def setup_VAD(self): - model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', - model='silero_vad') + model, utils = torch.hub.load( + repo_or_dir="snakers4/silero-vad", model="silero_vad" + ) # force_reload=True) get_speech_ts = utils[0] @@ -83,10 +86,10 @@ def windowed_embeds(self, signal, fs, window=1.5, period=0.75): segments = [] start = 0 while start + len_window < len_signal: - segments.append([start, start+len_window]) + segments.append([start, start + len_window]) start += len_period - segments.append([start, len_signal-1]) + segments.append([start, len_signal - 1]) embeds = [] with torch.no_grad(): @@ -107,15 +110,14 @@ def recording_embeds(self, signal, fs, speech_ts): all_embeds = [] all_segments = [] - for utt in tqdm(speech_ts, desc='Utterances', position=0): - start = utt['start'] - end = utt['end'] + for utt in tqdm(speech_ts, desc="Utterances", position=0): + start = utt["start"] + end = utt["end"] utt_signal = signal[:, start:end] - utt_embeds, utt_segments = self.windowed_embeds(utt_signal, - fs, - self.window, - self.period) + utt_embeds, utt_segments = self.windowed_embeds( + utt_signal, fs, self.window, self.period + ) all_embeds.append(utt_embeds) all_segments.append(utt_segments + start) @@ -134,31 +136,29 @@ def join_segments(cluster_labels, segments, tolerance=5): """ assert len(cluster_labels) == len(segments) - new_segments = [{'start': segments[0][0], - 'end': segments[0][1], - 'label': cluster_labels[0]}] + new_segments = [ + {"start": segments[0][0], "end": segments[0][1], "label": cluster_labels[0]} + ] for l, seg in zip(cluster_labels[1:], segments[1:]): start = seg[0] end = seg[1] - protoseg = {'start': seg[0], - 'end': seg[1], - 'label': l} + protoseg = {"start": seg[0], "end": seg[1], "label": l} - if start <= new_segments[-1]['end']: + if start <= new_segments[-1]["end"]: # If segments overlap - if l == new_segments[-1]['label']: + if l == new_segments[-1]["label"]: # If overlapping segment has same label - new_segments[-1]['end'] = end + new_segments[-1]["end"] = end else: # If overlapping segment has diff label # Resolve by setting new start to midpoint # And setting last segment end to midpoint - overlap = new_segments[-1]['end'] - start - midpoint = start + overlap//2 - new_segments[-1]['end'] = midpoint - protoseg['start'] = midpoint + overlap = new_segments[-1]["end"] - start + midpoint = start + overlap // 2 + new_segments[-1]["end"] = midpoint + protoseg["start"] = midpoint new_segments.append(protoseg) else: # If there's no overlap just append @@ -172,34 +172,36 @@ def make_output_seconds(cleaned_segments, fs): Convert cleaned segments to readable format in seconds """ for seg in cleaned_segments: - seg['start_sample'] = seg['start'] - seg['end_sample'] = seg['end'] - seg['start'] = seg['start']/fs - seg['end'] = seg['end']/fs + seg["start_sample"] = seg["start"] + seg["end_sample"] = seg["end"] + seg["start"] = seg["start"] / fs + seg["end"] = seg["end"] / fs return cleaned_segments - def diarize(self, - wav_file, - num_speakers=2, - threshold=None, - silence_tolerance=0.2, - enhance_sim=True, - extra_info=False, - outfile=None): + def diarize( + self, + wav_file, + num_speakers=2, + threshold=None, + silence_tolerance=0.2, + enhance_sim=True, + extra_info=False, + outfile=None, + ): """ Diarize a 16khz mono wav file, produces list of segments Inputs: wav_file (path): Path to input audio file num_speakers (int) or NoneType: Number of speakers to cluster to - threshold (float) or NoneType: Threshold to cluster to if + threshold (float) or NoneType: Threshold to cluster to if num_speakers is not defined silence_tolerance (float): Same speaker segments which are close enough together by silence_tolerance will be joined into a single segment enhance_sim (bool): Whether or not to perform affinity matrix enhancement during spectral clustering If self.cluster_method is 'ahc' this option does nothing. - extra_info (bool): Whether or not to return the embeddings and raw segments + extra_info (bool): Whether or not to return the embeddings and raw segments in addition to segments outfile (path): If specified will output an RTTM file @@ -222,31 +224,38 @@ def diarize(self, signal, fs = torchaudio.load(wav_file) else: print("Converting audio file to single channel WAV using ffmpeg...") - converted_wavfile = os.path.join(os.path.dirname( - wav_file), '{}_converted.wav'.format(recname)) + converted_wavfile = os.path.join( + os.path.dirname(wav_file), "{}_converted.wav".format(recname) + ) convert_wavfile(wav_file, converted_wavfile) assert os.path.isfile( - converted_wavfile), "Couldn't find converted wav file, failed for some reason" + converted_wavfile + ), "Couldn't find converted wav file, failed for some reason" signal, fs = torchaudio.load(converted_wavfile) - print('Running VAD...') + print("Running VAD...") speech_ts = self.vad(signal[0]) - print('Splitting by silence found {} utterances'.format(len(speech_ts))) + print("Splitting by silence found {} utterances".format(len(speech_ts))) assert len(speech_ts) >= 1, "Couldn't find any speech during VAD" - print('Extracting embeddings...') + print("Extracting embeddings...") embeds, segments = self.recording_embeds(signal, fs, speech_ts) - print('Clustering to {} speakers...'.format(num_speakers)) - cluster_labels = self.cluster(embeds, n_clusters=num_speakers, - threshold=threshold, enhance_sim=enhance_sim) + print("Clustering to {} speakers...".format(num_speakers)) + cluster_labels = self.cluster( + embeds, + n_clusters=num_speakers, + threshold=threshold, + enhance_sim=enhance_sim, + ) - print('Cleaning up output...') + print("Cleaning up output...") cleaned_segments = self.join_segments(cluster_labels, segments) cleaned_segments = self.make_output_seconds(cleaned_segments, fs) - cleaned_segments = self.join_samespeaker_segments(cleaned_segments, - silence_tolerance=silence_tolerance) - print('Done!') + cleaned_segments = self.join_samespeaker_segments( + cleaned_segments, silence_tolerance=silence_tolerance + ) + print("Done!") if outfile: self.rttm_output(cleaned_segments, recname, outfile=outfile) @@ -255,57 +264,22 @@ def diarize(self, else: return cleaned_segments, embeds, segments - def diarize_youtube(self, - youtube_url, - num_speakers=2, - threshold=None, - lang='en', - outfolder='./', - overwrite=False, - silence_tolerance=0.2, - enhance_sim=True, - outfile=None): - """ - Diarize a YouTube URL - """ - youtube_id = get_youtube_id(youtube_url) - converted_wavfile, text_segments = download_youtube_files(youtube_url, - overwrite=overwrite, - lang=lang, - outfolder=outfolder) - - segments = self.diarize(converted_wavfile, - num_speakers=num_speakers, - threshold=threshold, - silence_tolerance=silence_tolerance, - enhance_sim=enhance_sim, - extra_info=False, - outfile=os.path.join(outfolder, '{}.rttm'.format(youtube_id))) - - worded_segments = self.match_diarization_to_transcript( - segments, text_segments) - - self.nice_text_output(worded_segments, os.path.join( - outfolder, '{}_transcript.txt'.format(youtube_id))) - - return segments, worded_segments, converted_wavfile - @staticmethod def rttm_output(segments, recname, outfile=None): assert outfile, "Please specify an outfile" rttm_line = "SPEAKER {} 0 {} {} {} \n" - with open(outfile, 'w') as fp: + with open(outfile, "w") as fp: for seg in segments: - start = seg['start'] - offset = seg['end'] - seg['start'] - label = seg['label'] + start = seg["start"] + offset = seg["end"] - seg["start"] + label = seg["label"] line = rttm_line.format(recname, start, offset, label) fp.write(line) @staticmethod def join_samespeaker_segments(segments, silence_tolerance=0.5): """ - Join up segments that belong to the same speaker, + Join up segments that belong to the same speaker, even if there is a duration of silence in between them. If the silence is greater than silence_tolerance, does not join up @@ -313,10 +287,10 @@ def join_samespeaker_segments(segments, silence_tolerance=0.5): new_segments = [segments[0]] for seg in segments[1:]: - if seg['label'] == new_segments[-1]['label']: - if new_segments[-1]['end'] + silence_tolerance >= seg['start']: - new_segments[-1]['end'] = seg['end'] - new_segments[-1]['end_sample'] = seg['end_sample'] + if seg["label"] == new_segments[-1]["label"]: + if new_segments[-1]["end"] + silence_tolerance >= seg["start"]: + new_segments[-1]["end"] = seg["end"] + new_segments[-1]["end_sample"] = seg["end_sample"] else: new_segments.append(seg) else: @@ -331,23 +305,23 @@ def match_diarization_to_transcript(segments, text_segments): text_starts, text_ends, text_segs = [], [], [] for s in text_segments: - text_starts.append(s['start']) - text_ends.append(s['end']) - text_segs.append(s['text']) + text_starts.append(s["start"]) + text_ends.append(s["end"]) + text_segs.append(s["text"]) text_starts = np.array(text_starts) text_ends = np.array(text_ends) text_segs = np.array(text_segs) # Get the earliest start from either diar output or asr output - earliest_start = np.min([text_starts[0], segments[0]['start']]) + earliest_start = np.min([text_starts[0], segments[0]["start"]]) worded_segments = segments.copy() - worded_segments[0]['start'] = earliest_start + worded_segments[0]["start"] = earliest_start cutoffs = [] for seg in worded_segments: - end_idx = np.searchsorted(text_ends, seg['end'], side='left') - 1 + end_idx = np.searchsorted(text_ends, seg["end"], side="left") - 1 cutoffs.append(end_idx) indexes = [[0, cutoffs[0]]] @@ -362,34 +336,35 @@ def match_diarization_to_transcript(segments, text_segments): s_idx, e_idx = indexes[i] words = text_segs[s_idx:e_idx] newseg = deepcopy(seg) - newseg['words'] = ' '.join(words) + newseg["words"] = " ".join(words) final_segments.append(newseg) - return final_segments + return final_segments def match_diarization_to_transcript_ctm(self, segments, ctm_file): """ Match the output of .diarize to a ctm file produced by asr """ - ctm_df = pd.read_csv(ctm_file, delimiter=' ', - names=['utt', 'channel', 'start', 'offset', 'word', 'confidence']) - ctm_df['end'] = ctm_df['start'] + ctm_df['offset'] + ctm_df = pd.read_csv( + ctm_file, + delimiter=" ", + names=["utt", "channel", "start", "offset", "word", "confidence"], + ) + ctm_df["end"] = ctm_df["start"] + ctm_df["offset"] - starts = ctm_df['start'].values - ends = ctm_df['end'].values - words = ctm_df['word'].values + starts = ctm_df["start"].values + ends = ctm_df["end"].values + words = ctm_df["word"].values # Get the earliest start from either diar output or asr output - earliest_start = np.min( - [ctm_df['start'].values[0], segments[0]['start']]) + earliest_start = np.min([ctm_df["start"].values[0], segments[0]["start"]]) worded_segments = self.join_samespeaker_segments(segments) - worded_segments[0]['start'] = earliest_start + worded_segments[0]["start"] = earliest_start cutoffs = [] for seg in worded_segments: - end_idx = np.searchsorted( - ctm_df['end'].values, seg['end'], side='left') - 1 + end_idx = np.searchsorted(ctm_df["end"].values, seg["end"], side="left") - 1 cutoffs.append(end_idx) indexes = [[0, cutoffs[0]]] @@ -402,25 +377,29 @@ def match_diarization_to_transcript_ctm(self, segments, ctm_file): for i, seg in enumerate(worded_segments): s_idx, e_idx = indexes[i] - words = ctm_df['word'].values[s_idx:e_idx] - seg['words'] = ' '.join(words) + words = ctm_df["word"].values[s_idx:e_idx] + seg["words"] = " ".join(words) if len(words) >= 1: final_segments.append(seg) else: - print('Removed segment between {} and {} as no words were matched'.format( - seg['start'], seg['end'])) + print( + "Removed segment between {} and {} as no words were matched".format( + seg["start"], seg["end"] + ) + ) return final_segments @staticmethod def nice_text_output(worded_segments, outfile): - with open(outfile, 'w') as fp: + with open(outfile, "w") as fp: for seg in worded_segments: - fp.write('[{} to {}] Speaker {}: \n'.format(round(seg['start'], 2), - round( - seg['end'], 2), - seg['label'])) - fp.write('{}\n\n'.format(seg['words'])) + fp.write( + "[{} to {}] Speaker {}: \n".format( + round(seg["start"], 2), round(seg["end"], 2), seg["label"] + ) + ) + fp.write("{}\n\n".format(seg["words"])) if __name__ == "__main__": @@ -436,16 +415,17 @@ def nice_text_output(worded_segments, outfile): if check_wav_16khz_mono(wavfile): correct_wav = wavfile else: - correct_wav = os.path.join( - outfolder, '{}_converted.wav'.format(recname)) + correct_wav = os.path.join(outfolder, "{}_converted.wav".format(recname)) convert_wavfile(wavfile, correct_wav) diar = Diarizer( - embed_model='xvec', # supported types: ['xvec', 'ecapa'] - cluster_method='ahc', # supported types: ['ahc', 'sc'] + embed_model="ecapa", # supported types: ['xvec', 'ecapa'] + cluster_method="sc", # supported types: ['ahc', 'sc'] window=1.5, # size of window to extract embeddings (in seconds) - period=0.75 # hop of window (in seconds) + period=0.75, # hop of window (in seconds) + ) + segments = diar.diarize( + correct_wav, + num_speakers=num_speakers, + outfile=os.path.join(outfolder, "hyp.rttm"), ) - segments = diar.diarize(correct_wav, - num_speakers=num_speakers, - outfile=os.path.join(outfolder, 'hyp.rttm')) diff --git a/simple_diarizer/utils.py b/simple_diarizer/utils.py index 2af1fcc..de70954 100644 --- a/simple_diarizer/utils.py +++ b/simple_diarizer/utils.py @@ -1,15 +1,11 @@ import datetime -import os import subprocess from pprint import pprint import matplotlib.pyplot as plt import numpy as np import torchaudio -import validators -from bs4 import BeautifulSoup from IPython.display import Audio, display -from pytube.extract import video_id ################## @@ -37,7 +33,8 @@ def convert_wavfile(wavfile, outfile): Converts file to 16khz single channel mono wav """ cmd = "ffmpeg -y -i {} -acodec pcm_s16le -ar 16000 -ac 1 {}".format( - wavfile, outfile) + wavfile, outfile + ) subprocess.Popen(cmd, shell=True).wait() return outfile @@ -47,123 +44,32 @@ def check_ffmpeg(): Returns True if ffmpeg is installed """ try: - subprocess.check_output("ffmpeg", - stderr=subprocess.STDOUT) + subprocess.check_output("ffmpeg", stderr=subprocess.STDOUT) return True except OSError as e: return False -################## -# YouTube DL utils -################## -def get_youtube_id(url): - """ - Returns the youtube id for a youtube URL - """ - return video_id(url) - -def download_youtube_wav(youtube_id, outfolder='./', overwrite=True): - """ - Download the audio for a YouTube id/URL - """ - if validators.url(youtube_id): - youtube_id = video_id(youtube_id) - - os.makedirs(outfolder, exist_ok=True) - - outfile = os.path.join(outfolder, '{}.wav'.format(youtube_id)) - if not overwrite: - if os.path.isfile(outfile): - return outfile - - cmd = "youtube-dl --no-continue --extract-audio --audio-format wav -o '{}' {}".format( - outfile, youtube_id) - subprocess.Popen(cmd, shell=True).wait() - - assert os.path.isfile(outfile), "Couldn't find expected outfile, something went wrong" - return outfile - - -def download_youtube_ttml(youtube_id, outfolder='./', lang='en', overwrite=True): - """ - Download the autogenerated ttml for a YouTube id/URL - """ - if validators.url(youtube_id): - youtube_id = video_id(youtube_id) - - os.makedirs(outfolder, exist_ok=True) - - expected_outfile = os.path.join(outfolder, '{}.{}.ttml'.format(youtube_id, lang)) - if not overwrite: - if os.path.isfile(expected_outfile): - return expected_outfile - - outfile = os.path.join(outfolder, '%(id)s.%(ext)s') - cmd = "youtube-dl --write-auto-sub --sub-lang {} --sub-format ttml --skip-download -o '{}' {}".format( - lang, outfile, youtube_id) - subprocess.Popen(cmd, shell=True).wait() - - assert os.path.isfile(expected_outfile), "Couldn't find expected outfile, something went wrong" - return expected_outfile - -def download_youtube_files(url, overwrite=True, lang='en', outfolder='./'): - youtube_id = get_youtube_id(url) - - # Download files - wav_file = download_youtube_wav( - youtube_id, outfolder=outfolder, overwrite=overwrite) - - converted_wavfile = convert_wavfile(wav_file, os.path.join( - outfolder, '{}_converted.wav'.format(youtube_id))) - - print('Downloaded audio and converted to: {}'.format(converted_wavfile)) - - ttml_file = download_youtube_ttml( - youtube_id, outfolder=outfolder, lang=lang, overwrite=overwrite) - - return converted_wavfile, parse_ttml(ttml_file) - - -################## -# TTML utils -################## -def parse_timestamp(timestr): - starttime = datetime.datetime.strptime("00:00:00.000", "%H:%M:%S.%f") - date_time = datetime.datetime.strptime(timestr, "%H:%M:%S.%f") - delta = date_time - starttime - return delta.total_seconds() - - -def parse_entry(entry): - start = parse_timestamp(entry['begin']) - end = parse_timestamp(entry['end']) - text = entry.text - return {'start': start, 'end': end, 'text': text} - - -def parse_ttml(file): - with open(file) as fp: - soup = BeautifulSoup(fp, 'lxml') - starttime = datetime.datetime.strptime("00:00:00.000", "%H:%M:%S.%f") - entries = soup.findAll('p') - segments = [] - for e in entries: - seg = parse_entry(e) - segments.append(seg) - return segments - - ################## # Plotting utils ################## -colors = np.array(['tab:blue', 'tab:orange', 'tab:green', - 'tab:red', 'tab:purple', 'tab:brown', - 'tab:pink', 'tab:gray', 'tab:olive', - 'tab:cyan']) - - -def waveplot(signal, fs, start_idx=0, figsize=(5, 3), color='tab:blue'): +colors = np.array( + [ + "tab:blue", + "tab:orange", + "tab:green", + "tab:red", + "tab:purple", + "tab:brown", + "tab:pink", + "tab:gray", + "tab:olive", + "tab:cyan", + ] +) + + +def waveplot(signal, fs, start_idx=0, figsize=(5, 3), color="tab:blue"): """ A waveform plot for a signal @@ -181,9 +87,8 @@ def waveplot(signal, fs, start_idx=0, figsize=(5, 3), color='tab:blue'): start_time = start_idx / fs end_time = start_time + (len(signal) / fs) - plt.plot(np.linspace(start_time, end_time, - len(signal)), signal, color=color) - plt.xlabel('Time (s)') + plt.plot(np.linspace(start_time, end_time, len(signal)), signal, color=color) + plt.xlabel("Time (s)") plt.xlim([start_time, end_time]) max_amp = np.max(np.abs([np.max(signal), np.min(signal)])) @@ -199,34 +104,38 @@ def combined_waveplot(signal, fs, segments, figsize=(10, 3), tick_interval=60): Inputs: - Signal (array): The waveform (1D) - - fs (int): The frequency in Hz + - fs (int): The frequency in Hz (should be 16000 for the models in this repo) - segments (list): The diarization outputs (segment information) - figsize (tuple): Figsize passed into plt.figure() - - tick_interval (float): Where to place ticks for xlabel + - tick_interval (float): Where to place ticks for xlabel Outputs: - The matplotlib figure """ plt.figure(figsize=figsize) for seg in segments: - start = seg['start_sample'] - end = seg['end_sample'] + start = seg["start_sample"] + end = seg["end_sample"] speech = signal[start:end] - color = colors[seg['label']] + color = colors[seg["label"]] - linelabel = 'Speaker {}'.format(seg['label']) - plt.plot(np.linspace(seg['start'], seg['end'], len( - speech)), speech, color=color, label=linelabel) + linelabel = "Speaker {}".format(seg["label"]) + plt.plot( + np.linspace(seg["start"], seg["end"], len(speech)), + speech, + color=color, + label=linelabel, + ) handles, labels = plt.gca().get_legend_handles_labels() by_label = dict(zip(labels, handles)) - plt.legend(by_label.values(), by_label.keys(), loc='lower right') + plt.legend(by_label.values(), by_label.keys(), loc="lower right") - plt.xlabel('Time') - plt.xlim([0, len(signal)/fs]) + plt.xlabel("Time") + plt.xlim([0, len(signal) / fs]) - xticks = np.arange(0, (len(signal)//fs)+1, tick_interval) + xticks = np.arange(0, (len(signal) // fs) + 1, tick_interval) xtick_labels = [str(datetime.timedelta(seconds=int(x))) for x in xticks] plt.xticks(ticks=xticks, labels=xtick_labels) @@ -245,14 +154,14 @@ def waveplot_perspeaker(signal, fs, segments): Designed to be run in a jupyter notebook """ for seg in segments: - start = seg['start_sample'] - end = seg['end_sample'] + start = seg["start_sample"] + end = seg["end_sample"] speech = signal[start:end] - color = colors[seg['label']] + color = colors[seg["label"]] waveplot(speech, fs, start_idx=start, color=color) plt.show() - print('Speaker {} ({}s - {}s)'.format(seg['label'], seg['start'], seg['end'])) - if 'words' in seg: - pprint(seg['words']) + print("Speaker {} ({}s - {}s)".format(seg["label"], seg["start"], seg["end"])) + if "words" in seg: + pprint(seg["words"]) display(Audio(speech, rate=fs)) - print('='*40 + '\n') + print("=" * 40 + "\n")