Skip to content
This repository has been archived by the owner on Mar 25, 2024. It is now read-only.

Add device control #2

Open
wants to merge 4 commits into
base: NME_SC
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@
parser.add_argument("--max_speakers", dest='max_speakers', default=25, type=int, help="Maximum number of speakers (if number of speaker is unknown)")
parser.add_argument("--embed_model", dest='embed_model', default="ecapa", type=str, help="Name of embedding")
parser.add_argument("--cluster_method", dest='cluster_method', default="nme-sc", type=str, help="Clustering method")
parser.add_argument("--device", dest='device', default=None, type=str, help="choise of cpu or cuda")
args = parser.parse_args()

diar = Diarizer(
embed_model=args.embed_model, # 'xvec' and 'ecapa' supported
cluster_method=args.cluster_method # 'ahc' 'sc' and 'nme-sc' supported
cluster_method=args.cluster_method, # 'ahc' 'sc' and 'nme-sc' supported
device=args.device
)

WAV_FILE=args.audio_name
num_speakers=args.number_of_speaker if args.number_of_speaker != "None" else None
max_spk= args.max_speakers
output_file=args.outputfile





t0 = time.time()

segments = diar.diarize(WAV_FILE, num_speakers=num_speakers,max_speakers=max_spk,outfile=output_file)
Expand Down Expand Up @@ -73,5 +79,5 @@
json["speakers"] = list(_speakers.values())
json["segments"] = _segments

pprint.pprint(json)
#pprint.pprint(json)

7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# ipython>=7.9.0
# matplotlib>=3.5.1
pandas>=1.3.5
# pandas>=1.3.5
scikit-learn>=1.0.2
speechbrain>=0.5.11
torchaudio>=0.10.1
onnxruntime>=1.14.0
scipy<=1.8.1 # newer version can provoke segmentation faults
onnxruntime-gpu>=1.14.0
scipy<=1.8.1 # newer version can provoke segmentation faults

147 changes: 54 additions & 93 deletions simple_diarizer/diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from copy import deepcopy

import numpy as np
import pandas as pd
import torch
import torchaudio
from speechbrain.inference.speaker import EncoderClassifier
Expand All @@ -15,7 +14,13 @@

class Diarizer:
def __init__(
self, embed_model="xvec", cluster_method="sc", window=1.5, period=0.75
self,
embed_model="xvec",
cluster_method="sc",
window=1.5,
period=0.75,
device=None,
device_vad="cpu",
):

assert embed_model in [
Expand All @@ -35,33 +40,44 @@ def __init__(
if cluster_method == "nme-sc":
self.cluster = cluster_NME_SC

default_device = "cuda" if torch.cuda.is_available() else "cpu"
if device_vad is None:
device_vad = default_device

self.vad_model, self.get_speech_ts = self.setup_VAD()
self.vad_model, self.get_speech_ts = self.setup_VAD(device_vad)

self.run_opts = (
{"device": "cuda:0"} if torch.cuda.is_available() else {"device": "cpu"}
)
if device is None:
device = default_device

print(f"Devices: VAD={device_vad}, embedding={device}")

if embed_model == "xvec":
self.embed_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
savedir="pretrained_models/spkrec-xvect-voxceleb",
run_opts=self.run_opts,
run_opts={"device": device},
)
if embed_model == "ecapa":
elif embed_model == "ecapa":
self.embed_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spkrec-ecapa-voxceleb",
run_opts=self.run_opts,
run_opts={"device": device},
)

self.window = window
self.period = period

def setup_VAD(self):
def setup_VAD(self, device):
self.device_vad = device
use_gpu = device != "cpu"
model, utils = torch.hub.load(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device should also be used here

repo_or_dir="snakers4/silero-vad", model="silero_vad", onnx=True
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
onnx=not use_gpu,
# map_location=device
)
if use_gpu:
model = model.to(device)
# force_reload=True)

get_speech_ts = utils[0]
Expand All @@ -71,7 +87,7 @@ def vad(self, signal):
"""
Runs the VAD model on the signal
"""
return self.get_speech_ts(signal, self.vad_model)
return self.get_speech_ts(signal.to(self.device_vad), self.vad_model)

def windowed_embeds(self, signal, fs, window=1.5, period=0.75):
"""
Expand Down Expand Up @@ -238,7 +254,7 @@ def diarize(
Uses AHC/SC/NME-SC to cluster
"""
recname = os.path.splitext(os.path.basename(wav_file))[0]

if check_wav_16khz_mono(wav_file):
signal, fs = torchaudio.load(wav_file)
else:
Expand All @@ -255,50 +271,55 @@ def diarize(
print("Running VAD...")
speech_ts = self.vad(signal[0])
print("Splitting by silence found {} utterances".format(len(speech_ts)))
#assert len(speech_ts) >= 1, "Couldn't find any speech during VAD"
# assert len(speech_ts) >= 1, "Couldn't find any speech during VAD"

if len(speech_ts) >= 1:
print("Extracting embeddings...")
embeds, segments = self.recording_embeds(signal, fs, speech_ts)

[w,k]=embeds.shape
if w >= 2:
print('Clustering to {} speakers...'.format(num_speakers))
cluster_labels = self.cluster(embeds, n_clusters=num_speakers,max_speakers=max_speakers,
threshold=threshold, enhance_sim=enhance_sim)
[w, k] = embeds.shape
if w >= 2:
print("Clustering to {} speakers...".format(num_speakers))
cluster_labels = self.cluster(
embeds,
n_clusters=num_speakers,
max_speakers=max_speakers,
threshold=threshold,
enhance_sim=enhance_sim,
)



cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)
cleaned_segments = self.join_samespeaker_segments(cleaned_segments,
silence_tolerance=silence_tolerance)

cleaned_segments = self.join_samespeaker_segments(
cleaned_segments, silence_tolerance=silence_tolerance
)

else:
cluster_labels =[ 1]
cluster_labels = [1]
cleaned_segments = self.join_segments(cluster_labels, segments)
cleaned_segments = self.make_output_seconds(cleaned_segments, fs)

else:
cleaned_segments = []

print("Done!")
if outfile:
self.rttm_output(cleaned_segments, recname, outfile=outfile)

if not extra_info:
return cleaned_segments
else:
return {"clean_segments": cleaned_segments,
"embeds": embeds,
"segments": segments,
"cluster_labels": cluster_labels}
return {
"clean_segments": cleaned_segments,
"embeds": embeds,
"segments": segments,
"cluster_labels": cluster_labels,
}

@staticmethod
def rttm_output(segments, recname, outfile=None, channel=0):
assert outfile, "Please specify an outfile"
rttm_line = "SPEAKER {} "+str(channel)+" {} {} <NA> <NA> {} <NA> <NA>\n"
rttm_line = "SPEAKER {} " + str(channel) + " {} {} <NA> <NA> {} <NA> <NA>\n"
with open(outfile, "w") as fp:
for seg in segments:
start = seg["start"]
Expand Down Expand Up @@ -372,66 +393,6 @@ def match_diarization_to_transcript(segments, text_segments):

return final_segments

def match_diarization_to_transcript_ctm(self, segments, ctm_file):
"""
Match the output of .diarize to a ctm file produced by asr
"""
ctm_df = pd.read_csv(
ctm_file,
delimiter=" ",
names=["utt", "channel", "start", "offset", "word", "confidence"],
)
ctm_df["end"] = ctm_df["start"] + ctm_df["offset"]

starts = ctm_df["start"].values
ends = ctm_df["end"].values
words = ctm_df["word"].values

# Get the earliest start from either diar output or asr output
earliest_start = np.min([ctm_df["start"].values[0], segments[0]["start"]])

worded_segments = self.join_samespeaker_segments(segments)
worded_segments[0]["start"] = earliest_start
cutoffs = []

for seg in worded_segments:
end_idx = np.searchsorted(ctm_df["end"].values, seg["end"], side="left") - 1
cutoffs.append(end_idx)

indexes = [[0, cutoffs[0]]]
for c in cutoffs[1:]:
indexes.append([indexes[-1][-1], c])

indexes[-1][-1] = len(words)

final_segments = []

for i, seg in enumerate(worded_segments):
s_idx, e_idx = indexes[i]
words = ctm_df["word"].values[s_idx:e_idx]
seg["words"] = " ".join(words)
if len(words) >= 1:
final_segments.append(seg)
else:
print(
"Removed segment between {} and {} as no words were matched".format(
seg["start"], seg["end"]
)
)

return final_segments

@staticmethod
def nice_text_output(worded_segments, outfile):
with open(outfile, "w") as fp:
for seg in worded_segments:
fp.write(
"[{} to {}] Speaker {}: \n".format(
round(seg["start"], 2), round(seg["end"], 2), seg["label"]
)
)
fp.write("{}\n\n".format(seg["words"]))


if __name__ == "__main__":
wavfile = sys.argv[1]
Expand Down
43 changes: 25 additions & 18 deletions simple_diarizer/spectral_clustering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy
from sklearn.cluster import SpectralClustering
import torch

# NME low-level operations
# These functions are taken from the Kaldi scripts.
Expand Down Expand Up @@ -38,6 +39,9 @@ def Eigengap(S):
S = sorted(S)
return np.diff(S)

def getLamdaGaplist(lambdas):
lambdas = np.real(lambdas)
return list(lambdas[1:] - lambdas[:-1])

# Computes parameters of normalized eigenmaps for automatic thresholding selection
def ComputeNMEParameters(A, p, max_num_clusters):
Expand All @@ -48,23 +52,27 @@ def ComputeNMEParameters(A, p, max_num_clusters):
# Laplacian matrix computation
Lp = Laplacian(Ap)
# Get max_num_clusters+1 smallest eigenvalues
S = scipy.sparse.linalg.eigsh(
Lp,
k=max_num_clusters + 1,
which="SA",
tol=1e-6,
return_eigenvectors=False,
mode="buckling",
)
# Get largest eigenvalue
Smax = scipy.sparse.linalg.eigsh(
Lp, k=1, which="LA", tol=1e-6, return_eigenvectors=False, mode="buckling"
)
from torch.linalg import eigh

if torch.cuda.is_available()== True:
laplacian = torch.from_numpy(Lp).float().to('cuda')
lambdas, _ = eigh(laplacian)
S = lambdas.cpu().numpy()

else:
S, _ = eigh(laplacian)

# Eigengap computation
e = Eigengap(S)
g = np.max(e[:max_num_clusters]) / (Smax + 1e-10)
r = p / g
k = np.argmax(e[:max_num_clusters])

e = np.sort(S)
g = getLamdaGaplist(e)
k = np.argmax(g[: min(max_num_clusters, len(g))])
arg_sorted_idx = np.argsort(g[: max_num_clusters])[::-1]
max_key = arg_sorted_idx[0]
max_eig_gap = g[max_key] / (max(e) + 1e-10)
r = (p / A.shape[0]) / (max_eig_gap + 1e-10)


return (e, g, k, r)


Expand Down Expand Up @@ -96,8 +104,7 @@ def NME_SpectralClustering(
if rbest is None or rbest > r:
rbest = r
pbest = p
kbest = k

kbest = k
num_clusters = num_clusters if num_clusters is not None else (kbest + 1)
return NME_SpectralClustering_sklearn(
A, num_clusters, pbest
Expand Down