Skip to content

Commit

Permalink
add additional indexes for ansys gpt (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixKuhnAnsys authored Dec 13, 2024
1 parent cd4bf4a commit 9f67ca2
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 28 deletions.
18 changes: 16 additions & 2 deletions pkg/externalfunctions/ansysgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,11 @@ func AnsysGPTACSSemanticHybridSearchs(

output = make([]sharedtypes.ACSSearchResponse, 0)
for _, indexName := range indexList {
partOutput := ansysGPTACSSemanticHybridSearch(acsEndpoint, acsApiKey, acsApiVersion, query, embeddedQuery, indexName, filter, topK, false, nil)
partOutput, err := ansysGPTACSSemanticHybridSearch(acsEndpoint, acsApiKey, acsApiVersion, query, embeddedQuery, indexName, filter, topK, false, nil)
if err != nil {
logging.Log.Errorf(internalstates.Ctx, "Error in semantic hybrid search: %v", err)
panic(err)
}
output = append(output, partOutput...)
}

Expand Down Expand Up @@ -588,9 +592,15 @@ func AisReturnIndexList(accessPoint string) (indexList []string) {
indexList = append(indexList, "lsdyna-documentation-r14")
indexList = append(indexList, "scade-documentation-2023r2")
indexList = append(indexList, "external-marketing")
indexList = append(indexList, "external-learning-hub")
indexList = append(indexList, "external-crtech-thermal-desktop")
indexList = append(indexList, "external-release-notes")
indexList = append(indexList, "external-zemax-websites")
// indexList = append(indexList, "ansysgpt-alh")
case "ansysgpt-scbu":
indexList = append(indexList, "ansysgpt-scbu")
indexList = append(indexList, "external-scbu-learning-hub")
indexList = append(indexList, "scbu-data-except-alh")
default:
logging.Log.Errorf(internalstates.Ctx, "Invalid accessPoint: %v\n", accessPoint)
return
Expand Down Expand Up @@ -635,7 +645,11 @@ func AisAcsSemanticHybridSearchs(
go func(idx string) {
defer wg.Done()
// Run the search for this index
result := ansysGPTACSSemanticHybridSearch(acsEndpoint, acsApiKey, acsApiVersion, query, embeddedQuery, idx, nil, topK, true, physics)
result, err := ansysGPTACSSemanticHybridSearch(acsEndpoint, acsApiKey, acsApiVersion, query, embeddedQuery, idx, nil, topK, true, physics)
if err != nil {
logging.Log.Errorf(internalstates.Ctx, "Error in semantic hybrid search: %v", err)
return
}
resultChan <- result
}(indexName)
}
Expand Down
32 changes: 19 additions & 13 deletions pkg/externalfunctions/privatefunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ func ansysGPTACSSemanticHybridSearch(
filter map[string]string,
topK int,
isAis bool,
physics []string) (output []sharedtypes.ACSSearchResponse) {
physics []string) (output []sharedtypes.ACSSearchResponse, err error) {

// Create the URL
url := fmt.Sprintf("https://%s.search.windows.net/indexes/%s/docs/search?api-version=%s", acsEndpoint, indexName, acsApiVersion)
Expand Down Expand Up @@ -784,17 +784,17 @@ func ansysGPTACSSemanticHybridSearch(
// Marshal the search request
requestBody, err := json.Marshal(searchRequest)
if err != nil {
errMessage := fmt.Sprintf("failed to marshal search request to ACS: %v", err)
errMessage := fmt.Errorf("failed to marshal search request to ACS: %v", err)
logging.Log.Error(internalstates.Ctx, errMessage)
panic(errMessage)
return nil, errMessage
}

// Create the HTTP request
req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBody))
if err != nil {
errMessage := fmt.Sprintf("failed to create POST request for ACS: %v", err)
errMessage := fmt.Errorf("failed to create POST request for ACS: %v", err)
logging.Log.Error(internalstates.Ctx, errMessage)
panic(errMessage)
return nil, errMessage
}

req.Header.Set("Content-Type", "application/json")
Expand All @@ -804,25 +804,25 @@ func ansysGPTACSSemanticHybridSearch(
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
errMessage := fmt.Sprintf("failed to send POST request to ACS: %v", err)
errMessage := fmt.Errorf("failed to send POST request to ACS: %v", err)
logging.Log.Error(internalstates.Ctx, errMessage)
panic(errMessage)
return nil, errMessage
}
defer resp.Body.Close()

// Read and return the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
errMessage := fmt.Sprintf("failed to read response body from ACS: %v", err)
errMessage := fmt.Errorf("failed to read response body from ACS: %v", err)
logging.Log.Error(internalstates.Ctx, errMessage)
panic(errMessage)
return nil, errMessage
}

// check if the reponse is an error
if resp.StatusCode != 200 {
errMessage := fmt.Sprintf("Error in ACS semantic hybrid search for index %v: %s", indexName, string(body))
errMessage := fmt.Errorf("error in ACS semantic hybrid search for index %v: %s", indexName, string(body))
logging.Log.Error(internalstates.Ctx, errMessage)
panic(errMessage)
return nil, errMessage
}

// extract and convert the response
Expand All @@ -836,7 +836,7 @@ func ansysGPTACSSemanticHybridSearch(
output[i].IndexName = indexName
}

return output
return output, nil
}

// getFieldsAndReturnProperties returns the searchedEmbeddedFields and returnedProperties based on the index name
Expand Down Expand Up @@ -870,6 +870,12 @@ func getFieldsAndReturnProperties(indexName string) (searchedEmbeddedFields stri
case "external-crtech-thermal-desktop":
searchedEmbeddedFields = "contentVector, sourceTitle_lvl1_vctr, sourceTitle_lvl2_vctr, sourceTitle_lvl3_vctr"
returnedProperties = "token_size, physics, typeOFasset, product, version, weight, bridge_id, content, sourceTitle_lvl2, sourceURL_lvl2, sourceTitle_lvl3, sourceURL_lvl3"
case "external-learning-hub", "external-release-notes", "external-zemax-websites", "external-scbu-learning-hub":
searchedEmbeddedFields = "contentVector, sourceTitle_lvl1_vctr, sourceTitle_lvl2_vctr, sourceTitle_lvl3_vctr"
returnedProperties = "token_size, physics, typeOFasset, product, version, weight, content, sourceTitle_lvl2, sourceURL_lvl2, sourceTitle_lvl3, sourceURL_lvl3"
case "scbu-data-except-alh":
searchedEmbeddedFields = "content_vctr, sourceTitle_lvl1_vctr, sourceTitle_lvl2_vctr, sourceTitle_lvl3_vctr"
returnedProperties = "token_size, physics, typeOFasset, product, index_connection_id, version, weight, content, sourceTitle_lvl2, sourceURL_lvl2, sourceTitle_lvl3, sourceURL_lvl3"
default:
errMessage := fmt.Sprintf("Index name not found: %s", indexName)
logging.Log.Error(internalstates.Ctx, errMessage)
Expand Down Expand Up @@ -951,7 +957,7 @@ func extractAndConvertACSResponse(body []byte, indexName string) (output []share
})
}

case "external-crtech-thermal-desktop":
case "external-learning-hub", "external-crtech-thermal-desktop", "external-release-notes", "external-zemax-websites", "external-scbu-learning-hub", "scbu-data-except-alh":
respObjectCrtech := ACSSearchResponseStructCrtech{}
err := json.Unmarshal(body, &respObjectCrtech)
if err != nil {
Expand Down
27 changes: 14 additions & 13 deletions pkg/functiontesting/functiontesting.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"

"github.com/ansys/allie-flowkit/pkg/externalfunctions"
"github.com/ansys/allie-sharedtypes/pkg/sharedtypes"
)

// TestAnsysGPTACSSearchIndex tests the AnsysGPTACSSemanticHybridSearch function.
Expand All @@ -19,17 +18,17 @@ import (
func TestAnsysGPTACSSearchIndex(indexName string, query string) {
embeddedQuery := externalfunctions.PerformVectorEmbeddingRequest(query)

defaultFields := []sharedtypes.AnsysGPTDefaultFields{
{QueryWord: "course", FieldName: "type_of_asset", FieldDefaultValue: "aic"},
{QueryWord: "apdl", FieldName: "product", FieldDefaultValue: "mechanical apdl"},
{QueryWord: "lsdyna", FieldName: "product", FieldDefaultValue: "ls-dyna"},
}
// defaultFields := []sharedtypes.AnsysGPTDefaultFields{
// {QueryWord: "course", FieldName: "type_of_asset", FieldDefaultValue: "aic"},
// {QueryWord: "apdl", FieldName: "product", FieldDefaultValue: "mechanical apdl"},
// {QueryWord: "lsdyna", FieldName: "product", FieldDefaultValue: "ls-dyna"},
// }

filedValues := map[string][]string{
"physics": {"structures", "fluids", "electronics", "structural mechanics", "discovery", "optics", "photonics", "python", "scade", "materials", "stem", "student", "fluid dynamics", "semiconductors"},
"type_of_asset": {"aic", "km", "documentation", "youtube", "general_faq", "alh", "article", "white-paper", "brochure"},
"product": {"additive prep", "additive print", "autodyn", "avxcelerate", "cfx", "cfx pre", "cfx solver", "cfx turbogrid", "clock jitter flow", "cloud direct", "composite cure sim", "composite preppost", "designmodeler", "designxplorer", "diakopto", "discovery", "embedded software", "ensight", "exalto", "fluent", "forte", "gateway", "granta", "hfss", "icem cfd", "icepak", "ls-dyna", "lsdyna", "lumerical", "maxwell", "mechanical", "mechanical apdl", "medini", "meshing", "minerva", "motion", "ncode designlife", "pathfinder", "pathfinder-sc", "powerartist", "pragonx", "primex", "raptorh", "raptorx", "redhawk-sc", "redhawk-sc electrothermal", "redhawk-sc security", "rocky", "scade", "sherlock", "siwave", "spaceclaim", "spaceclaim directmodeler", "stk", "totem", "totem-sc", "twin builder", "velocerf", "voltage-timing", "workbench platform"},
}
// filedValues := map[string][]string{
// "physics": {"structures", "fluids", "electronics", "structural mechanics", "discovery", "optics", "photonics", "python", "scade", "materials", "stem", "student", "fluid dynamics", "semiconductors"},
// "type_of_asset": {"aic", "km", "documentation", "youtube", "general_faq", "alh", "article", "white-paper", "brochure"},
// "product": {"additive prep", "additive print", "autodyn", "avxcelerate", "cfx", "cfx pre", "cfx solver", "cfx turbogrid", "clock jitter flow", "cloud direct", "composite cure sim", "composite preppost", "designmodeler", "designxplorer", "diakopto", "discovery", "embedded software", "ensight", "exalto", "fluent", "forte", "gateway", "granta", "hfss", "icem cfd", "icepak", "ls-dyna", "lsdyna", "lumerical", "maxwell", "mechanical", "mechanical apdl", "medini", "meshing", "minerva", "motion", "ncode designlife", "pathfinder", "pathfinder-sc", "powerartist", "pragonx", "primex", "raptorh", "raptorx", "redhawk-sc", "redhawk-sc electrothermal", "redhawk-sc security", "rocky", "scade", "sherlock", "siwave", "spaceclaim", "spaceclaim directmodeler", "stk", "totem", "totem-sc", "twin builder", "velocerf", "voltage-timing", "workbench platform"},
// }

// indexNames := []string{"granular-ansysgpt", "ansysgpt-documentation-2023r2", "scade-documentation-2023r2", "ansys-dot-com-marketing", "ibp-app-brief", "ansysgpt-alh", "ansysgpt-scbu", "lsdyna-documentation-r14"}
indexNames := []string{indexName}
Expand All @@ -38,9 +37,11 @@ func TestAnsysGPTACSSearchIndex(indexName string, query string) {
acsEndpoint := ""
acsApiKey := ""
acsApiVersion := ""
physics := []string{}

// Extract fields from the query
filter := externalfunctions.AnsysGPTExtractFieldsFromQuery(query, filedValues, defaultFields)
output := externalfunctions.AnsysGPTACSSemanticHybridSearchs(acsEndpoint, acsApiKey, acsApiVersion, query, embeddedQuery, indexNames, filter, 10)
// filter := externalfunctions.AnsysGPTExtractFieldsFromQuery(query, filedValues, defaultFields)
// output := externalfunctions.AnsysGPTACSSemanticHybridSearchs(acsEndpoint, acsApiKey, acsApiVersion, query, embeddedQuery, indexNames, filter, 10)
output := externalfunctions.AisAcsSemanticHybridSearchs(acsEndpoint, acsApiKey, acsApiVersion, query, embeddedQuery, indexNames, physics, 10)
fmt.Println(len(output))
}

0 comments on commit 9f67ca2

Please sign in to comment.