Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make writing audio frames optional #216

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ python3 run_server.py --port 9090 \
```

### Running the Client
- Initializing the client:
- Initializing the client with below parameters:
- `lang`: Language of the input audio, applicable only if using a multilingual model.
- `translate`: If set to `True` then translate from any language to `en`.
- `model`: Whisper model size.
- `use_vad`: Whether to use `Voice Activity Detecion` on the server.
- `save_output_recording`: Set to True to save the microphone input as a `.wav` file during live transcription. This option is helpful for recording sessions for later playback or analysis. Defaults to `False`.
- `output_recording_filename`: Specifies the `.wav` file path where the microphone input will be saved if `save_output_recording` is set to `True`.
```python
from whisper_live.client import TranscriptionClient
client = TranscriptionClient(
Expand All @@ -72,11 +78,13 @@ client = TranscriptionClient(
translate=False,
model="small",
use_vad=False,
save_output_recording=True, # Only used for microphone input, False by Default
output_recording_filename="./output_recording.wav" # Only used for microphone input
)
```
It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language.

- Trancribe an audio file:
- Transcribe an audio file:
```python
client("tests/jfk.wav")
```
Expand All @@ -86,7 +94,7 @@ client("tests/jfk.wav")
client()
```

- TO transcribe from a RTSP stream:
- To transcribe from a RTSP stream:
```python
client(rtsp_url="rtsp://admin:[email protected]/rtsp")
```
Expand Down
111 changes: 77 additions & 34 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
import wave

import numpy as np
Expand Down Expand Up @@ -259,6 +260,7 @@ def wait_before_disconnect(self):
while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
continue


class TranscriptionTeeClient:
"""
Client for handling audio recording, streaming, and transcription tasks via one or more
Expand All @@ -272,7 +274,7 @@ class TranscriptionTeeClient:
Attributes:
clients (list): the underlying Client instances responsible for handling WebSocket connections.
"""
def __init__(self, clients):
def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav"):
self.clients = clients
if not self.clients:
raise Exception("At least one client is required.")
Expand All @@ -281,6 +283,8 @@ def __init__(self, clients):
self.channels = 1
self.rate = 16000
self.record_seconds = 60000
self.save_output_recording = save_output_recording
self.output_recording_filename = output_recording_filename
self.frames = b""
self.p = pyaudio.PyAudio()
try:
Expand Down Expand Up @@ -473,7 +477,43 @@ def get_hls_ffmpeg_process(self, hls_url, save_file):

return process

def record(self, out_file="output_recording.wav"):
def save_chunk(self, n_audio_file):
"""
Saves the current audio frames to a WAV file in a separate thread.

Args:
n_audio_file (int): The index of the audio file which determines the filename.
This helps in maintaining the order and uniqueness of each chunk.
"""
t = threading.Thread(
target=self.write_audio_frames_to_file,
args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
)
t.start()

def finalize_recording(self, n_audio_file):
"""
Finalizes the recording process by saving any remaining audio frames,
closing the audio stream, and terminating the process.

Args:
n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
This index is incremented before use if the last chunk is saved.
"""
if self.save_output_recording and len(self.frames):
self.write_audio_frames_to_file(
self.frames[:], f"chunks/{n_audio_file}.wav"
)
n_audio_file += 1
self.stream.stop_stream()
self.stream.close()
self.p.terminate()
self.close_all_clients()
if self.save_output_recording:
self.write_output_recording(n_audio_file)
self.write_all_clients_srt()

def record(self):
"""
Record audio data from the input stream and save it to a WAV file.

Expand All @@ -485,15 +525,12 @@ def record(self, out_file="output_recording.wav"):
The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
the method combines all the saved audio chunks into the specified `out_file`.

Args:
out_file (str, optional): The name of the output WAV file to save the entire recording.
Default is "output_recording.wav".

"""
n_audio_file = 0
if not os.path.exists("chunks"):
os.makedirs("chunks", exist_ok=True)
if self.save_output_recording:
if os.path.exists("chunks"):
shutil.rmtree("chunks")
os.makedirs("chunks")
try:
for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
if not any(client.recording for client in self.clients):
Expand All @@ -507,31 +544,14 @@ def record(self, out_file="output_recording.wav"):

# save frames if more than a minute
if len(self.frames) > 60 * self.rate:
t = threading.Thread(
target=self.write_audio_frames_to_file,
args=(
self.frames[:],
f"chunks/{n_audio_file}.wav",
),
)
t.start()
n_audio_file += 1
if self.save_output_recording:
self.save_chunk(n_audio_file)
n_audio_file += 1
self.frames = b""
self.write_all_clients_srt()

except KeyboardInterrupt:
if len(self.frames):
self.write_audio_frames_to_file(
self.frames[:], f"chunks/{n_audio_file}.wav"
)
n_audio_file += 1
self.stream.stop_stream()
self.stream.close()
self.p.terminate()
self.close_all_clients()

self.write_output_recording(n_audio_file, out_file)
self.write_all_clients_srt()
self.finalize_recording(n_audio_file)

def write_audio_frames_to_file(self, frames, file_name):
"""
Expand All @@ -552,7 +572,7 @@ def write_audio_frames_to_file(self, frames, file_name):
wavfile.setframerate(self.rate)
wavfile.writeframes(frames)

def write_output_recording(self, n_audio_file, out_file):
def write_output_recording(self, n_audio_file):
"""
Combine and save recorded audio chunks into a single WAV file.

Expand All @@ -571,7 +591,7 @@ def write_output_recording(self, n_audio_file, out_file):
for i in range(n_audio_file)
if os.path.exists(f"chunks/{i}.wav")
]
with wave.open(out_file, "wb") as wavfile:
with wave.open(self.output_recording_filename, "wb") as wavfile:
wavfile: wave.Wave_write
wavfile.setnchannels(self.channels)
wavfile.setsampwidth(2)
Expand All @@ -586,6 +606,9 @@ def write_output_recording(self, n_audio_file, out_file):
# remove this file
os.remove(in_file)
wavfile.close()
# clean up temporary directory to store chunks
if os.path.exists("chunks"):
shutil.rmtree("chunks")

@staticmethod
def bytes_to_float_array(audio_bytes):
Expand All @@ -604,6 +627,7 @@ def bytes_to_float_array(audio_bytes):
raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
return raw_data.astype(np.float32) / 32768.0


class TranscriptionClient(TranscriptionTeeClient):
"""
Client for handling audio transcription tasks via a single WebSocket connection.
Expand All @@ -616,6 +640,8 @@ class TranscriptionClient(TranscriptionTeeClient):
port (int): The port number to connect to on the server.
lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
translate (bool, optional): Indicates whether translation tasks are required (default is False).
save_output_recording (bool, optional): Indicates whether to save recording from microphone.
output_recording_filename (str, optional): File to save the output recording.

Attributes:
client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
Expand All @@ -627,6 +653,23 @@ class TranscriptionClient(TranscriptionTeeClient):
transcription_client()
```
"""
def __init__(self, host, port, lang=None, translate=False, model="small", use_vad=True):
def __init__(
self,
host,
port,
lang=None,
translate=False,
model="small",
use_vad=True,
save_output_recording=False,
output_recording_filename="./output_recording.wav"
):
self.client = Client(host, port, lang, translate, model, srt_file_path="output.srt", use_vad=use_vad)
TranscriptionTeeClient.__init__(self, [self.client])
if save_output_recording and not output_recording_filename.endswith(".wav"):
raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
TranscriptionTeeClient.__init__(
self,
[self.client],
save_output_recording=save_output_recording,
output_recording_filename=output_recording_filename
)
8 changes: 4 additions & 4 deletions whisper_live/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def generate_segments(
content_frames - seek,
seek_clip_end - seek,
)
segment = features[:, seek : seek + segment_size]
segment = features[:, seek:seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)

Expand Down Expand Up @@ -685,7 +685,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
current_segments[si + 1:]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
Expand Down Expand Up @@ -909,7 +909,7 @@ def get_prompt(

if previous_tokens:
prompt.append(tokenizer.sot_prev)
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
prompt.extend(previous_tokens[-(self.max_length // 2 - 1):])

prompt.extend(tokenizer.sot_sequence)

Expand All @@ -926,7 +926,7 @@ def get_prompt(

return prompt

def add_word_timestamps( # noqa: C901
def add_word_timestamps(
self,
segments: List[dict],
tokenizer: Tokenizer,
Expand Down
Loading