Skip to content

Commit

Permalink
Merge pull request #134 from JSchmie/fix-audio-torch-device-setting
Browse files Browse the repository at this point in the history
Improve Torch Device Configuration for Greater User Control
  • Loading branch information
mahenning authored Oct 24, 2024
2 parents f0989a5 + 101e913 commit 3fe1380
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 35 deletions.
14 changes: 4 additions & 10 deletions scraibe/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,20 @@ class AudioProcessor:
The sample rate of the audio.
"""

def __init__(self, waveform: torch.Tensor, sr: int = SAMPLE_RATE,
*args, **kwargs) -> None:
def __init__(self, waveform: torch.Tensor,
sr: int = SAMPLE_RATE) -> None:
"""
Initialize the AudioProcessor object.
Args:
waveform (torch.Tensor): The audio waveform tensor.
sr (int, optional): The sample rate of the audio. Defaults to SAMPLE_RATE.
args: Additional arguments.
kwargs: Additional keyword arguments, e.g., device to use for processing.
If CUDA is available, it defaults to CUDA.
Raises:
ValueError: If the provided sample rate is not of type int.
"""

device = kwargs.get(
"device", "cuda" if torch.cuda.is_available() else "cpu")

self.waveform = waveform.to(device)
self.waveform = waveform
self.sr = sr

if not isinstance(self.sr, int):
Expand Down Expand Up @@ -147,6 +141,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
np.float32) / NORMALIZATION_FACTOR

return out, sr

def __repr__(self) -> str:
return f'TorchAudioProcessor(waveform={len(self.waveform)}, sr={int(self.sr)})'
15 changes: 8 additions & 7 deletions scraibe/autotranscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .diarisation import Diariser
from .transcriber import Transcriber, load_transcriber, whisper
from .transcript_exporter import Transcript
from .misc import SCRAIBE_TORCH_DEVICE


DiarisationType = TypeVar('DiarisationType')
Expand Down Expand Up @@ -115,6 +116,9 @@ def __init__(self,
**kwargs)
else:
self.params = {}

self.device = kwargs.get(
"device", SCRAIBE_TORCH_DEVICE)

def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
remove_original: bool = False,
Expand All @@ -141,10 +145,10 @@ def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],

# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr
}

if self.verbose:
print("Starting diarisation.")

Expand All @@ -165,8 +169,6 @@ def autotranscribe(self, audio_file: Union[str, torch.Tensor, ndarray],
if self.verbose:
print("Diarisation finished. Starting transcription.")

audio_file.sr = torch.Tensor([audio_file.sr]).to(
audio_file.waveform.device)

# Transcribe each segment and store the results
final_transcript = dict()
Expand Down Expand Up @@ -213,7 +215,7 @@ def diarization(self, audio_file: Union[str, torch.Tensor, ndarray],

# Prepare waveform and sample rate for diarization
dia_audio = {
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)),
"waveform": audio_file.waveform.reshape(1, len(audio_file.waveform)).to(self.device),
"sample_rate": audio_file.sr
}

Expand Down Expand Up @@ -323,8 +325,7 @@ def remove_audio_file(audio_file: str,
print(f"Audiofile {audio_file} removed.")

@staticmethod
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray],
*args, **kwargs) -> AudioProcessor:
def get_audio_file(audio_file: Union[str, torch.Tensor, ndarray]) -> AudioProcessor:
"""Gets an audio file as TorchAudioProcessor.
Args:
Expand Down
11 changes: 3 additions & 8 deletions scraibe/diarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
from torch import Tensor
from torch import device as torch_device
from torch.cuda import is_available

from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError

from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG
from .misc import PYANNOTE_DEFAULT_PATH, PYANNOTE_DEFAULT_CONFIG, SCRAIBE_TORCH_DEVICE
Annotation = TypeVar('Annotation')

TOKEN_PATH = os.path.join(os.path.dirname(
Expand Down Expand Up @@ -190,8 +190,7 @@ def load_model(cls,
cache_token: bool = False,
cache_dir: Union[Path, str] = PYANNOTE_DEFAULT_PATH,
hparams_file: Union[str, Path] = None,
device: str = None,
*args, **kwargs
device: str = SCRAIBE_TORCH_DEVICE,
) -> Pipeline:
"""
Loads a pretrained model from pyannote.audio,
Expand Down Expand Up @@ -283,10 +282,6 @@ def load_model(cls,
'or from huggingface.co models. Please check your token'
'or your local model path')

# try to move the model to the device
if device is None:
device = "cuda" if is_available() else "cpu"

# torch_device is renamed from torch.device to avoid name conflict
_model = _model.to(torch_device(device))

Expand Down
2 changes: 2 additions & 0 deletions scraibe/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import yaml
from argparse import Action
from ast import literal_eval
from torch.cuda import is_available

CACHE_DIR = os.getenv(
"AUTOT_CACHE",
Expand All @@ -18,6 +19,7 @@
if os.path.exists(os.path.join(PYANNOTE_DEFAULT_PATH, "config.yaml")) \
else ('Jaikinator/ScrAIbe', 'pyannote/speaker-diarization-3.1')

SCRAIBE_TORCH_DEVICE = os.getenv("SCRAIBE_TORCH_DEVICE", "cuda" if is_available() else "cpu")

def config_diarization_yaml(file_path: str, path_to_segmentation: str = None) -> None:
"""Configure diarization pipeline from a YAML file.
Expand Down
19 changes: 9 additions & 10 deletions scraibe/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@
from faster_whisper.tokenizer import _LANGUAGE_CODES as FASTER_WHISPER_LANGUAGE_CODES
from typing import TypeVar, Union, Optional
from torch import Tensor, device
from torch.cuda import is_available as cuda_is_available
from numpy import ndarray
from inspect import signature
from abc import abstractmethod
import warnings

from .misc import WHISPER_DEFAULT_PATH
from .misc import WHISPER_DEFAULT_PATH, SCRAIBE_TORCH_DEVICE
whisper = TypeVar('whisper')


Expand Down Expand Up @@ -124,7 +123,7 @@ def load_model(cls,
model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> None:
Expand Down Expand Up @@ -206,7 +205,7 @@ def transcribe(self, audio: Union[str, Tensor, ndarray],
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> 'WhisperTranscriber':
Expand Down Expand Up @@ -305,7 +304,7 @@ def transcribe(self, audio: Union[str, Tensor, ndarray],
def load_model(cls,
model: str = "medium",
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
*args, **kwargs
) -> 'FasterWhisperModel':
"""
Expand All @@ -330,7 +329,7 @@ def load_model(cls,
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
Expand All @@ -339,10 +338,10 @@ def load_model(cls,
Returns:
Transcriber: A Transcriber object initialized with the specified model.
"""
if device is None:
device = "cuda" if cuda_is_available() else "cpu"

if not isinstance(device, str):
device = str(device)

compute_type = kwargs.get('compute_type', 'float16')
if device == 'cpu' and compute_type == 'float16':
warnings.warn(f'Compute type {compute_type} not compatible with '
Expand Down Expand Up @@ -412,7 +411,7 @@ def __repr__(self) -> str:
def load_transcriber(model: str = "medium",
whisper_type: str = 'whisper',
download_root: str = WHISPER_DEFAULT_PATH,
device: Optional[Union[str, device]] = None,
device: Optional[Union[str, device]] = SCRAIBE_TORCH_DEVICE,
in_memory: bool = False,
*args, **kwargs
) -> Union[WhisperTranscriber, FasterWhisperTranscriber]:
Expand All @@ -438,7 +437,7 @@ def load_transcriber(model: str = "medium",
download_root (str, optional): Path to download the model.
Defaults to WHISPER_DEFAULT_PATH.
device (Optional[Union[str, torch.device]], optional):
Device to load model on. Defaults to None.
Device to load model on. Defaults to SCRAIBE_TORCH_DEVICE.
in_memory (bool, optional): Whether to load model in memory.
Defaults to False.
args: Additional arguments only to avoid errors.
Expand Down

0 comments on commit 3fe1380

Please sign in to comment.