Skip to content

Commit

Permalink
fixed nonspeech_skip
Browse files Browse the repository at this point in the history
-fixed `nonspeech_skip` causing alignment to skip sections of speech
-fixed "'last_ts' referenced before assignment" error for alignment (#429)
  • Loading branch information
jianfch committed Jan 23, 2025
1 parent e7ff3dd commit fc0d0da
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions stable_whisper/non_whisper/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def align(
desc='Align'
) as tqdm_pbar:
result: List[BasicWordTiming] = []

last_ts = 0.0
while self._all_word_tokens:

self._time_offset = self._seek_sample / self.sample_rate
Expand Down Expand Up @@ -876,42 +876,50 @@ def _skip_nonspeech(

segment_samples = audio_segment.size(-1)

max_time_offset = self._time_offset + self.options.post.min_word_dur
min_time_offset = self._time_offset - self.options.post.min_word_dur

if (
(segment_nonspeech_timings[0][0] <= self._time_offset + self.options.post.min_word_dur) and
(
segment_nonspeech_timings[1][0]
>=
self._time_offset + segment_samples - self.options.post.min_word_dur
)
(segment_nonspeech_timings[0][0] < max_time_offset) and
(segment_nonspeech_timings[1][0] > min_time_offset + segment_samples)
):
# entire audio segment is within first nonspeech section
self._seek_sample += segment_samples
return

timing_indices = (segment_nonspeech_timings[1] - segment_nonspeech_timings[0]) >= self.nonspeech_skip
if not timing_indices.any():
# mask for valid nonspeech sections (i.e. sections with duration >= ``nonspeech_skip``)
valid_sections = (segment_nonspeech_timings[1] - segment_nonspeech_timings[0]) >= self.nonspeech_skip
if not valid_sections.any():
# no valid nonspeech sections
return audio_segment

nonspeech_starts = segment_nonspeech_timings[0][timing_indices]
nonspeech_ends = segment_nonspeech_timings[1][timing_indices]

if nonspeech_ends[0] <= round(self._time_offset, 3) >= nonspeech_starts[0]:
nonspeech_starts = segment_nonspeech_timings[0, valid_sections]
if max_time_offset < nonspeech_starts[0]:
# current time is before the first valid nonspeech section
return audio_segment

nonspeech_ends = segment_nonspeech_timings[1, valid_sections]
curr_total_samples = self.audio_loader.get_total_samples()

# skip to end of the first nonspeech section
self._seek_sample = round(nonspeech_ends[0] * self.sample_rate)
if self._seek_sample + (self.options.post.min_word_dur * self.sample_rate) >= curr_total_samples:
if self._seek_sample + (self.options.post.min_word_dur * self.sample_rate) > curr_total_samples:
# new time is over total duration of the audio
self._seek_sample = curr_total_samples
return

self._time_offset = self._seek_sample / self.sample_rate

# try to load audio segment from the new timestamp
audio_segment = self.audio_loader.next_chunk(self._seek_sample, self.n_samples)
if audio_segment is None:
# reached eof
return

# recompute nonspeech sections for the new audio segment for later use
self._nonspeech_preds = self.nonspeech_predictor.predict(audio=audio_segment, offset=self._time_offset)
if len(nonspeech_starts) > 1:
# remove all audio samples after start of second valid nonspeech section
new_sample_count = round((nonspeech_starts[1] - nonspeech_ends[0]) * self.sample_rate)
audio_segment = audio_segment[:new_sample_count]

Expand Down

0 comments on commit fc0d0da

Please sign in to comment.