Skip to content

Commit

Permalink
feat: make spec workflow works
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Sep 14, 2023
1 parent 2173bd1 commit 208775d
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package cc.unitmesh.rag.splitter
package cc.unitmesh.nlp.embedding

interface EncodingTokenizer {
fun encode(text: String): List<Int>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package cc.unitmesh.rag.splitter
package cc.unitmesh.nlp.embedding

import com.knuddels.jtokkit.Encodings
import com.knuddels.jtokkit.api.Encoding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package cc.unitmesh.rag.splitter

import cc.unitmesh.nlp.embedding.EncodingTokenizer
import cc.unitmesh.nlp.embedding.OpenAiEncoding
import kotlin.math.max
import kotlin.math.min

Expand Down
32 changes: 23 additions & 9 deletions src/main/kotlin/cc/unitmesh/cf/domains/spec/SpecRelevantSearch.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cc.unitmesh.cf.domains.spec

import cc.unitmesh.nlp.embedding.Embedding
import cc.unitmesh.nlp.embedding.EmbeddingProvider
import cc.unitmesh.nlp.embedding.EncodingTokenizer
import cc.unitmesh.rag.document.Document
import cc.unitmesh.rag.retriever.EmbeddingStoreRetriever
import cc.unitmesh.rag.splitter.MarkdownHeaderTextSplitter
Expand All @@ -15,40 +16,53 @@ class SpecRelevantSearch(val embeddingProvider: EmbeddingProvider) {
private lateinit var vectorStoreRetriever: EmbeddingStoreRetriever

// cached for performance
private val searchCache: MutableMap<String, List<String>> = mutableMapOf()
private val searchCache: MutableMap<String, List<SearchResult>> = mutableMapOf()

init {
val text = javaClass.getResourceAsStream("/be/specification.md")!!.bufferedReader().readText()
val headersToSplitOn: List<Pair<String, String>> = listOf(
Pair("#", "Header 1"),
Pair("##", "Header 2"),
Pair("###", "Header 3"),
Pair("#", "H1"),
Pair("##", "H2"),
)

val documents = MarkdownHeaderTextSplitter(headersToSplitOn)
.splitText(text)

val documentList = TokenTextSplitter(chunkSize = 384).apply(documents)
val documentList = documents.map {
val header = "${it.metadata["H1"]} > ${it.metadata["H2"]}"
val withHeader = it.copy(text = "$header ${it.text}")
TokenTextSplitter(chunkSize = 384).apply(listOf(withHeader)).first()
}

val vectorStore: EmbeddingStore<Document> = InMemoryEmbeddingStore()
val embeddings: List<Embedding> = documentList.map {
embeddingProvider.embed(it.text)
}
vectorStore.addAll(embeddings, documentList)

this.vectorStoreRetriever = EmbeddingStoreRetriever(vectorStore)
this.vectorStoreRetriever = EmbeddingStoreRetriever(vectorStore, 5, 0.6)
}

// TODO: change to search engine
fun search(query: String): List<String> {
fun search(query: String): List<SearchResult> {
if (searchCache.containsKey(query)) {
return searchCache[query]!!
}

val queryEmbedding = embeddingProvider.embed(query)
val similarDocuments = vectorStoreRetriever.retrieve(queryEmbedding)
val results = similarDocuments.map { it.embedded.text }
val results = similarDocuments.map {
SearchResult(
source = it.embedded.metadata.toString(),
content = it.embedded.text
)
}
searchCache[query] = results
return results
}
}
}

data class SearchResult(
val source: String,
val content: String
)
30 changes: 22 additions & 8 deletions src/main/kotlin/cc/unitmesh/cf/domains/spec/SpecWorkflow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@ class SpecWorkflow : Workflow() {
)

override fun execute(prompt: StageContext, chatWebContext: ChatWebContext): Flowable<WorkflowResult> {
// TODO clarify user question, 如系统包含了这些规范,你需要哪些规范?
val specs = relevantSearch.search(chatWebContext.messages.last {
val question = chatWebContext.messages.last {
it.role == LlmMsg.ChatRole.User.value
}.content)
}.content

// TODO clarify user question, 如系统包含了这些规范,你需要哪些规范?
val specs = relevantSearch.search(question)

val userMsg = EXECUTE.format()
.replace("${'$'}{specs}", specs.joinToString("\n"))
.replace("${'$'}{question}", chatWebContext.messages[0].content)
.replace("${'$'}{specs}", specs.map {
"source: ${it.source} content: ${it.content}"
}.joinToString("\n"))
.replace("${'$'}{question}", question)

val flowable = llmProvider.streamCompletion(listOf(
LlmMsg.ChatMessage(LlmMsg.ChatRole.System, userMsg),
Expand Down Expand Up @@ -67,17 +71,27 @@ class SpecWorkflow : Workflow() {
|
|- 如果规范缺少对应的信息,你不要回答。
|- 你必须回答用户的问题。
|- 请根据客户的问题,返回对应的规范,并返回对应的 source 相关信息。
|
|
|已有规范信息:
|
|```design
|${'$'}{specs}
|```
|用户的问题:
|${'$'}{question}
|
|现在请你根据规范信息,回答用户的问题。
|示例:
|用户的问题:哪些规范包含了架构设计?
|回答:
|###
|
|出处:后端代码规范的命名规范章节 // 这里根据规范的 source 项信息,写出对应的来源
|// 这里,你需要返回是规范中的详细信息,而不是规范的标题。
|###
|
|现在请你根据规范信息,回答用户的问题。
|用户的问题:
|${'$'}{question}
|""".trimMargin()
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.springframework.stereotype.Component
import cc.unitmesh.cf.STSemantic
import cc.unitmesh.nlp.embedding.Embedding
import cc.unitmesh.nlp.embedding.EmbeddingProvider
import cc.unitmesh.rag.splitter.EncodingTokenizer
import cc.unitmesh.nlp.embedding.EncodingTokenizer

@Component
class SentenceTransformersEmbedding : EmbeddingProvider, EncodingTokenizer {
Expand All @@ -26,8 +26,7 @@ class SentenceTransformersEmbedding : EmbeddingProvider, EncodingTokenizer {

override fun decode(tokens: List<Int>): String {
val map = tokens.map { it.toLong() }.toLongArray()
val output = tokenizer.decode(map)
// output will be "[CLS] blog [SEP]" for input "blog", so we need to remove the first and last token
return output
return tokenizer.decode(map)
}
}

0 comments on commit 208775d

Please sign in to comment.