Skip to content

Commit

Permalink
Make the recognition service more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
soupslurpr committed Apr 2, 2024
1 parent d93cfde commit e764559
Showing 1 changed file with 124 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import java.lang.System.currentTimeMillis
import java.util.concurrent.atomic.AtomicBoolean

private data class Transcription(
val start: Double?,
val audioData: MutableList<Short> = mutableListOf(),
var start: Double?,
var end: Double?,
var text: String?,
)
Expand Down Expand Up @@ -82,7 +83,7 @@ class MainRecognitionService : RecognitionService() {
val START_THRESHOLD = 0.6f
val END_THRESHOLD = 0.45f
val MIN_SILENCE_DURATION_MS = 3000
val SPEECH_PAD_MS = 1500
val SPEECH_PAD_MS = 0

val model =
this@MainRecognitionService.assets.open("models/silero_vad/silero_vad.with_runtime_opt.ort")
Expand All @@ -107,6 +108,8 @@ class MainRecognitionService : RecognitionService() {

override fun onStartListening(recognizerIntent: Intent?, listener: Callback?) {
val autoStopRecognition = recognizerIntent?.extras?.getBoolean(EXTRA_AUTO_STOP) ?: true
val isPartialResults = recognizerIntent?.extras?.getBoolean(RecognizerIntent.EXTRA_PARTIAL_RESULTS)
val speechStartPadMs = 24000

if (recordAndTranscribeJob?.isActive == true) {
listener?.error(SpeechRecognizer.ERROR_RECOGNIZER_BUSY)
Expand Down Expand Up @@ -242,10 +245,7 @@ class MainRecognitionService : RecognitionService() {
val audioRmsScope = CoroutineScope(Dispatchers.IO)
transcribeJobs.clear()
val transcriptions = mutableMapOf<Int, Transcription>()
var transcriptionIndex = 0

val audioData: MutableList<Short> = mutableListOf()
var audioDataSurminus by mutableIntStateOf(0)
var transcriptionIndex by mutableIntStateOf(0)

listener?.readyForSpeech(Bundle())

Expand All @@ -255,7 +255,10 @@ class MainRecognitionService : RecognitionService() {
val numberOfShorts = audioRecord.read(buffer, 0, bufferSize)

for (i in 0 until numberOfShorts) {
audioData.add(buffer[i])
if (transcriptions[transcriptionIndex] == null) {
transcriptions[transcriptionIndex] = Transcription(start = null, end = null, text = null)
}
transcriptions[transcriptionIndex]!!.audioData.add(buffer[i])
}

if (!isActive) {
Expand All @@ -264,165 +267,163 @@ class MainRecognitionService : RecognitionService() {
return@recordAndTranscribe
}

if (audioData.size >= bufferSize) {
val currentTranscriptionIndex = transcriptionIndex
audioRmsScope.launch vadAndTranscribe@{
if (stopListening) {
isRecording.set(false)
listener?.endOfSpeech()

if ((transcriptionIndex != 0) && !autoStopRecognition && !isSpeaking) {
// break if we already transcribed once or more,
// is long form, and not speaking
// because we just want to stop then with the already transcribed parts.
// if there were no transcriptions then we assume VAD didn't work,
// so we transcribe everything all at once.
// if there were transcriptions, then VAD was working.
// in that case, we can break as if it was working at first it was
// probably working afterward, so processing the audio could
// be detrimental.
return@vadAndTranscribe
}

if ((transcriptionIndex == 0) && !isSpeaking) {
transcriptionIndex += 1
transcriptions[transcriptionIndex] =
Transcription(0.toDouble(), null, null)
}
audioRmsScope.launch vadAndTranscribe@{
if (stopListening) {
isRecording.set(false)
listener?.endOfSpeech()

if ((transcriptionIndex != 0) && !autoStopRecognition && !isSpeaking) {
// break if we already transcribed once or more,
// is not auto stop recognition, and not speaking
// because we just want to stop then with the already transcribed parts.
// if there were no transcriptions then we assume VAD didn't work,
// so we transcribe everything all at once.
// if there were transcriptions, then VAD was working.
// in that case, we can break as if it was working at first it was
// probably working afterward, so processing the audio could
// be detrimental.
return@vadAndTranscribe
}

transcriptions[transcriptionIndex]!!.end = (audioData.size - 1).toDouble()
val transcription = transcriptions[transcriptionIndex]!!

val transcribeJob = transcribeScope.launch {
val timeBeforeTranscription = currentTimeMillis()
if ((transcriptionIndex == 0) && !isSpeaking) {
transcription.start = 0.toDouble()
}

val transcription = transcriptions[transcriptionIndex]!!
transcriptions[transcriptionIndex]!!.end =
(transcriptions[transcriptionIndex]!!.audioData.size - 1).toDouble()

if (transcription.text != null) {
return@launch
}
val transcribeJob = transcribeScope.launch {
val timeBeforeTranscription = currentTimeMillis()

val transcriptionText =
whisperRepository.transcribeAudio(
audioData.slice((transcription.start!!.toInt() - audioDataSurminus)..transcription.end!!.toInt())
.toShortArray(),
val transcriptionText =
whisperRepository.transcribeAudio(
transcription.audioData.slice(
((transcription.start!!.toInt() - speechStartPadMs).coerceAtLeast(
0
))..((transcription.end!!.toInt()).coerceAtMost(transcription.audioData.size - 1))
)
.toShortArray(),
)

if (transcription.text != null) {
return@launch
}

transcription.text = transcriptionText
transcription.text = transcriptionText

totalTranscriptionTime += currentTimeMillis() - timeBeforeTranscription
totalTranscriptionTime += currentTimeMillis() - timeBeforeTranscription

for (job in transcribeJobs.iterator().withIndex()) {
if ((job.index + 1) < transcriptionIndex) {
job.value.join()
}
for (job in transcribeJobs.iterator().withIndex()) {
if (job.index < transcriptionIndex - 1) {
job.value.join()
}
}

if (recognizerIntent?.extras?.getBoolean(RecognizerIntent.EXTRA_PARTIAL_RESULTS) == true) {
val bundle = Bundle().apply {
putStringArrayList(
SpeechRecognizer.RESULTS_RECOGNITION,
arrayListOf(
transcription.text!!
)
)
}
transcription.audioData.clear()

if (!isActive) {
return@launch
}
if (isPartialResults == true) {
val bundle = Bundle().apply {
putStringArrayList(
SpeechRecognizer.RESULTS_RECOGNITION,
arrayListOf(
transcription.text!!
)
)
}

listener?.partialResults(bundle)
if (!isActive) {
return@launch
}

listener?.partialResults(bundle)
}
}

transcribeJobs.add(transcribeJob)
transcribeJobs.add(transcribeJob)

transcribeJob.start() // start transcribing this segment right now
} else {
sileroVadRepository.detect(buffer)?.forEach {
when (it.key) {
"start" -> {
isSpeaking = true
listener?.beginningOfSpeech()
transcribeJob.start() // start transcribing this segment right now
} else {
sileroVadRepository.detect(buffer)?.forEach {
when (it.key) {
"start" -> {
isSpeaking = true
listener?.beginningOfSpeech()

transcriptionIndex += 1
if (transcriptions[transcriptionIndex] == null) {
transcriptions[transcriptionIndex] =
Transcription(it.value, null, null)
Transcription(start = it.value, end = null, text = null)
} else {
transcriptions[transcriptionIndex]!!.start =
it.value
}
}

"end" -> {
if (transcriptions[currentTranscriptionIndex]?.end == null) {
isSpeaking = false
listener?.endOfSpeech()
"end" -> {
val transcription = transcriptions[transcriptionIndex]!!

transcriptions[currentTranscriptionIndex]?.end = it.value
transcriptionIndex += 1
sileroVadRepository.reset()

val transcribeJob = transcribeScope.launch {
val timeBeforeTranscription = currentTimeMillis()
if (transcription.end == null) {
isSpeaking = false
listener?.endOfSpeech()

val transcription = transcriptions[currentTranscriptionIndex]!!
transcription.end = it.value

transcription.text =
whisperRepository.transcribeAudio(
audioData.slice((transcription.start!!.toInt() - audioDataSurminus)..(transcription.end!!.toInt() - audioDataSurminus))
.toShortArray(),
val transcribeJob = transcribeScope.launch {
val timeBeforeTranscription = currentTimeMillis()

transcription.text =
whisperRepository.transcribeAudio(
transcription.audioData.slice(
((transcription.start!!.toInt() - speechStartPadMs).coerceAtLeast(
0
))..((transcription.end!!.toInt()).coerceAtMost(transcription.audioData.size - 1))
)
.toShortArray(),
)

totalTranscriptionTime += currentTimeMillis() - timeBeforeTranscription
totalTranscriptionTime += currentTimeMillis() - timeBeforeTranscription

for (job in transcribeJobs.iterator().withIndex()) {
if ((job.index + 1) < currentTranscriptionIndex) {
job.value.join()
}
for (job in transcribeJobs.iterator().withIndex()) {
if (job.index < transcriptionIndex - 1) {
job.value.join()
}
}

audioData.subList(0, (transcription.end!!.toInt()) - audioDataSurminus)
.clear()
audioDataSurminus += transcription.end!!.toInt() - audioDataSurminus

if (recognizerIntent?.extras?.getBoolean(RecognizerIntent.EXTRA_PARTIAL_RESULTS) == true) {
val bundle = Bundle().apply {
putStringArrayList(
SpeechRecognizer.RESULTS_RECOGNITION,
arrayListOf(
transcription.text!!
)
)
}
transcription.audioData.clear()

if (!isActive) {
return@launch
}
if (isPartialResults == true) {
val bundle = Bundle().apply {
putStringArrayList(
SpeechRecognizer.RESULTS_RECOGNITION,
arrayListOf(
transcription.text!!
)
)
}

listener?.partialResults(bundle)
if (!isActive) {
return@launch
}

listener?.partialResults(bundle)
}
}

transcribeJobs.add(transcribeJob)
transcribeJobs.add(transcribeJob)

transcribeJob.start() // start transcribing this segment right now
transcribeJob.start() // start transcribing this segment right now

if (autoStopRecognition) {
isRecording.set(false)
}
if (autoStopRecognition) {
isRecording.set(false)
}
}
}
}

if (!isSpeaking && !autoStopRecognition && (currentTranscriptionIndex > 0) && (audioData.size > 480000) && (transcriptions[currentTranscriptionIndex]?.end != null)) {
audioData.subList(0, 240000).clear()
audioDataSurminus += 240000
}
}
}
if (stopListening) {
break
}
}
if (stopListening) {
break
}
}

Expand Down

0 comments on commit e764559

Please sign in to comment.