diff --git a/go.mod b/go.mod index 66b95d5..dfbdeaa 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.22.0 toolchain go1.22.5 require ( - github.com/ansys/allie-sharedtypes v0.0.0-20240904075127-5291d92348f7 + github.com/ansys/allie-sharedtypes v0.0.0-20240906091513-a52185c93d7b github.com/google/go-github/v56 v56.0.0 github.com/google/uuid v1.6.0 github.com/pandodao/tokenizer-go v0.2.0 diff --git a/go.sum b/go.sum index 7989b3f..b4ef71d 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/ansys/allie-sharedtypes v0.0.0-20240902100814-0bdc50fadbec h1:l2fkYSi github.com/ansys/allie-sharedtypes v0.0.0-20240902100814-0bdc50fadbec/go.mod h1:tp0CyD2VVrFzR6BzgcXAShhSGEcFWGr/+EMvyb1JJqM= github.com/ansys/allie-sharedtypes v0.0.0-20240904075127-5291d92348f7 h1:gCamyW5nSvIyevlpuT/PHpY9yNcKTJWHJGmMOR+jD1g= github.com/ansys/allie-sharedtypes v0.0.0-20240904075127-5291d92348f7/go.mod h1:tp0CyD2VVrFzR6BzgcXAShhSGEcFWGr/+EMvyb1JJqM= +github.com/ansys/allie-sharedtypes v0.0.0-20240906091513-a52185c93d7b h1:qnxBqYvjM8shO1SeUbcthOxIT6KM1LAxamY0UXAjK5Q= +github.com/ansys/allie-sharedtypes v0.0.0-20240906091513-a52185c93d7b/go.mod h1:tp0CyD2VVrFzR6BzgcXAShhSGEcFWGr/+EMvyb1JJqM= github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= diff --git a/main.go b/main.go index f0498a5..19e0395 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,18 @@ var externalFunctionsFile string //go:embed pkg/externalfunctions/dataextraction.go var dataExtractionFile string +//go:embed pkg/externalfunctions/generic.go +var genericFile string + +//go:embed pkg/externalfunctions/knowledgedb.go +var knowledgeDBFile string + +//go:embed pkg/externalfunctions/llmhandler.go +var llmHandlerFile string + +//go:embed pkg/externalfunctions/ansysgpt.go +var ansysGPTFile string + func init() { // initialize config config.InitConfig([]string{"EXTERNALFUNCTIONS_GRPC_PORT", "LLM_HANDLER_ENDPOINT"}, map[string]interface{}{ @@ -37,14 +49,21 @@ func main() { // Initialize internal states internalstates.InitializeInternalStates() - // Load function definitions - err := functiondefinitions.ExtractFunctionDefinitionsFromPackage(externalFunctionsFile) - if err != nil { - logging.Log.Fatalf(internalstates.Ctx, "Error extracting function definitions from package: %v", err) + // Create file list + files := map[string]string{ + "data_extraction": dataExtractionFile, + "generic": genericFile, + "knowledge_db": knowledgeDBFile, + "llm_handler": llmHandlerFile, + "ansys_gpt": ansysGPTFile, } - err = functiondefinitions.ExtractFunctionDefinitionsFromPackage(dataExtractionFile) - if err != nil { - logging.Log.Fatalf(internalstates.Ctx, "Error extracting function definitions from package: %v", err) + + // Load function definitions + for category, file := range files { + err := functiondefinitions.ExtractFunctionDefinitionsFromPackage(file, category) + if err != nil { + logging.Log.Fatalf(internalstates.Ctx, "Error extracting function definitions from package: %v", err) + } } // Log the version of the system diff --git a/pkg/externalfunctions/ansysgpt.go b/pkg/externalfunctions/ansysgpt.go new file mode 100644 index 0000000..9173a90 --- /dev/null +++ b/pkg/externalfunctions/ansysgpt.go @@ -0,0 +1,532 @@ +package externalfunctions + +import ( + "fmt" + "regexp" + "sort" + "strings" + + "github.com/ansys/allie-flowkit/pkg/internalstates" + "github.com/ansys/allie-sharedtypes/pkg/config" + "github.com/ansys/allie-sharedtypes/pkg/logging" + "github.com/ansys/allie-sharedtypes/pkg/sharedtypes" + "github.com/texttheater/golang-levenshtein/levenshtein" +) + +// AnsysGPTCheckProhibitedWords checks the user query for prohibited words +// +// Tags: +// - @displayName: Check Prohibited Words +// +// Parameters: +// - query: the user query +// - prohibitedWords: the list of prohibited words +// - errorResponseMessage: the error response message +// +// Returns: +// - foundProhibited: the flag indicating whether prohibited words were found +// - responseMessage: the response message +func AnsysGPTCheckProhibitedWords(query string, prohibitedWords []string, errorResponseMessage string) (foundProhibited bool, responseMessage string) { + // Check if all words in the value are present as whole words in the query + queryLower := strings.ToLower(query) + queryLower = strings.ReplaceAll(queryLower, ".", "") + for _, prohibitedValue := range prohibitedWords { + allWordsMatch := true + for _, fieldWord := range strings.Fields(strings.ToLower(prohibitedValue)) { + pattern := `\b` + regexp.QuoteMeta(fieldWord) + `\b` + match, _ := regexp.MatchString(pattern, queryLower) + if !match { + allWordsMatch = false + break + } + } + if allWordsMatch { + return true, errorResponseMessage + } + } + + // Check for prohibited words using fuzzy matching + cutoff := 0.9 + for _, prohibitedValue := range prohibitedWords { + wordMatchCount := 0 + for _, fieldWord := range strings.Fields(strings.ToLower(prohibitedValue)) { + for _, word := range strings.Fields(queryLower) { + distance := levenshtein.RatioForStrings([]rune(word), []rune(fieldWord), levenshtein.DefaultOptions) + if distance >= cutoff { + wordMatchCount++ + break + } + } + } + + if wordMatchCount == len(strings.Fields(prohibitedValue)) { + return true, errorResponseMessage + } + + // If multiple words are present in the field , also check for the whole words without spaces + if strings.Contains(prohibitedValue, " ") { + for _, word := range strings.Fields(queryLower) { + distance := levenshtein.RatioForStrings([]rune(word), []rune(prohibitedValue), levenshtein.DefaultOptions) + if distance >= cutoff { + return true, errorResponseMessage + } + } + } + } + + return false, "" +} + +// AnsysGPTExtractFieldsFromQuery extracts the fields from the user query +// +// Tags: +// - @displayName: Extract Fields +// +// Parameters: +// - query: the user query +// - fieldValues: the field values that the user query can contain +// - defaultFields: the default fields that the user query can contain +// +// Returns: +// - fields: the extracted fields +func AnsysGPTExtractFieldsFromQuery(query string, fieldValues map[string][]string, defaultFields []sharedtypes.AnsysGPTDefaultFields) (fields map[string]string) { + // Initialize the fields map + fields = make(map[string]string) + + // Check each field + for field, values := range fieldValues { + // Initializing the field with None + fields[field] = "" + + // Sort the values by length in descending order + sort.Slice(values, func(i, j int) bool { + return len(values[i]) > len(values[j]) + }) + + // Check if all words in the value are present as whole words in the query + lowercaseQuery := strings.ToLower(query) + for _, fieldValue := range values { + allWordsMatch := true + for _, fieldWord := range strings.Fields(strings.ToLower(fieldValue)) { + pattern := `\b` + regexp.QuoteMeta(fieldWord) + `\b` + match, _ := regexp.MatchString(pattern, lowercaseQuery) + if !match { + allWordsMatch = false + break + } + } + + if allWordsMatch { + fields[field] = fieldValue + break + } + } + + // Split the query into words + words := strings.Fields(lowercaseQuery) + + // If no exact match found, use fuzzy matching + if fields[field] == "" { + cutoff := 0.75 + for _, fieldValue := range values { + for _, fieldWord := range strings.Fields(fieldValue) { + for _, queryWord := range words { + distance := levenshtein.RatioForStrings([]rune(fieldWord), []rune(queryWord), levenshtein.DefaultOptions) + if distance >= cutoff { + fields[field] = fieldValue + break + } + } + } + } + } + } + + // If default value is found, use it + for _, defaultField := range defaultFields { + value, ok := fields[defaultField.FieldName] + if ok && value == "" { + if strings.Contains(strings.ToLower(query), strings.ToLower(defaultField.QueryWord)) { + fields[defaultField.FieldName] = defaultField.FieldDefaultValue + } + } + } + + return fields +} + +// AnsysGPTPerformLLMRephraseRequest performs a rephrase request to LLM +// +// Tags: +// - @displayName: Rephrase Request +// +// Parameters: +// - template: the template for the rephrase request +// - query: the user query +// - history: the conversation history +// +// Returns: +// - rephrasedQuery: the rephrased query +func AnsysGPTPerformLLMRephraseRequest(template string, query string, history []sharedtypes.HistoricMessage) (rephrasedQuery string) { + logging.Log.Debugf(internalstates.Ctx, "Performing LLM rephrase request") + + historyMessages := "" + + if len(history) >= 1 { + historyMessages += "user:" + history[len(history)-2].Content + "\n" + } else { + return query + } + + // Create map for the data to be used in the template + dataMap := make(map[string]string) + dataMap["query"] = query + dataMap["chat_history"] = historyMessages + + // Format the template + userTemplate := formatTemplate(template, dataMap) + logging.Log.Debugf(internalstates.Ctx, "User template: %v", userTemplate) + + // Perform the general request + rephrasedQuery, _, err := performGeneralRequest(userTemplate, nil, false, "You are AnsysGPT, a technical support assistant that is professional, friendly and multilingual that generates a clear and concise answer") + if err != nil { + panic(err) + } + + logging.Log.Debugf(internalstates.Ctx, "Rephrased query: %v", rephrasedQuery) + + return rephrasedQuery +} + +// AnsysGPTPerformLLMRephraseRequestOld performs a rephrase request to LLM +// +// Tags: +// - @displayName: Rephrase Request Old +// +// Parameters: +// - template: the template for the rephrase request +// - query: the user query +// - history: the conversation history +// +// Returns: +// - rephrasedQuery: the rephrased query +func AnsysGPTPerformLLMRephraseRequestOld(template string, query string, history []sharedtypes.HistoricMessage) (rephrasedQuery string) { + logging.Log.Debugf(internalstates.Ctx, "Performing LLM rephrase request") + + historyMessages := "" + for _, entry := range history { + switch entry.Role { + case "user": + historyMessages += "HumanMessage(content):" + entry.Content + "\n" + case "assistant": + historyMessages += "AIMessage(content):" + entry.Content + "\n" + } + } + + // Create map for the data to be used in the template + dataMap := make(map[string]string) + dataMap["query"] = query + dataMap["chat_history"] = historyMessages + + // Format the template + systemTemplate := formatTemplate(template, dataMap) + logging.Log.Debugf(internalstates.Ctx, "System template: %v", systemTemplate) + + // Perform the general request + rephrasedQuery, _, err := performGeneralRequest(query, nil, false, systemTemplate) + if err != nil { + panic(err) + } + + logging.Log.Debugf(internalstates.Ctx, "Rephrased query: %v", rephrasedQuery) + + return rephrasedQuery +} + +// AnsysGPTBuildFinalQuery builds the final query for Ansys GPT +// +// Tags: +// - @displayName: Build Final Query +// +// Parameters: +// - refrasedQuery: the refrased query +// - context: the context +// +// Returns: +// - finalQuery: the final query +func AnsysGPTBuildFinalQuery(refrasedQuery string, context []sharedtypes.ACSSearchResponse) (finalQuery string, errorResponse string, displayFixedMessageToUser bool) { + logging.Log.Debugf(internalstates.Ctx, "Building final query for Ansys GPT with context of length: %v", len(context)) + + // check if there is no context + if len(context) == 0 { + errorResponse = "Sorry, I could not find any knowledge from Ansys that can answer your question. Please try and revise your query by asking in a different way or adding more details." + return "", errorResponse, true + } + + // Build the final query using the KnowledgeDB response and the original request + finalQuery = "Based on the following examples:\n\n--- INFO START ---\n" + for _, example := range context { + finalQuery += fmt.Sprintf("%v", example) + "\n" + } + finalQuery += "--- INFO END ---\n\n" + refrasedQuery + "\n" + + return finalQuery, "", false +} + +// AnsysGPTPerformLLMRequest performs a request to Ansys GPT +// +// Tags: +// - @displayName: LLM Request +// +// Parameters: +// - finalQuery: the final query +// - history: the conversation history +// - systemPrompt: the system prompt +// +// Returns: +// - stream: the stream channel +func AnsysGPTPerformLLMRequest(finalQuery string, history []sharedtypes.HistoricMessage, systemPrompt string, isStream bool) (message string, stream *chan string) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send chat request + responseChannel := sendChatRequest(finalQuery, "general", history, 0, systemPrompt, llmHandlerEndpoint, nil) + + // If isStream is true, create a stream channel and return asap + if isStream { + // Create a stream channel + streamChannel := make(chan string, 400) + + // Start a goroutine to transfer the data from the response channel to the stream channel + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false) + + // Return the stream channel + return "", &streamChannel + } + + // else Process all responses + var responseAsStr string + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Accumulate the responses + responseAsStr += *(response.ChatData) + + // If we are at the last message, break the loop + if *(response.IsLast) { + break + } + } + + // Close the response channel + close(responseChannel) + + // Return the response + return responseAsStr, nil +} + +// AnsysGPTReturnIndexList returns the index list for Ansys GPT +// +// Tags: +// - @displayName: List Indexes +// +// Parameters: +// - indexGroups: the index groups +// +// Returns: +// - indexList: the index list +func AnsysGPTReturnIndexList(indexGroups []string) (indexList []string) { + indexList = make([]string, 0) + // iterate through indexGroups and append to indexList + for _, indexGroup := range indexGroups { + switch indexGroup { + case "Ansys Learning": + indexList = append(indexList, "granular-ansysgpt") + indexList = append(indexList, "ansysgpt-alh") + case "Ansys Products": + indexList = append(indexList, "lsdyna-documentation-r14") + indexList = append(indexList, "ansysgpt-documentation-2023r2") + indexList = append(indexList, "scade-documentation-2023r2") + indexList = append(indexList, "ansys-dot-com-marketing") + // indexList = append(indexList, "ibp-app-brief") + // indexList = append(indexList, "pyansys_help_documentation") + // indexList = append(indexList, "pyansys-examples") + case "Ansys Semiconductor": + // indexList = append(indexList, "ansysgpt-scbu") + default: + logging.Log.Warnf(internalstates.Ctx, "Invalid indexGroup: %v\n", indexGroup) + return + } + } + + return indexList +} + +// AnsysGPTACSSemanticHybridSearchs performs a semantic hybrid search in ACS +// +// Tags: +// - @displayName: ACS Semantic Hybrid Search +// +// Parameters: +// - query: the query string +// - embeddedQuery: the embedded query +// - indexList: the index list +// - typeOfAsset: the type of asset +// - physics: the physics +// - product: the product +// - productMain: the main product +// - filter: the filter +// - filterAfterVectorSearch: the flag to define the filter order +// - returnedProperties: the properties to be returned +// - topK: the number of results to be returned from vector search +// - searchedEmbeddedFields: the ACS fields to be searched +// +// Returns: +// - output: the search results +func AnsysGPTACSSemanticHybridSearchs( + query string, + embeddedQuery []float32, + indexList []string, + filter map[string]string, + topK int) (output []sharedtypes.ACSSearchResponse) { + + output = make([]sharedtypes.ACSSearchResponse, 0) + for _, indexName := range indexList { + partOutput := ansysGPTACSSemanticHybridSearch(query, embeddedQuery, indexName, filter, topK) + output = append(output, partOutput...) + } + + return output +} + +// AnsysGPTRemoveNoneCitationsFromSearchResponse removes none citations from search response +// +// Tags: +// - @displayName: Remove None Citations +// +// Parameters: +// - semanticSearchOutput: the search response +// - citations: the citations +// +// Returns: +// - reducedSemanticSearchOutput: the reduced search response +func AnsysGPTRemoveNoneCitationsFromSearchResponse(semanticSearchOutput []sharedtypes.ACSSearchResponse, citations []sharedtypes.AnsysGPTCitation) (reducedSemanticSearchOutput []sharedtypes.ACSSearchResponse) { + // iterate throught search response and keep matches to citations + reducedSemanticSearchOutput = make([]sharedtypes.ACSSearchResponse, len(citations)) + for _, value := range semanticSearchOutput { + for _, citation := range citations { + if value.SourceURLLvl2 == citation.Title { + reducedSemanticSearchOutput = append(reducedSemanticSearchOutput, value) + } else if value.SourceURLLvl2 == citation.URL { + reducedSemanticSearchOutput = append(reducedSemanticSearchOutput, value) + } else if value.SearchRerankerScore == citation.Relevance { + reducedSemanticSearchOutput = append(reducedSemanticSearchOutput, value) + } + } + } + + return reducedSemanticSearchOutput +} + +// AnsysGPTReorderSearchResponseAndReturnOnlyTopK reorders the search response +// +// Tags: +// - @displayName: Reorder Search Response +// +// Parameters: +// - semanticSearchOutput: the search response +// - topK: the number of results to be returned +// +// Returns: +// - reorderedSemanticSearchOutput: the reordered search response +func AnsysGPTReorderSearchResponseAndReturnOnlyTopK(semanticSearchOutput []sharedtypes.ACSSearchResponse, topK int) (reorderedSemanticSearchOutput []sharedtypes.ACSSearchResponse) { + logging.Log.Debugf(internalstates.Ctx, "Reordering search response of length %v based on reranker_score and returning only top %v results", len(semanticSearchOutput), topK) + // Sorting by Weight * SearchRerankerScore in descending order + sort.Slice(semanticSearchOutput, func(i, j int) bool { + return semanticSearchOutput[i].Weight*semanticSearchOutput[i].SearchRerankerScore > semanticSearchOutput[j].Weight*semanticSearchOutput[j].SearchRerankerScore + }) + + // Return only topK results + if len(semanticSearchOutput) > topK { + return semanticSearchOutput[:topK] + } + + return semanticSearchOutput +} + +// AnsysGPTGetSystemPrompt returns the system prompt for Ansys GPT +// +// Tags: +// - @displayName: Get System Prompt +// +// Parameters: +// - rephrasedQuery: the rephrased query +// +// Returns: +// - systemPrompt: the system prompt +func AnsysGPTGetSystemPrompt(rephrasedQuery string) string { + return `Orders: You are AnsysGPT, a technical support assistant that is professional, friendly and multilingual that generates a clear and concise answer to the user question adhering to these strict guidelines: \n + You must always answer user queries using the provided 'context' and 'chat_history' only. If you cannot find an answer in the 'context' or the 'chat_history', never use your base knowledge to generate a response. \n + + You are a multilingual expert that will *always reply the user in the same language as that of their 'query' in ` + rephrasedQuery + `*. If the 'query' is in Japanese, your response must be in Japanese. If the 'query' is in Cantonese, your response must be in Cantonese. If the 'query' is in English, your response must be in English. You *must always* be consistent in your multilingual ability. \n + + You have the capability to learn or *remember information from past three interactions* with the user. \n + + You are a smart Technical support assistant that can distingush between a fresh independent query and a follow-up query based on 'chat_history'. \n + + If you find the user's 'query' to be a follow-up question, consider the 'chat_history' while generating responses. Use the information from the 'chat_history' to provide contextually relevant responses. When answering follow-up questions that can be answered using the 'chat_history' alone, do not provide any references. \n + + *Always* your answer must include the 'content', 'sourceURL_lvl3' of all the chunks in 'context' that are relevant to the user's query in 'query'. But, never cite 'sourceURL_lvl3' under the heading 'References'. \n + + The 'content' and 'sourceURL_lvl3' must be included together in your answer, with the 'sourceTitle_lvl2', 'sourceURL_lvl2' and '@search.reranker_score' serving as a citation for the 'content'. Include 'sourceURL_lvl3' directly in the answer in-line with the source, not in the references section. \n + + In your response follow a style of citation where each source is assigned a number, for example '[1]', that corresponds to the 'sourceURL_lvl3', 'sourceTitle_lvl2' and 'sourceURL_lvl2' in the 'context'. \n + + Make sure you always provide 'URL: Extract the value of 'sourceURL_lvl3'' in line with every source in your answer. For example 'You will learn to find the total drag and lift on a solar car in Ansys Fluent in this course. URL: [1] https://courses.ansys.com/index.php/courses/aerodynamics-of-a-solar-car/'. \n + + Never mention the position of chunk in your response for example 'chunk 1 / chunk 4'/ first chunk / third chunk'. \n + + **Always** aim to make your responses conversational and engaging, while still providing accurate and helpful information. \n + + If the user greets you, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n + + If the user acknowledges you, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n + + If the user asks about your purpose, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n + + If the user asks who are you?, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n + + When providing information from a source, try to introduce it in a *conversational manner*. For example, instead of saying 'In the chunk titled...', you could say 'I found a great resource titled... that explains...'. \n + + If a chunk has empty fields in it's 'sourceTitle_lvl2' and 'sourceURL_lvl2', you *must never* cite that chunk under references in your response. \n + + You must never provide JSON format in your answer and never cite references in JSON format.\n + + Strictly provide your response everytime in the below format: + + Your answer + Always provide 'URL: Extract the value of 'sourceURL_lvl3'' *inline right next to each source* and *not at the end of your answer*. + References: + [1] Title: Extract the value of 'sourceTitle_lvl2', URL: Extract the value of 'sourceURL_lvl2', Relevance: Extract the value of '@search.reranker_score' /4.0. + *Always* provide References for all the chunks in 'context'. + Do not provide 'sourceTitle_lvl3' in your response. + When answering follow-up questions that can be answered using the 'chat_history' alone, *do not provide any references*. + **Never** cite chunk that has empty fields in it's 'sourceTitle_lvl2' and 'sourceURL_lvl2' under References. + **Never** provide the JSON format in your response and References. + + Only provide a reference if it was found in the "context". Under no circumstances should you create your own references from your base knowledge or the internet. \n + + Here's an example of how you should structure your response: \n + + Designing an antenna involves several steps, and Ansys provides a variety of tools to assist you in this process. \n + The Ansys HFSS Antenna Toolkit, for instance, can automatically create the geometry of your antenna design with boundaries and excitations assigned. It also sets up the solution and generates post-processing reports for several popular antenna elements. Over 60 standard antenna topologies are available in the toolkit, and all the antenna models generated are ready to simulate. You can run a quick analysis of any antenna of your choosing [1]. URL: [1] https://www.youtube.com/embed/mhM6U2xn0Q0?start=25&end=123 \n + In another example, a rectangular edge fed patch antenna is created using the HFSS antenna toolkit. The antenna is synthesized for 3.5 GHz and the geometry model is already created for you. After analyzing the model, you can view the results generated from the toolkit. The goal is to fold or bend the antenna so that it fits onto the sidewall of a smartphone. After folding the antenna and reanalyzing, you can view the results such as return loss, input impedance, and total radiated power of the antenna [2]. URL: [2] https://www.youtube.com/embed/h0QttEmQ88E?start=94&end=186 \n + Lastly, Ansys Electronics Desktop integrates rigorous electromagnetic analysis with system and circuit simulation in a comprehensive, easy-to-use design platform. This platform is used to automatically create antenna geometries with materials, boundaries, excitations, solution setups, and post-processing reports [3]. URL: [3] https://ansyskm.ansys.com/forums/topic/ansys-hfss-antenna-synthesis-from-hfss-antenna-toolkit-part-2/ \n + I hope this helps you in your antenna design process. If you have any more questions, feel free to ask! \n + References: + [1] Title: "ANSYS HFSS: Antenna Synthesis from HFSS Antenna Toolkit - Part 2", URL: https://ansyskm.ansys.com/forums/topic/ansys-hfss-antenna-synthesis-from-hfss-antenna-toolkit-part-2/, Relevance: 3.53/4.0 + [2] Title: "Cosimulation Using Ansys HFSS and Circuit - Lesson 2 - ANSYS Innovation Courses", URL: https://courses.ansys.com/index.php/courses/cosimulation-using-ansys-hfss/lessons/cosimulation-using-ansys-hfss-and-circuit-lesson-2/, Relevance: 2.54/4.0` +} diff --git a/pkg/externalfunctions/dataextraction.go b/pkg/externalfunctions/dataextraction.go index 4796c57..244b362 100644 --- a/pkg/externalfunctions/dataextraction.go +++ b/pkg/externalfunctions/dataextraction.go @@ -12,7 +12,6 @@ import ( "sync" "github.com/ansys/allie-flowkit/pkg/internalstates" - "github.com/ansys/allie-sharedtypes/pkg/config" "github.com/ansys/allie-sharedtypes/pkg/logging" "github.com/ansys/allie-sharedtypes/pkg/sharedtypes" "github.com/google/go-github/v56/github" @@ -25,6 +24,9 @@ import ( // DataExtractionGetGithubFilesToExtract gets all files from github that need to be extracted. // +// Tags: +// - @displayName: List Github Files +// // Parameters: // - githubRepoName: name of the github repository. // - githubRepoOwner: owner of the github repository. @@ -74,6 +76,9 @@ func DataExtractionGetGithubFilesToExtract(githubRepoName string, githubRepoOwne // DataExtractionGetLocalFilesToExtract gets all files from local that need to be extracted. // +// Tags: +// - @displayName: List Local Files +// // Parameters: // - localPath: path to the local directory. // - localFileExtensions: local file extensions. @@ -117,6 +122,9 @@ func DataExtractionGetLocalFilesToExtract(localPath string, localFileExtensions // DataExtractionAppendStringSlices creates a new slice by appending all elements of the provided slices. // +// Tags: +// - @displayName: Append String Slices +// // Parameters: // - slice1, slice2, slice3, slice4, slice5: slices to append. // @@ -137,6 +145,9 @@ func DataExtractionAppendStringSlices(slice1, slice2, slice3, slice4, slice5 []s // DataExtractionDownloadGithubFileContent downloads file content from github and returns checksum and content. // +// Tags: +// - @displayName: Download Github File Content +// // Parameters: // - githubRepoName: name of the github repository. // - githubRepoOwner: owner of the github repository. @@ -178,6 +189,10 @@ func DataExtractionDownloadGithubFileContent(githubRepoName string, githubRepoOw } // DataExtractionGetLocalFileContent reads local file and returns checksum and content. +// +// Tags: +// - @displayName: Get Local File Content +// // Parameters: // - localFilePath: path to file. // @@ -215,6 +230,9 @@ func DataExtractionGetLocalFileContent(localFilePath string) (checksum string, c // DataExtractionGetDocumentType returns the document type of a file. // +// Tags: +// - @displayName: Get Document Type +// // Parameters: // - filePath: path to file. // @@ -230,6 +248,9 @@ func DataExtractionGetDocumentType(filePath string) (documentType string) { // DataExtractionLangchainSplitter splits content into chunks using langchain. // +// Tags: +// - @displayName: Split Content +// // Parameters: // - content: content to split. // - documentType: type of document. @@ -315,6 +336,9 @@ func DataExtractionLangchainSplitter(content string, documentType string, chunkS // DataExtractionGenerateDocumentTree generates a tree structure from the document chunks. // +// Tags: +// - @displayName: Document Tree +// // Parameters: // - documentName: name of the document. // - documentId: id of the document. @@ -502,63 +526,3 @@ func DataExtractionGenerateDocumentTree(documentName string, documentId string, return returnedDocumentData } - -// DataExtractionAddDataRequest sends a request to the add_data endpoint. -// -// Parameters: -// - collectionName: name of the collection the request is sent to. -// - data: the data to add. -func DataExtractionAddDataRequest(collectionName string, documentData []sharedtypes.DbData) { - // Create the AddDataInput object - requestObject := sharedtypes.DbAddDataInput{ - CollectionName: collectionName, - Data: documentData, - } - - // Create the URL - url := fmt.Sprintf("%s/%s", config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT, "add_data") - - // Send the HTTP POST request - var response sharedtypes.DbAddDataOutput - err, _ := createPayloadAndSendHttpRequest(url, requestObject, &response) - if err != nil { - errorMessage := fmt.Sprintf("Error sending request to add_data endpoint: %v", err) - logging.Log.Error(internalstates.Ctx, errorMessage) - panic(errorMessage) - } - - logging.Log.Debugf(internalstates.Ctx, "Added data to collection: %s \n", collectionName) - - return -} - -// DataExtractionCreateCollectionRequest sends a request to the collection endpoint. -// -// Parameters: -// - collectionName: the name of the collection to create. -func DataExtractionCreateCollectionRequest(collectionName string) { - // Create the CreateCollectionInput object - requestObject := sharedtypes.DbCreateCollectionInput{ - CollectionName: collectionName, - } - - // Create the URL - url := fmt.Sprintf("%s/%s", config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT, "create_collection") - - // Send the HTTP POST request - var response sharedtypes.DbCreateCollectionOutput - err, statusCode := createPayloadAndSendHttpRequest(url, requestObject, &response) - if err != nil { - if statusCode == 409 { - logging.Log.Warn(internalstates.Ctx, "Collection already exists") - } else { - errorMessage := fmt.Sprintf("Error sending request to create_collection endpoint: %v", err) - logging.Log.Error(internalstates.Ctx, errorMessage) - panic(errorMessage) - } - } - - logging.Log.Debugf(internalstates.Ctx, "Created collection: %s \n", collectionName) - - return -} diff --git a/pkg/externalfunctions/externalfunctions.go b/pkg/externalfunctions/externalfunctions.go index ca4232f..6c2149d 100644 --- a/pkg/externalfunctions/externalfunctions.go +++ b/pkg/externalfunctions/externalfunctions.go @@ -1,46 +1,35 @@ package externalfunctions -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "regexp" - "sort" - "strings" - - "github.com/ansys/allie-flowkit/pkg/internalstates" - "github.com/ansys/allie-sharedtypes/pkg/config" - "github.com/ansys/allie-sharedtypes/pkg/logging" - "github.com/ansys/allie-sharedtypes/pkg/sharedtypes" - "github.com/texttheater/golang-levenshtein/levenshtein" -) - var ExternalFunctionsMap = map[string]interface{}{ - "PerformVectorEmbeddingRequest": PerformVectorEmbeddingRequest, - "PerformBatchEmbeddingRequest": PerformBatchEmbeddingRequest, - "PerformKeywordExtractionRequest": PerformKeywordExtractionRequest, - "PerformGeneralRequest": PerformGeneralRequest, - "PerformCodeLLMRequest": PerformCodeLLMRequest, - "BuildLibraryContext": BuildLibraryContext, - "SendVectorsToKnowledgeDB": SendVectorsToKnowledgeDB, - "GetListCollections": GetListCollections, - "RetrieveDependencies": RetrieveDependencies, - "GeneralNeo4jQuery": GeneralNeo4jQuery, - "GeneralQuery": GeneralQuery, - "BuildFinalQueryForGeneralLLMRequest": BuildFinalQueryForGeneralLLMRequest, - "BuildFinalQueryForCodeLLMRequest": BuildFinalQueryForCodeLLMRequest, - "SimilaritySearch": SimilaritySearch, - "CreateKeywordsDbFilter": CreateKeywordsDbFilter, - "CreateTagsDbFilter": CreateTagsDbFilter, - "CreateMetadataDbFilter": CreateMetadataDbFilter, - "CreateDbFilter": CreateDbFilter, - "AppendMessageHistory": AppendMessageHistory, + // llm handler + "PerformVectorEmbeddingRequest": PerformVectorEmbeddingRequest, + "PerformBatchEmbeddingRequest": PerformBatchEmbeddingRequest, + "PerformKeywordExtractionRequest": PerformKeywordExtractionRequest, + "PerformGeneralRequest": PerformGeneralRequest, + "PerformGeneralRequestSpecificModel": PerformGeneralRequestSpecificModel, + "PerformCodeLLMRequest": PerformCodeLLMRequest, + "BuildLibraryContext": BuildLibraryContext, + "BuildFinalQueryForGeneralLLMRequest": BuildFinalQueryForGeneralLLMRequest, + "BuildFinalQueryForCodeLLMRequest": BuildFinalQueryForCodeLLMRequest, + "AppendMessageHistory": AppendMessageHistory, + + // knowledge db + "SendVectorsToKnowledgeDB": SendVectorsToKnowledgeDB, + "GetListCollections": GetListCollections, + "RetrieveDependencies": RetrieveDependencies, + "GeneralNeo4jQuery": GeneralNeo4jQuery, + "GeneralQuery": GeneralQuery, + "SimilaritySearch": SimilaritySearch, + "CreateKeywordsDbFilter": CreateKeywordsDbFilter, + "CreateTagsDbFilter": CreateTagsDbFilter, + "CreateMetadataDbFilter": CreateMetadataDbFilter, + "CreateDbFilter": CreateDbFilter, + + // ansys gpt "AnsysGPTCheckProhibitedWords": AnsysGPTCheckProhibitedWords, "AnsysGPTExtractFieldsFromQuery": AnsysGPTExtractFieldsFromQuery, "AnsysGPTPerformLLMRephraseRequest": AnsysGPTPerformLLMRephraseRequest, + "AnsysGPTPerformLLMRephraseRequestOld": AnsysGPTPerformLLMRephraseRequestOld, "AnsysGPTBuildFinalQuery": AnsysGPTBuildFinalQuery, "AnsysGPTPerformLLMRequest": AnsysGPTPerformLLMRequest, "AnsysGPTReturnIndexList": AnsysGPTReturnIndexList, @@ -48,1742 +37,20 @@ var ExternalFunctionsMap = map[string]interface{}{ "AnsysGPTRemoveNoneCitationsFromSearchResponse": AnsysGPTRemoveNoneCitationsFromSearchResponse, "AnsysGPTReorderSearchResponseAndReturnOnlyTopK": AnsysGPTReorderSearchResponseAndReturnOnlyTopK, "AnsysGPTGetSystemPrompt": AnsysGPTGetSystemPrompt, - "DataExtractionGetGithubFilesToExtract": DataExtractionGetGithubFilesToExtract, - "DataExtractionGetLocalFilesToExtract": DataExtractionGetLocalFilesToExtract, - "DataExtractionAppendStringSlices": DataExtractionAppendStringSlices, - "DataExtractionDownloadGithubFileContent": DataExtractionDownloadGithubFileContent, - "DataExtractionGetLocalFileContent": DataExtractionGetLocalFileContent, - "DataExtractionGetDocumentType": DataExtractionGetDocumentType, - "DataExtractionLangchainSplitter": DataExtractionLangchainSplitter, - "DataExtractionGenerateDocumentTree": DataExtractionGenerateDocumentTree, - "DataExtractionAddDataRequest": DataExtractionAddDataRequest, - "DataExtractionCreateCollectionRequest": DataExtractionCreateCollectionRequest, - "PerformGeneralRequestSpecificModel": PerformGeneralRequestSpecificModel, - "AssignStringToString": AssignStringToString, -} - -// PerformVectorEmbeddingRequest performs a vector embedding request to LLM -// -// Parameters: -// - input: the input string -// -// Returns: -// - embeddedVector: the embedded vector in float32 format -func PerformVectorEmbeddingRequest(input string) (embeddedVector []float32) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send embeddings request - responseChannel := sendEmbeddingsRequest(input, llmHandlerEndpoint, nil) - - // Process the first response and close the channel - var embedding32 []float32 - var err error - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Log LLM response - logging.Log.Debugf(internalstates.Ctx, "Received embeddings response.") - - // Get embedded vector array - interfaceArray, ok := response.EmbeddedData.([]interface{}) - if !ok { - errMessage := "error converting embedded data to interface array" - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - embedding32, err = convertToFloat32Slice(interfaceArray) - if err != nil { - errMessage := fmt.Sprintf("error converting embedded data to float32 slice: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Mark that the first response has been received - firstResponseReceived := true - - // Exit the loop after processing the first response - if firstResponseReceived { - break - } - } - - // Close the response channel - close(responseChannel) - - return embedding32 -} - -// PerformBatchEmbeddingRequest performs a batch vector embedding request to LLM -// -// Parameters: -// - input: the input strings -// -// Returns: -// - embeddedVectors: the embedded vectors in float32 format -func PerformBatchEmbeddingRequest(input []string) (embeddedVectors [][]float32) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send embeddings request - responseChannel := sendEmbeddingsRequest(input, llmHandlerEndpoint, nil) - - // Process the first response and close the channel - embedding32Array := make([][]float32, len(input)) - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Log LLM response - logging.Log.Debugf(internalstates.Ctx, "Received embeddings response.") - - // Get embedded vector array - interfaceArray, ok := response.EmbeddedData.([]interface{}) - if !ok { - errMessage := "error converting embedded data to interface array" - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - for i, interfaceArrayElement := range interfaceArray { - lowerInterfaceArray, ok := interfaceArrayElement.([]interface{}) - if !ok { - errMessage := "error converting embedded data to interface array" - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - embedding32, err := convertToFloat32Slice(lowerInterfaceArray) - if err != nil { - errMessage := fmt.Sprintf("error converting embedded data to float32 slice: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - embedding32Array[i] = embedding32 - } - - // Mark that the first response has been received - firstResponseReceived := true - - // Exit the loop after processing the first response - if firstResponseReceived { - break - } - } - - // Close the response channel - close(responseChannel) - - return embedding32Array -} - -// PerformKeywordExtractionRequest performs a keywords extraction request to LLM -// -// Parameters: -// - input: the input string -// - maxKeywordsSearch: the maximum number of keywords to search for -// -// Returns: -// - keywords: the keywords extracted from the input string as a slice of strings -func PerformKeywordExtractionRequest(input string, maxKeywordsSearch uint32) (keywords []string) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send chat request - responseChannel := sendChatRequestNoHistory(input, "keywords", maxKeywordsSearch, llmHandlerEndpoint, nil) - - // Process all responses - var responseAsStr string - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Accumulate the responses - responseAsStr += *(response.ChatData) - - // If we are at the last message, break the loop - if *(response.IsLast) { - break - } - } - - logging.Log.Debugf(internalstates.Ctx, "Received keywords response.") - - // Close the response channel - close(responseChannel) - - // Unmarshal JSON data into the result variable - err := json.Unmarshal([]byte(responseAsStr), &keywords) - if err != nil { - errMessage := fmt.Sprintf("Error unmarshalling keywords response from allie-llm: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Return the response - return keywords -} - -// PerformSummaryRequest performs a summary request to LLM -// -// Parameters: -// - input: the input string -// -// Returns: -// - summary: the summary extracted from the input string -func PerformSummaryRequest(input string) (summary string) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send chat request - responseChannel := sendChatRequestNoHistory(input, "summary", 1, llmHandlerEndpoint, nil) - - // Process all responses - var responseAsStr string - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Accumulate the responses - responseAsStr += *(response.ChatData) - - // If we are at the last message, break the loop - if *(response.IsLast) { - break - } - } - - logging.Log.Debugf(internalstates.Ctx, "Received summary response.") - - // Close the response channel - close(responseChannel) - - // Return the response - return responseAsStr -} - -// PerformGeneralRequest performs a general chat completion request to LLM -// -// Parameters: -// - input: the input string -// - history: the conversation history -// - isStream: the stream flag -// - systemPrompt: the system prompt -// -// Returns: -// - message: the generated message -// - stream: the stream channel -func PerformGeneralRequest(input string, history []sharedtypes.HistoricMessage, isStream bool, systemPrompt string) (message string, stream *chan string) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send chat request - responseChannel := sendChatRequest(input, "general", history, 0, systemPrompt, llmHandlerEndpoint, nil) - // If isStream is true, create a stream channel and return asap - if isStream { - // Create a stream channel - streamChannel := make(chan string, 400) - - // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false) - - // Return the stream channel - return "", &streamChannel - } - - // else Process all responses - var responseAsStr string - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Accumulate the responses - responseAsStr += *(response.ChatData) - - // If we are at the last message, break the loop - if *(response.IsLast) { - break - } - } - - // Close the response channel - close(responseChannel) - - // Return the response - return responseAsStr, nil -} - -// PerformCodeLLMRequest performs a code generation request to LLM -// -// Parameters: -// - input: the input string -// - history: the conversation history -// - isStream: the stream flag -// -// Returns: -// - message: the generated code -// - stream: the stream channel -func PerformCodeLLMRequest(input string, history []sharedtypes.HistoricMessage, isStream bool, validateCode bool) (message string, stream *chan string) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send chat request - responseChannel := sendChatRequest(input, "code", history, 0, "", llmHandlerEndpoint, nil) - - // If isStream is true, create a stream channel and return asap - if isStream { - // Create a stream channel - streamChannel := make(chan string, 400) - - // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, validateCode) - - // Return the stream channel - return "", &streamChannel - } - - // else Process all responses - var responseAsStr string - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Accumulate the responses - responseAsStr += *(response.ChatData) - - // If we are at the last message, break the loop - if *(response.IsLast) { - break - } - } - - // Close the response channel - close(responseChannel) - - // Code validation - if validateCode { - - // Extract the code from the response - pythonCode, err := extractPythonCode(responseAsStr) - if err != nil { - logging.Log.Errorf(internalstates.Ctx, "Error extracting Python code: %v", err) - } else { - - // Validate the Python code - valid, warnings, err := validatePythonCode(pythonCode) - if err != nil { - logging.Log.Errorf(internalstates.Ctx, "Error validating Python code: %v", err) - } else { - if valid { - if warnings { - responseAsStr += "\nCode has warnings." - } else { - responseAsStr += "\nCode is valid." - } - } else { - responseAsStr += "\nCode is invalid." - } - } - } - } - - // Return the response - return responseAsStr, nil -} - -// BuildLibraryContext builds the context string for the query -// -// Parameters: -// - message: the message string -// - libraryContext: the library context string -// -// Returns: -// - messageWithContext: the message with context -func BuildLibraryContext(message string, libraryContext string) (messageWithContext string) { - // Check if "pyansys" is in the library context - message = libraryContext + message - - return message -} - -// SendVectorsToKnowledgeDB sends the given vector to the KnowledgeDB and -// returns the most relevant data. The number of results is specified in the -// config file. The keywords are used to filter the results. The min score -// filter is also specified in the config file. If it is not specified, the -// default value is used. -// -// The function returns the most relevant data. -// -// Parameters: -// - vector: the vector to be sent to the KnowledgeDB -// - keywords: the keywords to be used to filter the results -// - keywordsSearch: the flag to enable the keywords search -// - collection: the collection name -// - similaritySearchResults: the number of results to be returned -// - similaritySearchMinScore: the minimum score for the results -// -// Returns: -// - databaseResponse: an array of the most relevant data -func SendVectorsToKnowledgeDB(vector []float32, keywords []string, keywordsSearch bool, collection string, similaritySearchResults int, similaritySearchMinScore float64) (databaseResponse []sharedtypes.DbResponse) { - // get the KnowledgeDB endpoint - knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT - - // Log the request - logging.Log.Debugf(internalstates.Ctx, "Connecting to the KnowledgeDB.") - - // Build filters - var filters sharedtypes.DbFilters - - // -- Add the keywords filter if needed - if keywordsSearch { - filters.KeywordsFilter = sharedtypes.DbArrayFilter{ - NeedAll: false, - FilterData: keywords, - } - } - - // -- Add the level filter - filters.LevelFilter = []string{"leaf"} - - // Create a new resource instance - requestInput := similaritySearchInput{ - CollectionName: collection, - EmbeddedVector: vector, - MaxRetrievalCount: similaritySearchResults, - Filters: filters, - MinScore: similaritySearchMinScore, - OutputFields: []string{ - "guid", - "document_id", - "document_name", - "summary", - "keywords", - "text", - }, - } - - // Convert the resource instance to JSON. - jsonData, err := json.Marshal(requestInput) - if err != nil { - errMessage := fmt.Sprintf("Error marshalling JSON data of POST /similarity_search request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Specify the target endpoint. - requestURL := knowledgeDbEndpoint + "/similarity_search" - - // Create a new HTTP request with the JSON data. - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - errMessage := fmt.Sprintf("Error creating POST /similarity_search request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Set the appropriate content type for the request. - req.Header.Set("Content-Type", "application/json") - - // Send the HTTP request using the default HTTP client. - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - errMessage := fmt.Sprintf("Error sending POST /similarity_search request to allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - defer resp.Body.Close() - - // Read and display the response body. - body, err := io.ReadAll(resp.Body) - if err != nil { - errMessage := fmt.Sprintf("Error reading response body of POST /similarity_search request from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Log the similarity search response - logging.Log.Debugf(internalstates.Ctx, "Knowledge DB response: %v", string(body)) - logging.Log.Debugf(internalstates.Ctx, "Knowledge DB response received!") - - // Unmarshal the response body to the appropriate struct. - var response similaritySearchOutput - err = json.Unmarshal(body, &response) - if err != nil { - errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /similarity_search response from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - var mostRelevantData []sharedtypes.DbResponse - var count int = 1 - for _, element := range response.SimilarityResult { - // Log the result - logging.Log.Debugf(internalstates.Ctx, "Result #%d:", count) - logging.Log.Debugf(internalstates.Ctx, "Similarity score: %v", element.Score) - logging.Log.Debugf(internalstates.Ctx, "Similarity file id: %v", element.Data.DocumentId) - logging.Log.Debugf(internalstates.Ctx, "Similarity file name: %v", element.Data.DocumentName) - logging.Log.Debugf(internalstates.Ctx, "Similarity summary: %v", element.Data.Summary) - - // Add the result to the list - mostRelevantData = append(mostRelevantData, element.Data) - - // Check whether we have enough results - if count >= similaritySearchResults { - break - } else { - count++ - } - } - - // Return the most relevant data - return mostRelevantData -} - -// GetListCollections retrieves the list of collections from the KnowledgeDB. -// -// The function returns the list of collections. -// -// Parameters: -// - knowledgeDbEndpoint: the KnowledgeDB endpoint -// -// Returns: -// - collectionsList: the list of collections -func GetListCollections() (collectionsList []string) { - // get the KnowledgeDB endpoint - knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT - - // Specify the target endpoint. - requestURL := knowledgeDbEndpoint + "/list_collections" - - // Create a new HTTP request with the JSON data. - req, err := http.NewRequest("GET", requestURL, nil) - if err != nil { - errMessage := fmt.Sprintf("Error creating GET /list_collections request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Set the appropriate content type for the request. - req.Header.Set("Content-Type", "application/json") - - // Send the HTTP request using the default HTTP client. - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - errMessage := fmt.Sprintf("Error sending GET /list_collections request to allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - defer resp.Body.Close() - - // Read and display the response body. - body, err := io.ReadAll(resp.Body) - if err != nil { - errMessage := fmt.Sprintf("Error reading response body of GET /list_collections request from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Unmarshal the response body to the appropriate struct. - var response sharedtypes.DBListCollectionsOutput - err = json.Unmarshal(body, &response) - if err != nil { - errMessage := fmt.Sprintf("Error unmarshalling JSON data of GET /list_collections response from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Log the result and return the list of collections - if !response.Success { - errMessage := "Failed to retrieve list of collections from allie-db" - logging.Log.Warn(internalstates.Ctx, errMessage) - panic(errMessage) - } else { - logging.Log.Debugf(internalstates.Ctx, "List collections response received!") - logging.Log.Debugf(internalstates.Ctx, "Collections: %v", response.Collections) - return response.Collections - } -} - -// RetrieveDependencies retrieves the dependencies of the specified source node. -// -// The function returns the list of dependencies. -// -// Parameters: -// - collectionName: the name of the collection to which the data objects will be added. -// - relationshipName: the name of the relationship to retrieve dependencies for. -// - relationshipDirection: the direction of the relationship to retrieve dependencies for. -// - sourceDocumentId: the document ID of the source node. -// - nodeTypesFilter: filter based on node types. -// - maxHopsNumber: maximum number of hops to traverse. -// -// Returns: -// - dependenciesIds: the list of dependencies -func RetrieveDependencies( - collectionName string, - relationshipName string, - relationshipDirection string, - sourceDocumentId string, - nodeTypesFilter sharedtypes.DbArrayFilter, - maxHopsNumber int) (dependenciesIds []string) { - // get the KnowledgeDB endpoint - knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT - - // Create the URL - requestURL := knowledgeDbEndpoint + "/retrieve_dependencies" - - // Create the retrieveDependenciesInput object - requestInput := retrieveDependenciesInput{ - CollectionName: collectionName, - RelationshipName: relationshipName, - RelationshipDirection: relationshipDirection, - SourceDocumentId: sourceDocumentId, - NodeTypesFilter: nodeTypesFilter, - MaxHopsNumber: maxHopsNumber, - } - - // Convert the resource instance to JSON. - jsonData, err := json.Marshal(requestInput) - if err != nil { - errMessage := fmt.Sprintf("Error marshalling JSON data of POST /retrieve_dependencies request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Create a new HTTP request with the JSON data. - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - errMessage := fmt.Sprintf("Error creating POST /retrieve_dependencies request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Set the appropriate content type for the request. - req.Header.Set("Content-Type", "application/json") - - // Send the HTTP request using the default HTTP client. - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - errMessage := fmt.Sprintf("Error sending POST /retrieve_dependencies request to allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - defer resp.Body.Close() - - // Read and display the response body. - body, err := io.ReadAll(resp.Body) - if err != nil { - errMessage := fmt.Sprintf("Error reading response body of POST /retrieve_dependencies request from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - logging.Log.Debugf(internalstates.Ctx, "Knowledge DB RetrieveDependencies response received!") - - // Unmarshal the response body to the appropriate struct. - var response retrieveDependenciesOutput - err = json.Unmarshal(body, &response) - if err != nil { - errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /retrieve_dependencies response from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - return response.DependenciesIds -} - -// GeneralNeo4jQuery executes the given Neo4j query and returns the response. -// -// The function returns the neo4j response. -// -// Parameters: -// - query: the Neo4j query to be executed. -// -// Returns: -// - databaseResponse: the Neo4j response -func GeneralNeo4jQuery(query string) (databaseResponse sharedtypes.Neo4jResponse) { - // get the KnowledgeDB endpoint - knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT - - // Create the URL - requestURL := knowledgeDbEndpoint + "/general_neo4j_query" - - // Create the retrieveDependenciesInput object - requestInput := sharedtypes.GeneralNeo4jQueryInput{ - Query: query, - } - - // Convert the resource instance to JSON. - jsonData, err := json.Marshal(requestInput) - if err != nil { - errMessage := fmt.Sprintf("Error marshalling JSON data of POST /general_neo4j_query request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Create a new HTTP request with the JSON data. - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - errMessage := fmt.Sprintf("Error creating POST /general_neo4j_query request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Set the appropriate content type for the request. - req.Header.Set("Content-Type", "application/json") - - // Send the HTTP request using the default HTTP client. - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - errMessage := fmt.Sprintf("Error sending POST /general_neo4j_query request to allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - defer resp.Body.Close() - - // Read and display the response body. - body, err := io.ReadAll(resp.Body) - if err != nil { - errMessage := fmt.Sprintf("Error reading response body of POST /general_neo4j_query request from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - logging.Log.Debugf(internalstates.Ctx, "Knowledge DB GeneralNeo4jQuery response received!") - - // Unmarshal the response body to the appropriate struct. - var response sharedtypes.GeneralNeo4jQueryOutput - err = json.Unmarshal(body, &response) - if err != nil { - errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /general_neo4j_query response from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - return response.Response -} - -// GeneralQuery performs a general query in the KnowledgeDB. -// -// The function returns the query results. -// -// Parameters: -// - collectionName: the name of the collection to which the data objects will be added. -// - maxRetrievalCount: the maximum number of results to be retrieved. -// - outputFields: the fields to be included in the output. -// - filters: the filter for the query. -// -// Returns: -// - databaseResponse: the query results -func GeneralQuery(collectionName string, maxRetrievalCount int, outputFields []string, filters sharedtypes.DbFilters) (databaseResponse []sharedtypes.DbResponse) { - // get the KnowledgeDB endpoint - knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT - - // Create the URL - requestURL := knowledgeDbEndpoint + "/query" - - // Create the queryInput object - requestInput := queryInput{ - CollectionName: collectionName, - MaxRetrievalCount: maxRetrievalCount, - OutputFields: outputFields, - Filters: filters, - } - - // Convert the resource instance to JSON. - jsonData, err := json.Marshal(requestInput) - if err != nil { - errMessage := fmt.Sprintf("Error marshalling JSON data of POST /query request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Create a new HTTP request with the JSON data. - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - errMessage := fmt.Sprintf("Error creating POST /query request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Set the appropriate content type for the request. - req.Header.Set("Content-Type", "application/json") - - // Send the HTTP request using the default HTTP client. - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - errMessage := fmt.Sprintf("Error sending POST /query request to allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - defer resp.Body.Close() - - // Read and display the response body. - body, err := io.ReadAll(resp.Body) - if err != nil { - errMessage := fmt.Sprintf("Error reading response body of POST /query request from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - logging.Log.Debugf(internalstates.Ctx, "Knowledge DB GeneralQuery response received!") - - // Unmarshal the response body to the appropriate struct. - var response queryOutput - err = json.Unmarshal(body, &response) - if err != nil { - errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /query response from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - return response.QueryResult -} - -// BuildFinalQueryForGeneralLLMRequest builds the final query for a general -// request to LLM. The final query is a markdown string that contains the -// original request and the examples from the KnowledgeDB. -// -// Parameters: -// - request: the original request -// - knowledgedbResponse: the KnowledgeDB response -// -// Returns: -// - finalQuery: the final query -func BuildFinalQueryForGeneralLLMRequest(request string, knowledgedbResponse []sharedtypes.DbResponse) (finalQuery string) { - - // If there is no response from the KnowledgeDB, return the original request - if len(knowledgedbResponse) == 0 { - return request - } - - // Build the final query using the KnowledgeDB response and the original request - finalQuery = "Based on the following examples:\n\n--- INFO START ---\n" - for _, example := range knowledgedbResponse { - finalQuery += example.Text + "\n" - } - finalQuery += "--- INFO END ---\n\n" + request + "\n" - - // Return the final query - return finalQuery -} - -// BuildFinalQueryForCodeLLMRequest builds the final query for a code generation -// request to LLM. The final query is a markdown string that contains the -// original request and the code examples from the KnowledgeDB. -// -// Parameters: -// - request: the original request -// - knowledgedbResponse: the KnowledgeDB response -// -// Returns: -// - finalQuery: the final query -func BuildFinalQueryForCodeLLMRequest(request string, knowledgedbResponse []sharedtypes.DbResponse) (finalQuery string) { - // Build the final query using the KnowledgeDB response and the original request - // We have to use the text from the DB response and the original request. - // - // The prompt should be in the following format: - // - // ****************************************************************************** - // Based on the following examples: - // - // --- START EXAMPLE {response_n}--- - // >>> Summary: - // {knowledge_db_response_n_summary} - // - // >>> Code snippet: - // ```python - // {knowledge_db_response_n_text} - // ``` - // --- END EXAMPLE {response_n}--- - // - // --- START EXAMPLE {response_n}--- - // ... - // --- END EXAMPLE {response_n}--- - // - // Generate the Python code for the following request: - // - // >>> Request: - // {original_request} - // ****************************************************************************** - - // If there is no response from the KnowledgeDB, return the original request - if len(knowledgedbResponse) > 0 { - // Initial request - finalQuery = "Based on the following examples:\n\n" - - for i, element := range knowledgedbResponse { - // Add the example number - finalQuery += "--- START EXAMPLE " + fmt.Sprint(i+1) + "---\n" - finalQuery += ">>> Summary:\n" + element.Summary + "\n\n" - finalQuery += ">>> Code snippet:\n```python\n" + element.Text + "\n```\n" - finalQuery += "--- END EXAMPLE " + fmt.Sprint(i+1) + "---\n\n" - } - } - - // Pass in the original request - finalQuery += "Generate the Python code for the following request:\n>>> Request:\n" + request + "\n" - - // Return the final query - return finalQuery -} - -// SimilaritySearch performs a similarity search in the KnowledgeDB. -// -// The function returns the similarity search results. -// -// Parameters: -// - collectionName: the name of the collection to which the data objects will be added. -// - embeddedVector: the embedded vector used for searching. -// - maxRetrievalCount: the maximum number of results to be retrieved. -// - outputFields: the fields to be included in the output. -// - filters: the filter for the query. -// - minScore: the minimum score filter. -// - getLeafNodes: flag to indicate whether to retrieve all the leaf nodes in the result node branch. -// - getSiblings: flag to indicate whether to retrieve the previous and next node to the result nodes. -// - getParent: flag to indicate whether to retrieve the parent object. -// - getChildren: flag to indicate whether to retrieve the children objects. -// -// Returns: -// - databaseResponse: the similarity search results -func SimilaritySearch( - collectionName string, - embeddedVector []float32, - maxRetrievalCount int, - outputFields []string, - filters sharedtypes.DbFilters, - minScore float64, - getLeafNodes bool, - getSiblings bool, - getParent bool, - getChildren bool) (databaseResponse []sharedtypes.DbResponse) { - // get the KnowledgeDB endpoint - knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT - - // Create the URL - requestURL := knowledgeDbEndpoint + "/similarity_search" - - // Create the retrieveDependenciesInput object - requestInput := similaritySearchInput{ - CollectionName: collectionName, - EmbeddedVector: embeddedVector, - MaxRetrievalCount: maxRetrievalCount, - OutputFields: outputFields, - Filters: filters, - MinScore: minScore, - GetLeafNodes: getLeafNodes, - GetSiblings: getSiblings, - GetParent: getParent, - GetChildren: getChildren, - } - - // Convert the resource instance to JSON. - jsonData, err := json.Marshal(requestInput) - if err != nil { - errMessage := fmt.Sprintf("Error marshalling JSON data of POST /similarity_search request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Create a new HTTP request with the JSON data. - req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) - if err != nil { - errMessage := fmt.Sprintf("Error creating POST /similarity_search request for allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // Set the appropriate content type for the request. - req.Header.Set("Content-Type", "application/json") - - // Send the HTTP request using the default HTTP client. - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - errMessage := fmt.Sprintf("Error sending POST /similarity_search request to allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - defer resp.Body.Close() - - // Read and display the response body. - body, err := io.ReadAll(resp.Body) - if err != nil { - errMessage := fmt.Sprintf("Error reading response body of POST /similarity_search request from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - logging.Log.Debugf(internalstates.Ctx, "Knowledge DB SimilaritySearch response received!") - - // Unmarshal the response body to the appropriate struct. - var response similaritySearchOutput - err = json.Unmarshal(body, &response) - if err != nil { - errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /similarity_search response from allie-db: %v", err) - logging.Log.Error(internalstates.Ctx, errMessage) - panic(errMessage) - } - - var similarityResults []sharedtypes.DbResponse - for _, element := range response.SimilarityResult { - similarityResults = append(similarityResults, element.Data) - } - - return similarityResults -} - -// CreateKeywordsDbFilter creates a keywords filter for the KnowledgeDB. -// -// The function returns the keywords filter. -// -// Parameters: -// - keywords: the keywords to be used for the filter -// - needAll: flag to indicate whether all keywords are needed -// -// Returns: -// - databaseFilter: the keywords filter -func CreateKeywordsDbFilter(keywords []string, needAll bool) (databaseFilter sharedtypes.DbArrayFilter) { - var keywordsFilters sharedtypes.DbArrayFilter - - // -- Add the keywords filter if needed - if len(keywords) > 0 { - keywordsFilters = createDbArrayFilter(keywords, needAll) - } - - return keywordsFilters -} - -// CreateTagsDbFilter creates a tags filter for the KnowledgeDB. -// -// The function returns the tags filter. -// -// Parameters: -// - tags: the tags to be used for the filter -// - needAll: flag to indicate whether all tags are needed -// -// Returns: -// - databaseFilter: the tags filter -func CreateTagsDbFilter(tags []string, needAll bool) (databaseFilter sharedtypes.DbArrayFilter) { - var tagsFilters sharedtypes.DbArrayFilter - - // -- Add the tags filter if needed - if len(tags) > 0 { - tagsFilters = createDbArrayFilter(tags, needAll) - } - - return tagsFilters -} - -// CreateMetadataDbFilter creates a metadata filter for the KnowledgeDB. -// -// The function returns the metadata filter. -// -// Parameters: -// - fieldName: the name of the field -// - fieldType: the type of the field -// - filterData: the filter data -// - needAll: flag to indicate whether all data is needed -// -// Returns: -// - databaseFilter: the metadata filter -func CreateMetadataDbFilter(fieldName string, fieldType string, filterData []string, needAll bool) (databaseFilter sharedtypes.DbJsonFilter) { - return createDbJsonFilter(fieldName, fieldType, filterData, needAll) -} - -// CreateDbFilter creates a filter for the KnowledgeDB. -// -// The function returns the filter. -// -// Parameters: -// - guid: the guid filter -// - documentId: the document ID filter -// - documentName: the document name filter -// - level: the level filter -// - tags: the tags filter -// - keywords: the keywords filter -// - metadata: the metadata filter -// -// Returns: -// - databaseFilter: the filter -func CreateDbFilter( - guid []string, - documentId []string, - documentName []string, - level []string, - tags sharedtypes.DbArrayFilter, - keywords sharedtypes.DbArrayFilter, - metadata []sharedtypes.DbJsonFilter) (databaseFilter sharedtypes.DbFilters) { - var filters sharedtypes.DbFilters - - // -- Add the guid filter if needed - if len(guid) > 0 { - filters.GuidFilter = guid - } - - // -- Add the document ID filter if needed - if len(documentId) > 0 { - filters.DocumentIdFilter = documentId - } - - // -- Add the document name filter if needed - if len(documentName) > 0 { - filters.DocumentNameFilter = documentName - } - - // -- Add the level filter if needed - if len(level) > 0 { - filters.LevelFilter = level - } - - // -- Add the tags filter if needed - if len(tags.FilterData) > 0 { - filters.TagsFilter = tags - } - - // -- Add the keywords filter if needed - if len(keywords.FilterData) > 0 { - filters.KeywordsFilter = keywords - } - - // -- Add the metadata filter if needed - if len(metadata) > 0 { - filters.MetadataFilter = metadata - } - - return filters -} - -// AppendMessageHistoryInput represents the input for the AppendMessageHistory function. -type AppendMessageHistoryRole string - -const ( - user AppendMessageHistoryRole = "user" - assistant AppendMessageHistoryRole = "assistant" - system AppendMessageHistoryRole = "system" -) - -// AppendMessageHistory appends a new message to the conversation history -// -// Parameters: -// - newMessage: the new message -// - role: the role of the message -// - history: the conversation history -// -// Returns: -// - updatedHistory: the updated conversation history -func AppendMessageHistory(newMessage string, role AppendMessageHistoryRole, history []sharedtypes.HistoricMessage) (updatedHistory []sharedtypes.HistoricMessage) { - switch role { - case user: - case assistant: - case system: - default: - errMessage := fmt.Sprintf("Invalid role used for 'AppendMessageHistory': %v", role) - logging.Log.Warn(internalstates.Ctx, errMessage) - panic(errMessage) - } - - // skip for empty messages - if newMessage == "" { - return history - } - - // Create a new HistoricMessage - newMessageHistory := sharedtypes.HistoricMessage{ - Role: string(role), - Content: newMessage, - } - - // Append the new message to the history - history = append(history, newMessageHistory) - - return history -} - -// AnsysGPTCheckProhibitedWords checks the user query for prohibited words -// -// Parameters: -// - query: the user query -// - prohibitedWords: the list of prohibited words -// - errorResponseMessage: the error response message -// -// Returns: -// - foundProhibited: the flag indicating whether prohibited words were found -// - responseMessage: the response message -func AnsysGPTCheckProhibitedWords(query string, prohibitedWords []string, errorResponseMessage string) (foundProhibited bool, responseMessage string) { - // Check if all words in the value are present as whole words in the query - queryLower := strings.ToLower(query) - queryLower = strings.ReplaceAll(queryLower, ".", "") - for _, prohibitedValue := range prohibitedWords { - allWordsMatch := true - for _, fieldWord := range strings.Fields(strings.ToLower(prohibitedValue)) { - pattern := `\b` + regexp.QuoteMeta(fieldWord) + `\b` - match, _ := regexp.MatchString(pattern, queryLower) - if !match { - allWordsMatch = false - break - } - } - if allWordsMatch { - return true, errorResponseMessage - } - } - - // Check for prohibited words using fuzzy matching - cutoff := 0.9 - for _, prohibitedValue := range prohibitedWords { - wordMatchCount := 0 - for _, fieldWord := range strings.Fields(strings.ToLower(prohibitedValue)) { - for _, word := range strings.Fields(queryLower) { - distance := levenshtein.RatioForStrings([]rune(word), []rune(fieldWord), levenshtein.DefaultOptions) - if distance >= cutoff { - wordMatchCount++ - break - } - } - } - - if wordMatchCount == len(strings.Fields(prohibitedValue)) { - return true, errorResponseMessage - } - - // If multiple words are present in the field , also check for the whole words without spaces - if strings.Contains(prohibitedValue, " ") { - for _, word := range strings.Fields(queryLower) { - distance := levenshtein.RatioForStrings([]rune(word), []rune(prohibitedValue), levenshtein.DefaultOptions) - if distance >= cutoff { - return true, errorResponseMessage - } - } - } - } - - return false, "" -} - -// AnsysGPTExtractFieldsFromQuery extracts the fields from the user query -// -// Parameters: -// - query: the user query -// - fieldValues: the field values that the user query can contain -// - defaultFields: the default fields that the user query can contain -// -// Returns: -// - fields: the extracted fields -func AnsysGPTExtractFieldsFromQuery(query string, fieldValues map[string][]string, defaultFields []sharedtypes.AnsysGPTDefaultFields) (fields map[string]string) { - // Initialize the fields map - fields = make(map[string]string) - - // Check each field - for field, values := range fieldValues { - // Initializing the field with None - fields[field] = "" - - // Sort the values by length in descending order - sort.Slice(values, func(i, j int) bool { - return len(values[i]) > len(values[j]) - }) - - // Check if all words in the value are present as whole words in the query - lowercaseQuery := strings.ToLower(query) - for _, fieldValue := range values { - allWordsMatch := true - for _, fieldWord := range strings.Fields(strings.ToLower(fieldValue)) { - pattern := `\b` + regexp.QuoteMeta(fieldWord) + `\b` - match, _ := regexp.MatchString(pattern, lowercaseQuery) - if !match { - allWordsMatch = false - break - } - } - - if allWordsMatch { - fields[field] = fieldValue - break - } - } - - // Split the query into words - words := strings.Fields(lowercaseQuery) - - // If no exact match found, use fuzzy matching - if fields[field] == "" { - cutoff := 0.75 - for _, fieldValue := range values { - for _, fieldWord := range strings.Fields(fieldValue) { - for _, queryWord := range words { - distance := levenshtein.RatioForStrings([]rune(fieldWord), []rune(queryWord), levenshtein.DefaultOptions) - if distance >= cutoff { - fields[field] = fieldValue - break - } - } - } - } - } - } - - // If default value is found, use it - for _, defaultField := range defaultFields { - value, ok := fields[defaultField.FieldName] - if ok && value == "" { - if strings.Contains(strings.ToLower(query), strings.ToLower(defaultField.QueryWord)) { - fields[defaultField.FieldName] = defaultField.FieldDefaultValue - } - } - } - - return fields -} - -// AnsysGPTPerformLLMRephraseRequest performs a rephrase request to LLM -// -// Parameters: -// - template: the template for the rephrase request -// - query: the user query -// - history: the conversation history -// -// Returns: -// - rephrasedQuery: the rephrased query -func AnsysGPTPerformLLMRephraseRequest(template string, query string, history []sharedtypes.HistoricMessage) (rephrasedQuery string) { - logging.Log.Debugf(internalstates.Ctx, "Performing LLM rephrase request") - - historyMessages := "" - - if len(history) >= 1 { - historyMessages += "user:" + history[len(history)-2].Content + "\n" - } else { - return query - } - - // Create map for the data to be used in the template - dataMap := make(map[string]string) - dataMap["query"] = query - dataMap["chat_history"] = historyMessages - - // Format the template - userTemplate := formatTemplate(template, dataMap) - logging.Log.Debugf(internalstates.Ctx, "User template: %v", userTemplate) - - // Perform the general request - rephrasedQuery, _, err := performGeneralRequest(userTemplate, nil, false, "You are AnsysGPT, a technical support assistant that is professional, friendly and multilingual that generates a clear and concise answer") - if err != nil { - panic(err) - } - - logging.Log.Debugf(internalstates.Ctx, "Rephrased query: %v", rephrasedQuery) - - return rephrasedQuery -} - -// AnsysGPTBuildFinalQuery builds the final query for Ansys GPT -// -// Parameters: -// - refrasedQuery: the refrased query -// - context: the context -// -// Returns: -// - finalQuery: the final query -func AnsysGPTBuildFinalQuery(refrasedQuery string, context []sharedtypes.ACSSearchResponse) (finalQuery string, errorResponse string, displayFixedMessageToUser bool) { - logging.Log.Debugf(internalstates.Ctx, "Building final query for Ansys GPT with context of length: %v", len(context)) - - // check if there is no context - if len(context) == 0 { - errorResponse = "Sorry, I could not find any knowledge from Ansys that can answer your question. Please try and revise your query by asking in a different way or adding more details." - return "", errorResponse, true - } - - // Build the final query using the KnowledgeDB response and the original request - finalQuery = "Based on the following examples:\n\n--- INFO START ---\n" - for _, example := range context { - finalQuery += fmt.Sprintf("%v", example) + "\n" - } - finalQuery += "--- INFO END ---\n\n" + refrasedQuery + "\n" - - return finalQuery, "", false -} - -// AnsysGPTPerformLLMRequest performs a request to Ansys GPT -// -// Parameters: -// - finalQuery: the final query -// - history: the conversation history -// - systemPrompt: the system prompt -// -// Returns: -// - stream: the stream channel -func AnsysGPTPerformLLMRequest(finalQuery string, history []sharedtypes.HistoricMessage, systemPrompt string, isStream bool) (message string, stream *chan string) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send chat request - responseChannel := sendChatRequest(finalQuery, "general", history, 0, systemPrompt, llmHandlerEndpoint, nil) - - // If isStream is true, create a stream channel and return asap - if isStream { - // Create a stream channel - streamChannel := make(chan string, 400) - - // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false) - - // Return the stream channel - return "", &streamChannel - } - - // else Process all responses - var responseAsStr string - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Accumulate the responses - responseAsStr += *(response.ChatData) - - // If we are at the last message, break the loop - if *(response.IsLast) { - break - } - } - - // Close the response channel - close(responseChannel) - - // Return the response - return responseAsStr, nil -} - -// AnsysGPTReturnIndexList returns the index list for Ansys GPT -// -// Parameters: -// - indexGroups: the index groups -// -// Returns: -// - indexList: the index list -func AnsysGPTReturnIndexList(indexGroups []string) (indexList []string) { - indexList = make([]string, 0) - // iterate through indexGroups and append to indexList - for _, indexGroup := range indexGroups { - switch indexGroup { - case "Ansys Learning": - indexList = append(indexList, "granular-ansysgpt") - indexList = append(indexList, "ansysgpt-alh") - case "Ansys Products": - indexList = append(indexList, "lsdyna-documentation-r14") - indexList = append(indexList, "ansysgpt-documentation-2023r2") - indexList = append(indexList, "scade-documentation-2023r2") - indexList = append(indexList, "ansys-dot-com-marketing") - indexList = append(indexList, "ibp-app-brief") - // indexList = append(indexList, "pyansys_help_documentation") - // indexList = append(indexList, "pyansys-examples") - case "Ansys Semiconductor": - indexList = append(indexList, "ansysgpt-scbu") - default: - logging.Log.Warnf(internalstates.Ctx, "Invalid indexGroup: %v\n", indexGroup) - return - } - } - - return indexList -} - -// AnsysGPTACSSemanticHybridSearchs performs a semantic hybrid search in ACS -// -// Parameters: -// - query: the query string -// - embeddedQuery: the embedded query -// - indexList: the index list -// - typeOfAsset: the type of asset -// - physics: the physics -// - product: the product -// - productMain: the main product -// - filter: the filter -// - filterAfterVectorSearch: the flag to define the filter order -// - returnedProperties: the properties to be returned -// - topK: the number of results to be returned from vector search -// - searchedEmbeddedFields: the ACS fields to be searched -// -// Returns: -// - output: the search results -func AnsysGPTACSSemanticHybridSearchs( - query string, - embeddedQuery []float32, - indexList []string, - filter map[string]string, - topK int) (output []sharedtypes.ACSSearchResponse) { - - output = make([]sharedtypes.ACSSearchResponse, 0) - for _, indexName := range indexList { - partOutput := ansysGPTACSSemanticHybridSearch(query, embeddedQuery, indexName, filter, topK) - output = append(output, partOutput...) - } - - return output -} - -// AnsysGPTRemoveNoneCitationsFromSearchResponse removes none citations from search response -// -// Parameters: -// - semanticSearchOutput: the search response -// - citations: the citations -// -// Returns: -// - reducedSemanticSearchOutput: the reduced search response -func AnsysGPTRemoveNoneCitationsFromSearchResponse(semanticSearchOutput []sharedtypes.ACSSearchResponse, citations []sharedtypes.AnsysGPTCitation) (reducedSemanticSearchOutput []sharedtypes.ACSSearchResponse) { - // iterate throught search response and keep matches to citations - reducedSemanticSearchOutput = make([]sharedtypes.ACSSearchResponse, len(citations)) - for _, value := range semanticSearchOutput { - for _, citation := range citations { - if value.SourceURLLvl2 == citation.Title { - reducedSemanticSearchOutput = append(reducedSemanticSearchOutput, value) - } else if value.SourceURLLvl2 == citation.URL { - reducedSemanticSearchOutput = append(reducedSemanticSearchOutput, value) - } else if value.SearchRerankerScore == citation.Relevance { - reducedSemanticSearchOutput = append(reducedSemanticSearchOutput, value) - } - } - } - - return reducedSemanticSearchOutput -} - -// AnsysGPTReorderSearchResponseAndReturnOnlyTopK reorders the search response -// -// Parameters: -// - semanticSearchOutput: the search response -// - topK: the number of results to be returned -// -// Returns: -// - reorderedSemanticSearchOutput: the reordered search response -func AnsysGPTReorderSearchResponseAndReturnOnlyTopK(semanticSearchOutput []sharedtypes.ACSSearchResponse, topK int) (reorderedSemanticSearchOutput []sharedtypes.ACSSearchResponse) { - logging.Log.Debugf(internalstates.Ctx, "Reordering search response of length %v based on reranker_score and returning only top %v results", len(semanticSearchOutput), topK) - // Sorting by Weight * SearchRerankerScore in descending order - sort.Slice(semanticSearchOutput, func(i, j int) bool { - return semanticSearchOutput[i].Weight*semanticSearchOutput[i].SearchRerankerScore > semanticSearchOutput[j].Weight*semanticSearchOutput[j].SearchRerankerScore - }) - - // Return only topK results - if len(semanticSearchOutput) > topK { - return semanticSearchOutput[:topK] - } - - return semanticSearchOutput -} - -// AnsysGPTGetSystemPrompt returns the system prompt for Ansys GPT -// -// Returns: -// - systemPrompt: the system prompt -func AnsysGPTGetSystemPrompt(rephrasedQuery string) string { - return `Orders: You are AnsysGPT, a technical support assistant that is professional, friendly and multilingual that generates a clear and concise answer to the user question adhering to these strict guidelines: \n - You must always answer user queries using the provided 'context' and 'chat_history' only. If you cannot find an answer in the 'context' or the 'chat_history', never use your base knowledge to generate a response. \n - - You are a multilingual expert that will *always reply the user in the same language as that of their 'query' in ` + rephrasedQuery + `*. If the 'query' is in Japanese, your response must be in Japanese. If the 'query' is in Cantonese, your response must be in Cantonese. If the 'query' is in English, your response must be in English. You *must always* be consistent in your multilingual ability. \n - - You have the capability to learn or *remember information from past three interactions* with the user. \n - - You are a smart Technical support assistant that can distingush between a fresh independent query and a follow-up query based on 'chat_history'. \n - - If you find the user's 'query' to be a follow-up question, consider the 'chat_history' while generating responses. Use the information from the 'chat_history' to provide contextually relevant responses. When answering follow-up questions that can be answered using the 'chat_history' alone, do not provide any references. \n - - *Always* your answer must include the 'content', 'sourceURL_lvl3' of all the chunks in 'context' that are relevant to the user's query in 'query'. But, never cite 'sourceURL_lvl3' under the heading 'References'. \n - - The 'content' and 'sourceURL_lvl3' must be included together in your answer, with the 'sourceTitle_lvl2', 'sourceURL_lvl2' and '@search.reranker_score' serving as a citation for the 'content'. Include 'sourceURL_lvl3' directly in the answer in-line with the source, not in the references section. \n - - In your response follow a style of citation where each source is assigned a number, for example '[1]', that corresponds to the 'sourceURL_lvl3', 'sourceTitle_lvl2' and 'sourceURL_lvl2' in the 'context'. \n - - Make sure you always provide 'URL: Extract the value of 'sourceURL_lvl3'' in line with every source in your answer. For example 'You will learn to find the total drag and lift on a solar car in Ansys Fluent in this course. URL: [1] https://courses.ansys.com/index.php/courses/aerodynamics-of-a-solar-car/'. \n - - Never mention the position of chunk in your response for example 'chunk 1 / chunk 4'/ first chunk / third chunk'. \n - - **Always** aim to make your responses conversational and engaging, while still providing accurate and helpful information. \n - - If the user greets you, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n - - If the user acknowledges you, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n - - If the user asks about your purpose, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n - - If the user asks who are you?, you must *always* reply them in a polite and friendly manner. You *must never* reply "I'm sorry, could you please provide more details or ask a different question?" in this case. \n - - When providing information from a source, try to introduce it in a *conversational manner*. For example, instead of saying 'In the chunk titled...', you could say 'I found a great resource titled... that explains...'. \n - - If a chunk has empty fields in it's 'sourceTitle_lvl2' and 'sourceURL_lvl2', you *must never* cite that chunk under references in your response. \n - - You must never provide JSON format in your answer and never cite references in JSON format.\n - - Strictly provide your response everytime in the below format: - - Your answer - Always provide 'URL: Extract the value of 'sourceURL_lvl3'' *inline right next to each source* and *not at the end of your answer*. - References: - [1] Title: Extract the value of 'sourceTitle_lvl2', URL: Extract the value of 'sourceURL_lvl2', Relevance: Extract the value of '@search.reranker_score' /4.0. - *Always* provide References for all the chunks in 'context'. - Do not provide 'sourceTitle_lvl3' in your response. - When answering follow-up questions that can be answered using the 'chat_history' alone, *do not provide any references*. - **Never** cite chunk that has empty fields in it's 'sourceTitle_lvl2' and 'sourceURL_lvl2' under References. - **Never** provide the JSON format in your response and References. - - Only provide a reference if it was found in the "context". Under no circumstances should you create your own references from your base knowledge or the internet. \n - - Here's an example of how you should structure your response: \n - - Designing an antenna involves several steps, and Ansys provides a variety of tools to assist you in this process. \n - The Ansys HFSS Antenna Toolkit, for instance, can automatically create the geometry of your antenna design with boundaries and excitations assigned. It also sets up the solution and generates post-processing reports for several popular antenna elements. Over 60 standard antenna topologies are available in the toolkit, and all the antenna models generated are ready to simulate. You can run a quick analysis of any antenna of your choosing [1]. URL: [1] https://www.youtube.com/embed/mhM6U2xn0Q0?start=25&end=123 \n - In another example, a rectangular edge fed patch antenna is created using the HFSS antenna toolkit. The antenna is synthesized for 3.5 GHz and the geometry model is already created for you. After analyzing the model, you can view the results generated from the toolkit. The goal is to fold or bend the antenna so that it fits onto the sidewall of a smartphone. After folding the antenna and reanalyzing, you can view the results such as return loss, input impedance, and total radiated power of the antenna [2]. URL: [2] https://www.youtube.com/embed/h0QttEmQ88E?start=94&end=186 \n - Lastly, Ansys Electronics Desktop integrates rigorous electromagnetic analysis with system and circuit simulation in a comprehensive, easy-to-use design platform. This platform is used to automatically create antenna geometries with materials, boundaries, excitations, solution setups, and post-processing reports [3]. URL: [3] https://ansyskm.ansys.com/forums/topic/ansys-hfss-antenna-synthesis-from-hfss-antenna-toolkit-part-2/ \n - I hope this helps you in your antenna design process. If you have any more questions, feel free to ask! \n - References: - [1] Title: "ANSYS HFSS: Antenna Synthesis from HFSS Antenna Toolkit - Part 2", URL: https://ansyskm.ansys.com/forums/topic/ansys-hfss-antenna-synthesis-from-hfss-antenna-toolkit-part-2/, Relevance: 3.53/4.0 - [2] Title: "Cosimulation Using Ansys HFSS and Circuit - Lesson 2 - ANSYS Innovation Courses", URL: https://courses.ansys.com/index.php/courses/cosimulation-using-ansys-hfss/lessons/cosimulation-using-ansys-hfss-and-circuit-lesson-2/, Relevance: 2.54/4.0` -} - -// SendAPICall sends an API call to the specified URL with the specified headers and query parameters. -// -// Parameters: -// - requestType: the type of the request (GET, POST, PUT, PATCH, DELETE) -// - urlString: the URL to send the request to -// - headers: the headers to include in the request -// - query: the query parameters to include in the request -// - jsonBody: the body of the request as a JSON string -// -// Returns: -// - success: a boolean indicating whether the request was successful -// - returnJsonBody: the JSON body of the response as a string -func SendRestAPICall(requestType string, endpoint string, header map[string]string, query map[string]string, jsonBody string) (success bool, returnJsonBody string) { - // verify correct request type - if requestType != "GET" && requestType != "POST" && requestType != "PUT" && requestType != "PATCH" && requestType != "DELETE" { - panic(fmt.Sprintf("Invalid request type: %v", requestType)) - } - - // Parse the URL and add query parameters - parsedURL, err := url.Parse(endpoint) - if err != nil { - panic(fmt.Sprintf("Error parsing URL: %v", err)) - } - - q := parsedURL.Query() - for key, value := range query { - q.Add(key, value) - } - parsedURL.RawQuery = q.Encode() - - // Create the HTTP request - var req *http.Request - if jsonBody != "" { - req, err = http.NewRequest(requestType, parsedURL.String(), bytes.NewBuffer([]byte(jsonBody))) - } else { - req, err = http.NewRequest(requestType, parsedURL.String(), nil) - } - if err != nil { - panic(fmt.Sprintf("Error creating request: %v", err)) - } - - // Add headers - for key, value := range header { - req.Header.Add(key, value) - } - - // Execute the request - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - panic(fmt.Sprintf("Error executing request: %v", err)) - } - defer resp.Body.Close() - - // Read the response body - body, err := io.ReadAll(resp.Body) - if err != nil { - panic(fmt.Sprintf("Error reading response body: %v", err)) - } - - // Check if the response code is successful (2xx) - success = resp.StatusCode >= 200 && resp.StatusCode < 300 - - return success, string(body) -} - -// PerformGeneralRequestSpecificModel performs a general request to LLM with a specific model -// -// Parameters: -// - input: the user input -// - history: the conversation history -// - isStream: the flag to indicate whether the response should be streamed -// - systemPrompt: the system prompt -// - modelId: the model ID -// -// Returns: -// - message: the response message -// - stream: the stream channel -func PerformGeneralRequestSpecificModel(input string, history []sharedtypes.HistoricMessage, isStream bool, systemPrompt string, modelIds []string) (message string, stream *chan string) { - // get the LLM handler endpoint - llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT - - // Set up WebSocket connection with LLM and send chat request - responseChannel := sendChatRequest(input, "general", history, 0, systemPrompt, llmHandlerEndpoint, modelIds) - - // If isStream is true, create a stream channel and return asap - if isStream { - // Create a stream channel - streamChannel := make(chan string, 400) - - // Start a goroutine to transfer the data from the response channel to the stream channel - go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false) - - // Return the stream channel - return "", &streamChannel - } - - // else Process all responses - var responseAsStr string - for response := range responseChannel { - // Check if the response is an error - if response.Type == "error" { - panic(response.Error) - } - - // Accumulate the responses - responseAsStr += *(response.ChatData) - - // If we are at the last message, break the loop - if *(response.IsLast) { - break - } - } - - // Close the response channel - close(responseChannel) - - // Return the response - return responseAsStr, nil -} -// AssignStringToString assigns a string to another string -// -// Parameters: -// - inputString: the input string -// -// Returns: -// - outputString: the output string -func AssignStringToString(inputString string) (outputString string) { - return inputString + // data extraction + "DataExtractionGetGithubFilesToExtract": DataExtractionGetGithubFilesToExtract, + "DataExtractionGetLocalFilesToExtract": DataExtractionGetLocalFilesToExtract, + "DataExtractionAppendStringSlices": DataExtractionAppendStringSlices, + "DataExtractionDownloadGithubFileContent": DataExtractionDownloadGithubFileContent, + "DataExtractionGetLocalFileContent": DataExtractionGetLocalFileContent, + "DataExtractionGetDocumentType": DataExtractionGetDocumentType, + "DataExtractionLangchainSplitter": DataExtractionLangchainSplitter, + "DataExtractionGenerateDocumentTree": DataExtractionGenerateDocumentTree, + "DataExtractionAddDataRequest": DataExtractionAddDataRequest, + "DataExtractionCreateCollectionRequest": DataExtractionCreateCollectionRequest, + + // generic + "AssignStringToString": AssignStringToString, + "SendRestAPICall": SendRestAPICall, } diff --git a/pkg/externalfunctions/generic.go b/pkg/externalfunctions/generic.go new file mode 100644 index 0000000..74c10a5 --- /dev/null +++ b/pkg/externalfunctions/generic.go @@ -0,0 +1,92 @@ +package externalfunctions + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/url" +) + +// SendAPICall sends an API call to the specified URL with the specified headers and query parameters. +// +// Tags: +// - @displayName: REST Call +// +// Parameters: +// - requestType: the type of the request (GET, POST, PUT, PATCH, DELETE) +// - urlString: the URL to send the request to +// - headers: the headers to include in the request +// - query: the query parameters to include in the request +// - jsonBody: the body of the request as a JSON string +// +// Returns: +// - success: a boolean indicating whether the request was successful +// - returnJsonBody: the JSON body of the response as a string +func SendRestAPICall(requestType string, endpoint string, header map[string]string, query map[string]string, jsonBody string) (success bool, returnJsonBody string) { + // verify correct request type + if requestType != "GET" && requestType != "POST" && requestType != "PUT" && requestType != "PATCH" && requestType != "DELETE" { + panic(fmt.Sprintf("Invalid request type: %v", requestType)) + } + + // Parse the URL and add query parameters + parsedURL, err := url.Parse(endpoint) + if err != nil { + panic(fmt.Sprintf("Error parsing URL: %v", err)) + } + + q := parsedURL.Query() + for key, value := range query { + q.Add(key, value) + } + parsedURL.RawQuery = q.Encode() + + // Create the HTTP request + var req *http.Request + if jsonBody != "" { + req, err = http.NewRequest(requestType, parsedURL.String(), bytes.NewBuffer([]byte(jsonBody))) + } else { + req, err = http.NewRequest(requestType, parsedURL.String(), nil) + } + if err != nil { + panic(fmt.Sprintf("Error creating request: %v", err)) + } + + // Add headers + for key, value := range header { + req.Header.Add(key, value) + } + + // Execute the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + panic(fmt.Sprintf("Error executing request: %v", err)) + } + defer resp.Body.Close() + + // Read the response body + body, err := io.ReadAll(resp.Body) + if err != nil { + panic(fmt.Sprintf("Error reading response body: %v", err)) + } + + // Check if the response code is successful (2xx) + success = resp.StatusCode >= 200 && resp.StatusCode < 300 + + return success, string(body) +} + +// AssignStringToString assigns a string to another string +// +// Tags: +// - @displayName: Assign String to String +// +// Parameters: +// - inputString: the input string +// +// Returns: +// - outputString: the output string +func AssignStringToString(inputString string) (outputString string) { + return inputString +} diff --git a/pkg/externalfunctions/knowledgedb.go b/pkg/externalfunctions/knowledgedb.go new file mode 100644 index 0000000..b57afd5 --- /dev/null +++ b/pkg/externalfunctions/knowledgedb.go @@ -0,0 +1,774 @@ +package externalfunctions + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/ansys/allie-flowkit/pkg/internalstates" + "github.com/ansys/allie-sharedtypes/pkg/config" + "github.com/ansys/allie-sharedtypes/pkg/logging" + "github.com/ansys/allie-sharedtypes/pkg/sharedtypes" +) + +// SendVectorsToKnowledgeDB sends the given vector to the KnowledgeDB and +// returns the most relevant data. The number of results is specified in the +// config file. The keywords are used to filter the results. The min score +// filter is also specified in the config file. If it is not specified, the +// default value is used. +// +// The function returns the most relevant data. +// +// Tags: +// - @displayName: Similarity Search +// +// Parameters: +// - vector: the vector to be sent to the KnowledgeDB +// - keywords: the keywords to be used to filter the results +// - keywordsSearch: the flag to enable the keywords search +// - collection: the collection name +// - similaritySearchResults: the number of results to be returned +// - similaritySearchMinScore: the minimum score for the results +// +// Returns: +// - databaseResponse: an array of the most relevant data +func SendVectorsToKnowledgeDB(vector []float32, keywords []string, keywordsSearch bool, collection string, similaritySearchResults int, similaritySearchMinScore float64) (databaseResponse []sharedtypes.DbResponse) { + // get the KnowledgeDB endpoint + knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT + + // Log the request + logging.Log.Debugf(internalstates.Ctx, "Connecting to the KnowledgeDB.") + + // Build filters + var filters sharedtypes.DbFilters + + // -- Add the keywords filter if needed + if keywordsSearch { + filters.KeywordsFilter = sharedtypes.DbArrayFilter{ + NeedAll: false, + FilterData: keywords, + } + } + + // -- Add the level filter + filters.LevelFilter = []string{"leaf"} + + // Create a new resource instance + requestInput := similaritySearchInput{ + CollectionName: collection, + EmbeddedVector: vector, + MaxRetrievalCount: similaritySearchResults, + Filters: filters, + MinScore: similaritySearchMinScore, + OutputFields: []string{ + "guid", + "document_id", + "document_name", + "summary", + "keywords", + "text", + }, + } + + // Convert the resource instance to JSON. + jsonData, err := json.Marshal(requestInput) + if err != nil { + errMessage := fmt.Sprintf("Error marshalling JSON data of POST /similarity_search request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Specify the target endpoint. + requestURL := knowledgeDbEndpoint + "/similarity_search" + + // Create a new HTTP request with the JSON data. + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + errMessage := fmt.Sprintf("Error creating POST /similarity_search request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Set the appropriate content type for the request. + req.Header.Set("Content-Type", "application/json") + + // Send the HTTP request using the default HTTP client. + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + errMessage := fmt.Sprintf("Error sending POST /similarity_search request to allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + defer resp.Body.Close() + + // Read and display the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + errMessage := fmt.Sprintf("Error reading response body of POST /similarity_search request from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Log the similarity search response + logging.Log.Debugf(internalstates.Ctx, "Knowledge DB response: %v", string(body)) + logging.Log.Debugf(internalstates.Ctx, "Knowledge DB response received!") + + // Unmarshal the response body to the appropriate struct. + var response similaritySearchOutput + err = json.Unmarshal(body, &response) + if err != nil { + errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /similarity_search response from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + var mostRelevantData []sharedtypes.DbResponse + var count int = 1 + for _, element := range response.SimilarityResult { + // Log the result + logging.Log.Debugf(internalstates.Ctx, "Result #%d:", count) + logging.Log.Debugf(internalstates.Ctx, "Similarity score: %v", element.Score) + logging.Log.Debugf(internalstates.Ctx, "Similarity file id: %v", element.Data.DocumentId) + logging.Log.Debugf(internalstates.Ctx, "Similarity file name: %v", element.Data.DocumentName) + logging.Log.Debugf(internalstates.Ctx, "Similarity summary: %v", element.Data.Summary) + + // Add the result to the list + mostRelevantData = append(mostRelevantData, element.Data) + + // Check whether we have enough results + if count >= similaritySearchResults { + break + } else { + count++ + } + } + + // Return the most relevant data + return mostRelevantData +} + +// GetListCollections retrieves the list of collections from the KnowledgeDB. +// +// Tags: +// - @displayName: List Collections +// +// The function returns the list of collections. +// +// Parameters: +// - knowledgeDbEndpoint: the KnowledgeDB endpoint +// +// Returns: +// - collectionsList: the list of collections +func GetListCollections() (collectionsList []string) { + // get the KnowledgeDB endpoint + knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT + + // Specify the target endpoint. + requestURL := knowledgeDbEndpoint + "/list_collections" + + // Create a new HTTP request with the JSON data. + req, err := http.NewRequest("GET", requestURL, nil) + if err != nil { + errMessage := fmt.Sprintf("Error creating GET /list_collections request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Set the appropriate content type for the request. + req.Header.Set("Content-Type", "application/json") + + // Send the HTTP request using the default HTTP client. + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + errMessage := fmt.Sprintf("Error sending GET /list_collections request to allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + defer resp.Body.Close() + + // Read and display the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + errMessage := fmt.Sprintf("Error reading response body of GET /list_collections request from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Unmarshal the response body to the appropriate struct. + var response sharedtypes.DBListCollectionsOutput + err = json.Unmarshal(body, &response) + if err != nil { + errMessage := fmt.Sprintf("Error unmarshalling JSON data of GET /list_collections response from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Log the result and return the list of collections + if !response.Success { + errMessage := "Failed to retrieve list of collections from allie-db" + logging.Log.Warn(internalstates.Ctx, errMessage) + panic(errMessage) + } else { + logging.Log.Debugf(internalstates.Ctx, "List collections response received!") + logging.Log.Debugf(internalstates.Ctx, "Collections: %v", response.Collections) + return response.Collections + } +} + +// RetrieveDependencies retrieves the dependencies of the specified source node. +// +// The function returns the list of dependencies. +// +// Tags: +// - @displayName: Retrieve Dependencies +// +// Parameters: +// - collectionName: the name of the collection to which the data objects will be added. +// - relationshipName: the name of the relationship to retrieve dependencies for. +// - relationshipDirection: the direction of the relationship to retrieve dependencies for. +// - sourceDocumentId: the document ID of the source node. +// - nodeTypesFilter: filter based on node types. +// - maxHopsNumber: maximum number of hops to traverse. +// +// Returns: +// - dependenciesIds: the list of dependencies +func RetrieveDependencies( + collectionName string, + relationshipName string, + relationshipDirection string, + sourceDocumentId string, + nodeTypesFilter sharedtypes.DbArrayFilter, + maxHopsNumber int) (dependenciesIds []string) { + // get the KnowledgeDB endpoint + knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT + + // Create the URL + requestURL := knowledgeDbEndpoint + "/retrieve_dependencies" + + // Create the retrieveDependenciesInput object + requestInput := retrieveDependenciesInput{ + CollectionName: collectionName, + RelationshipName: relationshipName, + RelationshipDirection: relationshipDirection, + SourceDocumentId: sourceDocumentId, + NodeTypesFilter: nodeTypesFilter, + MaxHopsNumber: maxHopsNumber, + } + + // Convert the resource instance to JSON. + jsonData, err := json.Marshal(requestInput) + if err != nil { + errMessage := fmt.Sprintf("Error marshalling JSON data of POST /retrieve_dependencies request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Create a new HTTP request with the JSON data. + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + errMessage := fmt.Sprintf("Error creating POST /retrieve_dependencies request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Set the appropriate content type for the request. + req.Header.Set("Content-Type", "application/json") + + // Send the HTTP request using the default HTTP client. + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + errMessage := fmt.Sprintf("Error sending POST /retrieve_dependencies request to allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + defer resp.Body.Close() + + // Read and display the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + errMessage := fmt.Sprintf("Error reading response body of POST /retrieve_dependencies request from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + logging.Log.Debugf(internalstates.Ctx, "Knowledge DB RetrieveDependencies response received!") + + // Unmarshal the response body to the appropriate struct. + var response retrieveDependenciesOutput + err = json.Unmarshal(body, &response) + if err != nil { + errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /retrieve_dependencies response from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + return response.DependenciesIds +} + +// GeneralNeo4jQuery executes the given Neo4j query and returns the response. +// +// The function returns the neo4j response. +// +// Tags: +// - @displayName: General Neo4J Query +// +// Parameters: +// - query: the Neo4j query to be executed. +// +// Returns: +// - databaseResponse: the Neo4j response +func GeneralNeo4jQuery(query string) (databaseResponse sharedtypes.Neo4jResponse) { + // get the KnowledgeDB endpoint + knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT + + // Create the URL + requestURL := knowledgeDbEndpoint + "/general_neo4j_query" + + // Create the retrieveDependenciesInput object + requestInput := sharedtypes.GeneralNeo4jQueryInput{ + Query: query, + } + + // Convert the resource instance to JSON. + jsonData, err := json.Marshal(requestInput) + if err != nil { + errMessage := fmt.Sprintf("Error marshalling JSON data of POST /general_neo4j_query request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Create a new HTTP request with the JSON data. + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + errMessage := fmt.Sprintf("Error creating POST /general_neo4j_query request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Set the appropriate content type for the request. + req.Header.Set("Content-Type", "application/json") + + // Send the HTTP request using the default HTTP client. + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + errMessage := fmt.Sprintf("Error sending POST /general_neo4j_query request to allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + defer resp.Body.Close() + + // Read and display the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + errMessage := fmt.Sprintf("Error reading response body of POST /general_neo4j_query request from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + logging.Log.Debugf(internalstates.Ctx, "Knowledge DB GeneralNeo4jQuery response received!") + + // Unmarshal the response body to the appropriate struct. + var response sharedtypes.GeneralNeo4jQueryOutput + err = json.Unmarshal(body, &response) + if err != nil { + errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /general_neo4j_query response from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + return response.Response +} + +// GeneralQuery performs a general query in the KnowledgeDB. +// +// The function returns the query results. +// +// Tags: +// - @displayName: Query +// +// Parameters: +// - collectionName: the name of the collection to which the data objects will be added. +// - maxRetrievalCount: the maximum number of results to be retrieved. +// - outputFields: the fields to be included in the output. +// - filters: the filter for the query. +// +// Returns: +// - databaseResponse: the query results +func GeneralQuery(collectionName string, maxRetrievalCount int, outputFields []string, filters sharedtypes.DbFilters) (databaseResponse []sharedtypes.DbResponse) { + // get the KnowledgeDB endpoint + knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT + + // Create the URL + requestURL := knowledgeDbEndpoint + "/query" + + // Create the queryInput object + requestInput := queryInput{ + CollectionName: collectionName, + MaxRetrievalCount: maxRetrievalCount, + OutputFields: outputFields, + Filters: filters, + } + + // Convert the resource instance to JSON. + jsonData, err := json.Marshal(requestInput) + if err != nil { + errMessage := fmt.Sprintf("Error marshalling JSON data of POST /query request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Create a new HTTP request with the JSON data. + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + errMessage := fmt.Sprintf("Error creating POST /query request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Set the appropriate content type for the request. + req.Header.Set("Content-Type", "application/json") + + // Send the HTTP request using the default HTTP client. + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + errMessage := fmt.Sprintf("Error sending POST /query request to allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + defer resp.Body.Close() + + // Read and display the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + errMessage := fmt.Sprintf("Error reading response body of POST /query request from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + logging.Log.Debugf(internalstates.Ctx, "Knowledge DB GeneralQuery response received!") + + // Unmarshal the response body to the appropriate struct. + var response queryOutput + err = json.Unmarshal(body, &response) + if err != nil { + errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /query response from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + return response.QueryResult +} + +// SimilaritySearch performs a similarity search in the KnowledgeDB. +// +// The function returns the similarity search results. +// +// Tags: +// - @displayName: Similarity Search (Filtered) +// +// Parameters: +// - collectionName: the name of the collection to which the data objects will be added. +// - embeddedVector: the embedded vector used for searching. +// - maxRetrievalCount: the maximum number of results to be retrieved. +// - outputFields: the fields to be included in the output. +// - filters: the filter for the query. +// - minScore: the minimum score filter. +// - getLeafNodes: flag to indicate whether to retrieve all the leaf nodes in the result node branch. +// - getSiblings: flag to indicate whether to retrieve the previous and next node to the result nodes. +// - getParent: flag to indicate whether to retrieve the parent object. +// - getChildren: flag to indicate whether to retrieve the children objects. +// +// Returns: +// - databaseResponse: the similarity search results +func SimilaritySearch( + collectionName string, + embeddedVector []float32, + maxRetrievalCount int, + outputFields []string, + filters sharedtypes.DbFilters, + minScore float64, + getLeafNodes bool, + getSiblings bool, + getParent bool, + getChildren bool) (databaseResponse []sharedtypes.DbResponse) { + // get the KnowledgeDB endpoint + knowledgeDbEndpoint := config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT + + // Create the URL + requestURL := knowledgeDbEndpoint + "/similarity_search" + + // Create the retrieveDependenciesInput object + requestInput := similaritySearchInput{ + CollectionName: collectionName, + EmbeddedVector: embeddedVector, + MaxRetrievalCount: maxRetrievalCount, + OutputFields: outputFields, + Filters: filters, + MinScore: minScore, + GetLeafNodes: getLeafNodes, + GetSiblings: getSiblings, + GetParent: getParent, + GetChildren: getChildren, + } + + // Convert the resource instance to JSON. + jsonData, err := json.Marshal(requestInput) + if err != nil { + errMessage := fmt.Sprintf("Error marshalling JSON data of POST /similarity_search request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Create a new HTTP request with the JSON data. + req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) + if err != nil { + errMessage := fmt.Sprintf("Error creating POST /similarity_search request for allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Set the appropriate content type for the request. + req.Header.Set("Content-Type", "application/json") + + // Send the HTTP request using the default HTTP client. + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + errMessage := fmt.Sprintf("Error sending POST /similarity_search request to allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + defer resp.Body.Close() + + // Read and display the response body. + body, err := io.ReadAll(resp.Body) + if err != nil { + errMessage := fmt.Sprintf("Error reading response body of POST /similarity_search request from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + logging.Log.Debugf(internalstates.Ctx, "Knowledge DB SimilaritySearch response received!") + + // Unmarshal the response body to the appropriate struct. + var response similaritySearchOutput + err = json.Unmarshal(body, &response) + if err != nil { + errMessage := fmt.Sprintf("Error unmarshalling JSON data of POST /similarity_search response from allie-db: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + var similarityResults []sharedtypes.DbResponse + for _, element := range response.SimilarityResult { + similarityResults = append(similarityResults, element.Data) + } + + return similarityResults +} + +// CreateKeywordsDbFilter creates a keywords filter for the KnowledgeDB. +// +// The function returns the keywords filter. +// +// Tags: +// - @displayName: Keywords Filter +// +// Parameters: +// - keywords: the keywords to be used for the filter +// - needAll: flag to indicate whether all keywords are needed +// +// Returns: +// - databaseFilter: the keywords filter +func CreateKeywordsDbFilter(keywords []string, needAll bool) (databaseFilter sharedtypes.DbArrayFilter) { + var keywordsFilters sharedtypes.DbArrayFilter + + // -- Add the keywords filter if needed + if len(keywords) > 0 { + keywordsFilters = createDbArrayFilter(keywords, needAll) + } + + return keywordsFilters +} + +// CreateTagsDbFilter creates a tags filter for the KnowledgeDB. +// +// The function returns the tags filter. +// +// Tags: +// - @displayName: Tags Filter +// +// Parameters: +// - tags: the tags to be used for the filter +// - needAll: flag to indicate whether all tags are needed +// +// Returns: +// - databaseFilter: the tags filter +func CreateTagsDbFilter(tags []string, needAll bool) (databaseFilter sharedtypes.DbArrayFilter) { + var tagsFilters sharedtypes.DbArrayFilter + + // -- Add the tags filter if needed + if len(tags) > 0 { + tagsFilters = createDbArrayFilter(tags, needAll) + } + + return tagsFilters +} + +// CreateMetadataDbFilter creates a metadata filter for the KnowledgeDB. +// +// The function returns the metadata filter. +// +// Tags: +// - @displayName: Metadata Filter +// +// Parameters: +// - fieldName: the name of the field +// - fieldType: the type of the field +// - filterData: the filter data +// - needAll: flag to indicate whether all data is needed +// +// Returns: +// - databaseFilter: the metadata filter +func CreateMetadataDbFilter(fieldName string, fieldType string, filterData []string, needAll bool) (databaseFilter sharedtypes.DbJsonFilter) { + return createDbJsonFilter(fieldName, fieldType, filterData, needAll) +} + +// CreateDbFilter creates a filter for the KnowledgeDB. +// +// The function returns the filter. +// +// Tags: +// - @displayName: Create Filter +// +// Parameters: +// - guid: the guid filter +// - documentId: the document ID filter +// - documentName: the document name filter +// - level: the level filter +// - tags: the tags filter +// - keywords: the keywords filter +// - metadata: the metadata filter +// +// Returns: +// - databaseFilter: the filter +func CreateDbFilter( + guid []string, + documentId []string, + documentName []string, + level []string, + tags sharedtypes.DbArrayFilter, + keywords sharedtypes.DbArrayFilter, + metadata []sharedtypes.DbJsonFilter) (databaseFilter sharedtypes.DbFilters) { + var filters sharedtypes.DbFilters + + // -- Add the guid filter if needed + if len(guid) > 0 { + filters.GuidFilter = guid + } + + // -- Add the document ID filter if needed + if len(documentId) > 0 { + filters.DocumentIdFilter = documentId + } + + // -- Add the document name filter if needed + if len(documentName) > 0 { + filters.DocumentNameFilter = documentName + } + + // -- Add the level filter if needed + if len(level) > 0 { + filters.LevelFilter = level + } + + // -- Add the tags filter if needed + if len(tags.FilterData) > 0 { + filters.TagsFilter = tags + } + + // -- Add the keywords filter if needed + if len(keywords.FilterData) > 0 { + filters.KeywordsFilter = keywords + } + + // -- Add the metadata filter if needed + if len(metadata) > 0 { + filters.MetadataFilter = metadata + } + + return filters +} + +// DataExtractionAddDataRequest sends a request to the add_data endpoint. +// +// Tags: +// - @displayName: Add Data +// +// Parameters: +// - collectionName: name of the collection the request is sent to. +// - data: the data to add. +func DataExtractionAddDataRequest(collectionName string, documentData []sharedtypes.DbData) { + // Create the AddDataInput object + requestObject := sharedtypes.DbAddDataInput{ + CollectionName: collectionName, + Data: documentData, + } + + // Create the URL + url := fmt.Sprintf("%s/%s", config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT, "add_data") + + // Send the HTTP POST request + var response sharedtypes.DbAddDataOutput + err, _ := createPayloadAndSendHttpRequest(url, requestObject, &response) + if err != nil { + errorMessage := fmt.Sprintf("Error sending request to add_data endpoint: %v", err) + logging.Log.Error(internalstates.Ctx, errorMessage) + panic(errorMessage) + } + + logging.Log.Debugf(internalstates.Ctx, "Added data to collection: %s \n", collectionName) + + return +} + +// DataExtractionCreateCollectionRequest sends a request to the collection endpoint. +// +// Tags: +// - @displayName: Create Collection +// +// Parameters: +// - collectionName: the name of the collection to create. +func DataExtractionCreateCollectionRequest(collectionName string) { + // Create the CreateCollectionInput object + requestObject := sharedtypes.DbCreateCollectionInput{ + CollectionName: collectionName, + } + + // Create the URL + url := fmt.Sprintf("%s/%s", config.GlobalConfig.KNOWLEDGE_DB_ENDPOINT, "create_collection") + + // Send the HTTP POST request + var response sharedtypes.DbCreateCollectionOutput + err, statusCode := createPayloadAndSendHttpRequest(url, requestObject, &response) + if err != nil { + if statusCode == 409 { + logging.Log.Warn(internalstates.Ctx, "Collection already exists") + } else { + errorMessage := fmt.Sprintf("Error sending request to create_collection endpoint: %v", err) + logging.Log.Error(internalstates.Ctx, errorMessage) + panic(errorMessage) + } + } + + logging.Log.Debugf(internalstates.Ctx, "Created collection: %s \n", collectionName) + + return +} diff --git a/pkg/externalfunctions/llmhandler.go b/pkg/externalfunctions/llmhandler.go new file mode 100644 index 0000000..14a55d3 --- /dev/null +++ b/pkg/externalfunctions/llmhandler.go @@ -0,0 +1,588 @@ +package externalfunctions + +import ( + "encoding/json" + "fmt" + + "github.com/ansys/allie-flowkit/pkg/internalstates" + "github.com/ansys/allie-sharedtypes/pkg/config" + "github.com/ansys/allie-sharedtypes/pkg/logging" + "github.com/ansys/allie-sharedtypes/pkg/sharedtypes" +) + +// PerformVectorEmbeddingRequest performs a vector embedding request to LLM +// +// Tags: +// - @displayName: Embeddings +// +// Parameters: +// - input: the input string +// +// Returns: +// - embeddedVector: the embedded vector in float32 format +func PerformVectorEmbeddingRequest(input string) (embeddedVector []float32) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send embeddings request + responseChannel := sendEmbeddingsRequest(input, llmHandlerEndpoint, nil) + + // Process the first response and close the channel + var embedding32 []float32 + var err error + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Log LLM response + logging.Log.Debugf(internalstates.Ctx, "Received embeddings response.") + + // Get embedded vector array + interfaceArray, ok := response.EmbeddedData.([]interface{}) + if !ok { + errMessage := "error converting embedded data to interface array" + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + embedding32, err = convertToFloat32Slice(interfaceArray) + if err != nil { + errMessage := fmt.Sprintf("error converting embedded data to float32 slice: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Mark that the first response has been received + firstResponseReceived := true + + // Exit the loop after processing the first response + if firstResponseReceived { + break + } + } + + // Close the response channel + close(responseChannel) + + return embedding32 +} + +// PerformBatchEmbeddingRequest performs a batch vector embedding request to LLM +// +// Tags: +// - @displayName: Batch Embeddings +// +// Parameters: +// - input: the input strings +// +// Returns: +// - embeddedVectors: the embedded vectors in float32 format +func PerformBatchEmbeddingRequest(input []string) (embeddedVectors [][]float32) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send embeddings request + responseChannel := sendEmbeddingsRequest(input, llmHandlerEndpoint, nil) + + // Process the first response and close the channel + embedding32Array := make([][]float32, len(input)) + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Log LLM response + logging.Log.Debugf(internalstates.Ctx, "Received embeddings response.") + + // Get embedded vector array + interfaceArray, ok := response.EmbeddedData.([]interface{}) + if !ok { + errMessage := "error converting embedded data to interface array" + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + for i, interfaceArrayElement := range interfaceArray { + lowerInterfaceArray, ok := interfaceArrayElement.([]interface{}) + if !ok { + errMessage := "error converting embedded data to interface array" + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + embedding32, err := convertToFloat32Slice(lowerInterfaceArray) + if err != nil { + errMessage := fmt.Sprintf("error converting embedded data to float32 slice: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + embedding32Array[i] = embedding32 + } + + // Mark that the first response has been received + firstResponseReceived := true + + // Exit the loop after processing the first response + if firstResponseReceived { + break + } + } + + // Close the response channel + close(responseChannel) + + return embedding32Array +} + +// PerformKeywordExtractionRequest performs a keywords extraction request to LLM +// +// Tags: +// - @displayName: Keyword Extraction +// +// Parameters: +// - input: the input string +// - maxKeywordsSearch: the maximum number of keywords to search for +// +// Returns: +// - keywords: the keywords extracted from the input string as a slice of strings +func PerformKeywordExtractionRequest(input string, maxKeywordsSearch uint32) (keywords []string) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send chat request + responseChannel := sendChatRequestNoHistory(input, "keywords", maxKeywordsSearch, llmHandlerEndpoint, nil) + + // Process all responses + var responseAsStr string + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Accumulate the responses + responseAsStr += *(response.ChatData) + + // If we are at the last message, break the loop + if *(response.IsLast) { + break + } + } + + logging.Log.Debugf(internalstates.Ctx, "Received keywords response.") + + // Close the response channel + close(responseChannel) + + // Unmarshal JSON data into the result variable + err := json.Unmarshal([]byte(responseAsStr), &keywords) + if err != nil { + errMessage := fmt.Sprintf("Error unmarshalling keywords response from allie-llm: %v", err) + logging.Log.Error(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // Return the response + return keywords +} + +// PerformSummaryRequest performs a summary request to LLM +// +// Tags: +// - @displayName: Summary +// +// Parameters: +// - input: the input string +// +// Returns: +// - summary: the summary extracted from the input string +func PerformSummaryRequest(input string) (summary string) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send chat request + responseChannel := sendChatRequestNoHistory(input, "summary", 1, llmHandlerEndpoint, nil) + + // Process all responses + var responseAsStr string + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Accumulate the responses + responseAsStr += *(response.ChatData) + + // If we are at the last message, break the loop + if *(response.IsLast) { + break + } + } + + logging.Log.Debugf(internalstates.Ctx, "Received summary response.") + + // Close the response channel + close(responseChannel) + + // Return the response + return responseAsStr +} + +// PerformGeneralRequest performs a general chat completion request to LLM +// +// Tags: +// - @displayName: General LLM Request +// +// Parameters: +// - input: the input string +// - history: the conversation history +// - isStream: the stream flag +// - systemPrompt: the system prompt +// +// Returns: +// - message: the generated message +// - stream: the stream channel +func PerformGeneralRequest(input string, history []sharedtypes.HistoricMessage, isStream bool, systemPrompt string) (message string, stream *chan string) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send chat request + responseChannel := sendChatRequest(input, "general", history, 0, systemPrompt, llmHandlerEndpoint, nil) + // If isStream is true, create a stream channel and return asap + if isStream { + // Create a stream channel + streamChannel := make(chan string, 400) + + // Start a goroutine to transfer the data from the response channel to the stream channel + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false) + + // Return the stream channel + return "", &streamChannel + } + + // else Process all responses + var responseAsStr string + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Accumulate the responses + responseAsStr += *(response.ChatData) + + // If we are at the last message, break the loop + if *(response.IsLast) { + break + } + } + + // Close the response channel + close(responseChannel) + + // Return the response + return responseAsStr, nil +} + +// PerformGeneralRequestSpecificModel performs a general request to LLM with a specific model +// +// Tags: +// - @displayName: General LLM Request (Specific Models) +// +// Parameters: +// - input: the user input +// - history: the conversation history +// - isStream: the flag to indicate whether the response should be streamed +// - systemPrompt: the system prompt +// - modelId: the model ID +// +// Returns: +// - message: the response message +// - stream: the stream channel +func PerformGeneralRequestSpecificModel(input string, history []sharedtypes.HistoricMessage, isStream bool, systemPrompt string, modelIds []string) (message string, stream *chan string) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send chat request + responseChannel := sendChatRequest(input, "general", history, 0, systemPrompt, llmHandlerEndpoint, modelIds) + + // If isStream is true, create a stream channel and return asap + if isStream { + // Create a stream channel + streamChannel := make(chan string, 400) + + // Start a goroutine to transfer the data from the response channel to the stream channel + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false) + + // Return the stream channel + return "", &streamChannel + } + + // else Process all responses + var responseAsStr string + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Accumulate the responses + responseAsStr += *(response.ChatData) + + // If we are at the last message, break the loop + if *(response.IsLast) { + break + } + } + + // Close the response channel + close(responseChannel) + + // Return the response + return responseAsStr, nil +} + +// PerformCodeLLMRequest performs a code generation request to LLM +// +// Tags: +// - @displayName: Code LLM Request +// +// Parameters: +// - input: the input string +// - history: the conversation history +// - isStream: the stream flag +// +// Returns: +// - message: the generated code +// - stream: the stream channel +func PerformCodeLLMRequest(input string, history []sharedtypes.HistoricMessage, isStream bool, validateCode bool) (message string, stream *chan string) { + // get the LLM handler endpoint + llmHandlerEndpoint := config.GlobalConfig.LLM_HANDLER_ENDPOINT + + // Set up WebSocket connection with LLM and send chat request + responseChannel := sendChatRequest(input, "code", history, 0, "", llmHandlerEndpoint, nil) + + // If isStream is true, create a stream channel and return asap + if isStream { + // Create a stream channel + streamChannel := make(chan string, 400) + + // Start a goroutine to transfer the data from the response channel to the stream channel + go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, validateCode) + + // Return the stream channel + return "", &streamChannel + } + + // else Process all responses + var responseAsStr string + for response := range responseChannel { + // Check if the response is an error + if response.Type == "error" { + panic(response.Error) + } + + // Accumulate the responses + responseAsStr += *(response.ChatData) + + // If we are at the last message, break the loop + if *(response.IsLast) { + break + } + } + + // Close the response channel + close(responseChannel) + + // Code validation + if validateCode { + + // Extract the code from the response + pythonCode, err := extractPythonCode(responseAsStr) + if err != nil { + logging.Log.Errorf(internalstates.Ctx, "Error extracting Python code: %v", err) + } else { + + // Validate the Python code + valid, warnings, err := validatePythonCode(pythonCode) + if err != nil { + logging.Log.Errorf(internalstates.Ctx, "Error validating Python code: %v", err) + } else { + if valid { + if warnings { + responseAsStr += "\nCode has warnings." + } else { + responseAsStr += "\nCode is valid." + } + } else { + responseAsStr += "\nCode is invalid." + } + } + } + } + + // Return the response + return responseAsStr, nil +} + +// BuildLibraryContext builds the context string for the query +// +// Tags: +// - @displayName: Library Context +// +// Parameters: +// - message: the message string +// - libraryContext: the library context string +// +// Returns: +// - messageWithContext: the message with context +func BuildLibraryContext(message string, libraryContext string) (messageWithContext string) { + // Check if "pyansys" is in the library context + message = libraryContext + message + + return message +} + +// BuildFinalQueryForGeneralLLMRequest builds the final query for a general +// request to LLM. The final query is a markdown string that contains the +// original request and the examples from the KnowledgeDB. +// +// Tags: +// - @displayName: Final Query (General LLM Request) +// +// Parameters: +// - request: the original request +// - knowledgedbResponse: the KnowledgeDB response +// +// Returns: +// - finalQuery: the final query +func BuildFinalQueryForGeneralLLMRequest(request string, knowledgedbResponse []sharedtypes.DbResponse) (finalQuery string) { + + // If there is no response from the KnowledgeDB, return the original request + if len(knowledgedbResponse) == 0 { + return request + } + + // Build the final query using the KnowledgeDB response and the original request + finalQuery = "Based on the following examples:\n\n--- INFO START ---\n" + for _, example := range knowledgedbResponse { + finalQuery += example.Text + "\n" + } + finalQuery += "--- INFO END ---\n\n" + request + "\n" + + // Return the final query + return finalQuery +} + +// BuildFinalQueryForCodeLLMRequest builds the final query for a code generation +// request to LLM. The final query is a markdown string that contains the +// original request and the code examples from the KnowledgeDB. +// +// Tags: +// - @displayName: Final Query (Code LLM Request) +// +// Parameters: +// - request: the original request +// - knowledgedbResponse: the KnowledgeDB response +// +// Returns: +// - finalQuery: the final query +func BuildFinalQueryForCodeLLMRequest(request string, knowledgedbResponse []sharedtypes.DbResponse) (finalQuery string) { + // Build the final query using the KnowledgeDB response and the original request + // We have to use the text from the DB response and the original request. + // + // The prompt should be in the following format: + // + // ****************************************************************************** + // Based on the following examples: + // + // --- START EXAMPLE {response_n}--- + // >>> Summary: + // {knowledge_db_response_n_summary} + // + // >>> Code snippet: + // ```python + // {knowledge_db_response_n_text} + // ``` + // --- END EXAMPLE {response_n}--- + // + // --- START EXAMPLE {response_n}--- + // ... + // --- END EXAMPLE {response_n}--- + // + // Generate the Python code for the following request: + // + // >>> Request: + // {original_request} + // ****************************************************************************** + + // If there is no response from the KnowledgeDB, return the original request + if len(knowledgedbResponse) > 0 { + // Initial request + finalQuery = "Based on the following examples:\n\n" + + for i, element := range knowledgedbResponse { + // Add the example number + finalQuery += "--- START EXAMPLE " + fmt.Sprint(i+1) + "---\n" + finalQuery += ">>> Summary:\n" + element.Summary + "\n\n" + finalQuery += ">>> Code snippet:\n```python\n" + element.Text + "\n```\n" + finalQuery += "--- END EXAMPLE " + fmt.Sprint(i+1) + "---\n\n" + } + } + + // Pass in the original request + finalQuery += "Generate the Python code for the following request:\n>>> Request:\n" + request + "\n" + + // Return the final query + return finalQuery +} + +type AppendMessageHistoryRole string + +const ( + user AppendMessageHistoryRole = "user" + assistant AppendMessageHistoryRole = "assistant" + system AppendMessageHistoryRole = "system" +) + +// AppendMessageHistory appends a new message to the conversation history +// +// Tags: +// - @displayName: Append Message History +// +// Parameters: +// - newMessage: the new message +// - role: the role of the message +// - history: the conversation history +// +// Returns: +// - updatedHistory: the updated conversation history +func AppendMessageHistory(newMessage string, role AppendMessageHistoryRole, history []sharedtypes.HistoricMessage) (updatedHistory []sharedtypes.HistoricMessage) { + switch role { + case user: + case assistant: + case system: + default: + errMessage := fmt.Sprintf("Invalid role used for 'AppendMessageHistory': %v", role) + logging.Log.Warn(internalstates.Ctx, errMessage) + panic(errMessage) + } + + // skip for empty messages + if newMessage == "" { + return history + } + + // Create a new HistoricMessage + newMessageHistory := sharedtypes.HistoricMessage{ + Role: string(role), + Content: newMessage, + } + + // Append the new message to the history + history = append(history, newMessageHistory) + + return history +} diff --git a/pkg/functiondefinitions/functiondefinitions.go b/pkg/functiondefinitions/functiondefinitions.go index aafab5f..d2bfede 100644 --- a/pkg/functiondefinitions/functiondefinitions.go +++ b/pkg/functiondefinitions/functiondefinitions.go @@ -2,6 +2,7 @@ package functiondefinitions import ( "bytes" + "fmt" "go/ast" "go/format" "go/parser" @@ -28,7 +29,7 @@ import ( // // Returns: // - error: an error if the file cannot be parsed. -func ExtractFunctionDefinitionsFromPackage(content string) error { +func ExtractFunctionDefinitionsFromPackage(content string, category string) error { fset := token.NewFileSet() // positions are relative to fset // Parse the file given by filePath @@ -93,9 +94,15 @@ func ExtractFunctionDefinitionsFromPackage(content string) error { if fn, isFn := decl.(*ast.FuncDecl); isFn { // Check if the function is exported if fn.Name.IsExported() { + // Extract docstring text + description := fn.Doc.Text() + displayName := extractTagValue(description, "@displayName") + funcDef := &allieflowkitgrpc.FunctionDefinition{ Name: fn.Name.Name, - Description: fn.Doc.Text(), + DisplayName: displayNameOrDefault(displayName, fn.Name.Name), + Description: description, + Category: category, Input: []*allieflowkitgrpc.FunctionInputDefinition{}, Output: []*allieflowkitgrpc.FunctionOutputDefinition{}, } @@ -239,3 +246,41 @@ func typeExprToString(expr ast.Expr) string { return cleanedTypeStr } + +// extractTagValue extracts the value of a tag from a docstring. +// The tag value is expected to be in the format "- tag: value". +// +// Parameters: +// - docText: the docstring text to extract the tag value from. +// - tag: the tag to extract the value of. +// +// Returns: +// - string: the value of the tag, or an empty string if the tag is not found. +func extractTagValue(docText, tag string) string { + // Define a regex to extract the value of the tag in the new format + re := regexp.MustCompile(fmt.Sprintf(`- %s:\s*(.+)`, tag)) + matches := re.FindStringSubmatch(docText) + + // If a match is found, return the tag value + if len(matches) > 1 { + return strings.TrimSpace(matches[1]) + } + + // If no match, return an empty string + return "" +} + +// displayNameOrDefault returns the displayName if it is not empty, otherwise it returns the defaultName. +// +// Parameters: +// - displayName: the display name to check. +// - defaultName: the default name to use if the display name is empty. +// +// Returns: +// - string: the display name if it is not empty, otherwise the default name. +func displayNameOrDefault(displayName, defaultName string) string { + if displayName != "" { + return displayName + } + return defaultName +}