Skip to content

Commit

Permalink
fix: Fix CORS issue
Browse files Browse the repository at this point in the history
  • Loading branch information
damilolaedwards committed Jul 26, 2024
1 parent 5081de0 commit 744d66e
Show file tree
Hide file tree
Showing 16 changed files with 514 additions and 405 deletions.
55 changes: 42 additions & 13 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
"github.com/rs/cors"
"github.com/rs/zerolog"
"net"
"net/http"
Expand All @@ -15,7 +16,25 @@ import (
"syscall"
)

func Start(config *config.ProjectConfig, slitherOutput *types.SlitherOutput) {
type API struct {
targetContracts string
slitherOutput *types.SlitherOutput
logger *logging.Logger
}

func InitializeAPI(targetContracts string, slitherOutput *types.SlitherOutput) *API {
// Create sub-logger for api module
logger := logging.NewLogger(zerolog.InfoLevel)
logger.AddWriter(os.Stdout, logging.UNSTRUCTURED, true)

return &API{
targetContracts: targetContracts,
slitherOutput: slitherOutput,
logger: logger,
}
}

func (api *API) Start(config *config.ProjectConfig) {
var port string

if config.Port == 0 {
Expand All @@ -31,19 +50,16 @@ func Start(config *config.ProjectConfig, slitherOutput *types.SlitherOutput) {
// Create a new router
router := mux.NewRouter()

// Attach middleware
attachMiddleware(router)

// Serve the contracts on a sub-router
router.HandleFunc("/contracts", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(slitherOutput)
if err != nil {
logger.Error("Failed to encode contracts: ", err)
return
}
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://localhost:5173"},
AllowCredentials: true,
})

// Attach routes
api.attachRoutes(router)

handler := c.Handler(router)

var listener net.Listener
var err error

Expand Down Expand Up @@ -72,7 +88,7 @@ func Start(config *config.ProjectConfig, slitherOutput *types.SlitherOutput) {
// Start the server in a separate goroutine
serverErrorChan := make(chan error, 1)
go func() {
serverErrorChan <- http.Serve(listener, router)
serverErrorChan <- http.Serve(listener, handler)
}()

// Gracefully shutdown the server if a server error is encountered
Expand All @@ -89,6 +105,19 @@ func Start(config *config.ProjectConfig, slitherOutput *types.SlitherOutput) {
}
}

func (api *API) attachRoutes(router *mux.Router) {
router.HandleFunc("/contracts", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
err := json.NewEncoder(w).Encode(api.slitherOutput)
if err != nil {
api.logger.Error("Failed to encode contracts: ", err)
return
}
})

attachConversationRoutes(router, api.targetContracts)
}

func incrementPort(port string) string {
var portNum int

Expand Down
5 changes: 5 additions & 0 deletions api/dto/dto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package dto

type PromptLLMDto struct {
Message string `json:"message" validate:"required"`
}
85 changes: 85 additions & 0 deletions api/handlers/handlers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package handlers

import (
"assistant/api/dto"
"assistant/api/utils"
"assistant/llm"
"encoding/json"
"net/http"
"sync"
)

type ConversationHandler struct {
conversation []llm.Message
mu sync.Mutex
}

func NewConversationHandler(targetContracts string) *ConversationHandler {
return &ConversationHandler{
conversation: []llm.Message{llm.InitialPrompt(targetContracts)},
}
}

func (ch *ConversationHandler) GetConversation(w http.ResponseWriter, r *http.Request) {
ch.mu.Lock()
defer ch.mu.Unlock()

response := map[string][]llm.Message{"conversation": ch.conversation[1:]}
writeJSONResponse(w, http.StatusOK, response)
}

func (ch *ConversationHandler) PromptLLM(w http.ResponseWriter, r *http.Request) {
var data dto.PromptLLMDto

if err := utils.DecodeAndValidateRequestBody(r, &data); err != nil {
writeJSONResponse(w, http.StatusBadRequest, errorResponse(err.Error()))
return
}

ch.mu.Lock()
ch.conversation = append(ch.conversation, llm.Message{
Role: "user",
Content: data.Message,
})
ch.mu.Unlock()

response, err := llm.AskGPT4Turbo(ch.conversation)
if err != nil {
writeJSONResponse(w, http.StatusInternalServerError, errorResponse(err.Error()))
return
}

ch.mu.Lock()
ch.conversation = append(ch.conversation, llm.Message{
Role: "system",
Content: response,
})
ch.mu.Unlock()

writeJSONResponse(w, http.StatusOK, map[string]string{"response": response})
}

func (ch *ConversationHandler) ResetConversation(w http.ResponseWriter, r *http.Request) {
ch.mu.Lock()
ch.conversation = ch.conversation[0:1] // Keep the first prompt
ch.mu.Unlock()

writeJSONResponse(w, http.StatusOK, messageResponse("Conversation reset successfully"))
}

func messageResponse(message string) map[string]string {
return map[string]string{"message": message}
}

func errorResponse(errMessage string) map[string]string {
return map[string]string{"error": errMessage}
}

func writeJSONResponse(w http.ResponseWriter, statusCode int, data interface{}) {
w.WriteHeader(statusCode)
w.Header().Set("Access-Control-Allow-Origin", "*")

if err := json.NewEncoder(w).Encode(data); err != nil {
http.Error(w, "Error encoding response", http.StatusInternalServerError)
}
}
21 changes: 12 additions & 9 deletions api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,23 @@ import (
"net/http"
)

func setHeaders(next http.Handler) http.Handler {
func enableCORS(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set default headers
w.Header().Set("Content-Type", "application/json")

// Handle CORS headers
// Set CORS headers
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")

// If the request method is OPTIONS, return a 200 status (pre-flight request)
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
}

next.ServeHTTP(w, r)
handler.ServeHTTP(w, r)
})
}

func attachMiddleware(router *mux.Router) {
router.Use(setHeaders)
router.Use(enableCORS)
}
14 changes: 14 additions & 0 deletions api/routes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package api

import (
"assistant/api/handlers"
"github.com/gorilla/mux"
)

func attachConversationRoutes(router *mux.Router, targetContracts string) {
ch := handlers.NewConversationHandler(targetContracts)
conversationRoute := "/conversation"
router.HandleFunc(conversationRoute, ch.GetConversation).Methods("GET")
router.HandleFunc(conversationRoute, ch.PromptLLM).Methods("POST")
router.HandleFunc(conversationRoute, ch.ResetConversation).Methods("DELETE")
}
14 changes: 14 additions & 0 deletions api/utils/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package utils

import (
"encoding/json"
"net/http"
)

func DecodeRequestBody(r *http.Request, v interface{}) error {
err := json.NewDecoder(r.Body).Decode(v)
if err != nil {
return err
}
return nil
}
50 changes: 50 additions & 0 deletions api/utils/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package utils

import (
"encoding/json"
"errors"
"fmt"
"github.com/go-playground/validator/v10"
"net/http"
"strings"
)

func DecodeAndValidateRequestBody(r *http.Request, data interface{}) error {
// Read the request body
err := json.NewDecoder(r.Body).Decode(data)
if err != nil {
return errors.New("invalid request body")
}

// Validate the struct
err = ValidateData(data)
if err != nil {
return err
}

return nil
}

func ValidateData(v interface{}) error {
validate := validator.New(validator.WithRequiredStructEnabled())

err := validate.Struct(v)

if err != nil {
var validationErrors validator.ValidationErrors
ok := errors.As(err, &validationErrors)
if !ok {
// Handle unexpected validation error type
return fmt.Errorf("unexpected validation error: %v", err)
}

// Construct an error message with validation errors
var errorMsgs []string
for _, e := range validationErrors {
errorMsgs = append(errorMsgs, fmt.Sprintf("%s: %s", e.Field(), e.Tag()))
}
return fmt.Errorf("validation errors: %s", strings.Join(errorMsgs, "; "))
}

return nil
}
16 changes: 5 additions & 11 deletions cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"assistant/config"
"assistant/logging/colors"
"assistant/slither"
"assistant/utils"
"fmt"
"github.com/spf13/cobra"
"os"
Expand Down Expand Up @@ -119,21 +120,14 @@ func cmdRunGenerate(cmd *cobra.Command, args []string) error {
}
cmdLogger.Info("Successfully ran Slither on the target contracts directory")

// Write contracts to a file
file, err := os.OpenFile("contracts.json", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
targetContracts, err := utils.ReadDirectoryContents(projectConfig.TargetContracts.Dir, projectConfig.TargetContracts.ExcludePaths...)
if err != nil {
cmdLogger.Error("Failed to run the generate command", err)
return err
}

defer func(file *os.File) {
err := file.Close()
if err != nil {
cmdLogger.Error("Error closing contracts.json", err)
}
}(file)

// Start the API to serve
api.Start(projectConfig, slitherOutput)
// Start the API
api.InitializeAPI(targetContracts, slitherOutput).Start(projectConfig)

return nil
}
10 changes: 9 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module assistant

go 1.19
go 1.20

require (
github.com/rs/zerolog v1.33.0
Expand All @@ -11,12 +11,20 @@ require (

require (
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.22.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/mux v1.8.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.7 // indirect
github.com/rs/cors v1.11.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/text v0.16.0 // indirect
)
Loading

0 comments on commit 744d66e

Please sign in to comment.