Skip to content

Commit

Permalink
add token limit catch to ansys gpt (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixKuhnAnsys authored Jan 15, 2025
1 parent dcf342f commit 0b0334c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 12 deletions.
25 changes: 13 additions & 12 deletions pkg/externalfunctions/externalfunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ package externalfunctions

var ExternalFunctionsMap = map[string]interface{}{
// llm handler
"PerformVectorEmbeddingRequest": PerformVectorEmbeddingRequest,
"PerformBatchEmbeddingRequest": PerformBatchEmbeddingRequest,
"PerformKeywordExtractionRequest": PerformKeywordExtractionRequest,
"PerformGeneralRequest": PerformGeneralRequest,
"PerformGeneralRequestWithImages": PerformGeneralRequestWithImages,
"PerformGeneralRequestSpecificModel": PerformGeneralRequestSpecificModel,
"PerformCodeLLMRequest": PerformCodeLLMRequest,
"BuildLibraryContext": BuildLibraryContext,
"BuildFinalQueryForGeneralLLMRequest": BuildFinalQueryForGeneralLLMRequest,
"BuildFinalQueryForCodeLLMRequest": BuildFinalQueryForCodeLLMRequest,
"AppendMessageHistory": AppendMessageHistory,
"ShortenMessageHistory": ShortenMessageHistory,
"PerformVectorEmbeddingRequest": PerformVectorEmbeddingRequest,
"PerformVectorEmbeddingRequestWithTokenLimitCatch": PerformVectorEmbeddingRequestWithTokenLimitCatch,
"PerformBatchEmbeddingRequest": PerformBatchEmbeddingRequest,
"PerformKeywordExtractionRequest": PerformKeywordExtractionRequest,
"PerformGeneralRequest": PerformGeneralRequest,
"PerformGeneralRequestWithImages": PerformGeneralRequestWithImages,
"PerformGeneralRequestSpecificModel": PerformGeneralRequestSpecificModel,
"PerformCodeLLMRequest": PerformCodeLLMRequest,
"BuildLibraryContext": BuildLibraryContext,
"BuildFinalQueryForGeneralLLMRequest": BuildFinalQueryForGeneralLLMRequest,
"BuildFinalQueryForCodeLLMRequest": BuildFinalQueryForCodeLLMRequest,
"AppendMessageHistory": AppendMessageHistory,
"ShortenMessageHistory": ShortenMessageHistory,

// knowledge db
"SendVectorsToKnowledgeDB": SendVectorsToKnowledgeDB,
Expand Down
64 changes: 64 additions & 0 deletions pkg/externalfunctions/llmhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package externalfunctions
import (
"encoding/json"
"fmt"
"strings"

"github.com/ansys/allie-sharedtypes/pkg/config"
"github.com/ansys/allie-sharedtypes/pkg/logging"
Expand Down Expand Up @@ -67,6 +68,69 @@ func PerformVectorEmbeddingRequest(input string) (embeddedVector []float32) {
return embedding32
}

// PerformVectorEmbeddingRequestWithTokenLimitCatch performs a vector embedding request to LLM
// and catches the token limit error message
//
// Tags:
// - @displayName: Embeddings with Token Limit Catch
//
// Parameters:
// - input: the input string
//
// Returns:
// - embeddedVector: the embedded vector in float32 format
func PerformVectorEmbeddingRequestWithTokenLimitCatch(input string, tokenLimitMessage string) (embeddedVector []float32, tokenLimitReached bool, responseMessage string) {
// get the LLM handler endpoint
llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT

// Set up WebSocket connection with LLM and send embeddings request
responseChannel := sendEmbeddingsRequest(input, llmHandlerEndpoint, nil)

// Process the first response and close the channel
var embedding32 []float32
var err error
for response := range responseChannel {
// Check if the response is an error
if response.Type == "error" {
if strings.Contains(response.Error.Message, "tokens") {
return nil, true, tokenLimitMessage
} else {
panic(response.Error)
}
}

// Log LLM response
logging.Log.Debugf(&logging.ContextMap{}, "Received embeddings response.")

// Get embedded vector array
interfaceArray, ok := response.EmbeddedData.([]interface{})
if !ok {
errMessage := "error converting embedded data to interface array"
logging.Log.Error(&logging.ContextMap{}, errMessage)
panic(errMessage)
}
embedding32, err = convertToFloat32Slice(interfaceArray)
if err != nil {
errMessage := fmt.Sprintf("error converting embedded data to float32 slice: %v", err)
logging.Log.Error(&logging.ContextMap{}, errMessage)
panic(errMessage)
}

// Mark that the first response has been received
firstResponseReceived := true

// Exit the loop after processing the first response
if firstResponseReceived {
break
}
}

// Close the response channel
close(responseChannel)

return embedding32, false, ""
}

// PerformBatchEmbeddingRequest performs a batch vector embedding request to LLM
//
// Tags:
Expand Down

0 comments on commit 0b0334c

Please sign in to comment.