Skip to content

Commit

Permalink
Implement net.Listener hand-over during config reload
Browse files Browse the repository at this point in the history
Log library refactored a little to make it easier to enable debug logging.
  • Loading branch information
foxcpp committed Jan 29, 2025
1 parent 3ce6ebf commit f82742b
Show file tree
Hide file tree
Showing 19 changed files with 394 additions and 44 deletions.
67 changes: 39 additions & 28 deletions framework/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ import (
// No serialization is provided by Logger, its log.Output responsibility to
// ensure goroutine-safety if necessary.
type Logger struct {
Parent *Logger

Out Output
Name string
Debug bool
Expand All @@ -51,30 +53,37 @@ type Logger struct {
Fields map[string]interface{}
}

func (l Logger) Zap() *zap.Logger {
func (l *Logger) Zap() *zap.Logger {
// TODO: Migrate to using zap natively.
return zap.New(zapLogger{L: l})
}

func (l Logger) Debugf(format string, val ...interface{}) {
if !l.Debug {
func (l *Logger) IsDebug() bool {
if l.Parent == nil {
return l.Debug
}
return l.Debug || l.Parent.IsDebug()
}

func (l *Logger) Debugf(format string, val ...interface{}) {
if !l.IsDebug() {
return
}
l.log(true, l.formatMsg(fmt.Sprintf(format, val...), nil))
}

func (l Logger) Debugln(val ...interface{}) {
if !l.Debug {
func (l *Logger) Debugln(val ...interface{}) {
if !l.IsDebug() {
return
}
l.log(true, l.formatMsg(strings.TrimRight(fmt.Sprintln(val...), "\n"), nil))
}

func (l Logger) Printf(format string, val ...interface{}) {
func (l *Logger) Printf(format string, val ...interface{}) {
l.log(false, l.formatMsg(fmt.Sprintf(format, val...), nil))
}

func (l Logger) Println(val ...interface{}) {
func (l *Logger) Println(val ...interface{}) {
l.log(false, l.formatMsg(strings.TrimRight(fmt.Sprintln(val...), "\n"), nil))
}

Expand All @@ -87,13 +96,13 @@ func (l Logger) Println(val ...interface{}) {
// followed by corresponding values. That is, for example, []interface{"key",
// "value", "key2", "value2"}.
//
// If value in fields implements LogFormatter, it will be represented by the
// If value in fields implements Formatter, it will be represented by the
// string returned by FormatLog method. Same goes for fmt.Stringer and error
// interfaces.
//
// Additionally, time.Time is written as a string in ISO 8601 format.
// time.Duration follows fmt.Stringer rule above.
func (l Logger) Msg(msg string, fields ...interface{}) {
func (l *Logger) Msg(msg string, fields ...interface{}) {
m := make(map[string]interface{}, len(fields)/2)
fieldsToMap(fields, m)
l.log(false, l.formatMsg(msg, m))
Expand All @@ -112,7 +121,7 @@ func (l Logger) Msg(msg string, fields ...interface{}) {
// In the context of Error method, "msg" typically indicates the top-level
// context in which the error is *handled*. For example, if error leads to
// rejection of SMTP DATA command, msg will probably be "DATA error".
func (l Logger) Error(msg string, err error, fields ...interface{}) {
func (l *Logger) Error(msg string, err error, fields ...interface{}) {
if err == nil {
return
}
Expand All @@ -133,8 +142,8 @@ func (l Logger) Error(msg string, err error, fields ...interface{}) {
l.log(false, l.formatMsg(msg, allFields))
}

func (l Logger) DebugMsg(kind string, fields ...interface{}) {
if !l.Debug {
func (l *Logger) DebugMsg(kind string, fields ...interface{}) {
if !l.IsDebug() {
return
}
m := make(map[string]interface{}, len(fields)/2)
Expand Down Expand Up @@ -162,7 +171,7 @@ func fieldsToMap(fields []interface{}, out map[string]interface{}) {
}
}

func (l Logger) formatMsg(msg string, fields map[string]interface{}) string {
func (l *Logger) formatMsg(msg string, fields map[string]interface{}) string {
formatted := strings.Builder{}

formatted.WriteString(msg)
Expand All @@ -184,30 +193,31 @@ func (l Logger) formatMsg(msg string, fields map[string]interface{}) string {
return formatted.String()
}

type LogFormatter interface {
type Formatter interface {
FormatLog() string
}

// Write implements io.Writer, all bytes sent
// to it will be written as a separate log messages.
// No line-buffering is done.
func (l Logger) Write(s []byte) (int, error) {
func (l *Logger) Write(s []byte) (int, error) {
if !l.IsDebug() {
return len(s), nil
}
l.log(false, strings.TrimRight(string(s), "\n"))
return len(s), nil
}

// DebugWriter returns a writer that will act like Logger.Write
// but will use debug flag on messages. If Logger.Debug is false,
// Write method of returned object will be no-op.
func (l Logger) DebugWriter() io.Writer {
if !l.Debug {
return io.Discard
}
l.Debug = true
return &l
func (l *Logger) DebugWriter() io.Writer {
l2 := l.Sublogger("")
l2.Debug = true
return l2
}

func (l Logger) log(debug bool, s string) {
func (l *Logger) log(debug bool, s string) {
if l.Name != "" {
s = l.Name + ": " + s
}
Expand All @@ -224,14 +234,15 @@ func (l Logger) log(debug bool, s string) {
// Logging is disabled - do nothing.
}

func (l Logger) Sublogger(name string) Logger {
if l.Name != "" {
func (l *Logger) Sublogger(name string) *Logger {
if l.Name != "" && name != "" {
name = l.Name + "/" + name
}
return Logger{
Out: l.Out,
Name: name,
Debug: l.Debug,
return &Logger{
Parent: l,
Out: l.Out,
Name: name,
Debug: l.Debug,
}
}

Expand Down
2 changes: 1 addition & 1 deletion framework/log/orderedjson.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func marshalOrderedJSON(output *strings.Builder, m map[string]interface{}) error
val = casted.Format("2006-01-02T15:04:05.000")
case time.Duration:
val = casted.String()
case LogFormatter:
case Formatter:
val = casted.FormatLog()
case fmt.Stringer:
val = casted.String()
Expand Down
2 changes: 1 addition & 1 deletion framework/log/zap.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
// TODO: Migrate to using actual zapcore to improve logging performance

type zapLogger struct {
L Logger
L *Logger
}

func (l zapLogger) Enabled(level zapcore.Level) bool {
Expand Down
4 changes: 2 additions & 2 deletions framework/module/lifetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type ReloadModule interface {
}

type LifetimeTracker struct {
logger log.Logger
logger *log.Logger
instances []*struct {
mod LifetimeModule
started bool
Expand Down Expand Up @@ -114,7 +114,7 @@ func (lt *LifetimeTracker) StopAll() error {
return nil
}

func NewLifetime(log log.Logger) *LifetimeTracker {
func NewLifetime(log *log.Logger) *LifetimeTracker {
return &LifetimeTracker{
logger: log,
}
Expand Down
4 changes: 2 additions & 2 deletions framework/module/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ type registryEntry struct {
}

type Registry struct {
logger log.Logger
logger *log.Logger
instances map[string]registryEntry
initialized map[string]struct{}
started map[string]struct{}
aliases map[string]string
}

func NewRegistry(log log.Logger) *Registry {
func NewRegistry(log *log.Logger) *Registry {
return &Registry{
logger: log,
instances: make(map[string]registryEntry),
Expand Down
27 changes: 27 additions & 0 deletions framework/resource/netresource/dup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package netresource

import "net"

func dupTCPListener(l *net.TCPListener) (*net.TCPListener, error) {
f, err := l.File()
if err != nil {
return nil, err
}
l2, err := net.FileListener(f)
if err != nil {
return nil, err
}
return l2.(*net.TCPListener), nil
}

func dupUnixListener(l *net.UnixListener) (*net.UnixListener, error) {
f, err := l.File()
if err != nil {
return nil, err
}
l2, err := net.FileListener(f)
if err != nil {
return nil, err
}
return l2.(*net.UnixListener), nil
}
47 changes: 47 additions & 0 deletions framework/resource/netresource/fd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package netresource

import (
"errors"
"fmt"
"net"
"os"
"strconv"
"strings"
)

func ListenFD(fd uint) (net.Listener, error) {
file := os.NewFile(uintptr(fd), strconv.FormatUint(uint64(fd), 10))
defer file.Close()
return net.FileListener(file)
}

func ListenFDName(name string) (net.Listener, error) {
listenPDStr := os.Getenv("LISTEN_PID")
if listenPDStr == "" {
return nil, errors.New("$LISTEN_PID is not set")
}
listenPid, err := strconv.Atoi(listenPDStr)
if err != nil {
return nil, errors.New("$LISTEN_PID is not integer")
}
if listenPid != os.Getpid() {
return nil, fmt.Errorf("$LISTEN_PID (%d) is not our PID (%d)", listenPid, os.Getpid())
}

names := strings.Split(os.Getenv("LISTEN_FDNAMES"), ":")
fd := uintptr(0)
for i, fdName := range names {
if fdName == name {
fd = uintptr(3 + i)
break
}
}

if fd == 0 {
return nil, fmt.Errorf("name %s not found in $LISTEN_FDNAMES", name)
}

file := os.NewFile(3+fd, name)
defer file.Close()
return net.FileListener(file)
}
38 changes: 38 additions & 0 deletions framework/resource/netresource/listen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package netresource

import (
"fmt"
"net"
"strconv"

"github.com/foxcpp/maddy/framework/log"
)

var (
tracker = NewListenerTracker(log.DefaultLogger.Sublogger("netresource"))
)

func CloseUnusedListeners() error {
return tracker.Close()
}

func ResetListenersUsage() {
tracker.ResetUsage()
}

func Listen(network, addr string) (net.Listener, error) {
switch network {
case "fd":
fd, err := strconv.ParseUint(addr, 10, strconv.IntSize)
if err != nil {
return nil, fmt.Errorf("invalid FD number: %v", addr)
}
return ListenFD(uint(fd))
case "fdname":
return ListenFDName(addr)
case "tcp", "tcp4", "tcp6", "unix":
return tracker.Get(network, addr)
default:
return nil, fmt.Errorf("unsupported network: %v", network)
}
}
Loading

0 comments on commit f82742b

Please sign in to comment.