Skip to content
This repository has been archived by the owner on Mar 25, 2024. It is now read-only.

Add device control #2

Open
wants to merge 4 commits into
base: NME_SC
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@
parser.add_argument("--max_speakers", dest='max_speakers', default=25, type=int, help="Maximum number of speakers (if number of speaker is unknown)")
parser.add_argument("--embed_model", dest='embed_model', default="ecapa", type=str, help="Name of embedding")
parser.add_argument("--cluster_method", dest='cluster_method', default="nme-sc", type=str, help="Clustering method")
parser.add_argument("--device", dest='device', default=None, type=str, help="choise of cpu or cuda")
args = parser.parse_args()

diar = Diarizer(
embed_model=args.embed_model, # 'xvec' and 'ecapa' supported
cluster_method=args.cluster_method # 'ahc' 'sc' and 'nme-sc' supported
cluster_method=args.cluster_method, # 'ahc' 'sc' and 'nme-sc' supported
device=args.device
)

WAV_FILE=args.audio_name
num_speakers=args.number_of_speaker if args.number_of_speaker != "None" else None
max_spk= args.max_speakers
output_file=args.outputfile





t0 = time.time()

segments = diar.diarize(WAV_FILE, num_speakers=num_speakers,max_speakers=max_spk,outfile=output_file)
Expand Down Expand Up @@ -73,5 +79,5 @@
json["speakers"] = list(_speakers.values())
json["segments"] = _segments

pprint.pprint(json)
#pprint.pprint(json)

7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# ipython>=7.9.0
# matplotlib>=3.5.1
pandas>=1.3.5
# pandas>=1.3.5
scikit-learn>=1.0.2
speechbrain>=0.5.11
torchaudio>=0.10.1
onnxruntime>=1.14.0
scipy<=1.8.1 # newer version can provoke segmentation faults
onnxruntime-gpu>=1.14.0
scipy<=1.8.1 # newer version can provoke segmentation faults

28 changes: 20 additions & 8 deletions simple_diarizer/diarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class Diarizer:
def __init__(
self, embed_model="xvec", cluster_method="sc", window=1.5, period=0.75
self, embed_model="xvec", cluster_method="sc", window=1.5, period=0.75, device=None
):

assert embed_model in [
Expand All @@ -38,25 +38,33 @@ def __init__(

self.vad_model, self.get_speech_ts = self.setup_VAD()

self.run_opts = (
{"device": "cuda:0"} if torch.cuda.is_available() else {"device": "cpu"}
)

if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
elif device=="cpu":
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
self.device = "cpu"
elif device=="cuda":
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device_num=torch.cuda.current_device()
to_cuda = f'cuda:{device_num}'
self.device = to_cuda
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two "elif" are useless.
Remove them and Keep It Simple


if embed_model == "xvec":
self.embed_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-xvect-voxceleb",
savedir="pretrained_models/spkrec-xvect-voxceleb",
run_opts=self.run_opts,
run_opts={"device": self.device},
)
if embed_model == "ecapa":
self.embed_model = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir="pretrained_models/spkrec-ecapa-voxceleb",
run_opts=self.run_opts,
run_opts={"device": self.device},
)

self.window = window
self.period = period


def setup_VAD(self):
model, utils = torch.hub.load(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device should also be used here

Expand Down Expand Up @@ -95,7 +103,7 @@ def windowed_embeds(self, signal, fs, window=1.5, period=0.75):

segments.append([start, len_signal - 1])
embeds = []

with torch.no_grad():
for i, j in segments:
signal_seg = signal[:, i:j]
Expand Down Expand Up @@ -192,7 +200,11 @@ def diarize(
enhance_sim=True,
extra_info=False,
outfile=None,

):



"""
Diarize a 16khz mono wav file, produces list of segments

Expand Down