diff --git a/config/section.go b/config/section.go index 41fddd0..34f5f7a 100644 --- a/config/section.go +++ b/config/section.go @@ -199,6 +199,10 @@ func (r *section) SetConfig(c Config) error { r.lockObj.Lock() defer r.lockObj.Unlock() + if reflect.TypeOf(c).Kind() != reflect.Ptr { + return fmt.Errorf("config must be a Pointer") + } + if !DeepEqual(r.config, c) { r.config = c r.isDirty.Store(true) diff --git a/contextutils/context.go b/contextutils/context.go index b5a9c00..e5be11e 100644 --- a/contextutils/context.go +++ b/contextutils/context.go @@ -4,22 +4,24 @@ package contextutils import ( "context" "fmt" + "runtime/pprof" ) type Key string const ( - AppNameKey Key = "app_name" - NamespaceKey Key = "ns" - TaskTypeKey Key = "tasktype" - ProjectKey Key = "project" - DomainKey Key = "domain" - WorkflowIDKey Key = "wf" - NodeIDKey Key = "node" - TaskIDKey Key = "task" - ExecIDKey Key = "exec_id" - JobIDKey Key = "job_id" - PhaseKey Key = "phase" + AppNameKey Key = "app_name" + NamespaceKey Key = "ns" + TaskTypeKey Key = "tasktype" + ProjectKey Key = "project" + DomainKey Key = "domain" + WorkflowIDKey Key = "wf" + NodeIDKey Key = "node" + TaskIDKey Key = "task" + ExecIDKey Key = "exec_id" + JobIDKey Key = "job_id" + PhaseKey Key = "phase" + RoutineLabelKey Key = "routine" ) func (k Key) String() string { @@ -35,6 +37,7 @@ var logKeys = []Key{ WorkflowIDKey, TaskTypeKey, PhaseKey, + RoutineLabelKey, } // Gets a new context with namespace set. @@ -98,6 +101,14 @@ func WithTaskType(ctx context.Context, taskType string) context.Context { return context.WithValue(ctx, TaskTypeKey, taskType) } +// Gets a new context with Go Routine label key set and a label assigned to the context using pprof.Labels. +// You can then call pprof.SetGoroutineLabels(ctx) to annotate the current go-routine and have that show up in +// pprof analysis. +func WithGoroutineLabel(ctx context.Context, routineLabel string) context.Context { + ctx = pprof.WithLabels(ctx, pprof.Labels(RoutineLabelKey.String(), routineLabel)) + return context.WithValue(ctx, RoutineLabelKey, routineLabel) +} + func addFieldIfNotNil(ctx context.Context, m map[string]interface{}, fieldKey Key) { val := ctx.Value(fieldKey) if val != nil { @@ -110,6 +121,7 @@ func addStringFieldWithDefaults(ctx context.Context, m map[string]string, fieldK if val == nil { val = "" } + m[fieldKey.String()] = val.(string) } @@ -120,6 +132,7 @@ func GetLogFields(ctx context.Context) map[string]interface{} { for _, k := range logKeys { addFieldIfNotNil(ctx, res, k) } + return res } @@ -128,6 +141,7 @@ func Value(ctx context.Context, key Key) string { if val != nil { return val.(string) } + return "" } @@ -136,5 +150,6 @@ func Values(ctx context.Context, keys ...Key) map[string]string { for _, k := range keys { addStringFieldWithDefaults(ctx, res, k) } + return res } diff --git a/contextutils/context_test.go b/contextutils/context_test.go index 99d29d3..d65c2a2 100644 --- a/contextutils/context_test.go +++ b/contextutils/context_test.go @@ -2,6 +2,7 @@ package contextutils import ( "context" + "runtime/pprof" "testing" "github.com/stretchr/testify/assert" @@ -111,3 +112,12 @@ func TestValues(t *testing.T) { assert.Equal(t, "flyte", m[WorkflowIDKey.String()]) assert.Equal(t, "", m[ProjectKey.String()]) } + +func TestWithGoroutineLabel(t *testing.T) { + ctx := context.Background() + ctx = WithGoroutineLabel(ctx, "my_routine_123") + pprof.SetGoroutineLabels(ctx) + m := Values(ctx, RoutineLabelKey) + assert.Equal(t, 1, len(m)) + assert.Equal(t, "my_routine_123", m[RoutineLabelKey.String()]) +} diff --git a/logger/config.go b/logger/config.go index 1c665ac..993c624 100644 --- a/logger/config.go +++ b/logger/config.go @@ -21,12 +21,18 @@ const ( jsonDataKey string = "json" ) -var defaultConfig = &Config{ - Formatter: FormatterConfig{ - Type: FormatterJSON, - }, - Level: InfoLevel, -} +var ( + defaultConfig = &Config{ + Formatter: FormatterConfig{ + Type: FormatterJSON, + }, + Level: InfoLevel, + } + + configSection = config.MustRegisterSectionWithUpdates(configSectionKey, defaultConfig, func(ctx context.Context, newValue config.Config) { + onConfigUpdated(*newValue.(*Config)) + }) +) // Global logger config. type Config struct { @@ -47,13 +53,18 @@ type FormatterConfig struct { Type FormatterType `json:"type" pflag:",Sets logging format type."` } -var globalConfig = Config{} - // Sets global logger config -func SetConfig(cfg Config) { - globalConfig = cfg +func SetConfig(cfg *Config) error { + if err := configSection.SetConfig(cfg); err != nil { + return err + } - onConfigUpdated(cfg) + onConfigUpdated(*cfg) + return nil +} + +func GetConfig() *Config { + return configSection.GetConfig().(*Config) } // Level type. @@ -78,9 +89,3 @@ const ( // DebugLevel level. Usually only enabled when debugging. Very verbose logging. DebugLevel ) - -func init() { - config.MustRegisterSectionWithUpdates(configSectionKey, defaultConfig, func(ctx context.Context, newValue config.Config) { - SetConfig(*newValue.(*Config)) - }) -} diff --git a/logger/config_test.go b/logger/config_test.go index 7d2d378..08be8de 100644 --- a/logger/config_test.go +++ b/logger/config_test.go @@ -1,10 +1,14 @@ package logger -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestSetConfig(t *testing.T) { type args struct { - cfg Config + cfg *Config } tests := []struct { name string @@ -14,7 +18,7 @@ func TestSetConfig(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetConfig(tt.args.cfg) + assert.NoError(t, SetConfig(tt.args.cfg)) }) } } diff --git a/logger/logger.go b/logger/logger.go index 29aadc8..3d8eccc 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -17,7 +17,12 @@ import ( //go:generate gotests -w -all $FILE -const indentLevelKey contextutils.Key = "LoggerIndentLevel" +const ( + indentLevelKey contextutils.Key = "LoggerIndentLevel" + sourceCodeKey string = "src" +) + +var noopLogger = NoopLogger{} func onConfigUpdated(cfg Config) { logrus.SetLevel(logrus.Level(cfg.Level)) @@ -44,51 +49,35 @@ func onConfigUpdated(cfg Config) { } func getSourceLocation() string { - if globalConfig.IncludeSourceCode { - _, file, line, ok := runtime.Caller(3) - if !ok { - file = "???" - line = 1 - } else { - slash := strings.LastIndex(file, "/") - if slash >= 0 { - file = file[slash+1:] - } + // The reason we pass 3 here: 0 means this function (getSourceLocation), 1 means the getLogger function (only caller + // to getSourceLocation, 2 means the logging function (e.g. Debugln), and 3 means the caller for the logging function. + _, file, line, ok := runtime.Caller(3) + if !ok { + file = "???" + line = 1 + } else { + slash := strings.LastIndex(file, "/") + if slash >= 0 { + file = file[slash+1:] } - - return fmt.Sprintf("[%v:%v] ", file, line) } - return "" + return fmt.Sprintf("%v:%v", file, line) } -func wrapHeader(ctx context.Context, args ...interface{}) []interface{} { - args = append([]interface{}{getIndent(ctx)}, args...) - - if globalConfig.IncludeSourceCode { - return append( - []interface{}{ - fmt.Sprintf("%v", getSourceLocation()), - }, - args...) +func getLogger(ctx context.Context) logrus.FieldLogger { + cfg := GetConfig() + if cfg.Mute { + return noopLogger } - return args -} - -func wrapHeaderForMessage(ctx context.Context, message string) string { - message = fmt.Sprintf("%v%v", getIndent(ctx), message) - - if globalConfig.IncludeSourceCode { - return fmt.Sprintf("%v%v", getSourceLocation(), message) + entry := logrus.WithFields(logrus.Fields(contextutils.GetLogFields(ctx))) + if cfg.IncludeSourceCode { + entry = entry.WithField(sourceCodeKey, getSourceLocation()) } - return message -} + entry.Level = logrus.Level(cfg.Level) -func getLogger(ctx context.Context) *logrus.Entry { - entry := logrus.WithFields(logrus.Fields(contextutils.GetLogFields(ctx))) - entry.Level = logrus.Level(globalConfig.Level) return entry } @@ -107,231 +96,218 @@ func getIndent(ctx context.Context) string { // Gets a value indicating whether logs at this level will be written to the logger. This is particularly useful to avoid // computing log messages unnecessarily. -func IsLoggable(ctx context.Context, level Level) bool { - return getLogger(ctx).Level >= logrus.Level(level) +func IsLoggable(_ context.Context, level Level) bool { + return GetConfig().Level >= level } // Debug logs a message at level Debug on the standard logger. func Debug(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Debug(wrapHeader(ctx, args)...) + getLogger(ctx).Debug(args...) } // Print logs a message at level Info on the standard logger. func Print(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Print(wrapHeader(ctx, args)...) + getLogger(ctx).Print(args...) } // Info logs a message at level Info on the standard logger. func Info(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Info(wrapHeader(ctx, args)...) + getLogger(ctx).Info(args...) } // Warn logs a message at level Warn on the standard logger. func Warn(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warn(wrapHeader(ctx, args)...) + getLogger(ctx).Warn(args...) } // Warning logs a message at level Warn on the standard logger. func Warning(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warning(wrapHeader(ctx, args)...) + getLogger(ctx).Warning(args...) } // Error logs a message at level Error on the standard logger. func Error(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Error(wrapHeader(ctx, args)...) + getLogger(ctx).Error(args...) } // Panic logs a message at level Panic on the standard logger. func Panic(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Panic(wrapHeader(ctx, args)...) + getLogger(ctx).Panic(args...) } // Fatal logs a message at level Fatal on the standard logger. func Fatal(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Fatal(wrapHeader(ctx, args)...) + getLogger(ctx).Fatal(args...) } // Debugf logs a message at level Debug on the standard logger. func Debugf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Debugf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Debugf(format, args...) } // Printf logs a message at level Info on the standard logger. func Printf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Printf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Printf(format, args...) } // Infof logs a message at level Info on the standard logger. func Infof(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Infof(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Infof(format, args...) } // InfofNoCtx logs a formatted message without context. func InfofNoCtx(format string, args ...interface{}) { - if globalConfig.Mute { - return - } - getLogger(context.TODO()).Infof(format, args...) } // Warnf logs a message at level Warn on the standard logger. func Warnf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warnf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Warnf(format, args...) } // Warningf logs a message at level Warn on the standard logger. func Warningf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warningf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Warningf(format, args...) } // Errorf logs a message at level Error on the standard logger. func Errorf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Errorf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Errorf(format, args...) } // Panicf logs a message at level Panic on the standard logger. func Panicf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Panicf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Panicf(format, args...) } // Fatalf logs a message at level Fatal on the standard logger. func Fatalf(ctx context.Context, format string, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Fatalf(wrapHeaderForMessage(ctx, format), args...) + getLogger(ctx).Fatalf(format, args...) } // Debugln logs a message at level Debug on the standard logger. func Debugln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Debugln(wrapHeader(ctx, args)...) + getLogger(ctx).Debugln(args...) } // Println logs a message at level Info on the standard logger. func Println(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Println(wrapHeader(ctx, args)...) + getLogger(ctx).Println(args...) } // Infoln logs a message at level Info on the standard logger. func Infoln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Infoln(wrapHeader(ctx, args)...) + getLogger(ctx).Infoln(args...) } // Warnln logs a message at level Warn on the standard logger. func Warnln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warnln(wrapHeader(ctx, args)...) + getLogger(ctx).Warnln(args...) } // Warningln logs a message at level Warn on the standard logger. func Warningln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Warningln(wrapHeader(ctx, args)...) + getLogger(ctx).Warningln(args...) } // Errorln logs a message at level Error on the standard logger. func Errorln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Errorln(wrapHeader(ctx, args)...) + getLogger(ctx).Errorln(args...) } // Panicln logs a message at level Panic on the standard logger. func Panicln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } - - getLogger(ctx).Panicln(wrapHeader(ctx, args)...) + getLogger(ctx).Panicln(args...) } // Fatalln logs a message at level Fatal on the standard logger. func Fatalln(ctx context.Context, args ...interface{}) { - if globalConfig.Mute { - return - } + getLogger(ctx).Fatalln(args...) +} + +type NoopLogger struct { +} + +func (NoopLogger) WithField(key string, value interface{}) *logrus.Entry { + return nil +} + +func (NoopLogger) WithFields(fields logrus.Fields) *logrus.Entry { + return nil +} + +func (NoopLogger) WithError(err error) *logrus.Entry { + return nil +} + +func (NoopLogger) Debugf(format string, args ...interface{}) { +} + +func (NoopLogger) Infof(format string, args ...interface{}) { +} + +func (NoopLogger) Warnf(format string, args ...interface{}) { +} + +func (NoopLogger) Warningf(format string, args ...interface{}) { +} + +func (NoopLogger) Errorf(format string, args ...interface{}) { +} + +func (NoopLogger) Debug(args ...interface{}) { +} + +func (NoopLogger) Info(args ...interface{}) { +} + +func (NoopLogger) Warn(args ...interface{}) { +} + +func (NoopLogger) Warning(args ...interface{}) { +} + +func (NoopLogger) Error(args ...interface{}) { +} + +func (NoopLogger) Debugln(args ...interface{}) { +} + +func (NoopLogger) Infoln(args ...interface{}) { +} + +func (NoopLogger) Warnln(args ...interface{}) { +} + +func (NoopLogger) Warningln(args ...interface{}) { +} + +func (NoopLogger) Errorln(args ...interface{}) { +} + +func (NoopLogger) Print(...interface{}) { +} + +func (NoopLogger) Printf(string, ...interface{}) { +} + +func (NoopLogger) Println(...interface{}) { +} + +func (NoopLogger) Fatal(...interface{}) { +} + +func (NoopLogger) Fatalf(string, ...interface{}) { +} + +func (NoopLogger) Fatalln(...interface{}) { +} + +func (NoopLogger) Panic(...interface{}) { +} + +func (NoopLogger) Panicf(string, ...interface{}) { +} - getLogger(ctx).Fatalln(wrapHeader(ctx, args)...) +func (NoopLogger) Panicln(...interface{}) { } diff --git a/logger/logger_test.go b/logger/logger_test.go index 75c73a9..ce81002 100644 --- a/logger/logger_test.go +++ b/logger/logger_test.go @@ -6,7 +6,6 @@ package logger import ( "context" "reflect" - "strings" "testing" "github.com/sirupsen/logrus" @@ -14,46 +13,11 @@ import ( ) func init() { - SetConfig(Config{ + if err := SetConfig(&Config{ Level: InfoLevel, IncludeSourceCode: true, - }) -} - -func Test_getSourceLocation(t *testing.T) { - tests := []struct { - name string - want string - }{ - {"current", " "}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getSourceLocation(); !strings.HasSuffix(got, tt.want) { - t.Errorf("getSourceLocation() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_wrapHeaderForMessage(t *testing.T) { - type args struct { - message string - } - tests := []struct { - name string - args args - want string - }{ - {"no args", args{message: ""}, " "}, - {"1 arg", args{message: "hello"}, " hello"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := wrapHeaderForMessage(context.TODO(), tt.args.message); !strings.HasSuffix(got, tt.want) { - t.Errorf("wrapHeaderForMessage() = %v, want %v", got, tt.want) - } - }) + }); err != nil { + panic(err) } } @@ -488,27 +452,6 @@ func TestPanicln(t *testing.T) { } } -func Test_wrapHeader(t *testing.T) { - type args struct { - ctx context.Context - args []interface{} - } - tests := []struct { - name string - args args - want []interface{} - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := wrapHeader(tt.args.ctx, tt.args.args...); !reflect.DeepEqual(got, tt.want) { - t.Errorf("wrapHeader() = %v, want %v", got, tt.want) - } - }) - } -} - func Test_getLogger(t *testing.T) { type args struct { ctx context.Context diff --git a/profutils/server.go b/profutils/server.go index fe0157f..a1681c1 100644 --- a/profutils/server.go +++ b/profutils/server.go @@ -61,7 +61,14 @@ func healtcheckHandler(w http.ResponseWriter, req *http.Request) { // Handler that returns a JSON response indicating the Build Version information (refer to #version module) func versionHandler(w http.ResponseWriter, req *http.Request) { - err := WriteJSONResponse(w, http.StatusOK, BuildVersion{Build: version.Build, Version: version.Version, Timestamp: version.BuildTime}) + err := WriteJSONResponse( + w, + http.StatusOK, + BuildVersion{ + Build: version.Build, + Version: version.Version, + Timestamp: version.BuildTime, + }) if err != nil { panic(err) }