From ed89ada4391332c0ccebc71ea9f05265d90145a4 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 22 Jan 2024 18:02:46 +0800 Subject: [PATCH] Add JNI binding for speaker embedding extractor. --- .../speaker/identification/MainActivity.kt | 1 + .../onnx/speaker/identification/Speaker.kt | 90 ++++++++++ .../SpeakerEmbeddingExtractor.kt | 53 ------ .../identification/screens/Register.kt | 30 +++- kotlin-api-examples/Main.kt | 32 ++++ kotlin-api-examples/Speaker.kt | 1 + kotlin-api-examples/run.sh | 18 +- sherpa-onnx/jni/jni.cc | 163 ++++++++++++++++-- 8 files changed, 321 insertions(+), 67 deletions(-) create mode 100644 android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt delete mode 100644 android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/SpeakerEmbeddingExtractor.kt create mode 120000 kotlin-api-examples/Speaker.kt diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt index 7307b8623e..10dff4a881 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/MainActivity.kt @@ -58,6 +58,7 @@ class MainActivity : ComponentActivity() { } ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION) + } @Deprecated("Deprecated in Java") diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt new file mode 100644 index 0000000000..24db7ed925 --- /dev/null +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt @@ -0,0 +1,90 @@ +package com.k2fsa.sherpa.onnx + +import android.content.res.AssetManager + + +data class SpeakerEmbeddingExtractorConfig( + val model: String, + var numThreads: Int = 1, + var debug: Boolean = false, + var provider: String = "cpu", +) + +class SpeakerEmbeddingExtractorStream(var ptr: Long) { + fun acceptWaveform(samples: FloatArray, sampleRate: Int) = acceptWaveform(ptr, samples, sampleRate) + + fun inputFinished() = inputFinished(ptr) + + protected fun finalize() { + delete(ptr) + ptr = 0 + } + + private external fun myTest(ptr: Long, v: Array) + + fun release() = finalize() + private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) + + private external fun inputFinished(ptr: Long) + + private external fun delete(ptr: Long) + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} + +class SpeakerEmbeddingExtractor( + assetManager: AssetManager? = null, + config: SpeakerEmbeddingExtractorConfig, +) { + private var ptr: Long + + init { + ptr = if (assetManager != null) { + new(assetManager, config) + } else { + newFromFile(config) + } + } + + protected fun finalize() { + delete(ptr) + ptr = 0 + } + + fun release() = finalize() + + fun createStream(): SpeakerEmbeddingExtractorStream { + val p = createStream(ptr) + return SpeakerEmbeddingExtractorStream(p) + } + + fun isReady(stream: SpeakerEmbeddingExtractorStream) = isReady(ptr, stream.ptr) + fun compute(stream: SpeakerEmbeddingExtractorStream) = compute(ptr, stream.ptr) + + private external fun new( + assetManager: AssetManager, + config: SpeakerEmbeddingExtractorConfig, + ): Long + + private external fun newFromFile( + config: SpeakerEmbeddingExtractorConfig, + ): Long + + private external fun delete(ptr: Long) + + private external fun createStream(ptr: Long): Long + + private external fun isReady(ptr: Long, streamPtr: Long): Boolean + + + private external fun compute(ptr: Long, streamPtr: Long): FloatArray + + companion object { + init { + System.loadLibrary("sherpa-onnx-jni") + } + } +} diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/SpeakerEmbeddingExtractor.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/SpeakerEmbeddingExtractor.kt deleted file mode 100644 index 1d5354ca61..0000000000 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/SpeakerEmbeddingExtractor.kt +++ /dev/null @@ -1,53 +0,0 @@ -package com.k2fsa.sherpa.onnx.speaker.identification - -import android.content.res.AssetManager - -data class SpeakerEmbeddingExtractorConfig( - val model: String, -) - -class SpeakerEmbeddingExtractor( - assetManager: AssetManager? = null, - config: SpeakerEmbeddingExtractorConfig, -) { - private val ptr: Long - - init { - ptr = if (assetManager != null) { - new(assetManager, config) - } else { - newFromFile(config) - } - } - - protected fun finalize() { - delete(ptr) - } - - private external fun new( - assetManager: AssetManager, - config: SpeakerEmbeddingExtractorConfig, - ): Long - - private external fun newFromFile( - config: SpeakerEmbeddingExtractorConfig, - ): Long - - private external fun delete(ptr: Long) - - private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int) - - private external fun inputFinished(ptr: Long) - - private external fun isReady(ptr: Long): Boolean - - private external fun compute(ptr: Long): FloatArray - - private external fun reset(ptr: Long) - - companion object { - init { - System.loadLibrary("sherpa-onnx-jni") - } - } -} \ No newline at end of file diff --git a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt index 64ee83ec71..15638d37b2 100644 --- a/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt +++ b/android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/screens/Register.kt @@ -1,6 +1,7 @@ package com.k2fsa.sherpa.onnx.speaker.identification.screens import android.Manifest +import android.annotation.SuppressLint import android.app.Activity import android.content.pm.PackageManager import android.media.AudioAttributes @@ -20,6 +21,7 @@ import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.width import androidx.compose.material3.Button +import androidx.compose.material3.MaterialTheme import androidx.compose.material3.OutlinedTextField import androidx.compose.material3.Text import androidx.compose.runtime.Composable @@ -31,6 +33,7 @@ import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.res.stringResource +import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import androidx.core.app.ActivityCompat @@ -42,11 +45,26 @@ private var audioRecord: AudioRecord? = null private var sampleList: MutableList? = null +private var allSampleList: MutableList>? = null + +private var number = 0 + +@SuppressLint("UnrememberedMutableState") @Preview @Composable fun RegisterScreen(modifier: Modifier = Modifier) { val activity = LocalContext.current as Activity + var firstTime by remember { mutableStateOf(true) } + if (firstTime) { + firstTime = false + // clear states + + number = 0 + } + + var numberAudio by mutableStateOf(number) + Box( modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.TopCenter @@ -108,6 +126,8 @@ fun RegisterScreen(modifier: Modifier = Modifier) { } Log.i(TAG, "Recording is stopped. ${sampleList?.count()}") + + ++number } } } else { @@ -155,6 +175,12 @@ fun RegisterScreen(modifier: Modifier = Modifier) { Column(horizontalAlignment = Alignment.CenterHorizontally) { SpeakerNameRow(speakerName = speakerName, onValueChange = onSpeakerNameChange) + Text( + "Number of recordings: ${numberAudio}", + modifier = modifier.padding(24.dp), + style = MaterialTheme.typography.headlineMedium, + fontWeight = FontWeight.Bold, + ) RegisterSpeakerButtonRow( modifier, isStarted = isStarted, @@ -177,7 +203,9 @@ fun SpeakerNameRow( Text("Please input the speaker name") }, singleLine = true, - modifier = modifier.fillMaxWidth().padding(8.dp) + modifier = modifier + .fillMaxWidth() + .padding(8.dp) ) } diff --git a/kotlin-api-examples/Main.kt b/kotlin-api-examples/Main.kt index 220997202a..0459808e18 100644 --- a/kotlin-api-examples/Main.kt +++ b/kotlin-api-examples/Main.kt @@ -7,11 +7,43 @@ fun callback(samples: FloatArray): Unit { } fun main() { + testSpeakerRecognition() testTts() testAsr("transducer") testAsr("zipformer2-ctc") } +fun computeEmbedding(extractor: SpeakerEmbeddingExtractor, filename: String): FloatArray { + var objArray = WaveReader.readWaveFromFile( + filename = filename, + ) + var samples: FloatArray = objArray[0] as FloatArray + var sampleRate: Int = objArray[1] as Int + + val stream = extractor.createStream() + stream.acceptWaveform(sampleRate = sampleRate, samples=samples) + stream.inputFinished() + check(extractor.isReady(stream)) + + val embedding = extractor.compute(stream) + + stream.release() + + return embedding +} + +fun testSpeakerRecognition() { + val config = SpeakerEmbeddingExtractorConfig( + model="./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx", + ) + val extractor = SpeakerEmbeddingExtractor(config = config) + + val embedding1a = computeEmbedding(extractor, "./speaker1_a_cn_16k.wav") + val embedding2a = computeEmbedding(extractor, "./speaker2_a_cn_16k.wav") + val embedding1b = computeEmbedding(extractor, "./speaker1_b_cn_16k.wav") + println(embedding1a.count()) +} + fun testTts() { // see https://github.com/k2-fsa/sherpa-onnx/releases/tag/tts-models // https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-en_US-amy-low.tar.bz2 diff --git a/kotlin-api-examples/Speaker.kt b/kotlin-api-examples/Speaker.kt new file mode 120000 index 0000000000..5a1f0d51cb --- /dev/null +++ b/kotlin-api-examples/Speaker.kt @@ -0,0 +1 @@ +../android/SherpaOnnxSpeakerIdentification/app/src/main/java/com/k2fsa/sherpa/onnx/speaker/identification/Speaker.kt \ No newline at end of file diff --git a/kotlin-api-examples/run.sh b/kotlin-api-examples/run.sh index 283e8f44d8..c6c4c41aa1 100755 --- a/kotlin-api-examples/run.sh +++ b/kotlin-api-examples/run.sh @@ -29,6 +29,22 @@ export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH cd ../kotlin-api-examples +if [ ! -f ./3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx ]; then + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx +fi + +if [ ! -f ./speaker1_a_cn_16k.wav ]; then + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav +fi + +if [ ! -f ./speaker1_b_cn_16k.wav ]; then + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav +fi + +if [ ! -f ./speaker2_a_cn_16k.wav ]; then + wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav +fi + if [ ! -f ./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt ]; then git lfs install git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21 @@ -46,7 +62,7 @@ if [ ! -f ./vits-piper-en_US-amy-low/en_US-amy-low.onnx ]; then rm vits-piper-en_US-amy-low.tar.bz2 fi -kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt +kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt faked-asset-manager.kt Tts.kt Speaker.kt ls -lh main.jar diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index 846217eea5..80e078e8cc 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -209,6 +209,25 @@ class SherpaOnnxKws { int32_t input_sample_rate_ = -1; }; +class SherpaOnnxSpeakerEmbeddingExtractorStream { + public: + explicit SherpaOnnxSpeakerEmbeddingExtractorStream( + std::unique_ptr stream) + : stream_(std::move(stream)) {} + + void AcceptWaveform(int32_t sample_rate, const float *samples, + int32_t n) const { + stream_->AcceptWaveform(sample_rate, samples, n); + } + + void InputFinished() const { stream_->InputFinished(); } + + OnlineStream *Get() const { return stream_.get(); } + + private: + std::unique_ptr stream_; +}; + class SherpaOnnxSpeakerEmbeddingExtractor { public: #if __ANDROID_API__ >= 9 @@ -219,23 +238,26 @@ class SherpaOnnxSpeakerEmbeddingExtractor { explicit SherpaOnnxSpeakerEmbeddingExtractor( const SpeakerEmbeddingExtractorConfig &config) - : extractor_(config), stream_(extractor_.CreateStream()) {} + : extractor_(config) {} int32_t Dim() const { return extractor_.Dim(); } - bool IsReady() const { return extractor_.IsReady(stream_.get()); } - - std::vector Compute() const { - return extractor_.Compute(stream_.get()); + bool IsReady(const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const { + return extractor_.IsReady(stream->Get()); } - void Reset() { stream_ = extractor_.CreateStream(); } + SherpaOnnxSpeakerEmbeddingExtractorStream *CreateStream() const { + return new SherpaOnnxSpeakerEmbeddingExtractorStream( + extractor_.CreateStream()); + } - void InputFinished() const { stream_->InputFinished(); } + std::vector Compute( + const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const { + return extractor_.Compute(stream->Get()); + } private: SpeakerEmbeddingExtractor extractor_; - std::unique_ptr stream_; }; static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig( @@ -251,6 +273,18 @@ static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig( ans.model = p; env->ReleaseStringUTFChars(s, p); + fid = env->GetFieldID(cls, "numThreads", "I"); + ans.num_threads = env->GetIntField(config, fid); + + fid = env->GetFieldID(cls, "debug", "Z"); + ans.debug = env->GetBooleanField(config, fid); + + fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.provider = p; + env->ReleaseStringUTFChars(s, p); + return ans; } @@ -819,8 +853,10 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) { SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL -Java_com_k2fsa_sherpa_onnx_speaker_identification_SpeakerEmbeddingExtractor_new( - JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_new(JNIEnv *env, + jobject /*obj*/, + jobject asset_manager, + jobject _config) { #if __ANDROID_API__ >= 9 AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { @@ -834,13 +870,111 @@ Java_com_k2fsa_sherpa_onnx_speaker_identification_SpeakerEmbeddingExtractor_new( SHERPA_ONNX_LOGE("Errors found in config!"); } - auto tts = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor( + auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor( #if __ANDROID_API__ >= 9 mgr, #endif config); - return (jlong)tts; + return (jlong)extractor; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile( + JNIEnv *env, jobject /*obj*/, jobject _config) { + auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + } + + auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(config); + + return (jlong)extractor; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env, + jobject /*obj*/, + jlong ptr) { + delete reinterpret_cast( + ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jlong JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto stream = + reinterpret_cast(ptr) + ->CreateStream(); + + return (jlong)stream; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jboolean JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto extractor = + reinterpret_cast(ptr); + auto stream = reinterpret_cast< + sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr); + return extractor->IsReady(stream); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT jfloatArray JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env, + jobject /*obj*/, + jlong ptr, + jlong stream_ptr) { + auto extractor = + reinterpret_cast(ptr); + auto stream = reinterpret_cast< + sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr); + + std::vector embedding = extractor->Compute(stream); + jfloatArray embedding_arr = env->NewFloatArray(embedding.size()); + env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(), + embedding.data()); + return embedding_arr; +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_delete( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + delete reinterpret_cast< + sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_acceptWaveform( + JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, + jint sample_rate) { + auto stream = reinterpret_cast< + sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr); + + jfloat *p = env->GetFloatArrayElements(samples, nullptr); + jsize n = env->GetArrayLength(samples); + stream->AcceptWaveform(sample_rate, p, n); + env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); +} + +SHERPA_ONNX_EXTERN_C +JNIEXPORT void JNICALL +Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_inputFinished( + JNIEnv *env, jobject /*obj*/, jlong ptr) { + auto stream = reinterpret_cast< + sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr); + stream->InputFinished(); } SHERPA_ONNX_EXTERN_C @@ -873,6 +1007,11 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile( JNIEnv *env, jobject /*obj*/, jobject _config) { auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + if (!config.Validate()) { + SHERPA_ONNX_LOGE("Errors found in config!"); + } + auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(config); return (jlong)tts;