diff --git a/managed/cmd/pmm-managed/main.go b/managed/cmd/pmm-managed/main.go index f125609629..928c3bdb7a 100644 --- a/managed/cmd/pmm-managed/main.go +++ b/managed/cmd/pmm-managed/main.go @@ -338,6 +338,7 @@ func runHTTP1Server(ctx context.Context, deps *http1ServerDeps) { proxyMux := grpc_gateway.NewServeMux( grpc_gateway.WithMarshalerOption(grpc_gateway.MIMEWildcard, marshaller), grpc_gateway.WithErrorHandler(pmmerrors.PMMHTTPErrorHandler), + grpc_gateway.WithRoutingErrorHandler(pmmerrors.PMMRoutingErrorHandler), ) opts := []grpc.DialOption{ diff --git a/qan-api2/main.go b/qan-api2/main.go index 63a9e4602d..17efc46910 100644 --- a/qan-api2/main.go +++ b/qan-api2/main.go @@ -31,7 +31,6 @@ import ( "sync" "time" - "github.com/gogo/status" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/validator" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus" @@ -44,7 +43,6 @@ import ( "golang.org/x/sys/unix" "google.golang.org/grpc" channelz "google.golang.org/grpc/channelz/service" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/reflection" @@ -56,6 +54,7 @@ import ( aservice "github.com/percona/pmm/qan-api2/services/analytics" rservice "github.com/percona/pmm/qan-api2/services/receiver" "github.com/percona/pmm/qan-api2/utils/interceptors" + pmmerrors "github.com/percona/pmm/utils/errors" "github.com/percona/pmm/utils/logger" "github.com/percona/pmm/utils/sqlmetrics" "github.com/percona/pmm/version" @@ -150,7 +149,7 @@ func runJSONServer(ctx context.Context, grpcBindF, jsonBindF string) { proxyMux := grpc_gateway.NewServeMux( grpc_gateway.WithMarshalerOption(grpc_gateway.MIMEWildcard, marshaller), - grpc_gateway.WithRoutingErrorHandler(handleRoutingError), + grpc_gateway.WithRoutingErrorHandler(pmmerrors.PMMRoutingErrorHandler), ) opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())} @@ -247,23 +246,6 @@ func runDebugServer(ctx context.Context, debugBindF string) { cancel() } -// handleRoutingError customized the http status code for routes that can't be found (i.e. 404). -func handleRoutingError(ctx context.Context, mux *grpc_gateway.ServeMux, marshaler grpc_gateway.Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) { - if httpStatus != http.StatusNotFound { - grpc_gateway.DefaultRoutingErrorHandler(ctx, mux, marshaler, w, r, httpStatus) - return - } - - // Use HTTPStatusError to customize the DefaultHTTPErrorHandler status code - msg := fmt.Sprintf("Endpoint not found: %s, http method: %s", r.URL.Path, r.Method) - err := &grpc_gateway.HTTPStatusError{ - HTTPStatus: httpStatus, - Err: status.Error(codes.NotFound, msg), - } - - grpc_gateway.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err) -} - func main() { log.SetFlags(0) log.SetPrefix("stdlog: ") diff --git a/utils/errors/errors.go b/utils/errors/errors.go index 6d281040a0..d7e60944e8 100644 --- a/utils/errors/errors.go +++ b/utils/errors/errors.go @@ -142,3 +142,20 @@ func handleForwardResponseTrailer(w http.ResponseWriter, md runtime.ServerMetada } } } + +// PMMRoutingErrorHandler customizes the http status code for routes that can't be found (i.e. 404). +func PMMRoutingErrorHandler(ctx context.Context, mux *runtime.ServeMux, marshaler runtime.Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) { + if httpStatus != http.StatusNotFound { + runtime.DefaultRoutingErrorHandler(ctx, mux, marshaler, w, r, httpStatus) + return + } + + // Use HTTPStatusError to customize the DefaultHTTPErrorHandler status code + msg := fmt.Sprintf("Endpoint not found: %s, http method: %s", r.URL.Path, r.Method) + err := &runtime.HTTPStatusError{ + HTTPStatus: httpStatus, + Err: status.Error(codes.NotFound, msg), + } + + runtime.DefaultHTTPErrorHandler(ctx, mux, marshaler, w, r, err) +}