Skip to content

Commit

Permalink
refactor(embeddingref): replace STactor(embedding): replace STEmbeddi…
Browse files Browse the repository at this point in the history
…ng withEmbedding with LocalEmbedding

 LocalEmbedding

UpdatedUpdated the the Embed EmbeddingEngine to use LocalEmbedding instead of STEmbedding for better consistency and clarity.dingEngine to use LocalEmbedding instead of STEmbedding for better consistency and clarity.
  • Loading branch information
phodal committed Jul 6, 2024
1 parent 769a258 commit fb267c5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@ class STEmbedding(
private val session: OrtSession,
private val env: OrtEnvironment,
) : LocalEmbedding(tokenizer, session, env) {

companion object {
fun create(): LocalEmbedding {
return LocalEmbedding.create()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package cc.unitmesh.rag

import cc.unitmesh.cf.STEmbedding
import cc.unitmesh.cf.LocalEmbedding
import cc.unitmesh.nlp.embedding.Embedding
import cc.unitmesh.nlp.embedding.EmbeddingProvider
import cc.unitmesh.nlp.embedding.text.EnglishTextEmbeddingProvider
Expand All @@ -13,14 +13,14 @@ enum class EngineType {

class EmbeddingEngine(private val engine: EngineType = EngineType.SentenceTransformers) {
var provider: EmbeddingProvider = when (engine) {
EngineType.SentenceTransformers -> SentenceTransformersEmbedding()
EngineType.SentenceTransformers -> LocalTransformersEmbedding()
EngineType.EnglishTextEmbedding -> EnglishTextEmbeddingProvider()
EngineType.TextEmbeddingAda -> TODO()
}
}

class SentenceTransformersEmbedding : EmbeddingProvider {
private val semantic = STEmbedding.create()
class LocalTransformersEmbedding : EmbeddingProvider {
private val semantic = LocalEmbedding.create()
override fun embed(texts: List<String>): List<Embedding> {
return texts.map {
semantic.embed(it).toList()
Expand Down
3 changes: 2 additions & 1 deletion server/src/test/kotlin/RagIntegrationTests.kt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import cc.unitmesh.cf.LocalEmbedding
import cc.unitmesh.cf.STEmbedding
import cc.unitmesh.cf.infrastructure.llms.embedding.SentenceTransformersEmbedding
import cc.unitmesh.nlp.embedding.Embedding
Expand All @@ -13,7 +14,7 @@ import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.Test

class RagIntegrationTests {
val semantic = STEmbedding.create()
val semantic = LocalEmbedding.create()

private val embeddingProvider = SentenceTransformersEmbedding()

Expand Down

0 comments on commit fb267c5

Please sign in to comment.