Skip to content

Commit

Permalink
Merge branch 'main' into chore/get_auth_info
Browse files Browse the repository at this point in the history
  • Loading branch information
soneda-yuya committed Feb 6, 2025
2 parents 1aa23f7 + 7860571 commit 878e678
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 75 deletions.
41 changes: 12 additions & 29 deletions appx/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package appx
import (
"context"
"net/http"
"strings"

"github.com/google/uuid"
"github.com/reearth/reearthx/log"
Expand All @@ -12,15 +11,15 @@ import (
type requestIDKey struct{}

func ContextMiddleware(key, value any) func(http.Handler) http.Handler {
return ContextMiddlewareBy(func(r *http.Request) context.Context {
return ContextMiddlewareBy(func(w http.ResponseWriter, r *http.Request) context.Context {
return context.WithValue(r.Context(), key, value)
})
}

func ContextMiddlewareBy(c func(*http.Request) context.Context) func(http.Handler) http.Handler {
func ContextMiddlewareBy(c func(http.ResponseWriter, *http.Request) context.Context) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ctx := c(r); ctx == nil {
if ctx := c(w, r); ctx == nil {
next.ServeHTTP(w, r)
} else {
next.ServeHTTP(w, r.WithContext(ctx))
Expand All @@ -30,35 +29,21 @@ func ContextMiddlewareBy(c func(*http.Request) context.Context) func(http.Handle
}

func RequestIDMiddleware() func(http.Handler) http.Handler {
return ContextMiddlewareBy(func(r *http.Request) context.Context {
return ContextMiddlewareBy(func(w http.ResponseWriter, r *http.Request) context.Context {
ctx := r.Context()
reqid := getHeader(r,
"X-Request-ID",
"X-Amzn-Trace-Id", // AWS
"X-Cloud-Trace-Context", // GCP
"X-ARR-LOG-ID", // Azure
)
reqid := log.GetReqestID(w, r)
if reqid == "" {
reqid = uuid.NewString()
}
ctx = context.WithValue(ctx, requestIDKey{}, reqid)
w.Header().Set("X-Request-ID", reqid)

logger := log.GetLoggerFromContextOrDefault(ctx).SetPrefix(reqid)
ctx = log.AttachLoggerToContext(ctx, logger)
return ctx
})
}

func GetRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if reqid, ok := ctx.Value(requestIDKey{}).(string); ok {
return reqid
}
return ""
}

func GetAuthInfo(ctx context.Context, key any) *AuthInfo {
if ctx == nil {
return nil
Expand All @@ -69,14 +54,12 @@ func GetAuthInfo(ctx context.Context, key any) *AuthInfo {
return nil
}

func getHeader(r *http.Request, keys ...string) string {
for _, k := range keys {
if v := r.Header.Get(k); v != "" {
return v
}
if v := r.Header.Get(strings.ToLower(k)); v != "" {
return v
}
func GetRequestIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
if reqid, ok := ctx.Value(requestIDKey{}).(string); ok {
return reqid
}
return ""
}
7 changes: 4 additions & 3 deletions appx/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ func TestContextMiddleware(t *testing.T) {
}

func TestContextMiddlewareBy(t *testing.T) {
key := struct{}{}
ts := httptest.NewServer(ContextMiddlewareBy(func(r *http.Request) context.Context {
type keys struct{}
key := keys{}
ts := httptest.NewServer(ContextMiddlewareBy(func(w http.ResponseWriter, r *http.Request) context.Context {
if r.Method == http.MethodPost {
return context.WithValue(r.Context(), key, "aaa")
}
Expand All @@ -50,7 +51,7 @@ func TestContextMiddlewareBy(t *testing.T) {

func TestRequestIDMiddleware(t *testing.T) {
ts := httptest.NewServer(RequestIDMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(GetRequestID(r.Context())))
_, _ = w.Write([]byte(GetRequestIDFromContext(r.Context())))
})))
defer ts.Close()

Expand Down
5 changes: 4 additions & 1 deletion appx/gql.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/99designs/gqlgen/graphql/handler"
"github.com/99designs/gqlgen/graphql/handler/extension"
"github.com/ravilushqa/otelgqlgen"
"github.com/reearth/reearthx/log"
"github.com/vektah/gqlparser/v2/gqlerror"
)

Expand All @@ -33,8 +34,10 @@ func GraphQLHandler(c GraphQLHandlerConfig) http.Handler {
srv.SetErrorPresenter(
// show more detailed error messgage in debug mode
func(ctx context.Context, e error) *gqlerror.Error {
path := graphql.GetFieldContext(ctx).Path()
log.Debugfc(ctx, "gql error: %v: %v", path, e)
if c.Dev {
return gqlerror.ErrorPathf(graphql.GetFieldContext(ctx).Path(), e.Error())
return gqlerror.ErrorPathf(path, "%v", e)
}
return graphql.DefaultErrorPresenter(ctx, e)
},
Expand Down
49 changes: 33 additions & 16 deletions log/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,83 @@ func GetLoggerFromContextOrDefault(ctx context.Context) *Logger {
return globalLogger
}

func UpdateContext(ctx context.Context, f func(logger *Logger) *Logger) context.Context {
return AttachLoggerToContext(ctx, f(GetLoggerFromContextOrDefault(ctx)))
}

func WithPrefixMessage(ctx context.Context, prefix string) context.Context {
return UpdateContext(ctx, func(logger *Logger) *Logger {
return logger.AppendPrefixMessage(prefix)
})
}

func Tracefc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Debugf(format, args...)
getLoggerFromContextOrDefault(ctx).Debugf(format, args...)
}

func Debugfc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Debugf(format, args...)
getLoggerFromContextOrDefault(ctx).Debugf(format, args...)
}

func Infofc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Infof(format, args...)
getLoggerFromContextOrDefault(ctx).Infof(format, args...)
}

func Printfc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Infof(format, args...)
getLoggerFromContextOrDefault(ctx).Infof(format, args...)
}

func Warnfc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Warnf(format, args...)
getLoggerFromContextOrDefault(ctx).Warnf(format, args...)
}

func Errorfc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Errorf(format, args...)
getLoggerFromContextOrDefault(ctx).Errorf(format, args...)
}

func Fatalfc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Fatalf(format, args...)
getLoggerFromContextOrDefault(ctx).Fatalf(format, args...)
}

func Panicfc(ctx context.Context, format string, args ...any) {
GetLoggerFromContextOrDefault(ctx).Panicf(format, args...)
getLoggerFromContextOrDefault(ctx).Panicf(format, args...)
}

func Tracec(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Debug(args...)
getLoggerFromContextOrDefault(ctx).Debug(args...)
}

func Debugc(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Debug(args...)
getLoggerFromContextOrDefault(ctx).Debug(args...)
}

func Infoc(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Info(args...)
getLoggerFromContextOrDefault(ctx).Info(args...)
}

func Printc(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Info(args...)
getLoggerFromContextOrDefault(ctx).Info(args...)
}

func Warnc(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Warn(args...)
getLoggerFromContextOrDefault(ctx).Warn(args...)
}

func Errorc(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Error(args...)
getLoggerFromContextOrDefault(ctx).Error(args...)
}

func Fatalc(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Fatal(args...)
getLoggerFromContextOrDefault(ctx).Fatal(args...)
}

func Panicc(ctx context.Context, args ...any) {
GetLoggerFromContextOrDefault(ctx).Panic(args...)
getLoggerFromContextOrDefault(ctx).Panic(args...)
}

func getLoggerFromContextOrDefault(ctx context.Context) *Logger {
if logger := GetLoggerFromContext(ctx); logger != nil {
return logger.AddCallerSkip(1)
}
return globalLogger
}
9 changes: 5 additions & 4 deletions log/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@ func TestContextLogger(t *testing.T) {
SetOutput(DefaultOutput)
})

l := NewWithOutput(w).SetPrefix("test")
l := NewWithOutput(w).SetPrefix("prefix")
ctx := AttachLoggerToContext(context.Background(), l)

Infofc(ctx, "hoge %s", "fuga")
Infofc(context.Background(), "hoge %s", "fuga2")
//nolint:staticcheck // test nil context
Infofc(nil, "hoge %s", "fuga3")

scanner := bufio.NewScanner(w)
assert.True(t, scanner.Scan())
assert.Contains(t, scanner.Text(), "test\thoge fuga")
assert.Contains(t, scanner.Text(), "\tprefix\t")
assert.True(t, scanner.Scan())
assert.Contains(t, scanner.Text(), "hoge fuga2")
assert.NotContains(t, scanner.Text(), "test")
assert.NotContains(t, scanner.Text(), "\tprefix\t")
assert.True(t, scanner.Scan())
assert.Contains(t, scanner.Text(), "hoge fuga3")
assert.NotContains(t, scanner.Text(), "test")
assert.NotContains(t, scanner.Text(), "\tprefix\t")
assert.False(t, scanner.Scan())
}
89 changes: 76 additions & 13 deletions log/echo.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package log

import (
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/labstack/echo/v4"
Expand All @@ -23,6 +26,12 @@ func NewEcho() *Echo {
}

func NewEchoWith(logger *Logger) *Echo {
return &Echo{
logger: logger.AddCallerSkip(1),
}
}

func NewEchoWithRaw(logger *Logger) *Echo {
return &Echo{
logger: logger,
}
Expand Down Expand Up @@ -203,29 +212,54 @@ func (l *Echo) AccessLogger() echo.MiddlewareFunc {
req := c.Request()
res := c.Response()
start := time.Now()
if err := next(c); err != nil {
c.Error(err)
}
stop := time.Now()

logger := GetLoggerFromContext(c.Request().Context())
if logger == nil {
logger = l.logger
}
logger.Infow(
"Handled request",
reqid := GetReqestID(res, req)
args := []any{
"time_unix", start.Unix(),
"remote_ip", c.RealIP(),
"host", req.Host,
"uri", req.RequestURI,
"method", req.Method,
"path", req.URL.Path,
"protocol", req.Proto,
"referer", req.Referer(),
"user_agent", req.UserAgent(),
"status", res.Status,
"latency", stop.Sub(start).Microseconds(),
"latency_human", stop.Sub(start).String(),
"bytes_in", req.ContentLength,
"request_id", reqid,
"route", c.Path(),
}

logger := GetLoggerFromContext(c.Request().Context())
if logger == nil {
logger = l.logger
}
logger = logger.WithCaller(false)

// incoming log
logger.Infow(
fmt.Sprintf("<-- %s %s", req.Method, req.URL.Path),
args...,
)

if err := next(c); err != nil {
c.Error(err)
}

res = c.Response()
stop := time.Now()
latency := stop.Sub(start)
latencyHuman := latency.String()
args = append(args,
"status", res.Status,
"bytes_out", res.Size,
"letency", latency.Microseconds(),
"latency_human", latencyHuman,
)

// outcoming log
logger.Infow(
fmt.Sprintf("--> %s %d %s %s", req.Method, res.Status, req.URL.Path, latencyHuman),
args...,
)
return nil
}
Expand All @@ -240,3 +274,32 @@ func fromMap(m map[string]any) (res []any) {
}
return
}

func GetReqestID(w http.ResponseWriter, r *http.Request) string {
if reqid := getHeader(r,
"X-Request-ID",
"X-Cloud-Trace-Context", // Google Cloud
"X-Amzn-Trace-Id", // AWS
"X-ARR-LOG-ID", // Azure
); reqid != "" {
return reqid
}

if reqid := w.Header().Get("X-Request-ID"); reqid != "" {
return reqid
}

return ""
}

func getHeader(r *http.Request, keys ...string) string {
for _, k := range keys {
if v := r.Header.Get(k); v != "" {
return v
}
if v := r.Header.Get(strings.ToLower(k)); v != "" {
return v
}
}
return ""
}
4 changes: 2 additions & 2 deletions log/echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ func TestEcho(t *testing.T) {
w := &bytes.Buffer{}
l := NewEcho()
l.SetOutput(w)
l.SetPrefix("test")
l.SetPrefix("prefix")
l.Infof("hoge %s", "fuga")

scanner := bufio.NewScanner(w)
assert.True(t, scanner.Scan())
assert.Contains(t, scanner.Text(), "test\thoge fuga")
assert.Contains(t, scanner.Text(), "\tprefix\t")
assert.False(t, scanner.Scan())
}
Loading

0 comments on commit 878e678

Please sign in to comment.