-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5081de0
commit 744d66e
Showing
16 changed files
with
514 additions
and
405 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package dto | ||
|
||
type PromptLLMDto struct { | ||
Message string `json:"message" validate:"required"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package handlers | ||
|
||
import ( | ||
"assistant/api/dto" | ||
"assistant/api/utils" | ||
"assistant/llm" | ||
"encoding/json" | ||
"net/http" | ||
"sync" | ||
) | ||
|
||
type ConversationHandler struct { | ||
conversation []llm.Message | ||
mu sync.Mutex | ||
} | ||
|
||
func NewConversationHandler(targetContracts string) *ConversationHandler { | ||
return &ConversationHandler{ | ||
conversation: []llm.Message{llm.InitialPrompt(targetContracts)}, | ||
} | ||
} | ||
|
||
func (ch *ConversationHandler) GetConversation(w http.ResponseWriter, r *http.Request) { | ||
ch.mu.Lock() | ||
defer ch.mu.Unlock() | ||
|
||
response := map[string][]llm.Message{"conversation": ch.conversation[1:]} | ||
writeJSONResponse(w, http.StatusOK, response) | ||
} | ||
|
||
func (ch *ConversationHandler) PromptLLM(w http.ResponseWriter, r *http.Request) { | ||
var data dto.PromptLLMDto | ||
|
||
if err := utils.DecodeAndValidateRequestBody(r, &data); err != nil { | ||
writeJSONResponse(w, http.StatusBadRequest, errorResponse(err.Error())) | ||
return | ||
} | ||
|
||
ch.mu.Lock() | ||
ch.conversation = append(ch.conversation, llm.Message{ | ||
Role: "user", | ||
Content: data.Message, | ||
}) | ||
ch.mu.Unlock() | ||
|
||
response, err := llm.AskGPT4Turbo(ch.conversation) | ||
if err != nil { | ||
writeJSONResponse(w, http.StatusInternalServerError, errorResponse(err.Error())) | ||
return | ||
} | ||
|
||
ch.mu.Lock() | ||
ch.conversation = append(ch.conversation, llm.Message{ | ||
Role: "system", | ||
Content: response, | ||
}) | ||
ch.mu.Unlock() | ||
|
||
writeJSONResponse(w, http.StatusOK, map[string]string{"response": response}) | ||
} | ||
|
||
func (ch *ConversationHandler) ResetConversation(w http.ResponseWriter, r *http.Request) { | ||
ch.mu.Lock() | ||
ch.conversation = ch.conversation[0:1] // Keep the first prompt | ||
ch.mu.Unlock() | ||
|
||
writeJSONResponse(w, http.StatusOK, messageResponse("Conversation reset successfully")) | ||
} | ||
|
||
func messageResponse(message string) map[string]string { | ||
return map[string]string{"message": message} | ||
} | ||
|
||
func errorResponse(errMessage string) map[string]string { | ||
return map[string]string{"error": errMessage} | ||
} | ||
|
||
func writeJSONResponse(w http.ResponseWriter, statusCode int, data interface{}) { | ||
w.WriteHeader(statusCode) | ||
w.Header().Set("Access-Control-Allow-Origin", "*") | ||
|
||
if err := json.NewEncoder(w).Encode(data); err != nil { | ||
http.Error(w, "Error encoding response", http.StatusInternalServerError) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package api | ||
|
||
import ( | ||
"assistant/api/handlers" | ||
"github.com/gorilla/mux" | ||
) | ||
|
||
func attachConversationRoutes(router *mux.Router, targetContracts string) { | ||
ch := handlers.NewConversationHandler(targetContracts) | ||
conversationRoute := "/conversation" | ||
router.HandleFunc(conversationRoute, ch.GetConversation).Methods("GET") | ||
router.HandleFunc(conversationRoute, ch.PromptLLM).Methods("POST") | ||
router.HandleFunc(conversationRoute, ch.ResetConversation).Methods("DELETE") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package utils | ||
|
||
import ( | ||
"encoding/json" | ||
"net/http" | ||
) | ||
|
||
func DecodeRequestBody(r *http.Request, v interface{}) error { | ||
err := json.NewDecoder(r.Body).Decode(v) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package utils | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"github.com/go-playground/validator/v10" | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
func DecodeAndValidateRequestBody(r *http.Request, data interface{}) error { | ||
// Read the request body | ||
err := json.NewDecoder(r.Body).Decode(data) | ||
if err != nil { | ||
return errors.New("invalid request body") | ||
} | ||
|
||
// Validate the struct | ||
err = ValidateData(data) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func ValidateData(v interface{}) error { | ||
validate := validator.New(validator.WithRequiredStructEnabled()) | ||
|
||
err := validate.Struct(v) | ||
|
||
if err != nil { | ||
var validationErrors validator.ValidationErrors | ||
ok := errors.As(err, &validationErrors) | ||
if !ok { | ||
// Handle unexpected validation error type | ||
return fmt.Errorf("unexpected validation error: %v", err) | ||
} | ||
|
||
// Construct an error message with validation errors | ||
var errorMsgs []string | ||
for _, e := range validationErrors { | ||
errorMsgs = append(errorMsgs, fmt.Sprintf("%s: %s", e.Field(), e.Tag())) | ||
} | ||
return fmt.Errorf("validation errors: %s", strings.Join(errorMsgs, "; ")) | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.