Skip to content

Commit

Permalink
Use concrete slog.Logger instead of Logger interface now that we …
Browse files Browse the repository at this point in the history
…don't need a zap adapter for Caddy.
  • Loading branch information
sbruens committed Sep 23, 2024
1 parent 9b1b801 commit 3af803b
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 60 deletions.
19 changes: 3 additions & 16 deletions service/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,10 @@
package service

import (
"context"
"io"
"log/slog"
)

type Logger interface {
Enabled(ctx context.Context, level slog.Level) bool
LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr)
}

type noopLogger struct {
}

var _ Logger = (*noopLogger)(nil)

func (l *noopLogger) Enabled(ctx context.Context, level slog.Level) bool {
return false
}

func (l *noopLogger) LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) {
func noopLogger() *slog.Logger {
return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.LevelError}))
}
7 changes: 4 additions & 3 deletions service/shadowsocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package service

import (
"context"
"log/slog"
"net"
"time"

Expand Down Expand Up @@ -50,7 +51,7 @@ type Service interface {
type Option func(s *ssService)

type ssService struct {
logger Logger
logger *slog.Logger
metrics ServiceMetrics
ciphers CipherList
natTimeout time.Duration
Expand All @@ -74,7 +75,7 @@ func NewShadowsocksService(opts ...Option) (Service, error) {
}
// If no logger is provided via options, use a noop logger.
if s.logger == nil {
s.logger = &noopLogger{}
s.logger = noopLogger()
}

// TODO: Register initial data metrics at zero.
Expand All @@ -92,7 +93,7 @@ func NewShadowsocksService(opts ...Option) (Service, error) {

// WithLogger can be used to provide a custom log target. If not provided,
// the service uses a noop logger (i.e., no logging).
func WithLogger(l Logger) Option {
func WithLogger(l *slog.Logger) Option {
return func(s *ssService) {
s.logger = l
}
Expand Down
32 changes: 16 additions & 16 deletions service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func remoteIP(conn net.Conn) netip.Addr {
}

// Wrapper for slog.Debug during TCP access key searches.
func debugTCP(l Logger, template string, cipherID string, attr slog.Attr) {
func debugTCP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
// This is an optimization to reduce unnecessary allocations due to an interaction
// between Go's inlining/escape analysis and varargs functions like slog.Debug.
if l != nil && l.Enabled(nil, slog.LevelDebug) {
Expand All @@ -72,7 +72,7 @@ func debugTCP(l Logger, template string, cipherID string, attr slog.Attr) {
// required = saltSize + 2 + cipher.TagSize, the number of bytes needed to authenticate the connection.
const bytesForKeyFinding = 50

func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList CipherList, l *slog.Logger) (*CipherEntry, io.Reader, []byte, time.Duration, error) {
// We snapshot the list because it may be modified while we use it.
ciphers := cipherList.SnapshotForClientIP(clientIP)
firstBytes := make([]byte, bytesForKeyFinding)
Expand All @@ -95,7 +95,7 @@ func findAccessKey(clientReader io.Reader, clientIP netip.Addr, cipherList Ciphe
}

// Implements a trial decryption search. This assumes that all ciphers are AEAD.
func findEntry(firstBytes []byte, ciphers []*list.Element, l Logger) (*CipherEntry, *list.Element) {
func findEntry(firstBytes []byte, ciphers []*list.Element, l *slog.Logger) (*CipherEntry, *list.Element) {
// To hold the decrypted chunk length.
chunkLenBuf := [2]byte{}
for ci, elt := range ciphers {
Expand All @@ -116,12 +116,12 @@ type StreamAuthenticateFunc func(clientConn transport.StreamConn) (string, trans

// NewShadowsocksStreamAuthenticator creates a stream authenticator that uses Shadowsocks.
// TODO(fortuna): Offer alternative transports.
func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l Logger) StreamAuthenticateFunc {
func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCache, metrics ShadowsocksConnMetrics, l *slog.Logger) StreamAuthenticateFunc {
if metrics == nil {
metrics = &NoOpShadowsocksConnMetrics{}
}
if l == nil {
l = &noopLogger{}
l = noopLogger()
}
return func(clientConn transport.StreamConn) (string, transport.StreamConn, *onet.ConnectionError) {
// Find the cipher and acess key id.
Expand Down Expand Up @@ -157,7 +157,7 @@ func NewShadowsocksStreamAuthenticator(ciphers CipherList, replayCache *ReplayCa
}

type streamHandler struct {
l Logger
logger *slog.Logger
listenerId string
readTimeout time.Duration
authenticate StreamAuthenticateFunc
Expand All @@ -167,7 +167,7 @@ type streamHandler struct {
// NewStreamHandler creates a StreamHandler
func NewStreamHandler(authenticate StreamAuthenticateFunc, timeout time.Duration) StreamHandler {
return &streamHandler{
l: &noopLogger{},
logger: noopLogger(),
readTimeout: timeout,
authenticate: authenticate,
dialer: defaultDialer,
Expand All @@ -187,16 +187,16 @@ func makeValidatingTCPStreamDialer(targetIPValidator onet.TargetIPValidator) tra
type StreamHandler interface {
Handle(ctx context.Context, conn transport.StreamConn, connMetrics TCPConnMetrics)
// SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
SetLogger(l Logger)
SetLogger(l *slog.Logger)
// SetTargetDialer sets the [transport.StreamDialer] to be used to connect to target addresses.
SetTargetDialer(dialer transport.StreamDialer)
}

func (s *streamHandler) SetLogger(l Logger) {
func (s *streamHandler) SetLogger(l *slog.Logger) {
if l == nil {
l = &noopLogger{}
l = noopLogger()
}
s.l = l
s.logger = l
}

func (s *streamHandler) SetTargetDialer(dialer transport.StreamDialer) {
Expand Down Expand Up @@ -271,11 +271,11 @@ func (h *streamHandler) Handle(ctx context.Context, clientConn transport.StreamC
status := "OK"
if connError != nil {
status = connError.Status
h.l.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Error", slog.String("msg", connError.Message), slog.Any("cause", connError.Cause))
}
connMetrics.AddClosed(status, proxyMetrics, connDuration)
measuredClientConn.Close() // Closing after the metrics are added aids integration testing.
h.l.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration))
h.logger.LogAttrs(nil, slog.LevelDebug, "TCP: Done.", slog.String("status", status), slog.Duration("duration", connDuration))
}

func getProxyRequest(clientConn transport.StreamConn) (string, error) {
Expand All @@ -290,7 +290,7 @@ func getProxyRequest(clientConn transport.StreamConn) (string, error) {
return tgtAddr.String(), nil
}

func proxyConnection(l Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
func proxyConnection(l *slog.Logger, ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError {
tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr)
if dialErr != nil {
// We don't drain so dial errors and invalid addresses are communicated quickly.
Expand Down Expand Up @@ -365,7 +365,7 @@ func (h *streamHandler) handleConnection(ctx context.Context, outerConn transpor
tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy)
return tgtConn, nil
})
return proxyConnection(h.l, ctx, dialer, tgtAddr, innerConn)
return proxyConnection(h.logger, ctx, dialer, tgtAddr, innerConn)
}

// Keep the connection open until we hit the authentication deadline to protect against probing attacks
Expand All @@ -374,7 +374,7 @@ func (h *streamHandler) absorbProbe(clientConn io.ReadCloser, connMetrics TCPCon
// This line updates proxyMetrics.ClientProxy before it's used in AddTCPProbe.
_, drainErr := io.Copy(io.Discard, clientConn) // drain socket
drainResult := drainErrToString(drainErr)
h.l.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult))
h.logger.LogAttrs(nil, slog.LevelDebug, "Drain error.", slog.Any("err", drainErr), slog.String("result", drainResult))
connMetrics.AddProbe(status, drainResult, proxyMetrics.ClientProxy)
}

Expand Down
4 changes: 2 additions & 2 deletions service/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func BenchmarkTCPFindCipherFail(b *testing.B) {
}
clientIP := clientConn.RemoteAddr().(*net.TCPAddr).AddrPort().Addr()
b.StartTimer()
findAccessKey(clientConn, clientIP, cipherList, &noopLogger{})
findAccessKey(clientConn, clientIP, cipherList, noopLogger())
b.StopTimer()
}
}
Expand Down Expand Up @@ -205,7 +205,7 @@ func BenchmarkTCPFindCipherRepeat(b *testing.B) {
cipher := cipherEntries[cipherNumber].CryptoKey
go shadowsocks.NewWriter(writer, cipher).Write(makeTestPayload(50))
b.StartTimer()
_, _, _, _, err := findAccessKey(&c, clientIP, cipherList, &noopLogger{})
_, _, _, _, err := findAccessKey(&c, clientIP, cipherList, noopLogger())
b.StopTimer()
if err != nil {
b.Error(err)
Expand Down
36 changes: 18 additions & 18 deletions service/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,23 @@ type UDPMetrics interface {
const serverUDPBufferSize = 64 * 1024

// Wrapper for slog.Debug during UDP proxying.
func debugUDP(l Logger, template string, cipherID string, attr slog.Attr) {
func debugUDP(l *slog.Logger, template string, cipherID string, attr slog.Attr) {
// This is an optimization to reduce unnecessary allocations due to an interaction
// between Go's inlining/escape analysis and varargs functions like slog.Debug.
if l.Enabled(nil, slog.LevelDebug) {
l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("ID", cipherID), attr)
}
}

func debugUDPAddr(l Logger, template string, addr net.Addr, attr slog.Attr) {
func debugUDPAddr(l *slog.Logger, template string, addr net.Addr, attr slog.Attr) {
if l.Enabled(nil, slog.LevelDebug) {
l.LogAttrs(nil, slog.LevelDebug, fmt.Sprintf("UDP: %s", template), slog.String("address", addr.String()), attr)
}
}

// Decrypts src into dst. It tries each cipher until it finds one that authenticates
// correctly. dst and src must not overlap.
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList, l Logger) ([]byte, string, *shadowsocks.EncryptionKey, error) {
func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherList, l *slog.Logger) ([]byte, string, *shadowsocks.EncryptionKey, error) {
// Try each cipher until we find one that authenticates successfully. This assumes that all ciphers are AEAD.
// We snapshot the list because it may be modified while we use it.
snapshot := cipherList.SnapshotForClientIP(clientIP)
Expand All @@ -80,7 +80,7 @@ func findAccessKeyUDP(clientIP netip.Addr, dst, src []byte, cipherList CipherLis
}

type packetHandler struct {
l Logger
logger *slog.Logger
natTimeout time.Duration
ciphers CipherList
m UDPMetrics
Expand All @@ -97,7 +97,7 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr
ssMetrics = &NoOpShadowsocksConnMetrics{}
}
return &packetHandler{
l: &noopLogger{},
logger: noopLogger(),
natTimeout: natTimeout,
ciphers: cipherList,
m: m,
Expand All @@ -109,18 +109,18 @@ func NewPacketHandler(natTimeout time.Duration, cipherList CipherList, m UDPMetr
// PacketHandler is a running UDP shadowsocks proxy that can be stopped.
type PacketHandler interface {
// SetLogger sets the logger used to log messages. Uses a no-op logger if nil.
SetLogger(l Logger)
SetLogger(l *slog.Logger)
// SetTargetIPValidator sets the function to be used to validate the target IP addresses.
SetTargetIPValidator(targetIPValidator onet.TargetIPValidator)
// Handle returns after clientConn closes and all the sub goroutines return.
Handle(clientConn net.PacketConn)
}

func (h *packetHandler) SetLogger(l Logger) {
func (h *packetHandler) SetLogger(l *slog.Logger) {
if l == nil {
l = &noopLogger{}
l = noopLogger()
}
h.l = l
h.logger = l
}

func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPValidator) {
Expand All @@ -132,7 +132,7 @@ func (h *packetHandler) SetTargetIPValidator(targetIPValidator onet.TargetIPVali
func (h *packetHandler) Handle(clientConn net.PacketConn) {
var running sync.WaitGroup

nm := newNATmap(h.natTimeout, h.m, &running, h.l)
nm := newNATmap(h.natTimeout, h.m, &running, h.logger)
defer nm.Close()
cipherBuf := make([]byte, serverUDPBufferSize)
textBuf := make([]byte, serverUDPBufferSize)
Expand Down Expand Up @@ -160,7 +160,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
return onet.NewConnectionError("ERR_READ", "Failed to read from client", err)
}
defer slog.LogAttrs(nil, slog.LevelDebug, "UDP: Done", slog.String("address", clientAddr.String()))
debugUDPAddr(h.l, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes))
debugUDPAddr(h.logger, "Outbound packet.", clientAddr, slog.Int("bytes", clientProxyBytes))

cipherData := cipherBuf[:clientProxyBytes]
var payload []byte
Expand All @@ -171,7 +171,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
var textData []byte
var cryptoKey *shadowsocks.EncryptionKey
unpackStart := time.Now()
textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.l)
textData, keyID, cryptoKey, err = findAccessKeyUDP(ip, textBuf, cipherData, h.ciphers, h.logger)
timeToCipher := time.Since(unpackStart)
h.ssm.AddCipherSearch(err == nil, timeToCipher)

Expand Down Expand Up @@ -208,7 +208,7 @@ func (h *packetHandler) Handle(clientConn net.PacketConn) {
}
}

debugUDPAddr(h.l, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr()))
debugUDPAddr(h.logger, "Proxy exit.", clientAddr, slog.Any("target", targetConn.LocalAddr()))
proxyTargetBytes, err = targetConn.WriteTo(payload, tgtUDPAddr) // accept only UDPAddr despite the signature
if err != nil {
return onet.NewConnectionError("ERR_WRITE", "Failed to write to target", err)
Expand Down Expand Up @@ -317,14 +317,14 @@ func (c *natconn) ReadFrom(buf []byte) (int, net.Addr, error) {
type natmap struct {
sync.RWMutex
keyConn map[string]*natconn
l Logger
logger *slog.Logger
timeout time.Duration
metrics UDPMetrics
running *sync.WaitGroup
}

func newNATmap(timeout time.Duration, sm UDPMetrics, running *sync.WaitGroup, l Logger) *natmap {
m := &natmap{l: l, metrics: sm, running: running}
func newNATmap(timeout time.Duration, sm UDPMetrics, running *sync.WaitGroup, l *slog.Logger) *natmap {
m := &natmap{logger: l, metrics: sm, running: running}
m.keyConn = make(map[string]*natconn)
m.timeout = timeout
return m
Expand Down Expand Up @@ -370,7 +370,7 @@ func (m *natmap) Add(clientAddr net.Addr, clientConn net.PacketConn, cryptoKey *

m.running.Add(1)
go func() {
timedCopy(clientAddr, clientConn, entry, keyID, m.l)
timedCopy(clientAddr, clientConn, entry, keyID, m.logger)
connMetrics.RemoveNatEntry()
if pc := m.del(clientAddr.String()); pc != nil {
pc.Close()
Expand Down Expand Up @@ -399,7 +399,7 @@ func (m *natmap) Close() error {
var maxAddrLen int = len(socks.ParseAddr("[2001:db8::1]:12345"))

// copy from target to client until read timeout
func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string, l Logger) {
func timedCopy(clientAddr net.Addr, clientConn net.PacketConn, targetConn *natconn, keyID string, l *slog.Logger) {
// pkt is used for in-place encryption of downstream UDP packets, with the layout
// [padding?][salt][address][body][tag][extra]
// Padding is only used if the address is IPv4.
Expand Down
10 changes: 5 additions & 5 deletions service/udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ func assertAlmostEqual(t *testing.T, a, b time.Time) {
}

func TestNATEmpty(t *testing.T) {
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, &noopLogger{})
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, noopLogger())
if nat.Get("foo") != nil {
t.Error("Expected nil value from empty NAT map")
}
}

func setupNAT() (*fakePacketConn, *fakePacketConn, *natconn) {
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, &noopLogger{})
nat := newNATmap(timeout, &natTestMetrics{}, &sync.WaitGroup{}, noopLogger())
clientConn := makePacketConn()
targetConn := makePacketConn()
nat.Add(&clientAddr, clientConn, natCryptoKey, targetConn, "key id")
Expand Down Expand Up @@ -409,7 +409,7 @@ func BenchmarkUDPUnpackFail(b *testing.B) {
testIP := netip.MustParseAddr("192.0.2.1")
b.ResetTimer()
for n := 0; n < b.N; n++ {
findAccessKeyUDP(testIP, textBuf, testPayload, cipherList, &noopLogger{})
findAccessKeyUDP(testIP, textBuf, testPayload, cipherList, noopLogger())
}
}

Expand Down Expand Up @@ -439,7 +439,7 @@ func BenchmarkUDPUnpackRepeat(b *testing.B) {
cipherNumber := n % numCiphers
ip := ips[cipherNumber]
packet := packets[cipherNumber]
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, &noopLogger{})
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, noopLogger())
if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -468,7 +468,7 @@ func BenchmarkUDPUnpackSharedKey(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
ip := ips[n%numIPs]
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, &noopLogger{})
_, _, _, err := findAccessKeyUDP(ip, testBuf, packet, cipherList, noopLogger())
if err != nil {
b.Error(err)
}
Expand Down

0 comments on commit 3af803b

Please sign in to comment.