Skip to content

Commit

Permalink
Flutter OnlinePunctuation (#1854)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dokotela authored Feb 13, 2025
1 parent ce7c03b commit 115e9c2
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 1 deletion.
3 changes: 2 additions & 1 deletion flutter/sherpa_onnx/lib/sherpa_onnx.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import 'dart:ffi';
export 'src/audio_tagging.dart';
export 'src/feature_config.dart';
export 'src/keyword_spotter.dart';
export 'src/offline_punctuation.dart';
export 'src/offline_recognizer.dart';
export 'src/offline_speaker_diarization.dart';
export 'src/offline_stream.dart';
export 'src/online_punctuation.dart';
export 'src/online_recognizer.dart';
export 'src/online_stream.dart';
export 'src/punctuation.dart';
export 'src/speaker_identification.dart';
export 'src/tts.dart';
export 'src/vad.dart';
Expand Down
99 changes: 99 additions & 0 deletions flutter/sherpa_onnx/lib/src/online_punctuation.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import 'dart:ffi';
import 'package:ffi/ffi.dart';

import './sherpa_onnx_bindings.dart';

class OnlinePunctuationModelConfig {
OnlinePunctuationModelConfig(
{required this.cnnBiLstm,
required this.bpeVocab,
this.numThreads = 1,
this.provider = 'cpu',
this.debug = true});

@override
String toString() {
return 'OnlinePunctuationModelConfig(cnnBiLstm: $cnnBiLstm, '
'bpeVocab: $bpeVocab, numThreads: $numThreads, '
'provider: $provider, debug: $debug)';
}

final String cnnBiLstm;
final String bpeVocab;
final int numThreads;
final String provider;
final bool debug;
}

class OnlinePunctuationConfig {
OnlinePunctuationConfig({
required this.model,
});

@override
String toString() {
return 'OnlinePunctuationConfig(model: $model)';
}

final OnlinePunctuationModelConfig model;
}

class OnlinePunctuation {
OnlinePunctuation.fromPtr({required this.ptr, required this.config});

OnlinePunctuation._({required this.ptr, required this.config});

// The user has to invoke OnlinePunctuation.free() to avoid memory leak.
factory OnlinePunctuation({required OnlinePunctuationConfig config}) {
final c = calloc<SherpaOnnxOnlinePunctuationConfig>();

final cnnBiLstmPtr = config.model.cnnBiLstm.toNativeUtf8();
final bpeVocabPtr = config.model.bpeVocab.toNativeUtf8();
c.ref.model.cnnBiLstm = cnnBiLstmPtr;
c.ref.model.bpeVocab = bpeVocabPtr;
c.ref.model.numThreads = config.model.numThreads;
c.ref.model.debug = config.model.debug ? 1 : 0;

final providerPtr = config.model.provider.toNativeUtf8();
c.ref.model.provider = providerPtr;

final ptr = SherpaOnnxBindings.sherpaOnnxCreateOnlinePunctuation?.call(c) ??
nullptr;

// Free the allocated strings and struct memory
calloc.free(providerPtr);
calloc.free(cnnBiLstmPtr);
calloc.free(bpeVocabPtr);
calloc.free(c);

return OnlinePunctuation._(ptr: ptr, config: config);
}

void free() {
SherpaOnnxBindings.sherpaOnnxDestroyOnlinePunctuation?.call(ptr);
ptr = nullptr;
}

String addPunct(String text) {
final textPtr = text.toNativeUtf8();

final p = SherpaOnnxBindings.sherpaOnnxOnlinePunctuationAddPunct
?.call(ptr, textPtr) ??
nullptr;

calloc.free(textPtr);

if (p == nullptr) {
return '';
}

final ans = p.toDartString();

SherpaOnnxBindings.sherpaOnnxOnlinePunctuationFreeText?.call(p);

return ans;
}

Pointer<SherpaOnnxOnlinePunctuation> ptr;
final OnlinePunctuationConfig config;
}
67 changes: 67 additions & 0 deletions flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ final class SherpaOnnxOfflinePunctuationConfig extends Struct {
external SherpaOnnxOfflinePunctuationModelConfig model;
}

final class SherpaOnnxOnlinePunctuationModelConfig extends Struct {
external Pointer<Utf8> cnnBiLstm;
external Pointer<Utf8> bpeVocab;
@Int32()
external int numThreads;
@Int32()
external int debug;
external Pointer<Utf8> provider;
}

final class SherpaOnnxOnlinePunctuationConfig extends Struct {
external SherpaOnnxOnlinePunctuationModelConfig model;
}

final class SherpaOnnxOfflineZipformerAudioTaggingModelConfig extends Struct {
external Pointer<Utf8> model;
}
Expand Down Expand Up @@ -469,6 +483,8 @@ final class SherpaOnnxKeywordSpotterConfig extends Struct {

final class SherpaOnnxOfflinePunctuation extends Opaque {}

final class SherpaOnnxOnlinePunctuation extends Opaque {}

final class SherpaOnnxAudioTagging extends Opaque {}

final class SherpaOnnxKeywordSpotter extends Opaque {}
Expand Down Expand Up @@ -512,6 +528,10 @@ typedef SherpaOnnxCreateOfflinePunctuationNative
= Pointer<SherpaOnnxOfflinePunctuation> Function(
Pointer<SherpaOnnxOfflinePunctuationConfig>);

typedef SherpaOnnxCreateOnlinePunctuationNative
= Pointer<SherpaOnnxOnlinePunctuation> Function(
Pointer<SherpaOnnxOnlinePunctuationConfig>);

typedef SherpaOnnxOfflineSpeakerDiarizationGetSampleRateNative = Int32 Function(
Pointer<SherpaOnnxOfflineSpeakerDiarization>);

Expand Down Expand Up @@ -605,6 +625,26 @@ typedef SherpaOfflinePunctuationFreeTextNative = Void Function(Pointer<Utf8>);

typedef SherpaOfflinePunctuationFreeText = void Function(Pointer<Utf8>);

typedef SherpaOnnxCreateOnlinePunctuation
= SherpaOnnxCreateOnlinePunctuationNative;

typedef SherpaOnnxDestroyOnlinePunctuationNative = Void Function(
Pointer<SherpaOnnxOnlinePunctuation>);

typedef SherpaOnnxDestroyOnlinePunctuation = void Function(
Pointer<SherpaOnnxOnlinePunctuation>);

typedef SherpaOnnxOnlinePunctuationAddPunctNative = Pointer<Utf8> Function(
Pointer<SherpaOnnxOnlinePunctuation>, Pointer<Utf8>);

typedef SherpaOnnxOnlinePunctuationAddPunct
= SherpaOnnxOnlinePunctuationAddPunctNative;

typedef SherpaOnnxOnlinePunctuationFreeTextNative = Void Function(
Pointer<Utf8>);

typedef SherpaOnnxOnlinePunctuationFreeText = void Function(Pointer<Utf8>);

typedef SherpaOnnxCreateAudioTaggingNative = Pointer<SherpaOnnxAudioTagging>
Function(Pointer<SherpaOnnxAudioTaggingConfig>);

Expand Down Expand Up @@ -1155,6 +1195,13 @@ class SherpaOnnxBindings {
static SherpaOfflinePunctuationAddPunct? sherpaOfflinePunctuationAddPunct;
static SherpaOfflinePunctuationFreeText? sherpaOfflinePunctuationFreeText;

static SherpaOnnxCreateOnlinePunctuation? sherpaOnnxCreateOnlinePunctuation;
static SherpaOnnxDestroyOnlinePunctuation? sherpaOnnxDestroyOnlinePunctuation;
static SherpaOnnxOnlinePunctuationAddPunct?
sherpaOnnxOnlinePunctuationAddPunct;
static SherpaOnnxOnlinePunctuationFreeText?
sherpaOnnxOnlinePunctuationFreeText;

static SherpaOnnxCreateAudioTagging? sherpaOnnxCreateAudioTagging;
static SherpaOnnxDestroyAudioTagging? sherpaOnnxDestroyAudioTagging;
static SherpaOnnxAudioTaggingCreateOfflineStream?
Expand Down Expand Up @@ -1414,6 +1461,26 @@ class SherpaOnnxBindings {
'SherpaOfflinePunctuationFreeText')
.asFunction();

sherpaOnnxCreateOnlinePunctuation ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxCreateOnlinePunctuationNative>>(
'SherpaOnnxCreateOnlinePunctuation')
.asFunction();

sherpaOnnxDestroyOnlinePunctuation ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxDestroyOnlinePunctuationNative>>(
'SherpaOnnxDestroyOnlinePunctuation')
.asFunction();

sherpaOnnxOnlinePunctuationAddPunct ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxOnlinePunctuationAddPunctNative>>(
'SherpaOnnxOnlinePunctuationAddPunct')
.asFunction();

sherpaOnnxOnlinePunctuationFreeText ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxOnlinePunctuationFreeTextNative>>(
'SherpaOnnxOnlinePunctuationFreeText')
.asFunction();

sherpaOnnxCreateAudioTagging ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxCreateAudioTaggingNative>>(
'SherpaOnnxCreateAudioTagging')
Expand Down

0 comments on commit 115e9c2

Please sign in to comment.